Skip to content

Commit

Permalink
refactor(waves): integrate place_targets()
Browse files Browse the repository at this point in the history
  • Loading branch information
JorisVincent committed Aug 28, 2023
1 parent 9e8d136 commit daf6acd
Showing 1 changed file with 92 additions and 143 deletions.
235 changes: 92 additions & 143 deletions stimupy/stimuli/waves.py
@@ -1,9 +1,8 @@
import itertools

import numpy as np

from stimupy.components import draw_regions, waves
from stimupy.components.shapes import disc, rectangle
from stimupy.stimuli import place_targets

__all__ = [
"sine_linear",
Expand All @@ -21,29 +20,6 @@
]


def add_targets(wave_stim, target_indices, intensity_target):
# Create target-mask
if isinstance(target_indices, (int)):
target_indices = [target_indices]

targets_mask = np.zeros_like(wave_stim["grating_mask"])
for target_idx, bar_idx in enumerate(target_indices):
targets_mask = np.where(
wave_stim["grating_mask"] == (bar_idx + 1), target_idx + 1, targets_mask
)
targets_mask = targets_mask.astype(int)
wave_stim["target_mask"] = targets_mask

# Place target(s)
if isinstance(intensity_target, (int, float)):
intensities = [intensity_target]
intensities = itertools.cycle(intensities)
for target_idx, intensity in zip(np.unique(targets_mask[targets_mask > 0]), intensities):
wave_stim["img"] = np.where(targets_mask == target_idx, intensity, wave_stim["img"])

return wave_stim


