In [None]:
import numpy as np

import adaptive_latents.input_sources.datasets as datasets
import adaptive_latents.plotting_functions as pf
import matplotlib.pyplot as plt
import prediction_regression_pipeline as prp
from adaptive_latents.plotting_functions import PredictionVideo
from adaptive_latents import (
    AnimationManager
)
import adaptive_latents
import pandas as pd
from importlib import reload
from IPython import display
from adaptive_latents.plotting_functions import plot_history_with_tail
adaptive_latents.plotting_functions = reload(adaptive_latents.plotting_functions)



In [None]:
zong_run = prp.PipelineRun(**(prp.PipelineRun.default_parameter_values['naumann24u'] | dict(
    drop_behavior=True, 
    sub_dataset_identifier=2,
    neural_smoothing_tau=.688,
)))


In [None]:
%matplotlib qt
plt.plot(zong_run.pro_latents.t, zong_run.pro_latents)

for s in run.d.opto_stimulations.loc[30:31,'sample']:
    plt.axvline(s, color='k')

## Dimensionality reduction comparison

In [None]:
zong_run.d.sub_dataset, zong_run.d.sub_datasets

In [None]:
run.d.opto_stimulations[run.d.opto_stimulations.loc[:,'sample'] == current_t].target_neuron.astype('int')

In [None]:
from tqdm.notebook import tqdm

In [None]:
%matplotlib inline

run = zong_run

start_t = run.d.opto_stimulations.loc[:,'sample'].min()
end_t = run.d.neural_data.t[-1]

marks = run.d.opto_stimulations.loc[:,'sample']


tail_length = 10

for l, name in [(run.ica_latents, 'ICA'), (run.jpca_latents, 'jPCA'), (run.pro_latents, 'proSVD')]:
    fig, axs = plt.subplots(ncols=1, nrows=1, squeeze=False, figsize=(5,4), layout='tight')
    with AnimationManager(fig=fig, make_axs=False, fps=10, filetype='mp4', dpi=200) as am:
        for current_t in tqdm(np.arange(start_t, end_t)):
                
            adaptive_latents.plotting_functions.plot_history_with_tail(axs[0,0], l, current_t, scatter_all=False, tail_length=tail_length, alpha=.2)
            axs[0,0].set_title(name)
            if current_t in list(marks):
                neuron = run.d.opto_stimulations[run.d.opto_stimulations.loc[:,'sample'] == current_t].target_neuron.iloc[0].astype(int)
                axs[0,0].scatter(.05 ,.95, transform=axs[0,0].transAxes, s=20, color='red')
            am.grab_frame()

    plt.close()

In [None]:
%matplotlib inline

run = zong_run
fig, axs = plt.subplots(ncols=3, nrows=1, squeeze=False, figsize=(10,4), layout='tight')

start_t = run.d.opto_stimulations.loc[29,'sample']
end_t = run.d.opto_stimulations.loc[31,'sample']

marks = run.d.opto_stimulations.loc[:,'sample']

tail_length = 5
with AnimationManager(fig=fig, make_axs=False, fps=10, filetype='mp4') as am:
    for current_t in np.arange(start_t, end_t):
            
        adaptive_latents.plotting_functions.plot_history_with_tail(axs[0,0], run.pro_latents, current_t, scatter_all=False, tail_length=tail_length, alpha=.2)
        adaptive_latents.plotting_functions.plot_history_with_tail(axs[0,1], run.jpca_latents, current_t,scatter_all=False, tail_length=tail_length, alpha=.2)
        adaptive_latents.plotting_functions.plot_history_with_tail(axs[0,2], run.ica_latents, current_t, scatter_all=False, tail_length=tail_length, alpha=.2)
        if current_t == marks:
            axs[0,0].scatter(.05 ,.95, transform=axs[0,0].transAxes, s=20, color='red')
        am.grab_frame()

plt.close()

In [None]:
am.outfile

In [None]:
display.display(display.Video(am.outfile, embed=True))

In [None]:
run = zong_run
grid_n=13
square_radius=None
arrow_alpha=0
scatter_alpha=0

fig, axs = plt.subplots(nrows=1, ncols=3, squeeze=False, layout='tight', figsize=(10, 4))
axs = axs.T
e1, e2 = np.zeros(6), np.zeros(6)
e1[0] = 1
e2[1] = 1

