In [None]:
import matplotlib
%matplotlib widget

import numpy as np
from split_dataset import SplitDataset
from pathlib import Path
import flammkuchen as fl
from tifffile import imread
import matplotlib.pyplot as plt 
from fimpylab.core.lightsheet_experiment import LightsheetExperiment
from bouterin.plots.stimulus_log_plot import get_paint_function

from bouter.utilities import reliability 
from skimage.filters import threshold_otsu
import xarray as xr
from scipy.signal import detrend 

from motions.utilities import stim_vel_dir_dataframe, quantize_directions
import tifffile as tiff
from scipy.signal import argrelextrema
from scipy.signal import find_peaks

In [None]:
from scipy.cluster.hierarchy import dendrogram, cut_tree, set_link_color_palette
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import cm

In [None]:
from bouterin.plots.stimulus_log_plot import get_paint_function
from scipy.cluster.hierarchy import dendrogram, linkage, cut_tree, to_tree, set_link_color_palette
from sklearn.cluster import AgglomerativeClustering

In [None]:
def cluster_id_search(tree):
    nodes_list = []
    if tree.is_leaf():
        nodes_list.append(tree.get_id())
    else:
        nodes_list += cluster_id_search(tree.get_left())
        nodes_list += cluster_id_search(tree.get_right())
        
    return nodes_list

def find_trunc_dendro_clusters(linkage_mat, dendro):
    tree, branches = to_tree(linkage_mat, rd =True)
    ids = np.empty(linkage_mat.shape[0]+1, dtype=int)
    
    for i, clust in enumerate(dendro["leaves"]):
        ids[cluster_id_search(branches[clust])] = i
        
    return ids

In [None]:
def shade_plot(stim, ax=None, gamma=1/6, shade_range=(0.6, 0.9)):
    if type(stim) == list:  # these would be transitions
        _shade_plot(stim, ax=ax, gamma=gamma, shade_range=shade_range)

    elif type(stim) == Data:  # fish data
        transitions = find_transitions(Data.resampled_stim, Data.time_im_rep)
        _shade_plot(transitions, ax=ax, gamma=gamma, shade_range=shade_range)

    elif type(stim) == np.ndarray:  # stimulus array
        transitions = find_transitions(stim[:,1], stim[:,0])
        _shade_plot(transitions, ax=ax, gamma=gamma, shade_range=shade_range)

    elif type(stim) == tuple:  # time, lum tuple
        transitions = find_transitions(stim[1], stim[0])
        _shade_plot(transitions, ax=ax, gamma=gamma, shade_range=shade_range)


def _shade_plot(lum_transitions, ax=None, gamma=1/6, shade_range=(0.6, 0.9)):

    if ax is None:
        ax = plt.gca()
    shade = lum_transitions[0][1]
    for i in range(len(lum_transitions)-1):
        shade = shade + lum_transitions[i][1]
        new_shade = shade_range[0] + np.power(np.abs(shade), gamma) * (shade_range[1] - shade_range[0])
        ax.axvspan(lum_transitions[i][0], lum_transitions[i+1][0], color=(new_shade, )*3)
        

def _find_thr(linked, n_clust):
    interval = [0, 2000]
    new_height = np.mean(interval)
    clust = 0
    n_clust = n_clust
    while clust != n_clust:
        new_height = np.mean(interval)
        clust = cut_tree(linked, height=new_height).max()
        if clust > n_clust:
            interval[0] = new_height
        elif clust < n_clust:
            interval[1] = new_height


    return new_height


def find_plot_thr(linked, n_clust):
    min_thr = _find_thr(linked, n_clust - 1)
    return min_thr  


