diff --git a/stimupy/stimuli/whites.py b/stimupy/stimuli/whites.py index daca172..da4320a 100644 --- a/stimupy/stimuli/whites.py +++ b/stimupy/stimuli/whites.py @@ -4,6 +4,7 @@ from stimupy.components import combine_masks, draw_regions from stimupy.components.shapes import rectangle +from stimupy.stimuli import mask_targets from stimupy.stimuli.gratings import squarewave from stimupy.stimuli.pinwheels import pinwheel as angular from stimupy.stimuli.waves import square_radial as radial @@ -113,34 +114,28 @@ def generalized( round_phase_width=round_phase_width, ) - # Resolve target parameters - if isinstance(target_indices, (int)): - target_indices = (target_indices,) - if isinstance(intensity_target, (int, float)): - intensity_target = (intensity_target,) + + # Mask target bars + target_bar_mask = mask_targets(element_mask=stim["grating_mask"], target_indices=target_indices) + stim["target_indices"] = target_indices + + # Mask rectangular regions if isinstance(target_heights, (int, float)): target_heights = (target_heights,) if isinstance(target_center_offsets, (int, float)): target_center_offsets = (target_center_offsets,) - if len(target_indices) != 0 and target_heights is None: raise ValueError("generalized() missing argument 'target_heights' which is not 'None'") if len(target_indices) == 0 and target_heights is None: target_heights = (0,) - - intensity_target = tuple( - itertools.islice(itertools.cycle(intensity_target), len(target_indices)) - ) target_heights = tuple(itertools.islice(itertools.cycle(target_heights), len(target_indices))) target_center_offsets = tuple( itertools.islice(itertools.cycle(target_center_offsets), len(target_indices)) ) - # Place target(s) stim_center = stim["visual_size"].height / 2 - target_zip = zip(target_indices, intensity_target, target_heights, target_center_offsets) - targets_mask = np.zeros_like(stim["grating_mask"]) - for target_idx, (bar_idx, intensity, height, offset) in enumerate(target_zip): + target_rect_masks = [] + for target_idx, (height, offset) in enumerate(zip(target_heights, target_center_offsets)): # Draw a stripe of target_height x stim_width, at center + offset rect = rectangle( visual_size=stim["visual_size"], @@ -148,24 +143,36 @@ def generalized( shape=stim["shape"], rectangle_size=(height, stim["visual_size"].width), rectangle_position=(stim_center + offset - (height / 2), 0), - intensity_rectangle=intensity, + intensity_rectangle=target_idx, ) + target_rect_masks.append(rect["rectangle_mask"]) + stim["target_heights"] = target_heights + stim["target_center_offsets"] = target_center_offsets + # Combine rect & bar masks + target_masks = [] + for target_idx, rect_mask in enumerate(target_rect_masks): # Find where strip intersects with the target bar - if bar_idx < 0: - bar_idx = int(stim["n_bars"]) + bar_idx - mask = (stim["grating_mask"] == bar_idx + 1) & rect["rectangle_mask"] - targets_mask = np.where(mask, target_idx + 1, targets_mask) - - # Draw target - stim["img"] = np.where(targets_mask == target_idx + 1, intensity, stim["img"]) - - # Update and return stimulus - stim["target_mask"] = targets_mask.astype(int) - stim["target_indices"] = target_indices + target_mask = (target_bar_mask == target_idx + 1) & rect_mask + target_masks.append(target_mask) + + # Combine masks + if len(target_masks) > 0: + target_mask = combine_masks(*target_masks) + else: + target_mask = np.zeros_like(stim["img"]) + stim["target_mask"] = target_mask.astype(int) + + # Draw targets + stim["img"] = np.where( + target_mask, + draw_regions( + mask=target_mask, intensities=intensity_target, intensity_background=0.0 + ), + stim["img"], + ) stim["intensity_target"] = intensity_target - stim["target_heights"] = target_heights - stim["target_center_offsets"] = target_center_offsets + return stim @@ -492,7 +499,7 @@ def anderson( if bar_idx < 0: bar_idx = int(stim["n_bars"]) + bar_idx stripe_top["rectangle_mask"] = np.where( - stim["grating_mask"] == bar_idx + 1, 0, stripe_top["rectangle_mask"] + stim["grating_mask"] == bar_idx, 0, stripe_top["rectangle_mask"] ) stripe_bottom = rectangle( @@ -506,7 +513,7 @@ def anderson( if bar_idx < 0: bar_idx = int(stim["n_bars"]) + bar_idx stripe_bottom["rectangle_mask"] = np.where( - stim["grating_mask"] == bar_idx + 1, 0, stripe_bottom["rectangle_mask"] + stim["grating_mask"] == bar_idx, 0, stripe_bottom["rectangle_mask"] ) try: @@ -734,7 +741,7 @@ def yazdanbakhsh( # Reduce to just this bar: intersection between rect mask and bar mask if bar_idx < 0: bar_idx = int(stim["n_bars"]) + bar_idx - stim["gap_mask"] = (stim["grating_mask"] == bar_idx + 1) & rect["rectangle_mask"] + stim["gap_mask"] = (stim["grating_mask"] == bar_idx) & rect["rectangle_mask"] # Remove everywhere it intersects with target mask stim["gap_mask"] = np.where(stim["target_mask"] == t_idx + 1, 0, stim["gap_mask"]) @@ -758,7 +765,7 @@ def yazdanbakhsh( # Reduce to just this bar: intersection between rect mask and bar mask if bar_idx < 0: bar_idx = int(stim["n_bars"]) + bar_idx - stim["gap_mask"] = (stim["grating_mask"] == bar_idx + 1) & rect["rectangle_mask"] + stim["gap_mask"] = (stim["grating_mask"] == bar_idx) & rect["rectangle_mask"] # Remove everywhere it intersects with target mask stim["gap_mask"] = np.where(stim["target_mask"] == t_idx + 1, 0, stim["gap_mask"])