In [None]:
%matplotlib widget

In [None]:
from pathlib import Path
import numpy as np
import flammkuchen as fl
import pandas as pd

import skimage as sk
import json

from split_dataset import SplitDataset
from bouterin.plots.stimulus_log_plot import get_paint_function
from scipy.cluster.hierarchy import dendrogram, linkage, cut_tree, to_tree, set_link_color_palette
from sklearn.cluster import AgglomerativeClustering
from fimpylab.core.lightsheet_experiment import LightsheetExperiment

from bouter.utilities import reliability 
from skimage.filters import threshold_otsu
import xarray as xr

from matplotlib import  pyplot as plt
import seaborn as sns
from tqdm import tqdm
sns.set(style="ticks", palette="deep")
cols = sns.color_palette()
import ipywidgets as widgets

from lotr.pca import pca_and_phase, fictive_heading_and_fit, fit_phase_neurons
from circle_fit import hyper_fit
from lotr import LotrExperiment, A_FISH

import lotr.plotting as pltltr
COLS = pltltr.COLS

from lotr.utils import roll_columns_jit

from scipy.optimize import quadratic_assignment
from lotr.pca import qap_sorting_and_phase

In [None]:
def nan_phase_jumps(phase_array):
    out_array = phase_array.copy()
    out_array[1:][np.abs(np.diff(out_array)) > np.pi] = np.nan
    return out_array

In [None]:
master =  Path(r"\\funes\Shared\experiments\E0040_motions_cardinal\v13_cw_ccw\ls_fixed\spont_plus_v13\huc")
fish_list = list(master.glob("*_f*"))
path = fish_list[1]

cropped_data = fl.load(path / "ahb_cropped.h5")
coords = cropped_data['coords']
traces = cropped_data['traces']
ahb_idx = cropped_data['ahb_idx']

In [None]:
suite2p_data = fl.load(path / "data_from_suite2p_cells.h5")
anatomy = suite2p_data['anatomy_stack']

#df = fl.load(path / "bouts_df.h5")# exp.get_bout_properties()
exp = LotrExperiment(path)
fs = int(exp.fn)
beh_df = exp.behavior_log

