Skip to content

Commit

Permalink
refactor(stimuli): place_targets()
Browse files Browse the repository at this point in the history
actually places targets in stimulus-dict
  • Loading branch information
JorisVincent committed Aug 28, 2023
1 parent 9232891 commit fff6f14
Showing 1 changed file with 57 additions and 1 deletion.
58 changes: 57 additions & 1 deletion stimupy/stimuli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import itertools

import numpy as np

from stimupy.components import draw_regions
from stimupy.stimuli import *

__all__ = [
"mask_targets",
"place_targets",
"overview",
"plot_overview",
"benarys",
Expand Down Expand Up @@ -75,6 +80,57 @@ def mask_targets(element_mask, target_indices):
return target_mask


def place_targets(stim, element_mask_key, target_indices, intensity_target=0.5):
"""Place targets in stimulus
Turns image regions/elements defined by element_mask_key
and indicated by target_indices, into targets.
Targets are defined in a new target_mask, and drawn into image with intensity_target.
Parameters
----------
stim : dict[str, Any]
stimulus dictionary, with at least an "img" key, and mask indicated by element_mask_key
element_mask_key : str
key of the mask in stim-dict indicating image "elements"/regions
target_indices : Sequence[int] or int
index or indices of elements to be designated as targets
intensity_target : float, optional
intensity value for target, by default 0.5
Returns
-------
dict[str, Any]
dict with the stimulus (key: "img") with targets placed,
mask with integer index for the target (key: "target_mask")
See also
--------
mask_targets, draw_regions
"""
stim["target_mask"] = mask_targets(
element_mask=stim[element_mask_key], target_indices=target_indices
)

if isinstance(intensity_target, (int, float)):
intensity_target = [
intensity_target,
]
intensity_target = itertools.cycle(intensity_target)

stim["img"] = np.where(
stim["target_mask"],
draw_regions(
mask=stim["target_mask"], intensities=intensity_target, intensity_background=0.0
),
stim["img"],
)
stim["target_indices"] = target_indices
stim["intensity_target"] = intensity_target

return stim


def overview(skip=False):
"""Generate example stimuli from this module
Expand All @@ -85,7 +141,7 @@ def overview(skip=False):
"""
stimuli = {}
for stimmodule_name in __all__:
if stimmodule_name in ["overview", "plot_overview", "mask_targets"]:
if stimmodule_name in ["overview", "plot_overview", "mask_targets", "place_targets"]:
continue

print(f"Generating stimuli from {stimmodule_name}")
Expand Down

0 comments on commit fff6f14

Please sign in to comment.