In [None]:
# %load_ext autoreload
# %autoreload 2

# Readme

Plot results of forward pass

In [None]:
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from alphacnn.model.cnn_model import prepare_gpu
prepare_gpu()

In [None]:
!nvidia-smi

In [None]:
from alphacnn.model.cnn_model import evaluate_layers, evaluate_rgc_layers
from alphacnn.visualize import plot_stimulus
from alphacnn.visualize.plot_model import plot_simulation

In [None]:
from alphacnn.database.encoder_schema import *
from alphacnn.utils.data_utils import load_config
from alphacnn import paths

connect_to_database(
    dj_config_file=paths.CONFIG_FILE,
    create_tables=True, create_schema=True, schema_name=paths.SCHEMA_PREFIX + 'encoder')
encoder_schema

In [None]:
from alphaanalysis import plot as plota

plota.set_default_params(kind='paper')

# Parameters for text

In [None]:
StimulusConfig().fetch1('stimulus_dict')

In [None]:
(Stimulus() & Stimulus().fetch('KEY')[-1]).fetch1('video').shape

In [None]:
(BCSpatialRFOutput() & BCSpatialRFOutput().fetch('KEY')[0]).fetch1('bc_srf_output').shape

In [None]:
(RGCSynapticInputs() & (RGCSynapticInputs & "rgc_id='nsl'").fetch('KEY')[0]).fetch1('rgc_synaptic_inputs').shape

In [None]:
(RGCSynapticInputs() & (RGCSynapticInputs & "rgc_id='tmp'").fetch('KEY')[0]).fetch1('rgc_synaptic_inputs').shape

In [None]:
(BCsRfConfig & dict(bc_srf_config_id='ss')).fetch1('bc_cdist')

In [None]:
bc_srf_ss = (BCsRfConfig & dict(bc_srf_config_id='ss')).fetch1('bc_srf')
bc_srf_ss.shape

In [None]:
bc_srf_ws = (BCsRfConfig & dict(bc_srf_config_id='ws')).fetch1('bc_srf')
bc_srf_ws.shape

In [None]:
RGCSynapticWeights()

In [None]:
150 / 15

In [None]:
(RGCSynapticWeights & dict(rgc_id='nsl')).fetch1('rgc_synaptic_weights_1').shape

In [None]:
75	/ 15

In [None]:
(RGCSynapticWeights & dict(rgc_id='tmp')).fetch1('rgc_synaptic_weights_1').shape

# Plot BCs

## Plot RFs

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
from alphaanalysis import plot as plota
import seaborn as sns
#cmap = sns.diverging_palette(h_neg=250, h_pos=10, s=100, l=50, sep=1, n=10, center="light", as_cmap=True)



