diff --git a/stimuli/illusions/grating.py b/stimuli/illusions/grating.py index 9ef05458..8da6099a 100644 --- a/stimuli/illusions/grating.py +++ b/stimuli/illusions/grating.py @@ -93,7 +93,7 @@ def square_wave( # Update and return stimulus stim["bars_mask"] = stim["mask"] - stim["mask"] = targets_mask + stim["mask"] = targets_mask.astype(int) return stim @@ -174,7 +174,7 @@ def grating_uniform( ) stim["mask"] = pad_to_visual_size( img=stim["mask"], visual_size=image_size, ppd=ppd, pad_value=0 - ) + ).astype(int) # Repack stim.update( @@ -218,9 +218,9 @@ def grating_grating( # Superimpose small_grating_mask = rectangle( - rectangle_size=small_grating["visual_size"], + rectangle_size=np.array(small_grating["shape"]) / ppd, ppd=ppd, - visual_size=large_grating["visual_size"], + visual_size=np.array(large_grating["shape"]) / ppd, intensity_background=0, intensity_rectangle=1, rectangle_position=( @@ -228,8 +228,10 @@ def grating_grating( ) / 2, ) + small_grating["img"] = pad_to_shape(small_grating["img"], shape=large_grating["img"].shape) small_grating["mask"] = pad_to_shape(small_grating["mask"], shape=large_grating["img"].shape) + img = np.where(small_grating_mask["mask"], small_grating["img"], large_grating["img"]) mask = np.where(small_grating_mask["mask"], small_grating["mask"], large_grating["img"])