Skip to content

Commit

Permalink
Unify structure for overviews
Browse files Browse the repository at this point in the history
For each module, `overview()` generates a dict of stim_dicts, with some example stimuli.
For "parent"-modules, they run the `overview()` for each of their submodules, and concatenate.
`plot_overview()` simply passes this to `plot_stimuli`, as a nice shorthand.
  • Loading branch information
JorisVincent committed Mar 22, 2023
1 parent 937c832 commit f6b115a
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 37 deletions.
50 changes: 36 additions & 14 deletions stimupy/components/__init__.py
Expand Up @@ -5,6 +5,8 @@
from stimupy.utils import resolution

__all__ = [
"overview",
"plot_overview",
"image_base",
"draw_regions",
"mask_elements",
Expand Down Expand Up @@ -208,17 +210,16 @@ def draw_regions(mask, intensities, intensity_background=0.5):
return img


from . import angulars, edges, frames, gaussians, lines, radials, shapes, waves
from stimupy.components import angulars, edges, frames, gaussians, lines, radials, shapes, waves


def create_overview():
"""
Create dictionary with examples from all stimulus-components
def overview(skip=False):
"""Generate example stimuli from this module
Returns
-------
stims : dict
dict with all stimuli containing individual stimulus dicts.
dict[str, dict]
Dict mapping names to individual stimulus dicts
"""

p = {
Expand All @@ -227,7 +228,7 @@ def create_overview():
}

# fmt: off
stims = {
stimuli = {
# angulars
"wedge": angulars.wedge(**p, width=30, radius=3),
"angular_grating": angulars.grating(**p, n_segments=8),
Expand Down Expand Up @@ -262,12 +263,31 @@ def create_overview():
}
# fmt: on

return stims
# stimuli = {}
# for stimmodule_name in __all__:
# if stimmodule_name in ["overview", "plot_overview"]:
# pass

# print(f"Generating stimuli from {stimmodule_name}")
# # Get a reference to the actual module
# stimmodule = globals()[stimmodule_name]
# try:
# stims = stimmodule.overview()

def overview(mask=False, save=None, extent_key="shape"):
"""
Plot overview with examples from all stimulus-components
# # Accumulate
# stimuli.update(stims)
# except NotImplementedError as e:
# if not skip:
# raise e
# # Skip stimuli that aren't implemented
# print("-- not implemented")
# pass

return stimuli


def plot_overview(mask=False, save=None, extent_key="shape"):
"""Plot overview of examples in this module (and submodules)
Parameters
----------
Expand All @@ -284,7 +304,9 @@ def overview(mask=False, save=None, extent_key="shape"):
"""
from stimupy.utils import plot_stimuli

stims = create_overview()
stims = overview(skip=True)
plot_stimuli(stims, mask=mask, extent_key=extent_key, save=save)


# Plotting
plot_stimuli(stims, mask=mask, save=save, extent_key=extent_key)
if __name__ == "__main__":
plot_overview()
57 changes: 38 additions & 19 deletions stimupy/noises/__init__.py
@@ -1,28 +1,26 @@
from .binaries import *
from .narrowbands import *
from .naturals import *
from .utils import *
from .whites import *
from stimupy.noises.binaries import *
from stimupy.noises.narrowbands import *
from stimupy.noises.naturals import *
from stimupy.noises.utils import *
from stimupy.noises.whites import *


def create_overview():
"""
Create dictionary with examples from all stimulus-noises
def overview(skip=False):
"""Generate example stimuli from this module
Returns
-------
stims : dict
dict with all stimuli containing individual stimulus dicts.
dict[str, dict]
Dict mapping names to individual stimulus dicts
"""

params = {
"visual_size": 10,
"ppd": 10,
"pseudo_noise": True,
}

# fmt: off
stims = {
stimuli = {
# Binary
"binary_noise": binary(visual_size=10, ppd=10),
# White
Expand All @@ -37,12 +35,31 @@ def create_overview():
}
# fmt: on

return stims
# stimuli = {}
# for stimmodule_name in __all__:
# if stimmodule_name in ["overview", "plot_overview"]:
# pass

# print(f"Generating stimuli from {stimmodule_name}")
# # Get a reference to the actual module
# stimmodule = globals()[stimmodule_name]
# try:
# stims = stimmodule.overview()

def overview(mask=False, save=None, extent_key="shape"):
"""
Plot overview with examples from all stimulus-noises
# # Accumulate
# stimuli.update(stims)
# except NotImplementedError as e:
# if not skip:
# raise e
# # Skip stimuli that aren't implemented
# print("-- not implemented")
# pass

return stimuli


def plot_overview(mask=False, save=None, extent_key="shape"):
"""Plot overview of examples in this module (and submodules)
Parameters
----------
Expand All @@ -59,7 +76,9 @@ def overview(mask=False, save=None, extent_key="shape"):
"""
from stimupy.utils import plot_stimuli

stims = create_overview()
stims = overview(skip=True)
plot_stimuli(stims, mask=mask, extent_key=extent_key, save=save)


# Plotting
plot_stimuli(stims, mask=mask, save=save, extent_key=extent_key)
if __name__ == "__main__":
plot_overview()
30 changes: 26 additions & 4 deletions stimupy/stimuli/__init__.py
Expand Up @@ -2,6 +2,7 @@

__all__ = [
"overview",
"plot_overview",
"benarys",
"bullseyes",
"checkerboards",
Expand Down Expand Up @@ -33,9 +34,11 @@ def overview(skip=False):
dict[str, dict]
Dict mapping names to individual stimulus dicts
"""

stimuli = {}
for stimmodule_name in __all__:
if stimmodule_name in ["overview", "plot_overview"]:
continue

print(f"Generating stimuli from {stimmodule_name}")
# Get a reference to the actual module
stimmodule = globals()[stimmodule_name]
Expand All @@ -54,8 +57,27 @@ def overview(skip=False):
return stimuli


if __name__ == "__main__":
def plot_overview(mask=False, save=None, extent_key="shape"):
"""Plot overview of examples in this module (and submodules)
Parameters
----------
mask : bool or str, optional
If True, plot mask on top of stimulus image (default: False).
If string is provided, plot this key from stimulus dictionary as mask
save : None or str, optional
If None (default), do not save the plot.
If string is provided, save plot under this name.
extent_key : str, optional
Key to extent which will be used for plotting.
Default is "shape", using the image size in pixels as extent.
"""
from stimupy.utils import plot_stimuli

stims = overview()
plot_stimuli(stims, mask=False, save=None)
stims = overview(skip=True)
plot_stimuli(stims, mask=mask, extent_key=extent_key, save=save)


if __name__ == "__main__":
plot_overview()

0 comments on commit f6b115a

Please sign in to comment.