Skip to content

Commit

Permalink
Gabor args consistent with wave args, and added function to create pl…
Browse files Browse the repository at this point in the history
…aids from any kind of waves or gabors

Closes #31, Closes #21
  • Loading branch information
LynnSchmittwilken authored and JorisVincent committed Mar 23, 2023
1 parent b1dfd47 commit 71f5348
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 62 deletions.
22 changes: 10 additions & 12 deletions stimupy/papers/modelfest.py
Expand Up @@ -43,6 +43,7 @@
from stimupy.components.waves import bessel
from stimupy.noises.binaries import binary as binary_noise
from stimupy.stimuli.gabors import gabor
from stimupy.stimuli.plaids import gabors as plaid
from stimupy.utils import pad_dict_to_shape, resize_dict, roll_dict, stack_dicts

__all__ = [
Expand Down Expand Up @@ -2045,14 +2046,13 @@ def Plaids38(ppd=PPD):
"sigma": 0.14,
"phase_shift": 90,
"origin": "center",
"round_phase_width": False,
}

stim = gabor(**params, rotation=0)
stim2 = gabor(**params, rotation=90)

stim["img"] = stim["img"] / 2 + stim2["img"] / 2
stim["rotation"] = stim2["rotation"]
stim["grating_mask2"] = stim2["grating_mask"]
stim = plaid(
gabor_parameters1={**params, "rotation": 0},
gabor_parameters2={**params, "rotation": 90},
)

v = 149
experimental_data = {
Expand Down Expand Up @@ -2101,12 +2101,10 @@ def Plaids39(ppd=PPD):
"origin": "center",
}

stim = gabor(**params, rotation=45)
stim2 = gabor(**params, rotation=90)

stim["img"] = stim["img"] / 2 + stim2["img"] / 2
stim["rotation"] = stim2["rotation"]
stim["grating_mask2"] = stim2["grating_mask"]
stim = plaid(
gabor_parameters1={**params, "rotation": 45},
gabor_parameters2={**params, "rotation": 90},
)

