diff --git a/stimupy/stimuli/__init__.py b/stimupy/stimuli/__init__.py index 2278746..7a12161 100644 --- a/stimupy/stimuli/__init__.py +++ b/stimupy/stimuli/__init__.py @@ -1,6 +1,8 @@ +import numpy as np from stimupy.stimuli import * __all__ = [ + "mask_targets", "overview", "plot_overview", "benarys", @@ -27,6 +29,52 @@ ] +def mask_targets(element_mask, target_indices): + """Indicate elements as targets + + Creates a new target_mask from a mask of elements (e.g., grating bars, rings, frames, etc.), + by indexing these elements. + + Parameters + ---------- + element_mask : numpy.ndarray + mask with integer values for different elements / regions in a stimulus + target_indices : Sequence[int] or int + index or indices of elements to be designated as targets. + Index 0 should always refer to background region. + Indices can be negative, which results in "counting backwards" + from the highest index in element_mask. + + Returns + ------- + numpy.ndarray + target mask, with integer values indicating target regions, + in order that they appear in target_indices + + Raises + ------ + ValueError + if a target_idx is greater than any value in the element_mask + """ + if target_indices is None: + target_indices = () + if isinstance(target_indices, (int, float)): + target_indices = [ + target_indices, + ] + + target_mask = np.zeros_like(element_mask) + for target_idx, element_idx in enumerate(target_indices): + if element_idx < 0: + element_idx = int(element_mask.max()) + element_idx + + if element_idx > element_mask.max(): + raise ValueError("target_idx is outside stimulus") + target_mask = np.where(element_mask == element_idx, target_idx + 1, target_mask) + + return target_mask + + def overview(skip=False): """Generate example stimuli from this module @@ -37,7 +85,7 @@ def overview(skip=False): """ stimuli = {} for stimmodule_name in __all__: - if stimmodule_name in ["overview", "plot_overview"]: + if stimmodule_name in ["overview", "plot_overview", "mask_targets"]: continue print(f"Generating stimuli from {stimmodule_name}")