Skip to content

Commit

Permalink
Better mask in waves
Browse files Browse the repository at this point in the history
  • Loading branch information
JorisVincent committed Mar 21, 2023
1 parent 281acc4 commit 8f958bf
Showing 1 changed file with 33 additions and 29 deletions.
62 changes: 33 additions & 29 deletions stimupy/waves.py
Expand Up @@ -118,24 +118,26 @@ def sine_linear(
stim["bar_width"] = stim.pop("phase_width")
stim.pop("base_type")

# Resolve target parameters
# Resolve targets
if target_indices is not None and target_indices != ():
# Create target-mask
if isinstance(target_indices, (int)):
target_indices = [
target_indices,
]
if isinstance(intensity_target, (int, float)):
intensity_target = [
intensity_target,
]
intensity_target = itertools.cycle(intensity_target)
target_indices = [target_indices]

# Place target(s)
targets_mask = np.zeros_like(stim["grating_mask"])
for target_idx, (bar_idx, intensity) in enumerate(zip(target_indices, intensity_target)):
targets_mask = np.where(stim["grating_mask"] == bar_idx, target_idx + 1, targets_mask)
stim["img"] = np.where(targets_mask == target_idx + 1, intensity, stim["img"])
stim["target_mask"] = targets_mask.astype(int)
for target_idx, bar_idx in enumerate(target_indices):
targets_mask = np.where(
stim["grating_mask"] == (bar_idx + 1), target_idx + 1, targets_mask
)
targets_mask = targets_mask.astype(int)
stim["targets_mask"] = targets_mask

# Place target(s)
if isinstance(intensity_target, (int, float)):
intensities = [intensity_target]
intensities = itertools.cycle(intensities)
for target_idx, intensity in zip(np.unique(targets_mask[targets_mask > 0]), intensities):
stim["img"] = np.where(targets_mask == target_idx, intensity, stim["img"])

return stim

Expand Down Expand Up @@ -236,24 +238,26 @@ def square_linear(
stim["intensity_bars"] = stim.pop("intensities")
stim.pop("base_type")

# Resolve target parameters
# Resolve targets
if target_indices is not None and target_indices != ():
# Create target-mask
if isinstance(target_indices, (int)):
target_indices = [
target_indices,
]
if isinstance(intensity_target, (int, float)):
intensity_target = [
intensity_target,
]
intensity_target = itertools.cycle(intensity_target)
target_indices = [target_indices]

# Place target(s)
targets_mask = np.zeros_like(stim["grating_mask"])
for target_idx, (bar_idx, intensity) in enumerate(zip(target_indices, intensity_target)):
targets_mask = np.where(stim["grating_mask"] == bar_idx, target_idx + 1, targets_mask)
stim["img"] = np.where(targets_mask == target_idx + 1, intensity, stim["img"])
stim["target_mask"] = targets_mask.astype(int)
for target_idx, bar_idx in enumerate(target_indices):
targets_mask = np.where(
stim["grating_mask"] == (bar_idx + 1), target_idx + 1, targets_mask
)
targets_mask = targets_mask.astype(int)
stim["targets_mask"] = targets_mask

# Place target(s)
if isinstance(intensity_target, (int, float)):
intensities = [intensity_target]
intensities = itertools.cycle(intensities)
for target_idx, intensity in zip(np.unique(targets_mask[targets_mask > 0]), intensities):
stim["img"] = np.where(targets_mask == target_idx, intensity, stim["img"])

return stim

Expand All @@ -270,7 +274,7 @@ def overview(**kwargs):
default_params.update(kwargs)

grating_params = {
"period": "odd",
"period": "ignore",
"phase_shift": 90,
"origin": "center",
"round_phase_width": False,
Expand Down

0 comments on commit 8f958bf

Please sign in to comment.