t_start_s = 50
t_lims = (t_start_s*exp.fn, exp.n_pts//2)
t_slice = slice(*t_lims)

In [None]:
t = np.arange(np.shape(traces)[1]) / fs
len_rec, num_traces = np.shape(traces)
print("num_traces: ", num_traces)
print("len_rec: ", len_rec)
print("sampling rate: ", fs)

In [None]:
######################### Part 2 - looking for neurons that reliably respond to the visual stimulus
# selectnig reliable neruons 
n_blocks = 4
norm_traces = traces.T
stim_traces = norm_traces[:, ((len_rec // 2)):]
new_len_rec = len_rec // (2 * n_blocks)
print(np.shape(stim_traces))

trial_traces = np.zeros((n_blocks, num_traces, new_len_rec))
trial_traces = np.zeros((num_traces, n_blocks, new_len_rec))
trial_traces_corrected = np.zeros((num_traces, n_blocks, new_len_rec))

for i in range(n_blocks):
    t1 = i * new_len_rec
    t2 = t1 + new_len_rec
    trial_traces[:, i] = stim_traces[:, t1:t2]

In [None]:
dt = 1 / fs
traces_xr = xr.DataArray(
    data=trial_traces,                               #Adding the data
    dims=['roi', 'block', 't'],                #Defining name of the dimensions
    coords={                                   #Defining values at which each dimension wase valuated
        'roi':np.arange(trial_traces.shape[0]), 
        'block':np.arange(n_blocks),
        't':np.arange(trial_traces.shape[2])*dt
        }
    )
reliability_arr = reliability(np.swapaxes(traces_xr, 0, 2).values)
rel_thresh = threshold_otsu(reliability_arr)
print("Reliability threshold: ", rel_thresh)

rel_thresh_3 = np.round(rel_thresh * 1000)
rel_thresh_3 /=1000

#rel_thresh_3 = 0.33

In [None]:
selected_vis = np.where(reliability_arr > rel_thresh_3)[0]
selected_non_vis = np.where(reliability_arr < rel_thresh_3)[0]
print(np.shape(selected_non_vis))
print(np.shape(selected_vis))
print(np.shape(reliability_arr))

In [None]:
### Get negatively correlated traces out of the non visual neurons 
## only looking at the first half of the experiment (darkness)

traces_non_vis = traces[:, selected_non_vis]

norm_traces = traces_non_vis.T
corrmat = np.corrcoef(norm_traces)

In [None]:
fig2, ax = plt.subplots(2, 2, figsize=(8, 6), gridspec_kw={'width_ratios': [1, 3]})
ax[0,0].imshow(corrmat, cmap='coolwarm', vmin=-1, vmax=1)

In [None]:
ng_corr_ind = np.unique(np.where(corrmat < -0.5)[0])
corrmat_neg = corrmat[ng_corr_ind]
norm_traces_neg = traces_non_vis[:, ng_corr_ind]
num_neurons = np.shape(ng_corr_ind)[0]
len_rec = np.shape(traces)[0]  / fs 

perm, com_phase = qap_sorting_and_phase(norm_traces_neg)
sorted_traces = norm_traces_neg[:, perm]
sorted_corrmat = np.corrcoef(sorted_traces)

In [None]:
np.shape(norm_traces_neg)

In [None]:
ax[0,1].imshow(norm_traces, cmap='gray_r', extent=[0, num_neurons, 0, num_neurons], vmin=-5, vmax=5)

corrmat_neg = np.corrcoef(norm_traces_neg.T)
ax[1,0].imshow(sorted_corrmat, cmap='coolwarm', vmin=-1, vmax=1)
ax[1,1].imshow(sorted_traces, cmap='gray_r', extent=[0, num_neurons, 0, num_neurons], vmin=-5, vmax=5)


In [None]:
# figure 2 - visual responses
xlim1 = 0
xlim2 = len_rec
fig_vis, ax_vis = plt.subplots(2, 2, figsize=(8, 4), gridspec_kw={'width_ratios': [1, 2], 'height_ratios': [1, 6]}) 

extent_new_ls = (0, exp.plane_ext_um[3], 0, exp.plane_ext_um[1])

exp = LotrExperiment(path)
selected_vis = np.where(reliability_arr > rel_thresh_3)[0]
selected_non_vis = np.where(reliability_arr < rel_thresh_3)[0]

#print(selected_vis)
coords_aHB = exp.coords_um[ahb_idx, 1:]
coords_vis = coords_aHB[selected_vis]
coords_non_vis = coords_aHB[selected_non_vis]

############ Anatomy plot of Visually responsive neurons:

for i in range(2):
    for j in range(2):
        ax_vis[i,j].axis('off')
    
anatomy_layer = np.sum(anatomy, axis=0)
anatomy_layer = np.rot90(anatomy_layer, k=2, axes=(1, 0))
ax_vis[1, 0].imshow(anatomy_layer[:, ::-1], extent=extent_new_ls, origin="upper", cmap='gray_r')#, vmin=9, vmax=100)

colored_rois = exp.color_rois_by(np.ones(len(selected_vis))*5, indexes=selected_vis, color_scheme='rainbow')

ax_vis[1, 0].scatter(coords_aHB[:,1], coords_aHB[:,0], s=5, c=reliability_arr, cmap='coolwarm', vmin=-1, vmax=1, alpha=0.5)
#ax_vis[1, 0].scatter(coords_vis[:,1], coords_vis[:,0], s=5, c='purple')
pltltr.add_anatomy_scalebar(ax_vis[1, 0], pos=(-10,-10))

############ Visual stimulus:
ax_vis[0, 1].set_xlim(xlim1, xlim2)
ax_vis[0, 0].set_title("Reliability thresh: " + str(rel_thresh_3))

stimulus_log = exp.stimulus_log
stim_value, t_values = get_paint_function(stimulus_log, 'E0040_motions_cardinal')
stim_value = stim_value / 255
num_stim = np.shape(stim_value)[0]

t_values *= fs
for i in range(num_stim):
    ax_vis[0, 1].axvspan(
        t_values[i, 0],
        t_values[i, 1],
        facecolor=[
            stim_value[i, 0],
            stim_value[i, 1],
            stim_value[i, 2],
        ],
        alpha=0.7,
    )
    
############ Traces (only visually responsive taces):
tmp_traces = traces.T
rel_ind = np.where(reliability_arr > rel_thresh_3)[0]
vis_traces = tmp_traces[rel_ind, :]
im = ax_vis[1,1].imshow(
    vis_traces,
    extent=[0, exp.time_arr[-1], -np.pi, np.pi],
    aspect="auto",
    cmap=COLS["dff_plot"],
    vmin=-1.7,
    vmax=2.0,
)

pltltr.add_dff_cbar(
    im,
    ax_vis[1,1],
    (1.07, 0.04, 0.03, 0.5),
    title="ΔF (Z.)",
    titlesize=6,
    labelsize=5,
    ticklabels=None,
)

In [None]:
file_name = "visually responsive neurons in aHB (thresh " + str(rel_thresh_3) + ").pdf"
fig_vis.savefig(path / file_name, dpi=300)