In [None]:
def plot_clusters_dendro(traces, stim, linkage_mat, labels, dendrolims=(900, 30),
                         thr=None, f_lim=2, gamma=1):
    fig_clust, ax = plt.subplots(3, 1, figsize=(10, 10))
    hexac = cluster_cols()

    n_clust = labels.max() + 1

    ##################
    ### Dendrogram ###
    # Compute and plot first dendrogram.
    if thr is None:
        thr = find_plot_thr(linkage_mat, n_clust)

    set_link_color_palette(hexac)
    
    ax_dendro = ax[2]
    ax_traces = ax[1]
    ax_clusters = ax[0]
    
    panel_dendro = dendrogram(linkage_mat,
                              color_threshold=thr,
                              #orientation='left',
                              distance_sort='descending',
                              show_leaf_counts=False,
                              no_labels=True,
                              above_threshold_color='#%02x%02x%02x' % (
                              120, 120, 120))
    
    ax_dendro.axhline(thr, linewidth=0.7, color="k")
    ax_dendro.axis("off")

    # Plot traces matrix.
    im = ax_traces.imshow(traces[panel_dendro["leaves"], :],
                         aspect='auto', origin='lower', cmap=cm.RdBu_r,
                         vmin=-f_lim, vmax=f_lim)
    ax_traces.axes.spines['left'].set_visible(False)
    ax_traces.set_yticks([])

    # Time bar:
    dt = stim[1, 0]
    barlength = 10
    bounds = np.array([traces.shape[1] - barlength / dt,
                       traces.shape[1]])

    ##################
    # Cluster sizes ##
    # Calculate size of each defined cluster to put colored labels on the side.
    # Find indervals spanned by each cluster in the sorted traces matrix.
    # Add percentages spanned by each cluster.
    sizes = np.cumsum(np.array([np.sum(labels == i) for i in range(np.max(labels) + 1)]))
    intervals = np.insert(sizes, 0, 0)

    ##################
    # Cluster means ##

    for i in range(n_clust):
        ax_clusters.plot(np.nanmean(traces[labels == i, :], 0) +
                      i * 5, label=i, color=hexac[i])
    ax_clusters.axes.spines['left'].set_visible(False)
    ax_clusters.set_yticks([])

    barlength = 10
    ax_traces.axis("off")
    ax_clusters.axis("off")

    return fig_clust

def cluster_cols():
    color_list = ["lightblue", "lightcoral", "orange", "springgreen", "deepskyblue", "mediumpurple","gold", "cyan", "crimson",
              "deeppink", "lawngreen", "darkviolet", "Darkgreen", "blue", "brown", "dodgerblue", "hotpink", "OliveDrab", "gray", "seagreen"][0:k]
    #color_list = ["lightblue", "lightcoral", "orange", "springgreen", "deepskyblue", "mediumpurple","gold", "cyan", "crimson", "deeppink", "lawngreen", "darkviolet"]
    return color_list

In [None]:
master_path =  Path(r"Z:\Hagar\E0040\v31\pre ablation")
fish_list = list(master_path.glob("*f*"))
path = fish_list[3]
print(path)

In [None]:
suite2p_brain = fl.load(path / "data_from_suite2p_cells_brain.h5")
in_brain_idx = suite2p_brain['coords_idx']

In [None]:
traces = fl.load(path / "filtered_traces.h5", "/detr")[:, in_brain_idx]
suite2p_data = fl.load(path / "data_from_suite2p_cells.h5")
coords = suite2p_data['coords'][in_brain_idx]
anatomy = suite2p_data['anatomy_stack']

#df = fl.load(path / "bouts_df.h5")
exp = LightsheetExperiment(path)
fs = int(exp.fn)
beh_df = exp.behavior_log

In [None]:
t = np.arange(np.shape(traces)[1]) / fs
len_rec, num_traces = np.shape(traces)
print("num_traces: ", num_traces)
print("len_rec: ", len_rec)
print("sampling rate: ", fs)

In [None]:
regs = fl.load(path / "sensory_regressors.h5", "/regressors")#[0]
right = np.asarray(regs.iloc[:, 0])
left = np.asarray(regs.iloc[:, 4])

