diff --git a/stimupy/stimuli/pinwheels.py b/stimupy/stimuli/pinwheels.py index e4958fc..7f58c02 100644 --- a/stimupy/stimuli/pinwheels.py +++ b/stimupy/stimuli/pinwheels.py @@ -2,9 +2,10 @@ import numpy as np +from stimupy.components import combine_masks, draw_regions from stimupy.components.shapes import circle from stimupy.components.shapes import ring as ring_shape -from stimupy.stimuli import waves +from stimupy.stimuli import mask_targets, waves __all__ = [ "pinwheel", @@ -87,6 +88,7 @@ def pinwheel( https://doi.org/10.1016/j.visres.2007.02.017 """ + # Draw angular grating stim = waves.square_angular( visual_size=visual_size, ppd=ppd, @@ -99,76 +101,80 @@ def pinwheel( origin=origin, intensity_segments=intensity_segments, ) - radius = min(stim["visual_size"]) / 2 - stim["target_indices"] = target_indices - stim["target_center"] = target_center - stim["target_width"] = target_width - stim["intensity_target"] = intensity_target - - circle_mask = circle( + # Mask to circular aperture + radius = min(stim["visual_size"]) / 2 + circle_aperture = circle( visual_size=visual_size, ppd=ppd, shape=shape, radius=radius, origin=origin, )["circle_mask"] + stim["img"] = np.where(circle_aperture, stim["img"], intensity_background) + stim["intensity_background"] = intensity_background - stim["img"] = np.where(circle_mask, stim["img"], intensity_background) + # Target segment mask + if isinstance(target_indices, (int, float)): + target_indices = (target_indices,) + target_segment_mask = mask_targets(element_mask=stim["grating_mask"], target_indices=target_indices) + stim["target_indices"] = target_indices + # Mask ring regions if target_center is None: target_center = radius / 2 - - # Place target(s) - if isinstance(target_indices, (int)): - target_indices = [ - target_indices, - ] if isinstance(target_center, (int, float)): - target_center = [ - target_center, - ] + target_center = (target_center,) + stim["target_center"] = target_center + target_center = tuple(itertools.islice(itertools.cycle(target_center), len(target_indices))) + + if target_width is None: + raise ValueError("pinwheel() missing argument 'target_width' which is not 'None'") if isinstance(target_width, (int, float)): - target_width = [ - target_width, - ] - if isinstance(intensity_target, (int, float)): - intensity_target = [ - intensity_target, - ] - - # Initiate target mask - target_mask = np.zeros_like(stim["grating_mask"]) - - if target_indices is not None: - if target_width is None: - raise ValueError("pinwheel() missing argument 'target_width' which is not 'None'") - - target_center = itertools.cycle(target_center) - target_width = itertools.cycle(target_width) - intensity_target = itertools.cycle(intensity_target) - - for target_idx, (segment_idx, center, width, intensity) in enumerate( - zip(target_indices, target_center, target_width, intensity_target) - ): - # Draw ring - inner_radius = center - (width / 2) - outer_radius = center + (width / 2) - if inner_radius < 0 or outer_radius > np.min(visual_size) / 2: - raise ValueError("target does not fully fit into pinwheel") - ring_stim = ring_shape( - radii=[inner_radius, outer_radius], - intensity_ring=intensity, - visual_size=stim["visual_size"], - ppd=stim["ppd"], - shape=stim["shape"], - ) - condition1 = stim["grating_mask"] == segment_idx - condition2 = ring_stim["ring_mask"] == 1 - target_mask = np.where(condition1 & condition2, target_idx + 1, target_mask) - stim["img"] = np.where(target_mask == (target_idx + 1), intensity, stim["img"]) - stim["target_mask"] = target_mask - stim["intensity_background"] = intensity_background + target_width = (target_width,) + stim["target_width"] = target_width + target_width = tuple(itertools.islice(itertools.cycle(target_width), len(target_indices))) + + target_ring_masks = [] + for target_idx, (center, width) in enumerate(zip(target_center, target_width)): + # Draw ring + inner_radius = center - (width / 2) + outer_radius = center + (width / 2) + if inner_radius < 0 or outer_radius > np.min(visual_size) / 2: + raise ValueError("target does not fully fit into pinwheel") + ring = ring_shape( + radii=[inner_radius, outer_radius], + intensity_ring=target_idx, + visual_size=stim["visual_size"], + ppd=stim["ppd"], + shape=stim["shape"], + ) + target_ring_masks.append(ring["ring_mask"]) + + # Combine segment & ring masks + target_masks = [] + for target_idx, ring_mask in enumerate(target_ring_masks): + # Find where ring intesects with target segment + target_mask = (target_segment_mask == target_idx + 1) & ring_mask + target_masks.append(target_mask) + + # Combine target masks + if len(target_masks) > 0: + target_mask = combine_masks(*target_masks) + else: + target_mask = np.zeros_like(stim["img"]) + stim["target_mask"] = target_mask.astype(int) + + # Draw target(s) + stim["img"] = np.where( + target_mask, + draw_regions( + mask=target_mask, intensities=intensity_target + ), + stim["img"] + ) + stim["intensity_target"] = intensity_target + return stim