def sine_linear(
visual_size=None,
ppd=None,
Expand Down Expand Up @@ -145,12 +121,13 @@ def sine_linear(
stim["bar_width"] = stim.pop("phase_width")
stim.pop("distance_metric")

# Add targets(?)
if target_indices is not None and target_indices != ():
stim = add_targets(stim, target_indices=target_indices, intensity_target=intensity_target)

stim["target_indices"] = target_indices
stim["intensity_target"] = intensity_target
# Add targets
stim = place_targets(
stim=stim,
element_mask_key="grating_mask",
target_indices=target_indices,
intensity_target=intensity_target,
)
return stim


Expand Down Expand Up @@ -250,12 +227,13 @@ def square_linear(
stim["intensity_bars"] = stim.pop("intensities")
stim.pop("distance_metric")

# Add targets(?)
if target_indices is not None and target_indices != ():
stim = add_targets(stim, target_indices=target_indices, intensity_target=intensity_target)

stim["target_indices"] = target_indices
stim["intensity_target"] = intensity_target
# Add targets
stim = place_targets(
stim=stim,
element_mask_key="grating_mask",
target_indices=target_indices,
intensity_target=intensity_target,
)
return stim


Expand Down Expand Up @@ -358,12 +336,13 @@ def staircase_linear(
stim["intensity_bars"] = stim.pop("intensities")
stim.pop("distance_metric")

# Add targets(?)
if target_indices is not None and target_indices != ():
stim = add_targets(stim, target_indices=target_indices, intensity_target=intensity_target)

stim["target_indices"] = target_indices
stim["intensity_target"] = intensity_target
# Add targets
stim = place_targets(
stim=stim,
element_mask_key="grating_mask",
target_indices=target_indices,
intensity_target=intensity_target,
)
return stim


Expand Down Expand Up @@ -478,15 +457,16 @@ def sine_radial(
)
stim["img"] = np.where(circle["ring_mask"], stim["img"], intensity_background)
stim["grating_mask"] = np.where(circle["ring_mask"], stim["grating_mask"], 0)

# Resolve target parameters
if target_indices is not None and target_indices != ():
stim = add_targets(stim, target_indices=target_indices, intensity_target=intensity_target)

stim["target_indices"] = target_indices
stim["intensity_target"] = intensity_target
stim["clip"] = clip
stim["intensity_background"] = intensity_background

# Add targets
stim = place_targets(
stim=stim,
element_mask_key="grating_mask",
target_indices=target_indices,
intensity_target=intensity_target,
)
return stim


Expand Down Expand Up @@ -601,15 +581,16 @@ def square_radial(
)
stim["img"] = np.where(circle["ring_mask"], stim["img"], intensity_background)
stim["grating_mask"] = np.where(circle["ring_mask"], stim["grating_mask"], 0)

# Resolve target parameters
if target_indices is not None and target_indices != ():
stim = add_targets(stim, target_indices=target_indices, intensity_target=intensity_target)

stim["target_indices"] = target_indices
stim["intensity_target"] = intensity_target
stim["clip"] = clip
stim["intensity_background"] = intensity_background

# Add targets
stim = place_targets(
stim=stim,
element_mask_key="grating_mask",
target_indices=target_indices,
intensity_target=intensity_target,
)
return stim


Expand Down Expand Up @@ -725,15 +706,16 @@ def staircase_radial(
)
stim["img"] = np.where(circle["ring_mask"], stim["img"], intensity_background)
stim["grating_mask"] = np.where(circle["ring_mask"], stim["grating_mask"], 0)

# Resolve target parameters
if target_indices is not None and target_indices != ():
stim = add_targets(stim, target_indices=target_indices, intensity_target=intensity_target)

stim["target_indices"] = target_indices
stim["intensity_target"] = intensity_target
stim["clip"] = clip
stim["intensity_background"] = intensity_background

# Add targets
stim = place_targets(
stim=stim,
element_mask_key="grating_mask",
target_indices=target_indices,
intensity_target=intensity_target,
)
return stim


Expand Down Expand Up @@ -856,15 +838,16 @@ def sine_rectilinear(
)
stim["img"] = np.where(rect["rectangle_mask"], stim["img"], intensity_background)
stim["grating_mask"] = np.where(rect["rectangle_mask"], stim["grating_mask"], 0)

# Add targets(?)
if target_indices is not None and target_indices != ():
stim = add_targets(stim, target_indices=target_indices, intensity_target=intensity_target)

stim["target_indices"] = target_indices
stim["intensity_target"] = intensity_target
stim["clip"] = clip
stim["intensity_background"] = intensity_background

# Add targets
stim = place_targets(
stim=stim,
element_mask_key="grating_mask",
target_indices=target_indices,
intensity_target=intensity_target,
)
return stim


Expand Down Expand Up @@ -988,15 +971,16 @@ def square_rectilinear(
)
stim["img"] = np.where(rect["rectangle_mask"], stim["img"], intensity_background)
stim["grating_mask"] = np.where(rect["rectangle_mask"], stim["grating_mask"], 0)

# Add targets(?)
if target_indices is not None and target_indices != ():
stim = add_targets(stim, target_indices=target_indices, intensity_target=intensity_target)

stim["target_indices"] = target_indices
stim["intensity_target"] = intensity_target
stim["clip"] = clip
stim["intensity_background"] = intensity_background

# Add targets
stim = place_targets(
stim=stim,
element_mask_key="grating_mask",
target_indices=target_indices,
intensity_target=intensity_target,
)
return stim


Expand All @@ -1020,8 +1004,6 @@ def staircase_rectilinear(
):
"""Rectiinear staircase, with some frame(s) as target(s)
Parameters
----------
Parameters
----------
visual_size : Sequence[Number, Number], Number, or None (default)
Expand Down Expand Up @@ -1122,15 +1104,16 @@ def staircase_rectilinear(
)
stim["img"] = np.where(rect["rectangle_mask"], stim["img"], intensity_background)
stim["grating_mask"] = np.where(rect["rectangle_mask"], stim["grating_mask"], 0)

# Add targets(?)
if target_indices is not None and target_indices != ():
stim = add_targets(stim, target_indices=target_indices, intensity_target=intensity_target)

stim["target_indices"] = target_indices
stim["intensity_target"] = intensity_target
stim["clip"] = clip
stim["intensity_background"] = intensity_background

# Add targets
stim = place_targets(
stim=stim,
element_mask_key="grating_mask",
target_indices=target_indices,
intensity_target=intensity_target,
)
return stim


Expand Down Expand Up @@ -1227,32 +1210,14 @@ def sine_angular(
stim["n_segments"] = stim.pop("n_phases")
stim["segment_width"] = stim.pop("phase_width")
stim.pop("distance_metric")
stim["target_indices"] = target_indices
stim["intensity_target"] = intensity_target

# Resolve target parameters
if target_indices is not None and target_indices != ():
if isinstance(target_indices, (int)):
target_indices = [
target_indices,
]
if isinstance(intensity_target, (int, float)):
intensity_target = [
intensity_target,
]
intensity_target = itertools.cycle(intensity_target)

# Place target(s)
targets_mask = np.zeros_like(stim["grating_mask"])
for target_idx, (segment_idx, intensity) in enumerate(
zip(target_indices, intensity_target)
):
targets_mask = np.where(
stim["grating_mask"] == (segment_idx + 1), target_idx + 1, targets_mask
)
stim["img"] = np.where(targets_mask == (target_idx + 1), intensity, stim["img"])
stim["target_mask"] = targets_mask.astype(int)

# Add targets
stim = place_targets(
stim=stim,
element_mask_key="grating_mask",
target_indices=target_indices,
intensity_target=intensity_target,
)
return stim


Expand Down Expand Up @@ -1351,32 +1316,14 @@ def square_angular(
stim["segment_width"] = stim.pop("phase_width")
stim["intensity_segments"] = stim.pop("intensities")
stim.pop("distance_metric")
stim["target_indices"] = target_indices
stim["intensity_target"] = intensity_target

# Resolve target parameters
if target_indices is not None and target_indices != ():
if isinstance(target_indices, (int)):
target_indices = [
target_indices,
]
if isinstance(intensity_target, (int, float)):
intensity_target = [
intensity_target,
]
intensity_target = itertools.cycle(intensity_target)

# Place target(s)
targets_mask = np.zeros_like(stim["grating_mask"])
for target_idx, (segment_idx, intensity) in enumerate(
zip(target_indices, intensity_target)
):
targets_mask = np.where(
stim["grating_mask"] == (segment_idx + 1), target_idx + 1, targets_mask
)
stim["img"] = np.where(targets_mask == (target_idx + 1), intensity, stim["img"])
stim["target_mask"] = targets_mask.astype(int)

# Add targets
stim = place_targets(
stim=stim,
element_mask_key="grating_mask",
target_indices=target_indices,
intensity_target=intensity_target,
)
return stim


Expand Down Expand Up @@ -1466,6 +1413,7 @@ def staircase_angular(
origin=origin,
round_phase_width=round_phase_width,
distance_metric="angular",
intensities=intensity_segments,
)

# Repackage output
Expand All @@ -1474,12 +1422,13 @@ def staircase_angular(
stim["intensity_segments"] = stim.pop("intensities")
stim.pop("distance_metric")

# Add targets(?)
if target_indices is not None and target_indices != ():
stim = add_targets(stim, target_indices=target_indices, intensity_target=intensity_target)

stim["target_indices"] = target_indices
stim["intensity_target"] = intensity_target
# Add targets
stim = place_targets(
stim=stim,
element_mask_key="grating_mask",
target_indices=target_indices,
intensity_target=intensity_target,
)
return stim


Expand Down

0 comments on commit daf6acd

Please sign in to comment.