Skip to content

Commit

Permalink
refactor(whites): use mask_targets()
Browse files Browse the repository at this point in the history
  • Loading branch information
JorisVincent committed Aug 28, 2023
1 parent daf6acd commit 1e04f07
Showing 1 changed file with 39 additions and 32 deletions.
71 changes: 39 additions & 32 deletions stimupy/stimuli/whites.py
Expand Up @@ -4,6 +4,7 @@

from stimupy.components import combine_masks, draw_regions
from stimupy.components.shapes import rectangle
from stimupy.stimuli import mask_targets
from stimupy.stimuli.gratings import squarewave
from stimupy.stimuli.pinwheels import pinwheel as angular
from stimupy.stimuli.waves import square_radial as radial
Expand Down Expand Up @@ -113,59 +114,65 @@ def generalized(
round_phase_width=round_phase_width,
)

# Resolve target parameters
if isinstance(target_indices, (int)):
target_indices = (target_indices,)
if isinstance(intensity_target, (int, float)):
intensity_target = (intensity_target,)

# Mask target bars
target_bar_mask = mask_targets(element_mask=stim["grating_mask"], target_indices=target_indices)
stim["target_indices"] = target_indices

# Mask rectangular regions
if isinstance(target_heights, (int, float)):
target_heights = (target_heights,)
if isinstance(target_center_offsets, (int, float)):
target_center_offsets = (target_center_offsets,)

if len(target_indices) != 0 and target_heights is None:
raise ValueError("generalized() missing argument 'target_heights' which is not 'None'")
if len(target_indices) == 0 and target_heights is None:
target_heights = (0,)

intensity_target = tuple(
itertools.islice(itertools.cycle(intensity_target), len(target_indices))
)
target_heights = tuple(itertools.islice(itertools.cycle(target_heights), len(target_indices)))
target_center_offsets = tuple(
itertools.islice(itertools.cycle(target_center_offsets), len(target_indices))
)

# Place target(s)
stim_center = stim["visual_size"].height / 2
target_zip = zip(target_indices, intensity_target, target_heights, target_center_offsets)
targets_mask = np.zeros_like(stim["grating_mask"])
for target_idx, (bar_idx, intensity, height, offset) in enumerate(target_zip):
target_rect_masks = []
for target_idx, (height, offset) in enumerate(zip(target_heights, target_center_offsets)):
# Draw a stripe of target_height x stim_width, at center + offset
rect = rectangle(
visual_size=stim["visual_size"],
ppd=stim["ppd"],
shape=stim["shape"],
rectangle_size=(height, stim["visual_size"].width),
rectangle_position=(stim_center + offset - (height / 2), 0),
intensity_rectangle=intensity,
intensity_rectangle=target_idx,
)
target_rect_masks.append(rect["rectangle_mask"])
stim["target_heights"] = target_heights
stim["target_center_offsets"] = target_center_offsets

# Combine rect & bar masks
target_masks = []
for target_idx, rect_mask in enumerate(target_rect_masks):
# Find where strip intersects with the target bar
if bar_idx < 0:
bar_idx = int(stim["n_bars"]) + bar_idx
mask = (stim["grating_mask"] == bar_idx + 1) & rect["rectangle_mask"]
targets_mask = np.where(mask, target_idx + 1, targets_mask)

# Draw target
stim["img"] = np.where(targets_mask == target_idx + 1, intensity, stim["img"])

# Update and return stimulus
stim["target_mask"] = targets_mask.astype(int)
stim["target_indices"] = target_indices
target_mask = (target_bar_mask == target_idx + 1) & rect_mask
target_masks.append(target_mask)

# Combine masks
if len(target_masks) > 0:
target_mask = combine_masks(*target_masks)
else:
target_mask = np.zeros_like(stim["img"])
stim["target_mask"] = target_mask.astype(int)

# Draw targets
stim["img"] = np.where(
target_mask,
draw_regions(
mask=target_mask, intensities=intensity_target, intensity_background=0.0
),
stim["img"],
)
stim["intensity_target"] = intensity_target
stim["target_heights"] = target_heights
stim["target_center_offsets"] = target_center_offsets

return stim


Expand Down Expand Up @@ -492,7 +499,7 @@ def anderson(
if bar_idx < 0:
bar_idx = int(stim["n_bars"]) + bar_idx
stripe_top["rectangle_mask"] = np.where(
stim["grating_mask"] == bar_idx + 1, 0, stripe_top["rectangle_mask"]
stim["grating_mask"] == bar_idx, 0, stripe_top["rectangle_mask"]
)

stripe_bottom = rectangle(
Expand All @@ -506,7 +513,7 @@ def anderson(
if bar_idx < 0:
bar_idx = int(stim["n_bars"]) + bar_idx
stripe_bottom["rectangle_mask"] = np.where(
stim["grating_mask"] == bar_idx + 1, 0, stripe_bottom["rectangle_mask"]
stim["grating_mask"] == bar_idx, 0, stripe_bottom["rectangle_mask"]
)

try:
Expand Down Expand Up @@ -734,7 +741,7 @@ def yazdanbakhsh(
# Reduce to just this bar: intersection between rect mask and bar mask
if bar_idx < 0:
bar_idx = int(stim["n_bars"]) + bar_idx
stim["gap_mask"] = (stim["grating_mask"] == bar_idx + 1) & rect["rectangle_mask"]
stim["gap_mask"] = (stim["grating_mask"] == bar_idx) & rect["rectangle_mask"]

# Remove everywhere it intersects with target mask
stim["gap_mask"] = np.where(stim["target_mask"] == t_idx + 1, 0, stim["gap_mask"])
Expand All @@ -758,7 +765,7 @@ def yazdanbakhsh(
# Reduce to just this bar: intersection between rect mask and bar mask
if bar_idx < 0:
bar_idx = int(stim["n_bars"]) + bar_idx
stim["gap_mask"] = (stim["grating_mask"] == bar_idx + 1) & rect["rectangle_mask"]
stim["gap_mask"] = (stim["grating_mask"] == bar_idx) & rect["rectangle_mask"]

# Remove everywhere it intersects with target mask
stim["gap_mask"] = np.where(stim["target_mask"] == t_idx + 1, 0, stim["gap_mask"])
Expand Down

0 comments on commit 1e04f07

Please sign in to comment.