Skip to content

Commit

Permalink
Clean up RHS2007 plotting code a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
JorisVincent committed Jun 7, 2022
1 parent ab5d943 commit ea466cb
Showing 1 changed file with 27 additions and 49 deletions.
76 changes: 27 additions & 49 deletions stimuli/papers/RHS2007.py
Expand Up @@ -161,53 +161,31 @@ def bullseye_thick():

if __name__ == "__main__":
import matplotlib.pyplot as plt
plot_all = True
if plot_all:
stims = {
"WE_thick": WE_thick,
"WE_thin_wide": WE_thin_wide,
"WE_dual": WE_dual,
"WE_anderson": WE_anderson,
"WE_howe": WE_howe,
"WE_radial_thick_small": WE_radial_thick_small,
"WE_radial_thick": WE_radial_thick,
"WE_radial_thin_small": WE_radial_thin_small,
"WE_radial_thin": WE_radial_thin,
"WE_circular1": WE_circular1,
"WE_circular05": WE_circular05,
"WE_circular025": WE_circular025,
"grating_induction": grating_induction,
"sbc_large": sbc_large,
"sbc_small": sbc_small,
"todorovic_equal": todorovic_equal,
"todorovic_in_large": todorovic_in_large,
"todorovic_in_small": todorovic_in_small,
"checkerboard_0.16": checkerboard_016,
"checkerboard_0.938": checkerboard_0938,
"checherboard_2.09": checkerboard209
}

a = math.ceil(math.sqrt(len(stims)))
plt.figure(figsize=(a*3, a*3))
for i, (stim_name, stim) in enumerate(stims.items()):
print("Generating", stim_name+"")
st = stim()
img, mask = st["img"], st["mask"]
img = np.dstack([img, img, img])

mask = np.insert(np.expand_dims(mask, 2), 1, 0, axis=2)
mask = np.insert(mask, 2, 0, axis=2)
final = mask + img
final /= np.max(final)

plt.subplot(a, a, i + 1)
plt.title(stim_name + " - img")
plt.imshow(final)

plt.tight_layout()

else:
plt.imshow(img, cmap='gray')

plt.savefig("overview_RHS2007.png")

stims = {}
for stimname in __all__:
print("Generating " + stimname)
try:
stims[stimname] = globals()[stimname]()
except NotImplementedError:
print("-- not implemented")

# Plot each stimulus+mask
n_stim = math.ceil(math.sqrt(len(stims)))
plt.figure(figsize=(n_stim * 3, n_stim * 3))
for i, (stim_name, stim) in enumerate(stims.items()):
img, mask = stim["img"], stim["mask"]
img = np.dstack([img, img, img])

mask = np.insert(np.expand_dims(mask, 2), 1, 0, axis=2)
mask = np.insert(mask, 2, 0, axis=2)
final = mask + img
final /= np.max(final)

plt.subplot(n_stim, n_stim, i + 1)
plt.title(stim_name)
plt.imshow(final)

plt.tight_layout()

plt.show()

0 comments on commit ea466cb

Please sign in to comment.