Skip to content

Commit

Permalink
refactor(pinwheels): integrate mask_targets()
Browse files Browse the repository at this point in the history
  • Loading branch information
JorisVincent committed Aug 28, 2023
1 parent fff6f14 commit c748706
Showing 1 changed file with 64 additions and 58 deletions.
122 changes: 64 additions & 58 deletions stimupy/stimuli/pinwheels.py
Expand Up @@ -2,9 +2,10 @@

import numpy as np

from stimupy.components import combine_masks, draw_regions
from stimupy.components.shapes import circle
from stimupy.components.shapes import ring as ring_shape
from stimupy.stimuli import waves
from stimupy.stimuli import mask_targets, waves

__all__ = [
"pinwheel",
Expand Down Expand Up @@ -87,6 +88,7 @@ def pinwheel(
https://doi.org/10.1016/j.visres.2007.02.017
"""

# Draw angular grating
stim = waves.square_angular(
visual_size=visual_size,
ppd=ppd,
Expand All @@ -99,76 +101,80 @@ def pinwheel(
origin=origin,
intensity_segments=intensity_segments,
)
radius = min(stim["visual_size"]) / 2

stim["target_indices"] = target_indices
stim["target_center"] = target_center
stim["target_width"] = target_width
stim["intensity_target"] = intensity_target

circle_mask = circle(
# Mask to circular aperture
radius = min(stim["visual_size"]) / 2
circle_aperture = circle(
visual_size=visual_size,
ppd=ppd,
shape=shape,
radius=radius,
origin=origin,
)["circle_mask"]
stim["img"] = np.where(circle_aperture, stim["img"], intensity_background)
stim["intensity_background"] = intensity_background

stim["img"] = np.where(circle_mask, stim["img"], intensity_background)
# Target segment mask
if isinstance(target_indices, (int, float)):
target_indices = (target_indices,)
target_segment_mask = mask_targets(element_mask=stim["grating_mask"], target_indices=target_indices)
stim["target_indices"] = target_indices

# Mask ring regions
if target_center is None:
target_center = radius / 2

# Place target(s)
if isinstance(target_indices, (int)):
target_indices = [
target_indices,
]
if isinstance(target_center, (int, float)):
target_center = [
target_center,
]
target_center = (target_center,)
stim["target_center"] = target_center
target_center = tuple(itertools.islice(itertools.cycle(target_center), len(target_indices)))

if target_width is None:
raise ValueError("pinwheel() missing argument 'target_width' which is not 'None'")
if isinstance(target_width, (int, float)):
target_width = [
target_width,
]
if isinstance(intensity_target, (int, float)):
intensity_target = [
intensity_target,
]

# Initiate target mask
target_mask = np.zeros_like(stim["grating_mask"])

if target_indices is not None:
if target_width is None:
raise ValueError("pinwheel() missing argument 'target_width' which is not 'None'")

target_center = itertools.cycle(target_center)
target_width = itertools.cycle(target_width)
intensity_target = itertools.cycle(intensity_target)

for target_idx, (segment_idx, center, width, intensity) in enumerate(
zip(target_indices, target_center, target_width, intensity_target)
):
# Draw ring
inner_radius = center - (width / 2)
outer_radius = center + (width / 2)
if inner_radius < 0 or outer_radius > np.min(visual_size) / 2:
raise ValueError("target does not fully fit into pinwheel")
ring_stim = ring_shape(
radii=[inner_radius, outer_radius],
intensity_ring=intensity,
visual_size=stim["visual_size"],
ppd=stim["ppd"],
shape=stim["shape"],
)
condition1 = stim["grating_mask"] == segment_idx
condition2 = ring_stim["ring_mask"] == 1
target_mask = np.where(condition1 & condition2, target_idx + 1, target_mask)
stim["img"] = np.where(target_mask == (target_idx + 1), intensity, stim["img"])
stim["target_mask"] = target_mask
stim["intensity_background"] = intensity_background
target_width = (target_width,)
stim["target_width"] = target_width
target_width = tuple(itertools.islice(itertools.cycle(target_width), len(target_indices)))

target_ring_masks = []
for target_idx, (center, width) in enumerate(zip(target_center, target_width)):
# Draw ring
inner_radius = center - (width / 2)
outer_radius = center + (width / 2)
if inner_radius < 0 or outer_radius > np.min(visual_size) / 2:
raise ValueError("target does not fully fit into pinwheel")
ring = ring_shape(
radii=[inner_radius, outer_radius],
intensity_ring=target_idx,
visual_size=stim["visual_size"],
ppd=stim["ppd"],
shape=stim["shape"],
)
target_ring_masks.append(ring["ring_mask"])

# Combine segment & ring masks
target_masks = []
for target_idx, ring_mask in enumerate(target_ring_masks):
# Find where ring intesects with target segment
target_mask = (target_segment_mask == target_idx + 1) & ring_mask
target_masks.append(target_mask)

# Combine target masks
if len(target_masks) > 0:
target_mask = combine_masks(*target_masks)
else:
target_mask = np.zeros_like(stim["img"])
stim["target_mask"] = target_mask.astype(int)

# Draw target(s)
stim["img"] = np.where(
target_mask,
draw_regions(
mask=target_mask, intensities=intensity_target
),
stim["img"]
)
stim["intensity_target"] = intensity_target

return stim


Expand Down

0 comments on commit c748706

Please sign in to comment.