num_traces = np.shape(traces)[1]

right_corr = np.zeros((num_traces))
left_corr = np.zeros((num_traces))
for i in range(num_traces):
    right_corr[i] = np.corrcoef(right, traces[:, i])[0,1]
    left_corr[i] = np.corrcoef(left, traces[:, i])[0,1]

In [None]:
thresh = 0.5
right_tuned = np.where(np.abs(right_corr) > thresh)[0]
print(np.shape(right_tuned))
n_right_tuned = np.shape(right_tuned)[0]

left_tuned = np.where(np.abs(left_corr) > thresh)[0]
print(np.shape(left_tuned))
n_left_tuned = np.shape(left_tuned)[0]

In [None]:
left_traces = traces[:, left_tuned].T
right_traces = traces[:, right_tuned].T
print(np.shape(left_traces))

In [None]:
#### Getting a list of stimuli order: 0=right, 7=right-up
#theta = fl.load(path / "sensory_regressors.h5", "/theta")
#np.unique(theta)
pause_duration = int(exp['stimulus']['protocol']['E0040_motions_cardinal']['v31_8dir_plus_hd']['pause_duration']) * fs
stim_duration = int(exp['stimulus']['protocol']['E0040_motions_cardinal']['v31_8dir_plus_hd']['moving_duration']) * fs

left_diff = np.diff(left)
right_diff = np.diff(right)
#left_start = argrelextrema(left_diff, np.greater)[0] - stim_duration - pause_duration
#left_end = argrelextrema(left_diff, np.greater)[0] + stim_duration + pause_duration

left_start = find_peaks(left_diff, height=0.1)[0] - stim_duration - pause_duration
left_end = find_peaks(left_diff, height=0.1)[0] + stim_duration + pause_duration
right_start = find_peaks(right_diff, height=0.1)[0] - stim_duration - pause_duration
right_end = find_peaks(right_diff, height=0.1)[0] + stim_duration + pause_duration


#left_start = left_start[left_start >  (20 * fs)]
#left_end = left_end[left_end >  (40 * fs)]

fig, ax = plt.subplots(1,1)
#ax.plot(left)
ax.plot(np.diff(left))
ax.scatter(left_start, np.ones(np.shape(left_start))*0.15)
ax.scatter(left_end, np.ones(np.shape(left_end))*0.17)

ax.plot(np.diff(right))
ax.scatter(right_start, np.ones(np.shape(right_start))*0.15)
ax.scatter(right_end, np.ones(np.shape(right_start))*0.17)

In [None]:
n_dir=8
n_sessions = 4
num_left_trials = np.shape(left_start)[0]
num_right_trials = np.shape(right_start)[0]
len_segment = (pause_duration + stim_duration) * 2
print(len_segment)

left_trials = np.zeros((n_dir, n_left_tuned, n_sessions, len_segment))
right_trials = np.zeros((n_dir, n_right_tuned, n_sessions, len_segment))

In [None]:
regs_array = np.asarray(regs)
curr_session = np.zeros((n_dir), dtype=int)
for i in range(num_left_trials):
    t1 = left_start[i]
    t2 = t1 + stim_duration
    
    curr_seg = np.nanmean(regs_array[t1:t2], axis=0)
    
    try:
        curr_dir = np.where(curr_seg > 0.1)[0][0]
    
        t1 = left_start[i]
        t2 = t1 + len_segment
    
        if curr_session[curr_dir] < n_sessions:
            left_trials[curr_dir, :, curr_session[curr_dir], :] = left_traces[:, t1:t2]
            curr_session[curr_dir] += 1
    except:
        print("Stupid trial")
        
left_trials[left_trials == 0] = 'nan'