def plot_srf_1d(ax, srf, orientation='horizontal'):
    cmap = sns.color_palette('bwr', as_cmap=True)
    
    x = np.arange(srf.shape[1]) - srf.shape[0]//2
    y = srf[srf.shape[0]//2]
    
    xi = np.linspace(x[0], x[-1], 1001, endpoint=True)
    yi = np.interp(xi, xp=x, fp=y)
    
    ypos = np.where(yi > 0)[0]
    i1 = np.min(ypos)
    i2 = np.max(ypos)
    
    print(i1, i2)
    
    if orientation == 'vertical':
        raise NotImplementedError()
    else:
        ax.fill_between(xi[:i1+1], yi[:i1+1], color=cmap(0.), lw=0)
        ax.fill_between(xi[i2:], yi[i2:], color=cmap(0.), lw=0)
        ax.fill_between(xi[i1:i2+1], yi[i1:i2+1], color=cmap(1.), lw=0)
    ax.axis('off')
    ax.invert_xaxis()

    
def plot_srf_2d(ax, srf, cax=None, pixel_size_um=None):
    cmap = sns.color_palette('bwr', as_cmap=True)
    
    srf = srf/srf.max()
    vabsmax = np.max(np.abs(srf))
    extent = (-srf.shape[0]//2, srf.shape[0]-srf.shape[0]//2, -srf.shape[1]//2, srf.shape[1]-srf.shape[0]//2)
    im = ax.imshow(srf, vmin=-vabsmax, vmax=vabsmax, extent=extent, cmap=cmap)
    ax.set(xticks=[], yticks=[])
    if cax is not None:
        plt.colorbar(im, cax=cax, shrink=2, ticks=(-1, 0, 1))
        
    if pixel_size_um is not None:
        size = 100
        plota.plot_scale_bar(ax=ax, x0=np.mean(extent[:2]), y0=extent[2]+10, size=size/pixel_size_um, text=f'{size:d} µm', pad=-3)


for name, srf in dict(ss=bc_srf_ss, ws=bc_srf_ws).items():
    fig, axs = plt.subplots(2, 2, height_ratios=(1, 3), width_ratios=(10, 1), sharex='col', figsize=(0.7, 0.7))

    ax = axs[0, 0]
    plot_srf_1d(ax, srf, orientation='horizontal')
    ax = axs[0, 1].axis('off')
    
    ax = axs[1, 0]
    cax = axs[1, 1]
    plot_srf_2d(ax, srf, cax=cax)#, pixel_size_um=5)

    plt.savefig(f'figures/sRf_{name}.pdf', dpi=300, bbox_inches='tight')
    plt.show()

## Plot NLs

In [None]:
BCRectConfig()

In [None]:
bc_nl_ss = (BCRectConfig & dict(bc_rect_config_id='ss')).fetch1('bc_nl')
bc_nl_ss.shape

In [None]:
bc_nl_ws = (BCRectConfig & dict(bc_rect_config_id='ws')).fetch1('bc_nl')
bc_nl_ws.shape

In [None]:
from alphacnn.model.cnn_model import parametrized_sigmoid


def plot_nl(ax, nl):
    x = np.linspace(-0.1, +0.4, 101)
    y = parametrized_sigmoid(x, *nl)
    ax.plot(x, y, color='k', lw=1)
    ax.axis('off')

    
for name, srf in dict(ss=bc_nl_ss, ws=bc_nl_ws).items():
    fig, ax = plt.subplots(1, 1, figsize=(0.7, 0.7))
    plot_nl(ax, srf)
    plt.savefig(f'figures/bc_nl_{name}.pdf', dpi=300, bbox_inches='tight')
    plt.show()

## Plot synaptic weights

In [None]:
syn12_nsl = (RGCSynapticWeights & dict(rgc_id='nsl')).fetch1('rgc_synaptic_weights_1', 'rgc_synaptic_weights_2')
syn12_nsl[0].shape

In [None]:
syn12_tmp = (RGCSynapticWeights & dict(rgc_id='tmp')).fetch1('rgc_synaptic_weights_1', 'rgc_synaptic_weights_2')
syn12_tmp[0].shape

In [None]:
syn12_tmp_ss = (RGCSynapticWeights & dict(rgc_id='tmp_ss')).fetch1('rgc_synaptic_weights_1', 'rgc_synaptic_weights_2')
syn12_tmp_ws = (RGCSynapticWeights & dict(rgc_id='tmp_ws')).fetch1('rgc_synaptic_weights_1', 'rgc_synaptic_weights_2')

In [None]:
from alphaanalysis import plot as plota
import seaborn as sns

In [None]:
name_dict = dict(
    nsl=r'n$_\mathrm{wi}$',
    tmp_ws=r't$_\mathrm{wi}$',
    tmp=r't$_\mathrm{mi}$',
    tmp_ss=r't$_\mathrm{si}$',
)

In [None]:
def plot_syn_2d(axs, syn12, cax=None, pixel_size_um=None, vmax=None, cmap='viridis'):
    syn1, syn2 = syn12
    
    if vmax is None:
        vmax = np.max([np.max(syn1), np.max(syn2)])
        
    extent = (-syn1.shape[0]//2, syn1.shape[0]-syn1.shape[0]//2, -syn1.shape[1]//2, syn1.shape[1]-syn1.shape[0]//2)
    
    for ax, syn in zip(axs, (syn1, syn2)):
        syn = syn.copy()
        #syn[syn<1e-9] = np.nan
        im = ax.imshow(syn, vmin=0, vmax=vmax, extent=extent, cmap=cmap, interpolation=None)
        ax.set(xticks=[], yticks=[])
        
    if cax is not None:
        plt.colorbar(im, cax=cax, shrink=2)
        
    if pixel_size_um is not None:
        size = 100
        plota.plot_scale_bar(ax=ax, x0=np.mean(extent[:2]), y0=extent[2]+10, size=size/pixel_size_um, text=f'{size:d} µm', pad=-3)


def plot_synaptic_weights(axs, data_dict, yticks_list=None):
    cmap = sns.color_palette("viridis", as_cmap=True)
    
    vmax = np.max([np.max([np.max(syn1), np.max(syn2)]) for syn1, syn2 in  data_dict.values()])
    
    for ax_row, (name, syn12) in zip(axs, data_dict.items()):
        plot_syn_2d(ax_row[:2], syn12, vmax=None, cax=ax_row[2])
        plota.row_title(ax_row[0], name_dict[name], pad=10)

    if yticks_list is not None:
        for i, ax in enumerate(yticks_list):
            axs[i, 2].set_yticks(yticks_list[i])


fig, axs = plt.subplots(4, 3, figsize=(1.3, 1.5), width_ratios=(8, 8, 1))

plot_synaptic_weights(axs,
                      dict(tmp=syn12_tmp, tmp_ss=syn12_tmp_ss, tmp_ws=syn12_tmp_ws, nsl=syn12_nsl),
                     yticks_list=[[0, 0.01], [0, 0.02],[0, 0.02], [0, 0.003]])

plt.tight_layout(rect=(0.05, 0, 1.0, 1), h_pad=0.5, w_pad=0.8, pad=0.2)

plt.savefig(f'figures/all_syns.pdf', dpi=300)
plota.show_saved_figure(fig)
plt.show()

## Plot encoder parameters in one

In [None]:
fig, axs = plt.subplots(8, 3, figsize=(1.7, 3.5), width_ratios=(8, 8, 1), height_ratios=(1, 0.01, 1, 0.4, 1, 1, 1, 1))

for ax, (name, srf) in zip(axs[0, :], dict(w=bc_srf_ws, s=bc_srf_ss).items()):
    ax.set_title(name)
    plot_srf_2d(ax, srf, cax=axs[0, -1])

for ax in axs[1, :]:
    ax.axis('off')

axs[2, -1].axis('off')
for ax, (name, srf) in zip(axs[2, :], dict(w=bc_nl_ws, s=bc_nl_ss).items()):
    plot_nl(ax, srf)

for ax in axs[3, :]:
    ax.axis('off')

plot_synaptic_weights(axs[-4:, :], dict(tmp=syn12_tmp, tmp_ss=syn12_tmp_ss, tmp_ws=syn12_tmp_ws, nsl=syn12_nsl), 
                      yticks_list=[[0, 0.01], [0, 0.02],[0, 0.02], [0, 0.003]])

plt.tight_layout(rect=(0.05, 0, 1.0, 1), h_pad=0.3, w_pad=0.5, pad=0.2)

plt.savefig(f'figures/all_encoder_params.pdf', dpi=300)
plota.show_saved_figure(fig)
plt.show()

# Plot encoding

In [None]:
StimulusIDs()

In [None]:
stim_key1 = dict(stimulus_id="f002_368415-368640-hr_right.mp4", wo_cricket=0)
stim_key2 = dict(stimulus_id="f002_177568-178166-hr_right.mp4", wo_cricket=1)

video1 = (Stimulus() & stim_key1).fetch1('video')
video2 = (Stimulus() & stim_key2).fetch1('video')

In [None]:
plot_stimulus.plot_video_frames(video1, n_rows=4, n_cols=8);

In [None]:
plot_stimulus.plot_video_frames(video2, n_rows=4, n_cols=8);

In [None]:
idx1 = 74
idx1b = 40
idx2 = 160

In [None]:
bc_key = dict(bc_noise_id='med')

bc_srf_outputs_1 = (BCSpatialRFOutput() & stim_key1 & bc_key).fetch('bc_srf_output')
assert len(bc_srf_outputs_1) == 2

In [None]:
bc_rect_key = [dict(bc_srf_config_id='ss', bc_rect_config_id='ss'), dict(bc_srf_config_id='ws', bc_rect_config_id='ws')]

bc_rect_outputs_1 = (BCRectOutput() & stim_key1 & bc_key & bc_rect_key).fetch('bc_rect_output')
assert len(bc_rect_outputs_1) == 2
#bc_srf_outputs_2 = (BCSpatialRFOutput() & stim_key2 & bc_key).fetch('bc_rect_output')

In [None]:
bc_noise_key = dict(bc_noise_sample=0, bc_noise_id='med')

bc_noise_outputs_1 = (BCNoiseOutput() & stim_key1 & bc_key & bc_rect_key & bc_noise_key).fetch('bc_noise_output')
assert len(bc_noise_outputs_1) == 2
bc_noise_outputs_2 = (BCNoiseOutput() & stim_key2 & bc_key & bc_rect_key & bc_noise_key).fetch('bc_noise_output')
assert len(bc_noise_outputs_2) == 2

In [None]:
nsl_rgc_key = dict(rgc_id='nsl')

nsl_input1 = (RGCSynapticInputs() & stim_key1 & nsl_rgc_key & bc_key & bc_rect_key & bc_noise_key).fetch1('rgc_synaptic_inputs')
nsl_input2 = (RGCSynapticInputs() & stim_key2 & nsl_rgc_key & bc_key & bc_rect_key & bc_noise_key).fetch1('rgc_synaptic_inputs')

In [None]:
tmp_rgc_key = dict(rgc_id='tmp')

tmp_input1 = (RGCSynapticInputs() & stim_key1 & tmp_rgc_key & bc_key & bc_rect_key & bc_noise_key).fetch1('rgc_synaptic_inputs')
tmp_input2 = (RGCSynapticInputs() & stim_key2 & tmp_rgc_key & bc_key & bc_rect_key & bc_noise_key).fetch1('rgc_synaptic_inputs')

In [None]:
tmp_ws_rgc_key = dict(rgc_id='tmp_ws')

tmp_ws_input1 = (RGCSynapticInputs() & stim_key1 & tmp_ws_rgc_key & bc_key & bc_rect_key & bc_noise_key).fetch1('rgc_synaptic_inputs')
tmp_ws_input2 = (RGCSynapticInputs() & stim_key2 & tmp_ws_rgc_key & bc_key & bc_rect_key & bc_noise_key).fetch1('rgc_synaptic_inputs')

In [None]:
tmp_ss_rgc_key = dict(rgc_id='tmp_ss')

tmp_ss_input1 = (RGCSynapticInputs() & stim_key1 & tmp_ss_rgc_key & bc_key & bc_rect_key & bc_noise_key).fetch1('rgc_synaptic_inputs')
tmp_ss_input2 = (RGCSynapticInputs() & stim_key2 & tmp_ss_rgc_key & bc_key & bc_rect_key & bc_noise_key).fetch1('rgc_synaptic_inputs')

In [None]:
def plot_foward_pass_title(title, ax):
    ax.text(0.5, 0.5, title, ha='center', va='center', transform=ax.transAxes)
    ax.axis('off')

In [None]:
def get_vrng(images):
    vmin = np.min([np.min(image) for image in images])
    vmax = np.max([np.max(image) for image in images])
    
    return vmin, vmax

In [None]:
def plot_images(images, axs, cax=None, cmap='gray', vmin=None, vmax=None, cbar_kwargs=None):

    if vmin is None or vmax is None:
        vmin, vmax = get_vrng(images)
    
    for image, ax in zip(images, axs):
        im = ax.imshow(image, cmap='gray', vmin=vmin, vmax=vmax, interpolation='none')
        ax.set(xticks=[], yticks=[])
        
    if cax is not None:
        if cbar_kwargs is None:
            cbar_kwargs = dict()
        
        plt.colorbar(im, cax=cax, **cbar_kwargs)

In [None]:
mosaic = [
    ["t_Stim"] * 2 + ['.'],
    ["Stim", "Stim", "Stim_cb"],

    ['.'] * 3,
    
    ["t_BCsRF"] * 2 + ['.'],
    ["BCsRF2", "BCsRF1", "BCsRF_cb"],
    
    ["t_BCnl"] * 2 + ['.'],
    ["BCnl2", "BCnl1", "BCnl_cb"],
    
    ["t_BCnoise"] * 2 + ['.'],
    ["BCnoise2", "BCnoise1", "BCnoise_cb"],

    ['.'] * 3,
    
    ["t_RGC"] * 2 + ['.'],

    ['.'] * 3,
    
    ["RGC", "RGC", "RGC_cb"],
]

fig, axs = plt.subplot_mosaic(mosaic=mosaic, figsize=(2., 0.3*len(mosaic)), height_ratios=[1, 4, 0.9, 1, 4, 1, 4, 1, 4, 0.9, 1, 0.9, 4], width_ratios=(1, 1, 1/10))

plot_foward_pass_title(title='Stimulus', ax=axs["t_Stim"])
plot_images(images=[video1[idx1, :, :]], axs=[axs["Stim"]], cax=axs["Stim_cb"])

plot_foward_pass_title(title='BC sRF', ax=axs["t_BCsRF"])
plot_images(images=[bc_srf_outputs_1[0][idx1], bc_srf_outputs_1[1][idx1]], axs=[axs["BCsRF1"], axs["BCsRF2"]], cax=axs["BCsRF_cb"])

plot_foward_pass_title(title='BC nl', ax=axs["t_BCnl"])
plot_images(images=[bc_rect_outputs_1[0][idx1], bc_rect_outputs_1[1][idx1]], axs=[axs["BCnl1"], axs["BCnl2"]], cax=axs["BCnl_cb"])

plot_foward_pass_title(title='BC noise', ax=axs["t_BCnoise"])
plot_images(images=[bc_noise_outputs_1[0][idx1], bc_noise_outputs_1[1][idx1]], axs=[axs["BCnoise1"], axs["BCnoise2"]], cax=axs["BCnoise_cb"])

plot_foward_pass_title(title='RGC dendrites', ax=axs["t_RGC"])
plot_images(images=[tmp_input1[idx1]], axs=[axs["RGC"]], cax=axs["RGC_cb"])

plt.tight_layout(h_pad=0)
plt.savefig('figures/forward_pass.pdf', dpi=300, bbox_inches='tight')

In [None]:
mosaic = [
    ["Stim1", "Stim2"],
    ["RGC1", "RGC2"],
]

fig, axs = plt.subplot_mosaic(mosaic=mosaic, figsize=(3, 3/2*len(mosaic)))

plot_images(images=[video1[idx1], video2[idx2]], axs=[axs["Stim1"], axs["Stim2"]], cax=None)
plot_images(images=[tmp_input1[idx1], tmp_input2[idx2]], axs=[axs["RGC1"], axs["RGC2"]], cax=None)

plt.tight_layout(h_pad=0)
plt.savefig('figures/forward_pass_input_output.pdf', dpi=300, bbox_inches='tight')

# Plot encodings for different encoders

In [None]:
stim_images = [video1[idx1, :, :], video1[idx1b, :, :], video2[idx2, :, :]]
bc_ws_images = [bc_noise_outputs_1[1][idx1], bc_noise_outputs_1[1][idx1b], bc_noise_outputs_2[1][idx2]]
bc_ss_images = [bc_noise_outputs_1[0][idx1], bc_noise_outputs_1[0][idx1b], bc_noise_outputs_2[0][idx2]]
t_images = [tmp_input1[idx1], tmp_input1[idx1b], tmp_input2[idx2]]
n_images = [nsl_input1[idx1], nsl_input1[idx1b], nsl_input2[idx2]]
t_ws_images = [tmp_ws_input1[idx1], tmp_ws_input1[idx1b], tmp_ws_input2[idx2]]
t_ss_images = [tmp_ss_input1[idx1], tmp_ss_input1[idx1b], tmp_ss_input2[idx2]]

cols = ['Stim', '.', 'BCws', 'BCss', '.', 'tmp', 'tmp_ss', 'tmp_ws', 'nsl']
width_ratios = [1, 0.2, 1, 1, 0.2, 1, 1, 1, 1]

n_rows = 3
mosaic = [[f"{col}{i}" if col not in ['.'] else col for col in cols] for i in range(n_rows)]
mosaic += [['Stim_cb', '.', 'BC_cb', '.', '.', 'RGC_cb', '.', '.', '.']]

height_ratios = [1] * n_rows + [0.1]

fig, axs = plt.subplot_mosaic(mosaic=mosaic, figsize=(5, 0.5*len(mosaic)), width_ratios=width_ratios, height_ratios=height_ratios)
for ax in axs.values():
    ax.set(xticks=[], yticks=[])
    
# Stim
stim_vmin, stim_vmax = get_vrng(images=stim_images)

axs['Stim0'].set(title='Stimulus\n')
plot_images(images=stim_images, axs=[axs[f'Stim{i}'] for i in range(n_rows)], cax=axs["Stim_cb"], vmin=stim_vmin, vmax=stim_vmax, cbar_kwargs=dict(orientation='horizontal'))

# BC
bc_vmin, bc_vmax = get_vrng(images=bc_ws_images + bc_ss_images)
axs['BCws0'].set(title='BCs\nws')
plot_images(images=bc_ws_images, axs=[axs[f'BCws{i}'] for i in range(n_rows)], cax=axs["BC_cb"], vmin=bc_vmin, vmax=bc_vmax, cbar_kwargs=dict(orientation='horizontal'))
axs['BCss0'].set(title='ss')
plot_images(images=bc_ss_images, axs=[axs[f'BCss{i}'] for i in range(n_rows)], vmin=bc_vmin, vmax=bc_vmax)

# RGC
rgc_vmin, rgc_vmax = get_vrng(images=t_images + n_images + t_ws_images + t_ss_images)

axs['tmp0'].set(title='RGCs\n' + r't$_\mathrm{mi}$')
plot_images(images=t_images, axs=[axs[f'tmp{i}'] for i in range(n_rows)], cax=axs["RGC_cb"], vmin=0, vmax=rgc_vmax, cbar_kwargs=dict(orientation='horizontal'))
axs['tmp_ws0'].set(title=r't$_\mathrm{wi}$')
plot_images(images=t_ws_images, axs=[axs[f'tmp_ws{i}'] for i in range(n_rows)], vmin=rgc_vmin, vmax=rgc_vmax)
axs['tmp_ss0'].set(title=r't$_\mathrm{si}$')
plot_images(images=t_ss_images, axs=[axs[f'tmp_ss{i}'] for i in range(n_rows)], vmin=rgc_vmin, vmax=rgc_vmax)
axs['nsl0'].set(title=r'n$_\mathrm{wi}$')
plot_images(images=n_images, axs=[axs[f'nsl{i}'] for i in range(n_rows)], vmin=rgc_vmin, vmax=rgc_vmax)

plt.savefig('figures/compare_encodings.pdf', dpi=300, bbox_inches='tight')

# Appendix

In [None]:
raise NotImplementedError()

## Videos

In [None]:
from matplotlib import animation
from alphacnn.visualize import plot_video
from ipywidgets import HTML

fps = 60
FFwriter = animation.FFMpegWriter(fps=fps)

In [None]:
results['BC-rect'][:, :, :, 0].shape

In [None]:
HTML(plot_video.array_to_anim(video[-results['BC-rect'].shape[0]:], fps=fps, cbar=True).to_html5_video())

In [None]:
HTML(plot_video.array_to_anim(results['BC-rect'][:, :, :, 0], fps=fps, cbar=True).to_html5_video())

In [None]:
HTML(plot_video.array_to_anim(results['BC-rect'][:, :, :, 1], fps=fps, cbar=True).to_html5_video())

In [None]:
HTML(plot_video.array_to_anim(results['RGC-nsl-input'][:, :, :], fps=fps, cbar=True).to_html5_video())

In [None]:
HTML(plot_video.array_to_anim(results['RGC-tmp-input'][:, :, :], fps=fps, cbar=True).to_html5_video())

In [None]:
HTML(plot_video.array_to_anim(results['RGC-nsl_alt-input'][:, :, :], fps=fps, cbar=True).to_html5_video())

In [None]:
HTML(plot_video.array_to_anim(results['RGC-tmp_alt-input'][:, :, :], fps=fps, cbar=True).to_html5_video())

# Save to video

In [None]:
# n_clip_end = 1
# data_sets = {
#     'stimulus': video[-results['BC-rect'].shape[0]:-n_clip_end],
#     'RGCs_tmp': results['RGC-tmp-input'][:-n_clip_end, :, :],
#     'RGCs_nsl': results['RGC-nsl-input'][:-n_clip_end, :, :],
#     'BCs_ws': results['BC-rect'][:-n_clip_end, :, :, 0],
#     'BCs_ss': results['BC-rect'][:-n_clip_end, :, :, 1],
# }
# 
# for name, data in data_sets.items():
#     anim = plot_video.array_to_anim(data, fps=fps, cbar=False, axis_off=True, xy_upsample=5 if data.shape[1] < 30 else 0)
#     anim.save(os.path.join(paths.VIDEO_OUT_PATH, f"example_{name}.mp4"), writer = FFwriter)