v = 153
experimental_data = {
Expand Down
16 changes: 9 additions & 7 deletions stimupy/stimuli/gabors.py
Expand Up @@ -9,12 +9,12 @@ def gabor(
ppd=None,
shape=None,
frequency=None,
n_phases=None,
phase_width=None,
n_bars=None,
bar_width=None,
period="ignore",
rotation=0.0,
phase_shift=None,
intensities=(0.0, 1.0),
intensity_bars=(0.0, 1.0),
origin=None,
round_phase_width=False,
sigma=None,
Expand All @@ -31,6 +31,8 @@ def gabor(
shape [height, width] of image, in pixels
frequency : Number, or None (default)
spatial frequency of grating, in cycles per degree visual angle
n_bars : Number, or None (default)
number of bars in the grating
bar_width : Number, or None (default)
width of a single bar, in degrees visual angle
sigma : float or (float, float)
Expand Down Expand Up @@ -65,12 +67,12 @@ def gabor(
ppd=ppd,
shape=shape,
frequency=frequency,
n_phases=n_phases,
phase_width=phase_width,
n_phases=n_bars,
phase_width=bar_width,
period=period,
rotation=rotation,
phase_shift=phase_shift,
intensities=intensities,
intensities=intensity_bars,
origin=origin,
distance_metric="oblique",
round_phase_width=round_phase_width,
Expand All @@ -82,7 +84,7 @@ def gabor(
sigma=sigma,
origin=origin,
)
mean_int = (intensities[0] + intensities[1]) / 2
mean_int = (intensity_bars[0] + intensity_bars[1]) / 2
stim["img"] = (stim["img"] - mean_int) * gaussian_window["img"] + mean_int

return {
Expand Down
172 changes: 129 additions & 43 deletions stimupy/stimuli/plaids.py
@@ -1,17 +1,98 @@
from stimupy.components import waves
from stimupy.components.gaussians import gaussian
from stimupy.stimuli import gabors as gabors_stim
from stimupy.stimuli import waves

__all__ = [
"plaid",
"sine_waves",
"square_waves",
]


def plaid(
def add_waves(wave_dict1, wave_dict2, weight1=1, weight2=1):
"""
Create plaid-like stimulus by adding two waves
Parameters
----------
wave_dict1 : dict
dictionary which contains the first wave-array (key: "img"), as well as
keys "shape" and "ppd"
wave_dict2 : dict
dictionary which contains the second wave-array (key: "img"), as well as
keys "shape" and "ppd"
weight1 : float, optional
Factor with which the first wave is multiplied. The default is 1.
weight2 : float, optional
Factor with which the second wave is multiplied. The default is 1.
Returns
-------
wave_dict1 : dict
dictionary with plaid-like stimulus and additional keys if specified.
"""
if wave_dict1["shape"] != wave_dict2["shape"]:
raise ValueError(
f"Waves have different shapes; 1: {wave_dict1['shape']}, 2: {wave_dict2['shape']}"
)
if wave_dict1["ppd"] != wave_dict2["ppd"]:
raise ValueError(
"Waves have different ppds; 1: {wave_dict1['ppd']}, 2: {wave_dict2['ppd']}"
)

img = weight1 * wave_dict1["img"] + weight2 * wave_dict2["img"]
img = img / (weight1 + weight2)

# Update parameters
wave_dict1["img"] = img
try:
wave_dict1["grating_mask2"] = wave_dict2["grating_mask"]
wave_dict1["frequency2"] = wave_dict2["frequency"]
wave_dict1["phase_width2"] = wave_dict2["phase_width"]
wave_dict1["n_phases2"] = wave_dict2["n_phases"]
except Exception:
pass
return wave_dict1


def gabors(
gabor_parameters1,
gabor_parameters2,
weight1=1,
weight2=1,
):
"""Draw plaid consisting of two gabors
Parameters
----------
gabor_parameters1 : dict
kwargs to generate first Gabor
gabor_parameters2 : dict
kwargs to generate second Gabor
weight1 : float
weight of first Gabor (default: 1)
weight2 : float
weight of second Gabor (default: 1)
Returns
-------
dict[str, Any]
dict with the stimulus (key: "img"),
mask with integer index for each phase (key: "grating_mask"),
and additional keys containing stimulus parameters
"""

# Create sine-wave gratings
grating1 = gabors_stim.gabor(**gabor_parameters1)
grating2 = gabors_stim.gabor(**gabor_parameters2)
plaid = add_waves(grating1, grating2, weight1, weight2)
return plaid


def sine_waves(
grating_parameters1,
grating_parameters2,
weight1=1,
weight2=1,
sigma=None,
):
"""Draw plaid consisting of two sine-wave gratings
Expand All @@ -25,8 +106,6 @@ def plaid(
weight of first grating (default: 1)
weight2 : float
weight of second grating (default: 1)
sigma : float or (float, float)
sigma of Gaussian window in degree visual angle (y, x)
Returns
-------
Expand All @@ -35,40 +114,46 @@ def plaid(
mask with integer index for each phase (key: "grating_mask"),
and additional keys containing stimulus parameters
"""
if sigma is None:
raise ValueError("plaid() missing argument 'sigma' which is not 'None'")

# Create sine-wave gratings
grating1 = waves.sine(**grating_parameters1)
grating2 = waves.sine(**grating_parameters2)

if grating1["shape"] != grating2["shape"]:
raise ValueError("Gratings must have the same shape")
if grating1["ppd"] != grating2["ppd"]:
raise ValueError("Gratings must have same ppd")
if grating1["origin"] != grating2["origin"]:
raise ValueError("Grating origins must be the same")

# Create Gaussian window
window = gaussian(
visual_size=grating1["visual_size"],
ppd=grating1["ppd"],
sigma=sigma,
origin=grating1["origin"],
)

img = (weight1 * grating1["img"] + weight2 * grating2["img"]) * window["img"]
img = img / (weight1 + weight2)
grating1 = waves.sine_linear(**grating_parameters1)
grating2 = waves.sine_linear(**grating_parameters2)
plaid = add_waves(grating1, grating2, weight1, weight2)
return plaid

# Update parameters
grating1["img"] = img
grating1["sigma"] = sigma
grating1["grating_mask2"] = grating2["grating_mask"]
grating1["frequency2"] = grating2["frequency"]
grating1["phase_width2"] = grating2["phase_width"]
grating1["n_phases2"] = grating2["n_phases"]
grating1["gaussian_mask"] = window["gaussian_mask"]
return grating1

def square_waves(
grating_parameters1,
grating_parameters2,
weight1=1,
weight2=1,
):
"""Draw plaid consisting of two square-wave gratings
Parameters
----------
grating_parameters1 : dict
kwargs to generate first sine-wave grating
grating_parameters2 : dict
kwargs to generate second sine-wave grating
weight1 : float
weight of first grating (default: 1)
weight2 : float
weight of second grating (default: 1)
Returns
-------
dict[str, Any]
dict with the stimulus (key: "img"),
mask with integer index for each phase (key: "grating_mask"),
and additional keys containing stimulus parameters
"""

# Create sine-wave gratings
grating1 = waves.square_linear(**grating_parameters1)
grating2 = waves.square_linear(**grating_parameters2)
plaid = add_waves(grating1, grating2, weight1, weight2)
return plaid


def overview(**kwargs):
Expand All @@ -84,29 +169,30 @@ def overview(**kwargs):
"ppd": 10,
"origin": "center",
"phase_shift": 30,
"distance_metric": "oblique",
}
default_params.update(kwargs)

grating1 = {
**default_params,
"phase_width": 3.5,
"period": "odd",
"bar_width": 1,
"period": "ignore",
"rotation": 0,
"round_phase_width": False,
}

grating2 = {
**default_params,
"phase_width": 3.5,
"bar_width": 0.5,
"period": "ignore",
"rotation": 90,
"round_phase_width": False,
}

# fmt: off
stimuli = {
"plaid": plaid(grating1, grating2, sigma=4.0),
"plaid_gabors": gabors({**grating1, "sigma": 3}, {**grating2, "sigma": 3}),
"plaid_sine_waves": sine_waves(grating1, grating2),
"plaid_square_waves": square_waves(grating1, grating2),
}
# fmt: on

Expand Down

0 comments on commit 71f5348

Please sign in to comment.