In [None]:
n_col=4
titles = ['right', 'backward right', 'backward', 'backward left', 'left', 'forward left', 'forward', 'forward right']
fig1, ax1 = plt.subplots(2,n_col, figsize=(10,4))
for i in range(8):
    r = i // n_col
    c = np.mod(i, n_col)
    ax1[r,c].imshow(np.nanmean(left_trials[i], axis=1), cmap='coolwarm', vmin=-1, vmax=2, extent=[0,50,0,50])
    ax1[r,c].set_title(titles[i])
    ax1[r,c].axis('off')

fig1.suptitle('Leftward tuned (n=' + str(n_left_tuned) + ')')

In [None]:
file_name = "leftward tuned history dependence v31" + str(thresh) + ".jpg"
fig1.savefig(path / file_name, dpi=300)

In [None]:
curr_session = np.zeros((n_dir), dtype=int)
for i in range(num_right_trials):
    t1 = right_start[i]
    t2 = t1 + stim_duration
    try:
        curr_seg = np.nanmean(regs_array[t1:t2], axis=0)
        curr_dir = np.where(curr_seg > 0.1)[0][0]

        t1 = right_start[i]
        t2 = t1 + len_segment
    
    
        if curr_session[curr_dir] < n_sessions:
            right_trials[curr_dir, :, curr_session[curr_dir], :] = right_traces[:, t1:t2]
            curr_session[curr_dir] += 1
    except:
        print("Stupid trial")
        
right_trials[right_trials == 0] = 'nan'

In [None]:
fig2, ax2 = plt.subplots(2,n_col, figsize=(10,4))
for i in range(n_dir):
    r = i // n_col
    c = np.mod(i, n_col)
    ax2[r,c].imshow(np.nanmean(right_trials[i], axis=1), cmap='coolwarm', vmin=-1, vmax=2, extent=[0,50,0,50])
    ax2[r,c].set_title(titles[i])
    ax2[r,c].axis('off')

fig2.suptitle('Rightward tuned (n=' + str(n_right_tuned) + ')')

In [None]:
file_name = "rightward tuned history dependence v31" + str(thresh) + ".jpg"
fig2.savefig(path / file_name, dpi=300)

In [None]:
####### Concatenate average responses and cluster
left_trials_concat = np.zeros((n_left_tuned, len_segment * n_dir))

for i in range(n_dir):
    lef_trials_avg = np.nanmean(left_trials[i], axis=1)
    t1 = i * len_segment
    t2 = t1 + len_segment
    left_trials_concat[:, t1:t2] = lef_trials_avg

    
    
    
right_trials_concat = np.zeros((n_right_tuned, len_segment * n_dir))

for i in range(n_dir):
    right_trials_avg = np.nanmean(right_trials[i], axis=1)
    t1 = i * len_segment
    t2 = t1 + len_segment
    right_trials_concat[:, t1:t2] = right_trials_avg


In [None]:
fig3, ax3 = plt.subplots(2,1, figsize=(10,4))
ax3[0].imshow(left_trials_concat, cmap='coolwarm', vmin=-1, vmax=2,extent=[0,50,0,10])
ax3[1].imshow(right_trials_concat, cmap='coolwarm', vmin=-1, vmax=2,extent=[0,50,0,10])

In [None]:
file_name = "right anf lest history dependence v31" + str(thresh) + ".jpg"
fig3.savefig(path / file_name, dpi=300)

In [None]:
d = {
    'concat_reordered_left_tuned_avg': left_trials_concat,
     'concat_reordered_right_tuned_avg': right_trials_concat,
    'reordered_trials_left_tuned': left_trials,
    'reordered_trials_right_tuned': right_trials,
}
fl.save(path / 'reordered_traces.h5', d)

In [None]:
norm_traces = right_trials_concat

In [None]:
# clustering 
linked = linkage(norm_traces, method='ward')
dend = dendrogram(linked)

In [None]:
k = 8
cluster = AgglomerativeClustering(n_clusters=k, affinity='euclidean', linkage='ward')
her_clustering = cluster.fit_predict(norm_traces)

