In [None]:
import numpy as np

import adaptive_latents.input_sources.datasets as datasets
import adaptive_latents.plotting_functions as pf
import matplotlib.pyplot as plt
import prediction_regression_pipeline as prp
from adaptive_latents.plotting_functions import PredictionVideo
from adaptive_latents import (
    AnimationManager
)
import adaptive_latents
import pandas as pd
from importlib import reload
from IPython import display
from tqdm.notebook import tqdm
from adaptive_latents.plotting_functions import plot_history_with_tail
import matplotlib.cm as cm



In [None]:
adaptive_latents.plotting_functions = reload(adaptive_latents.plotting_functions)

# Intro Video

In [None]:
def intro_plot_fish(d, tail_length=5, duration=10, start_time=10, fps=20):
    with AnimationManager(n_cols=2, n_rows=1, figsize=(10, 5), filetype='gif', fps=fps) as am:
        am.axs[0,1].scatter(d.tail_position[:,-1,0], d.tail_position[:,-1,1])
        
        for current_t in tqdm(np.linspace(start_time, start_time + duration, int(duration*fps))):
            am.axs[0,0].cla()
    
            n_columns = np.floor(tail_length / np.median(np.diff(d.neural_data.t))).astype(int)
            idx = np.nonzero(~(d.neural_data.t < current_t))[0][0]
            
            
            am.axs[0,0].imshow(d.neural_data.a[idx-n_columns:idx,0,:200].T, aspect='auto', interpolation='none', extent=[current_t - tail_length, current_t, d.neural_data.a.shape[2], 0],
                              vmin=d.neural_data.a.min(),vmax=d.neural_data.a.max())
            am.axs[0,0].set_xticklabels([])
    
    
    
            old_lims = am.axs[0,1].axis()
            am.axs[0,1].cla()
            am.axs[0,1].axis('off')
            
            s = np.nonzero(((current_t - tail_length) < d.behavioral_data.t) & (d.behavioral_data.t < current_t))[0]
            
            for i, sample in enumerate(d.behavioral_data.a[s,...]):
                am.axs[0,1].plot(sample[:,0], sample[:,1], color='gray', alpha=(i/s.size)*.5)
                if i == s.size-1:
                    am.axs[0,1].plot(sample[:,0], sample[:,1], color='w', linewidth=3)
                    am.axs[0,1].plot(sample[:,0], sample[:,1], color=f'C{d.pose_class[s[i]]}', alpha=1)
            pf.use_bigger_lims(am.axs[0,1], old_lims)
            
            am.grab_frame()
    
    display.display(display.Video(am.outfile, embed=True))
    plt.close()

In [None]:
d = datasets.Naumann24uDataset(datasets.Naumann24uDataset.sub_datasets[1], beh_type='whole tail')
intro_plot_fish(d, start_time=315, duration=19, fps=50, tail_length=1.1)

# Stim part

In [None]:
fish_run = prp.PipelineRun(**(prp.PipelineRun.default_parameter_values['naumann24u'] | dict(exit_time=-1, beh_type='bout')))

## Dimensionality reduction comparison

In [None]:
run = fish_run
fig, axs = plt.subplots(ncols=3, nrows=1, squeeze=False, figsize=(10,4), layout='tight')


tail_length = 1
with AnimationManager(fig=fig, make_axs=False, fps=30, filetype='webm') as am:
    for current_t in np.linspace(1500, 1550, 100):
        plot_history_with_tail(axs[0,0], run.pro_latents, current_t, tail_length=2)
        plot_history_with_tail(axs[0,1], run.jpca_latents, current_t, tail_length=2)
        plot_history_with_tail(axs[0,2], run.ica_latents, current_t, tail_length=2)
        # axs[0,2].set_ylim([-3,5])
        am.grab_frame()

plt.close()
am.display_video()

In [None]:
run = fish_run
fig, axs = plt.subplots(ncols=3, nrows=1, squeeze=False, figsize=(10,4), layout='tight')
tail_length = 4

current_t = 1580
plot_history_with_tail(axs[0,0], run.pro_latents, current_t, tail_length=tail_length)
plot_history_with_tail(axs[0,1], run.jpca_latents, current_t, tail_length=tail_length)
plot_history_with_tail(axs[0,2], run.ica_latents, current_t, tail_length=tail_length)

fig.savefig(adaptive_latents.CONFIG['plot_save_path'] / 'fish_clouds.svg')


In [None]:
# run = zong_run
run = fish_run
grid_n=13
square_radius=None
arrow_alpha=0
scatter_alpha=0
current_t = 1580

fig, axs = plt.subplots(nrows=1, ncols=3, squeeze=False, layout='tight', figsize=(10, 4))
axs = axs.T
e1, e2 = np.zeros(6), np.zeros(6)
e1[0] = 1
e2[1] = 1

