diff --git a/stimuli/utils/plotting.py b/stimuli/utils/plotting.py index 549bbad4..a145bb9f 100644 --- a/stimuli/utils/plotting.py +++ b/stimuli/utils/plotting.py @@ -52,13 +52,15 @@ def plot_stim( img = np.dstack([img, img, img]) mask = np.dstack([mask, mask, mask]) - if np.unique(mask).size > 10: + if np.unique(mask).size >= 20: + colormap = plt.cm.colors.ListedColormap(np.random.rand(mask.max()+1,3)) + elif np.unique(mask).size > 10 and np.unique(mask).size < 20: colormap = plt.cm.tab20 else: colormap = plt.cm.tab10 for idx in np.unique(mask)[np.unique(mask) > 0]: - color = colormap.colors[idx % 19] + color = colormap.colors[idx] color = np.reshape(color, (1, 1, 3)) img = np.where(mask == idx, color, img) ax.imshow(img)