for idx, latents in enumerate([run.pro_latents, run.jpca_latents, run.ica_latents]):
    ax: plt.Axes = axs[idx,0]
    
    plot_history_with_tail(ax, latents, current_t, tail_length=1, invisible=True)

    if idx == 2:
        ax.set_xlim([-6,6])
        ax.set_ylim([-3,5])

    d_latents = np.diff(latents, axis=0)
    d_latents = d_latents / np.linalg.norm(d_latents, axis=1)[:, np.newaxis]
    
    x1, x2, y1, y2 = ax.axis()
    x_points = np.linspace(x1, x2, grid_n)
    y_points = np.linspace(y1, y2, grid_n)
    if square_radius is not None:
        x_points = np.linspace(-square_radius, square_radius, grid_n)
        y_points = np.linspace(-square_radius, square_radius, grid_n)

    origins = []
    arrows = []
    n_points = []
    for i in range(len(x_points) - 1):
        for j in range(len(y_points) - 1):
            proj_1 = (latents[:-1] @ e1)
            proj_2 = (latents[:-1] @ e2)
            s = (x_points[i] <= proj_1) & (proj_1 < x_points[i + 1]) & (y_points[j] <= proj_2) & (
                    proj_2 < y_points[j + 1])
            if s.sum():
                arrow = d_latents[s].mean(axis=0)
                arrow = arrow / np.linalg.norm(arrow)
                arrows.append(arrow)
                origins.append([x_points[i:i + 2].mean(), y_points[j:j + 2].mean()])
                n_points.append(s.sum())
    origins, arrows, n_points = np.array(origins), np.array(arrows), np.array(n_points)
    # n_points = n_points / 5
    n_points = n_points > 15
    # n_points = 1
    ax.quiver(origins[:, 0], origins[:, 1], arrows @ e1, arrows @ e2, scale=1 / 20, alpha=n_points, units='dots', color='red')

    # ax.axis('equal')
    ax.axis('off')

axs[0, 0].set_ylabel('pro')
axs[1, 0].set_ylabel('jpca')
axs[2, 0].set_ylabel('ica')

## Stimulation figure

In [None]:
run = zong_run
latents = run.jpca_latents


fig, axs = plt.subplots(ncols=2,figsize=(8,4), layout='tight')


dim_1, dim_2 = 0,1


planes = []
for i in range(500):
    run.d.raw_images.seek(i)
    planes.append(np.array(run.d.raw_images))

im = np.mean(planes, axis=0)

means = np.mean(latents, axis=0)

tail_length = 1
with AnimationManager(fig=fig, make_axs=False, fps=20, filetype='mp4', dpi=400) as am:
    a = 500
    l = 10.5
    for current_t in np.linspace(a, a+l, int(round(l*20))):
        last_frame = current_t == a+l
        
        ax = axs[0]
        ax.cla()
        # plot_history_with_tail(ax, latents, current_t, tail_length=1)
        plot_history_with_tail(ax, latents, current_t, tail_length=1, invisible=False)

        s = latents.t <= current_t
        current_state = latents[s][-1]

        if last_frame:
            ax.arrow(
            current_state[dim_1], current_state[dim_2], means[dim_1] - current_state[dim_1], means[dim_2] - current_state[dim_2], 
            zorder=5, color='k', length_includes_head=True,  width=0.001, head_width=.04,
            )
        ax.axis('equal');
        # ax.set_title(f'{current_t:.2f}')


        
        ax = axs[1]
        ax.cla()
        if last_frame:
            desired_state = means - current_state
            desired_stim = (desired_state @ run.pro.Q.T)[:run.d.neural_data.a.shape[-1]]
    
            desired_stim = np.abs(desired_stim)
            desired_stim[np.abs(desired_stim) < .1] = np.nan
    
            ax.matshow(-im, cmap='Grays')
            xs, ys = list(zip(*[cell['med'] for cell in run.d.stat]))
            
    
            ax.scatter(ys, xs, s=desired_stim*30, color='red')
        ax.axis('off')

        am.grab_frame()

In [None]:
run = zong_run
latents = run.jpca_latents


fig, axs = plt.subplots(ncols=2,figsize=(8,4), layout='tight')


dim_1, dim_2 = 0,1


planes = []
for i in range(500):
    run.d.raw_images.seek(i)
    planes.append(np.array(run.d.raw_images))

im = np.mean(planes, axis=0)

means = np.mean(latents, axis=0)

tail_length = 1
with AnimationManager(fig=fig, make_axs=False, fps=20, filetype='mp4', dpi=400) as am:
    a = 500 + 10.5
    l = 3
    for current_t in np.linspace(a, a+l, int(round(l*20))):
        last_frame = current_t == a+l
        
        ax = axs[0]
        ax.cla()
        # plot_history_with_tail(ax, latents, current_t, tail_length=1)
        plot_history_with_tail(ax, latents, current_t, tail_length=1, invisible=False)

        s = latents.t <= current_t
        current_state = latents[s][-1]

        if last_frame:
            ax.arrow(
            current_state[dim_1], current_state[dim_2], means[dim_1] - current_state[dim_1], means[dim_2] - current_state[dim_2], 
            zorder=5, color='k', length_includes_head=True,  width=0.001, head_width=.04,
            )
        ax.axis('equal');
        # ax.set_title(f'{current_t:.2f}')


        
        ax = axs[1]
        ax.cla()
        if last_frame:
            desired_state = means - current_state
            desired_stim = (desired_state @ run.pro.Q.T)[:run.d.neural_data.a.shape[-1]]
    
            desired_stim = np.abs(desired_stim)
            desired_stim[np.abs(desired_stim) < .1] = np.nan
    
            ax.matshow(-im, cmap='Grays')
            xs, ys = list(zip(*[cell['med'] for cell in run.d.stat]))
            
    
            ax.scatter(ys, xs, s=desired_stim*30, color='red')
        ax.axis('off')

        am.grab_frame()