In [None]:
plt.figure(figsize=(0.1, 0.1))  
n_clust = k
dendro = dendrogram(linked, n_clust, truncate_mode ="lastp")
plt.close()
cluster_ids = dendro["leaves"]
labels = find_trunc_dendro_clusters(linked, dendro)

In [None]:
meanresps = norm_traces
base_sub_mean = (meanresps.T - np.nanmean(meanresps[:,:8], 1)).T
X = base_sub_mean

In [None]:
stim = np.asarray([[1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7]])
fig_clust = plot_clusters_dendro(meanresps, stim, linked, labels)#, dendrolims=(940, 0))

In [None]:
file_name = 'hierarchical clustering right tuned neurons concat k' + str(k) + 'thresh ' + str(thresh) +  '.jpg'
fig_clust.savefig(str(path / file_name), dpi=300)

In [None]:
coords_tuned = coords[right_tuned]

In [None]:
z_res = 10
fig2, ax2 = fig, axs = plt.subplots(2, 2, figsize=(8, 5), gridspec_kw={'width_ratios': [3, 1], 'height_ratios': [1, 3]})
color_list = ["lightblue", "lightcoral", "orange", "springgreen", "deepskyblue", "mediumpurple","gold", "cyan", "crimson",
              "deeppink", "lawngreen", "darkviolet", "Darkgreen", "blue", "brown", "dodgerblue", "hotpink", "OliveDrab", "gray", "seagreen"][0:k]


ax2[1,0].scatter(coords[:,2]*0.6, coords[:,1]*.6, c='lightgray', s=2)
ax2[1,1].scatter(coords[:,0]*z_res, coords[:,1]*0.6, c='lightgray')
ax2[0,0].scatter(coords[:,2]*0.6, coords[:,0]*z_res, c='lightgray')
    
for i in range(k): 
    tmp_coords = np.where(labels == i)[0]
    ax2[1,0].scatter(coords_tuned[tmp_coords, 2]*0.6, coords_tuned[tmp_coords, 1]*.6, c=color_list[i], s=4)
    ax2[1,1].scatter(coords_tuned[tmp_coords, 0]*z_res, coords_tuned[tmp_coords, 1]*0.6, c=color_list[i])
    ax2[0,0].scatter(coords_tuned[tmp_coords, 2]*0.6, coords_tuned[tmp_coords, 0]*z_res, c=color_list[i])

    
for i in range(2):
    ax2[i,0].spines['right'].set_visible(False)
    ax2[i,0].spines['top'].set_visible(False)
    ax2[0,i].spines['right'].set_visible(False)
    ax2[0,i].spines['top'].set_visible(False)
    
ax2[0,1].axis('off')

In [None]:
file_name = 'hrc right tuned k' + str(k) + 'thresh ' + str(thresh) + '.jpg'
fig2.savefig(str(path / file_name), dpi=300)

In [None]:
# getting stimulus information
interp_theta = fl.load(path / "sensory_regressors.h5", "/individual_theta_interp")
trial_duration = 10 # sec 
pause_duration = 10 # sec

In [None]:
# start by choosing only left/ right tuned neurons 

left reg = 
left_right = 

In [None]:
# getting the timing of trial start for each of the 8 direction
# getting the number of trials 

In [None]:
######################### Part 2 - looking for neurons that reliably respond to the visual stimulus
# selectnig reliable neruons 

stim_traces = np.copy(traces)
new_len_rec = len_rec // (2 * n_blocks)
print(np.shape(stim_traces))

trial_traces = np.zeros((n_blocks, num_traces, new_len_rec))
trial_traces = np.zeros((num_traces, n_blocks, new_len_rec))

for i in range(n_blocks):
    t1 = i * new_len_rec
    t2 = t1 + new_len_rec
    trial_traces[:, i] = stim_traces[:, t1:t2]