for idx, latents in enumerate([run.pro_latents, run.jpca_latents, run.ica_latents]):
    ax: plt.Axes = axs[idx,0]
    
    plot_history_with_tail(ax, latents, current_t, tail_length=4, invisible=False)

    if idx == 2:
        ax.set_xlim([-6,6])
        ax.set_ylim([-3,5])

    d_latents = np.diff(latents, axis=0)
    d_latents = d_latents / np.linalg.norm(d_latents, axis=1)[:, np.newaxis]
    
    x1, x2, y1, y2 = ax.axis()
    x_points = np.linspace(x1, x2, grid_n)
    y_points = np.linspace(y1, y2, grid_n)
    if square_radius is not None:
        x_points = np.linspace(-square_radius, square_radius, grid_n)
        y_points = np.linspace(-square_radius, square_radius, grid_n)

    origins = []
    arrows = []
    n_points = []
    for i in range(len(x_points) - 1):
        for j in range(len(y_points) - 1):
            proj_1 = (latents[:-1] @ e1)
            proj_2 = (latents[:-1] @ e2)
            s = (x_points[i] <= proj_1) & (proj_1 < x_points[i + 1]) & (y_points[j] <= proj_2) & (
                    proj_2 < y_points[j + 1])
            if s.sum():
                arrow = d_latents[s].mean(axis=0)
                arrow = arrow / np.linalg.norm(arrow)
                arrows.append(arrow)
                origins.append([x_points[i:i + 2].mean(), y_points[j:j + 2].mean()])
                n_points.append(s.sum())
    origins, arrows, n_points = np.array(origins), np.array(arrows), np.array(n_points)
    # n_points = n_points / 5
    n_points = n_points > 15
    # n_points = 1
    ax.quiver(origins[:, 0], origins[:, 1], arrows @ e1, arrows @ e2, scale=1 / 20, alpha=n_points, units='dots', color='red')

    # ax.axis('equal')
    ax.axis('off')

axs[0, 0].set_ylabel('pro')
axs[1, 0].set_ylabel('jpca')
axs[2, 0].set_ylabel('ica')
fig.savefig(adaptive_latents.CONFIG['plot_save_path'] / 'fish_clouds_with_arrows.svg')

In [None]:
locations = np.loadtxt(d.dataset_base_path/d.sub_dataset/'contours.txt')

In [None]:
plt.imshow(np.load(d.dataset_base_path/d.sub_dataset/'red_channel_image.npy'), cmap='gray', vmax=350)
# plt.colorbar()
plt.plot(locations[:,1], locations[:,0], 'r.')


In [None]:
plt.matshow(np.loadtxt(d.dataset_base_path/d.sub_dataset/'image.txt'))

## Stimulation figure

In [None]:
run = fish_run
d = fish_run.d
latents = run.jpca_latents


fig, axs = plt.subplots(ncols=2,figsize=(8,4), layout='tight')


dim_1, dim_2 = 0,1


im = np.loadtxt(d.dataset_base_path/d.sub_dataset/'image.txt')
means = np.mean(latents, axis=0)
ax = axs[0]
ax.cla()
current_t = 1464
adaptive_latents.plotting_functions.plot_history_with_tail(ax, latents, current_t, tail_length=12, invisible=False, scatter_alpha=.05, scatter_s=20)
s = latents.t <= current_t
current_state = latents[s][-1]
ax.arrow( current_state[dim_1], current_state[dim_2], means[dim_1] - current_state[dim_1], means[dim_2] - current_state[dim_2], zorder=5, color='k', length_includes_head=True,  width=0.001, head_width=.5, )
ax.set_ylim([-10, 10])
ax.set_ylim([-10, 10])
# ax.axis('equal')
ax = axs[1]
ax.cla()
desired_state = means - current_state
desired_stim = (desired_state @ run.pro.Q.T)[:run.d.neural_data.a.shape[-1]]
desired_stim = np.abs(desired_stim)
desired_stim[np.abs(desired_stim) < .5] = np.nan
ax.imshow(im, cmap=cm.gray, vmin=15, vmax=50)
ax.scatter(d.neuron_locations[:,1], d.neuron_locations[:,0], s=desired_stim*50, color='red')
ax.axis('off')
axs[1].set_ylim([350, 0])
fig.savefig(adaptive_latents.CONFIG['plot_save_path'] / 'fish_stim.svg')

In [None]:
plt.hist((im).flatten(), 100);

In [None]:
latents = run.jpca_latents


fig, axs = plt.subplots(ncols=2,figsize=(8,4), layout='tight')


dim_1, dim_2 = 0,1


planes = []
for i in range(500):
    run.d.raw_images.seek(i)
    planes.append(np.array(run.d.raw_images))

im = np.mean(planes, axis=0)

means = np.mean(latents, axis=0)

tail_length = 1
with AnimationManager(fig=fig, make_axs=False, fps=20, filetype='mp4', dpi=400) as am:
    a = 500 + 10.5
    l = 3
    for current_t in np.linspace(a, a+l, int(round(l*20))):
        last_frame = current_t == a+l
        
        ax = axs[0]
        ax.cla()
        # plot_history_with_tail(ax, latents, current_t, tail_length=1)
        plot_history_with_tail(ax, latents, current_t, tail_length=1, invisible=False)

        s = latents.t <= current_t
        current_state = latents[s][-1]

        if last_frame:
            ax.arrow(
            current_state[dim_1], current_state[dim_2], means[dim_1] - current_state[dim_1], means[dim_2] - current_state[dim_2], 
            zorder=5, color='k', length_includes_head=True,  width=0.001, head_width=.04,
            )
        ax.axis('equal');
        # ax.set_title(f'{current_t:.2f}')


        
        ax = axs[1]
        ax.cla()
        if last_frame:
            desired_state = means - current_state
            desired_stim = (desired_state @ run.pro.Q.T)[:run.d.neural_data.a.shape[-1]]
    
            desired_stim = np.abs(desired_stim)
            desired_stim[np.abs(desired_stim) < .1] = np.nan
    
            ax.matshow(-im, cmap='Grays')
            xs, ys = list(zip(*[cell['med'] for cell in run.d.stat]))
            
    
            ax.scatter(ys, xs, s=desired_stim*30, color='red')
        ax.axis('off')

        am.grab_frame()