Skip to content

Commit

Permalink
whites.anderson draws stripes as rectangles
Browse files Browse the repository at this point in the history
  • Loading branch information
JorisVincent committed Mar 24, 2023
1 parent 9ae1b74 commit d7b05c9
Showing 1 changed file with 44 additions and 43 deletions.
87 changes: 44 additions & 43 deletions stimupy/stimuli/whites.py
Expand Up @@ -10,6 +10,7 @@
from stimupy.stimuli.waves import square_radial as radial
from stimupy.stimuli.wedding_cakes import wedding_cake
from stimupy.utils import resolution
from stimupy.utils.utils import combine_masks

__all__ = [
"generalized",
Expand Down Expand Up @@ -471,7 +472,10 @@ def anderson(
raise ValueError("anderson() missing argument 'target_height' which is not 'None'")
if stripe_height is None:
raise ValueError("anderson() missing argument 'stripe_height' which is not 'None'")
if isinstance(stripe_height, (int, float)):
stripe_height = (stripe_height, stripe_height)

# Generate White's stimulus with two rows of targets
stim = white_two_rows(
visual_size=visual_size,
ppd=ppd,
Expand All @@ -491,56 +495,53 @@ def anderson(
round_phase_width=round_phase_width,
)

img = stim["img"]
mask = stim["target_mask"]
soffset = resolution.lengths_from_visual_angles_ppd(stripe_center_offset, np.unique(ppd)[0])
sheight = resolution.lengths_from_visual_angles_ppd(stripe_height, np.unique(ppd)[0])
height, width = img.shape

if isinstance(target_indices_top, (float, int)):
target_indices_top = (target_indices_top,)
if isinstance(target_indices_bottom, (float, int)):
target_indices_bottom = (target_indices_bottom,)

if sheight / 2.0 > soffset:
raise ValueError("Stripes overlap! Increase stripe offset or decrease stripe size.")
if (target_height / 2 - target_center_offset + stripe_height / 2 - stripe_center_offset) > 0:
raise ValueError(
"Stripes overlap with targets! Increase stripe or target offsets or"
"decrease stripe or target size"
)
if stripe_center_offset * ppd % 1 != 0:
offsets_new = soffset / ppd
warnings.warn(
f"Stripe offsets rounded because of ppd; {stripe_center_offset} -> {offsets_new}"
# Masks for stripes (as rectangles)
stim_center = stim["visual_size"].height / 2
stripe_top = rectangle(
visual_size=stim["visual_size"],
ppd=stim["ppd"],
shape=stim["shape"],
rectangle_size=(stripe_height[0], stim["visual_size"].width),
rectangle_position=(stim_center - stripe_center_offset - (stripe_height[0] / 2), 0),
)
for bar_idx in stim["target_indices_top"]:
if bar_idx < 0:
bar_idx = int(stim["n_bars"]) + bar_idx
stripe_top["shape_mask"] = np.where(
stim["grating_mask"] == bar_idx + 1, 0, stripe_top["shape_mask"]
)

# Add stripe at top
ystart = height // 2 - soffset - sheight // 2
stripe_mask_top = np.zeros(stim["shape"])
stripe_mask_top[ystart : ystart + sheight, :] = 1
for t in target_indices_top:
if t < 0:
t = int(stim["n_bars"] + t)
stripe_mask_top = np.where(stim["grating_mask"] == t + 1, 0, stripe_mask_top)

# Add stripes at bottom
ystart = height // 2 + soffset - sheight // 2
stripe_mask_bot = np.zeros(stim["shape"])
stripe_mask_bot[ystart : ystart + sheight, :] = 1
for t in target_indices_bottom:
if t < 0:
t = int(stim["n_bars"] + t)
stripe_mask_bot = np.where(stim["grating_mask"] == t + 1, 0, stripe_mask_bot)
stripe_bottom = rectangle(
visual_size=stim["visual_size"],
ppd=stim["ppd"],
shape=stim["shape"],
rectangle_size=(stripe_height[1], stim["visual_size"].width),
rectangle_position=(stim_center + stripe_center_offset - (stripe_height[1] / 2), 0),
)
for bar_idx in stim["target_indices_bottom"]:
if bar_idx < 0:
bar_idx = int(stim["n_bars"]) + bar_idx
stripe_bottom["shape_mask"] = np.where(
stim["grating_mask"] == bar_idx + 1, 0, stripe_bottom["shape_mask"]
)

img = np.where(stripe_mask_top, intensity_stripes[0], img)
img = np.where(stripe_mask_bot, intensity_stripes[1], img)
try:
stripes_mask = combine_masks(stripe_top["shape_mask"], stripe_bottom["shape_mask"])
except ValueError:
raise ValueError("Stripes overlap. Increase stripe offset or decrease stripe size.")

# Combine images
stripes_img = draw_regions(stripes_mask, intensities=intensity_stripes)
img = np.where(stripes_mask, stripes_img, stim["img"])
img = np.where(stim["target_mask"], stim["img"], img)
stim["img"] = img
stim["target_mask"] = mask
stim["intensity_stripes"] = intensity_stripes

# Output
stim["stripe_center_offset"] = stripe_center_offset
stim["stripe_height"] = stripe_height
stim["stripes_mask"] = stripes_mask
stim["intensity_stripes"] = intensity_stripes

return stim


Expand Down

0 comments on commit d7b05c9

Please sign in to comment.