In [None]:
%matplotlib widget

from pathlib import Path
import numpy as np
import flammkuchen as fl
import pandas as pd

from fimpylab import LightsheetExperiment

from matplotlib import  pyplot as plt
import seaborn as sns
from tqdm import tqdm
sns.set(style="ticks", palette="deep")
cols = sns.color_palette()
import ipywidgets as widgets

from lotr.utils import zscore
from lotr.pca import pca_and_phase, get_fictive_heading, fictive_heading_and_fit, \
        fit_phase_neurons
from circle_fit import hyper_fit
from lotr import LotrExperiment, A_FISH

import lotr.plotting as pltltr
COLS = pltltr.COLS

from lotr.utils import interpolate, roll_columns_jit, zscore

In [None]:
import skimage as sk
from skimage import io as skio
import json
import pandas as pd

from split_dataset import SplitDataset
from bouterin.plots.stimulus_log_plot import get_paint_function
from scipy.cluster.hierarchy import dendrogram, linkage, cut_tree, to_tree, set_link_color_palette
from sklearn.cluster import AgglomerativeClustering
from fimpylab.core.lightsheet_experiment import LightsheetExperiment

In [None]:
def nan_phase_jumps(phase_array):
    out_array = phase_array.copy()
    out_array[1:][np.abs(np.diff(out_array)) > np.pi] = np.nan
    return out_array

In [None]:
plt.close("all")
master = Path(r"\\Funes\Shared\experiments\E0040_motions_cardinal\v13_cw_ccw\ls_fixed\spont_plus_v13\new")
fish_list = list(master.glob("*_f*"))
path = fish_list[4]

traces = fl.load(path / "filtered_traces.h5", "/detr")
traces_full = fl.load(path / "filtered_traces.h5", "/detr")
selected = fl.load(path / "selected.h5")
traces_hdn = traces_full[:, selected]

reg_df = fl.load(path / "motor_regressors.h5")
cc_motor = reg_df["all_bias_abs"].values
cc_motor_integr = reg_df["all_bias_abs_dfdt"].values
coords = fl.load(path / "data_from_suite2p_unfiltered.h5", "/coords")
anat = fl.load(path / "data_from_suite2p_unfiltered.h5", "/anatomy_stack")

df = fl.load(path / "bouts_df.h5")# exp.get_bout_properties()
exp = LotrExperiment(path)
fn = int(exp.fn)
beh_df = exp.behavior_log

