Skip to content

Commit

Permalink
refactor(stimuli): mask_targets()
Browse files Browse the repository at this point in the history
to standardize logic of indicating elements of gratings etc. as targets, which is a big par of the `stimuli`
Also allows for negative element indices, counting backwards from the highest element index

Closes #12
  • Loading branch information
JorisVincent committed Aug 28, 2023
1 parent 37b70c3 commit 9232891
Showing 1 changed file with 49 additions and 1 deletion.
50 changes: 49 additions & 1 deletion stimupy/stimuli/__init__.py
@@ -1,6 +1,8 @@
import numpy as np
from stimupy.stimuli import *

__all__ = [
"mask_targets",
"overview",
"plot_overview",
"benarys",
Expand All @@ -27,6 +29,52 @@
]


def mask_targets(element_mask, target_indices):
"""Indicate elements as targets
Creates a new target_mask from a mask of elements (e.g., grating bars, rings, frames, etc.),
by indexing these elements.
Parameters
----------
element_mask : numpy.ndarray
mask with integer values for different elements / regions in a stimulus
target_indices : Sequence[int] or int
index or indices of elements to be designated as targets.
Index 0 should always refer to background region.
Indices can be negative, which results in "counting backwards"
from the highest index in element_mask.
Returns
-------
numpy.ndarray
target mask, with integer values indicating target regions,
in order that they appear in target_indices
Raises
------
ValueError
if a target_idx is greater than any value in the element_mask
"""
if target_indices is None:
target_indices = ()
if isinstance(target_indices, (int, float)):
target_indices = [
target_indices,
]

target_mask = np.zeros_like(element_mask)
for target_idx, element_idx in enumerate(target_indices):
if element_idx < 0:
element_idx = int(element_mask.max()) + element_idx

if element_idx > element_mask.max():
raise ValueError("target_idx is outside stimulus")
target_mask = np.where(element_mask == element_idx, target_idx + 1, target_mask)

return target_mask


def overview(skip=False):
"""Generate example stimuli from this module
Expand All @@ -37,7 +85,7 @@ def overview(skip=False):
"""
stimuli = {}
for stimmodule_name in __all__:
if stimmodule_name in ["overview", "plot_overview"]:
if stimmodule_name in ["overview", "plot_overview", "mask_targets"]:
continue

print(f"Generating stimuli from {stimmodule_name}")
Expand Down

0 comments on commit 9232891

Please sign in to comment.