t_start_s = 50
t_lims = (t_start_s*exp.fn, exp.n_pts//2)
t_slice = slice(*t_lims)

In [None]:
np.shape(traces)

In [None]:
pca_scores, angles, _, _ = pca_and_phase(traces[t_slice, selected].T, traces[t_slice, selected].T)
pcaed, phase, _, _ = pca_and_phase(traces[t_slice, selected], traces[:, selected])

In [None]:

unwrapped_phase = np.unwrap(phase)
traj, params = fictive_heading_and_fit(unwrapped_phase, df, min_bias=0.1)  # min bias adjusted in some fish to compensate tail noise

exp = LotrExperiment(path)
coords = exp.coords_um[selected, 1:]

sort_idxs = np.argsort(exp.rpc_angles)
phase_shifts = (exp.network_phase / (2 * np.pi)) * (exp.n_hdns - 1)


In [None]:
suite2p_data = fl.load(path / "data_from_suite2p_unfiltered.h5")
suite2p_data.keys()
roi_map = suite2p_data['rois_stack']
coords = suite2p_data['coords']
anatomy = suite2p_data['anatomy_stack']

In [None]:
fs = exp.fs
t = np.arange(np.shape(traces)[1]) / fs
len_rec, num_traces = np.shape(traces)
print("num_traces: ", num_traces)
print("len_rec: ", len_rec)
print("sampling rate: ", fs)

In [None]:
# figure 1 - HDN neurons
xlim1 = 0
xlim2 = len_rec // exp.fs
print(xlim2)
fig_hdn, ax_hdn = plt.subplots(2, 2, figsize=(6, 3), gridspec_kw={'width_ratios': [1, 2], 'height_ratios': [1, 6]}) 

############ Anatomy plot of HDNs:
for i in range(2):
    for j in range(2):
        ax_hdn[i,j].axis('off')

roi_layer = np.sum(roi_map, axis=0)
ind_roi = np.where(roi_layer > 1)
roi_layer = np.ma.masked_where(roi_layer < 1, roi_layer)
roi_layer = np.rot90(roi_layer, k=1, axes=(1, 0))
#ax_hdn[1,0].imshow(roi_layer, cmap="gray_r")

#ax_hdn[1,0].imshow(np.sum(anatomy, axis=0), cmap="gray_r")
anatomy_layer = np.mean(anatomy, axis=0)
anatomy_layer = np.rot90(anatomy_layer, k=3, axes=(1, 0))
colored_rois = exp.color_rois_by(np.ones(len(selected))*5, indexes=selected, color_scheme='rainbow')
#colored_rois = exp.color_rois_by(angles, indexes=selected, color_scheme=COLS["phase"])
ax_hdn[1, 0].imshow(anatomy_layer, extent=exp.plane_ext_um, origin="upper", cmap='gray_r')


ax_hdn[1, 0].imshow(np.rot90(colored_rois.max(0), 2), extent=exp.plane_ext_um, origin="upper", alpha=1)
pltltr.add_anatomy_scalebar(ax_hdn[1, 0], pos=(-10,-10))

############ Tail:
t_beh = np.asarray(beh_df["t"])
t_beh[-1]
ax_hdn[0, 1].plot(t_beh, beh_df["tail_sum"], color=cols[7], label='Tail', rasterized=True)
ax_hdn[0, 1].legend(loc=2, bbox_to_anchor=(0.8, 2), fontsize=7)
ax_hdn[0, 1].set_aspect('auto')
ax_hdn[0, 1].set_xlim(xlim1, xlim2)

############ Sorted traces:
N_BINS_RESAMPLED = 100
resampling_base = np.linspace(-np.pi, np.pi, N_BINS_RESAMPLED)
angle_resampled_traces = np.zeros((exp.n_pts, N_BINS_RESAMPLED))
for i in range(exp.n_pts):
    angle_resampled_traces[i, :] = np.interp(
        resampling_base,
        exp.rpc_angles[sort_idxs],
        exp.traces[i, exp.hdn_indexes[sort_idxs]],
    )

phase_shifts_resamp = (exp.network_phase / (2 * np.pi)) * (N_BINS_RESAMPLED - 1)
shifted_traces_resamp = roll_columns_jit(
    angle_resampled_traces, -np.round(phase_shifts_resamp)
)

ax = ax_hdn[1, 1]
im = ax.imshow(
    angle_resampled_traces.T,
    extent=[0, exp.time_arr[-1], -np.pi, np.pi],
    aspect="auto",
    cmap=COLS["dff_plot"],
    vmin=-1.7,
    vmax=2.0,
)

pltltr.add_dff_cbar(
    im,
    ax,
    (1.07, 0.04, 0.03, 0.5),
    title="ΔF (Z.)",
    titlesize=6,
    labelsize=5,
    ticklabels=None,
)

pltltr.despine(ax, ["left", "right", "top", "bottom"])
ax.set(ylabel="ROI angle", **pltltr.get_pi_labels(0.5, ax="y"))

line = nan_phase_jumps(phase_shifts) + exp.n_hdns / 3
line = line - np.nanmin(line)
line /=(np.nanmax(line))
line *= np.pi * 2
line -= np.pi

t_line = np.arange(0, np.shape(line)[0]) / 3
ax.plot(t_line, line, lw=1, c=cols[2], label="Network phase")
ax.legend(loc=2, bbox_to_anchor=(0.65, 1.15), fontsize=7)
ax.set_xlim(xlim1, xlim2)


In [None]:
file_name = path / "HDN_sorted_e0040_v13.jpg"
fig_hdn.savefig(file_name, dpi=300)

In [None]:
norm_traces = traces.T
norm_traces = norm_traces[:, exp.fn//2:]
fig1 = plt.figure(figsize=(2,2))
linked = linkage(norm_traces, method='ward')
dend = dendrogram(linked)
cluster = AgglomerativeClustering(n_clusters=12, affinity='euclidean', linkage='ward')
her_clustering = cluster.fit_predict(norm_traces)

In [None]:
def cluster_id_search(tree):
    nodes_list = []
    if tree.is_leaf():
        nodes_list.append(tree.get_id())
    else:
        nodes_list += cluster_id_search(tree.get_left())
        nodes_list += cluster_id_search(tree.get_right())
        
    return nodes_list

def find_trunc_dendro_clusters(linkage_mat, dendro):
    tree, branches = to_tree(linkage_mat, rd =True)
    ids = np.empty(linkage_mat.shape[0]+1, dtype=int)
    
    for i, clust in enumerate(dendro["leaves"]):
        ids[cluster_id_search(branches[clust])] = i
        
    return ids

In [None]:
# make truncated tree to get clusters ids. 
# Ugly but necessary to get the same sequence of leaves as the cut.
plt.figure(figsize=(0.1, 0.1))  
n_clust = 8
dendro = dendrogram(linked, n_clust, truncate_mode ="lastp")
plt.close()
cluster_ids = dendro["leaves"]
labels = find_trunc_dendro_clusters(linked, dendro)

In [None]:
meanresps = norm_traces
base_sub_mean = (meanresps.T - np.nanmean(meanresps[:,:8], 1)).T
X = base_sub_mean
#smooth_mean_resps = pd.DataFrame(meanresps.T).rolling(4, center=True).mean().as_matrix().T


In [None]:
from skimage import color
import numpy as np
from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram, cut_tree, set_link_color_palette
#import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import cm

In [None]:
def shade_plot(stim, ax=None, gamma=1/6, shade_range=(0.6, 0.9)):
    if type(stim) == list:  # these would be transitions
        _shade_plot(stim, ax=ax, gamma=gamma, shade_range=shade_range)

    elif type(stim) == Data:  # fish data
        transitions = find_transitions(Data.resampled_stim, Data.time_im_rep)
        _shade_plot(transitions, ax=ax, gamma=gamma, shade_range=shade_range)

    elif type(stim) == np.ndarray:  # stimulus array
        transitions = find_transitions(stim[:,1], stim[:,0])
        _shade_plot(transitions, ax=ax, gamma=gamma, shade_range=shade_range)

    elif type(stim) == tuple:  # time, lum tuple
        transitions = find_transitions(stim[1], stim[0])
        _shade_plot(transitions, ax=ax, gamma=gamma, shade_range=shade_range)


def _shade_plot(lum_transitions, ax=None, gamma=1/6, shade_range=(0.6, 0.9)):

    if ax is None:
        ax = plt.gca()
    shade = lum_transitions[0][1]
    for i in range(len(lum_transitions)-1):
        shade = shade + lum_transitions[i][1]
        new_shade = shade_range[0] + np.power(np.abs(shade), gamma) * (shade_range[1] - shade_range[0])
        ax.axvspan(lum_transitions[i][0], lum_transitions[i+1][0], color=(new_shade, )*3)
        

def _find_thr(linked, n_clust):
    interval = [0, 2000]
    new_height = np.mean(interval)
    clust = 0
    n_clust = n_clust
    while clust != n_clust:
        new_height = np.mean(interval)
        clust = cut_tree(linked, height=new_height).max()
        if clust > n_clust:
            interval[0] = new_height
        elif clust < n_clust:
            interval[1] = new_height


    return new_height


def find_plot_thr(linked, n_clust):
    min_thr = _find_thr(linked, n_clust - 1)
    return min_thr  


In [None]:

thr = find_plot_thr(linked, n_clust)
        
panel_dendro = dendrogram(linked,
                      color_threshold=thr,
                      #orientation='left',
                      distance_sort='descending',
                      show_leaf_counts=False,
                      no_labels=True,
                      above_threshold_color='#%02x%02x%02x' % (
                      120, 120, 120))

In [None]:
exp.fn

In [None]:
# getting stimulus regressors: 
xlim1 = 0
xlim2 = len_rec
fig_vis, ax_vis = plt.subplots(2, 2, figsize=(6, 3), gridspec_kw={'width_ratios': [1, 2], 'height_ratios': [1, 6]}) 

############ Anatomy plot of HDNs:
'''
for i in range(2):
    for j in range(2):
        ax_vis[i,j].axis('off')
'''
anatomy_layer = np.sum(anatomy, axis=0)
anatomy_layer = np.rot90(anatomy_layer, k=3, axes=(1, 0))
ax_vis[1, 0].imshow(anatomy_layer, extent=exp.plane_ext_um, origin="upper", cmap='gray_r')#, vmin=9, vmax=100)


pltltr.add_anatomy_scalebar(ax_hdn[1, 0], pos=(-10,-10))

############ Tail:
t_beh = np.asarray(beh_df["t"])
t_beh[-1]
#ax_vis[0, 1].plot(t_beh, beh_df["tail_sum"], color=cols[7], label='Tail', rasterized=True)
#ax_vis[0, 1].legend(loc=2, bbox_to_anchor=(0.8, 2), fontsize=7)
#ax_vis[0, 1].set_aspect('auto')
ax_vis[0, 1].set_xlim(xlim1, xlim2)

stimulus_log = exp.stimulus_log
stim_value, t_values = get_paint_function(stimulus_log, 'E0040_motions_cardinal')
stim_value = stim_value / 255
num_stim = np.shape(stim_value)[0]

t_values *= fs
for i in range(num_stim):
    ax_vis[0, 1].axvspan(
        t_values[i, 0],
        t_values[i, 1],
        facecolor=[
            stim_value[i, 0],
            stim_value[i, 1],
            stim_value[i, 2],
        ],
        alpha=0.7,
    )
    
############ Traces (clustered?):
tmp_traces = traces.T
im = ax_vis[1,1].imshow(
    tmp_traces[panel_dendro["leaves"], :],
    extent=[0, exp.time_arr[-1], -np.pi, np.pi],
    aspect="auto",
    cmap=COLS["dff_plot"],
    vmin=-1.7,
    vmax=2.0,
)

In [None]:
fig_reg, ax_reg = plt.subplots(1,1)
ax_reg.plot(stimulus_log.bg_y)
ax_reg.plot(stimulus_log.bg_x)
for i in range(num_stim):
    ax_reg.axvspan(
        t_values[i, 0],
        t_values[i, 1],
        facecolor=[
            stim_value[i, 0],
            stim_value[i, 1],
            stim_value[i, 2],
        ],
        alpha=0.7,
    )

In [None]:
print(xlim1, xlim2)

In [None]:
file_name = path / "all_traces_clusetered_e0040_v13.jpg"
fig_vis.savefig(file_name, dpi=300)

In [None]:
np.shape(traces)

In [None]:
from bouter.utilities import reliability 
from skimage.filters import threshold_otsu
import xarray as xr

In [None]:
# fifure 2 - Visually tuned neurons 
# selectnig reliable neruons 
n_blocks = 4
norm_traces = traces.T
stim_traces = norm_traces[:, ((len_rec // 2)):]
new_len_rec = len_rec // (2 * n_blocks)
print(np.shape(stim_traces))

trial_traces = np.zeros((n_blocks, num_traces, new_len_rec))
trial_traces = np.zeros((num_traces, n_blocks, new_len_rec))
trial_traces_corrected = np.zeros((num_traces, n_blocks, new_len_rec))

for i in range(n_blocks):
    t1 = i * new_len_rec
    t2 = t1 + new_len_rec
    trial_traces[:, i] = stim_traces[:, t1:t2]
avg_traces = np.nanmean(trial_traces, 1)
print(np.shape(trial_traces))


In [None]:
dt = 0.33
traces_xr = xr.DataArray(
    data=trial_traces,                               #Adding the data
    dims=['roi', 'block', 't'],                #Defining name of the dimensions
    coords={                                   #Defining values at which each dimension wase valuated
        'roi':np.arange(trial_traces.shape[0]), 
        'block':np.arange(n_blocks),
        't':np.arange(trial_traces.shape[2])*dt
        }
    )
reliability_arr = reliability(np.swapaxes(traces_xr, 0, 2).values)
rel_thresh = threshold_otsu(reliability_arr)
print("Reliability threshold: ", rel_thresh)

rel_thresh_3 = np.round(rel_thresh * 1000)
rel_thresh_3 /=1000

traces_xr_det = xr.DataArray(
    data=trial_traces_corrected,                               #Adding the data
    dims=['roi', 'block', 't'],                #Defining name of the dimensions
    coords={                                   #Defining values at which each dimension wase valuated
        'roi':np.arange(trial_traces_corrected.shape[0]), 
        'block':np.arange(n_blocks),
        't':np.arange(trial_traces_corrected.shape[2])*dt
        }
    )
reliability_arr_det = reliability(np.swapaxes(traces_xr_det, 0, 2).values)
rel_thresh_det = threshold_otsu(reliability_arr_det)
print("Reliability threshold: ", rel_thresh_det)

rel_thresh_3_det = np.round(rel_thresh_det * 1000)
rel_thresh_3_det /=1000
print(np.shape(reliability_arr_det))




In [None]:
#Visualize
fig, ax = plt.subplots(1, 2, figsize=(8,5))
ax[0].hist(reliability_arr, bins=50, density=True);
ax[0].axvline(rel_thresh, c='red', ls='--')

ax[0].set_xlim([-1,1])
ax[0].set_xlabel('Average correlation between reps')
ax[0].set_ylabel('Density')
ax[0].set_title("Reliability threshold: " + str(rel_thresh_3))
plt.tight_layout()

rel_ind = np.where(reliability_arr > rel_thresh_3)[0]
np.shape(rel_ind)

ax[1].imshow(stim_traces[rel_ind, :], extent=[0,500,0,500])



In [None]:
roi_map_rel = np.zeros_like(roi_map, dtype=float)
roi_map_rel_test = np.copy(roi_map)
for i in range(0, num_traces):
    #print(np.where(roi_map == (i)))
    #print(reliability_arr_det[i])
    #print("ddddd")
    roi_map_rel[np.where(roi_map == (i+1))] = reliability_arr[i]
    roi_map_rel_test[np.where(roi_map_rel_test == (i+1))] = reliability_arr[i]
    #print(roi_map_rel[np.where(roi_map == (i+1))])
print(np.unique(roi_map_rel))
#print(np.unique(reliability_arr_det))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4))

for i in range(10):

    roi_layer = roi_map_rel[i]
    roi_layer_orig = roi_map[i]
    roi_layer = np.ma.masked_where(roi_layer_orig < 0, roi_layer)
    roi_layer = np.rot90(roi_layer, k=1, axes=(1, 0))

    ax.imshow(roi_layer, cmap="coolwarm",  vmin=-0.5, vmax=0.5, alpha=0.5)
    ax.axis('off')
    
plt.show()