In [None]:
%matplotlib widget

In [None]:
from pathlib import Path
import numpy as np
import flammkuchen as fl
import pandas as pd
import tifffile as tiff

from fimpylab import LightsheetExperiment

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
sns.set(style="ticks", palette="deep")
cols = sns.color_palette()
import ipywidgets as widgets

from lotr.utils import zscore
from lotr.pca import pca_and_phase, get_fictive_heading, fictive_heading_and_fit, \
        fit_phase_neurons,qap_sorting_and_phase
from circle_fit import hyper_fit
from lotr.experiment_class import LotrExperiment
import json

from lotr.plotting.color_utils import get_n_colors
from scipy.interpolate import interp1d
from scipy import signal

In [None]:
from lotr.default_vals import REGRESSOR_TAU_S, TURN_BIAS
master =  Path(r"\\Funes\Shared\experiments\E0071_lotr\full_ring")
files_clol = list(master.glob("*/*_f*_clol"))

all_rois_clol = 0
num_fish_clol = 0
num_fish = 7
fish_inds = np.zeros(num_fish)
num_rois = np.zeros(num_fish)
count = 0
for i in range(len(files_clol)):
    path = files_clol[i]
    traces = fl.load(path / "filtered_traces.h5", "/detr").T
    num_traces = np.shape(traces)[0]
    try:
        selected = fl.load(path / "selected.h5")
        num_traces = np.shape(traces)[0]
        
        if all_rois_clol is 0:
            all_rois_clol = traces
        else:
            all_rois_clol = np.append(all_rois_clol, traces, axis=0)
        num_fish_clol += 1
        
        num_rois[i-count] = num_traces 
        fish_inds[i-count] = i
    except:
        count += 1

print(num_fish_clol)
print(num_rois)
print(fish_inds)

In [None]:
from scipy.cluster.hierarchy import dendrogram, linkage, cut_tree, to_tree, set_link_color_palette
from sklearn.cluster import AgglomerativeClustering

In [None]:
def cluster_id_search(tree):
    nodes_list = []
    if tree.is_leaf():
        nodes_list.append(tree.get_id())
    else:
        nodes_list += cluster_id_search(tree.get_left())
        nodes_list += cluster_id_search(tree.get_right())
        
    return nodes_list

def find_trunc_dendro_clusters(linkage_mat, dendro):
    tree, branches = to_tree(linkage_mat, rd =True)
    ids = np.empty(linkage_mat.shape[0]+1, dtype=int)
    
    for i, clust in enumerate(dendro["leaves"]):
        ids[cluster_id_search(branches[clust])] = i
        
    return ids

In [None]:
# clustering of traces to see if the responses are more motor (fish specific) or visual (similar across fish):

n_clust = 14
fig1 = plt.figure(figsize=(10,7))
linked = linkage(all_rois_clol, method='ward')
dend = dendrogram(linked)

cluster = AgglomerativeClustering(n_clusters=n_clust, affinity='euclidean', linkage='ward')
her_clustering = cluster.fit_predict(all_rois_clol)


In [None]:
# make truncated tree to get clusters ids. 
# Ugly but necessary to get the same sequence of leaves as the cut.
plt.figure(figsize=(0.1, 0.1))  
dendro = dendrogram(linked, n_clust, truncate_mode ="lastp")
plt.close()
cluster_ids = dendro["leaves"]
labels = find_trunc_dendro_clusters(linked, dendro)

In [None]:
meanresps = all_rois_clol
base_sub_mean = (meanresps.T - np.nanmean(meanresps[:,:], 1)).T
X = base_sub_mean

In [None]:
from skimage import color
import numpy as np
from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram, cut_tree, set_link_color_palette
#import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import cm

In [None]:
def shade_plot(stim, ax=None, gamma=1/6, shade_range=(0.6, 0.9)):
    if type(stim) == list:  # these would be transitions
        _shade_plot(stim, ax=ax, gamma=gamma, shade_range=shade_range)

    elif type(stim) == Data:  # fish data
        transitions = find_transitions(Data.resampled_stim, Data.time_im_rep)
        _shade_plot(transitions, ax=ax, gamma=gamma, shade_range=shade_range)

    elif type(stim) == np.ndarray:  # stimulus array
        transitions = find_transitions(stim[:,1], stim[:,0])
        _shade_plot(transitions, ax=ax, gamma=gamma, shade_range=shade_range)

    elif type(stim) == tuple:  # time, lum tuple
        transitions = find_transitions(stim[1], stim[0])
        _shade_plot(transitions, ax=ax, gamma=gamma, shade_range=shade_range)


def _shade_plot(lum_transitions, ax=None, gamma=1/6, shade_range=(0.6, 0.9)):

    if ax is None:
        ax = plt.gca()
    shade = lum_transitions[0][1]
    for i in range(len(lum_transitions)-1):
        shade = shade + lum_transitions[i][1]
        new_shade = shade_range[0] + np.power(np.abs(shade), gamma) * (shade_range[1] - shade_range[0])
        ax.axvspan(lum_transitions[i][0], lum_transitions[i+1][0], color=(new_shade, )*3)
        

def _find_thr(linked, n_clust):
    interval = [0, 2000]
    new_height = np.mean(interval)
    clust = 0
    n_clust = n_clust
    while clust != n_clust:
        new_height = np.mean(interval)
        clust = cut_tree(linked, height=new_height).max()
        if clust > n_clust:
            interval[0] = new_height
        elif clust < n_clust:
            interval[1] = new_height


    return new_height


def find_plot_thr(linked, n_clust):
    min_thr = _find_thr(linked, n_clust - 1)
    return min_thr  


In [None]:
def plot_clusters_dendro(traces, stim, linkage_mat, labels, dendrolims=(900, 30),
                         thr=None, f_lim=1.5, gamma=1):
    fig_clust, ax = plt.subplots(3, 1, figsize=(10, 10))
    hexac = cluster_cols()

    n_clust = labels.max() + 1

    ##################
    ### Dendrogram ###
    # Compute and plot first dendrogram.
    if thr is None:
        thr = find_plot_thr(linkage_mat, n_clust)

    set_link_color_palette(hexac)
    ax_traces = plt.subplot2grid((2, 2), (0, 0))
    ax_clusters = plt.subplot2grid((2, 2), (0, 1))
    ax_dendro = plt.subplot2grid((2, 2), (1, 0), colspan=2)

    #ax_dendro = ax[2]
    #ax_traces = ax[1]
    #ax_clusters = ax[0]
    
    panel_dendro = dendrogram(linkage_mat,
                              color_threshold=thr,
                              #orientation='left',
                              distance_sort='descending',
                              show_leaf_counts=False,
                              no_labels=True,
                              above_threshold_color='#%02x%02x%02x' % (
                              120, 120, 120))
    
    ax_dendro.axhline(thr, linewidth=0.7, color="k")
    ax_dendro.axis("off")

    # Plot traces matrix.
    im = ax_traces.imshow(traces[panel_dendro["leaves"], :],
                         aspect='auto', origin='lower', cmap=cm.RdBu_r,
                         vmin=-f_lim, vmax=f_lim)
    ax_traces.axes.spines['left'].set_visible(False)
    ax_traces.set_yticks([])

    # Time bar:
    dt = stim[1, 0]
    barlength = 10
    bounds = np.array([traces.shape[1] - barlength / dt,
                       traces.shape[1]])

    ##################
    # Cluster sizes ##
    # Calculate size of each defined cluster to put colored labels on the side.
    # Find indervals spanned by each cluster in the sorted traces matrix.
    # Add percentages spanned by each cluster.
    sizes = np.cumsum(np.array([np.sum(labels == i) for i in range(np.max(labels) + 1)]))
    intervals = np.insert(sizes, 0, 0)

    ##################
    # Cluster means ##

    for i in range(n_clust):
        ax_clusters.plot(np.nanmean(traces[labels == i, :], 0) +
                      i * 5, label=i, color=hexac[i])
    ax_clusters.axes.spines['left'].set_visible(False)
    ax_clusters.set_yticks([])

    barlength = 10
    ax_traces.axis("off")
    ax_clusters.axis("off")

    return fig_clust, ax_clusters

def cluster_cols():
    color_list = ["lightblue", "lightcoral", "orange", "springgreen", "deepskyblue", "mediumpurple","gold", "cyan", "crimson", "deeppink", "lawngreen", "darkviolet", "orchid", "limegreen", "seagreen", "chocolate", "blue", "navy"]
    #color_list = ["#cc566a", "#cd6c39", "#a39440", "#64ac48", "#4aac8d", "#688bcd", "#8562cc", "#c361aa"]
    # color_list = ["#ff5c67", "#af0006", "#ffa468", "#8c5f00", "#e4a400", "#d5c86f", "#939400", "#a7d380", "#138b00", "#42e087", "#00a86d", "#81c7a8", "#019a82", "#1eaaff", "#0268bb", "#5951d7", "#6b4570", "#ad20aa", "#ffa1e2", "#ff4a94"]
    return color_list

In [None]:
stim = np.asarray([[1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7]])
fig_clust, ax_clust = plot_clusters_dendro(all_rois_clol, stim, linked, labels)#, dendrolims=(940, 0))

f = files_clol[0]
exp = LotrExperiment(path=f)
    
plt.show()
file_name = 'hierarchical_clustering_k' + str(n_clust) + '_220531.jpg'
fig_clust.savefig(str(master/file_name), dpi=300)

In [None]:
ind2 = np.cumsum(num_rois).astype(int)
ind2

In [None]:
num_rois= num_rois.astype(int)

In [None]:
ind1 = np.zeros(num_fish+1, dtype=int)
for i in range(num_fish):
    ind1[i+1] = ind2[i]
ind1

In [None]:
np.shape(labels)


In [None]:
col_list = ["lightblue", "lightcoral", "orange", "springgreen", "deepskyblue", "mediumpurple","gold", "cyan", "crimson", "deeppink", "lawngreen", "darkviolet", "orchid", "limegreen", "seagreen", "chocolate", "blue", "navy"]

num_row = 3
num_col = 3
fig1, ax1 = plt.subplots(num_row, num_col, figsize=(8, 8))

for i in range(num_row*num_row):
    r = i // num_row
    c = np.mod(i, num_row)
    
    try:
        path = files_clol[fish_inds[i].astype(int)]
        coords = fl.load(path / "data_from_suite2p_unfiltered.h5", "/coords")
        selected = fl.load(path / "selected.h5")
        col_hrc = labels[ind1[i]:ind2[i]]
        
        for j in range(n_clust):
            ind = np.where(col_hrc == j)
            ax1[r, c].scatter(coords[ind, 1], coords[ind, 2], c=col_list[j], cmap='rainbow', s=3)
        ax1[r, c].axis('off')
    except:
        ax1[r, c].axis('off')

    
plt.show()
file_name = 'clusters_hrc_rois_k' + str(n_clust) + '_220531.jpg'
fig1.savefig(str(master/file_name), dpi=300)

In [None]:
# coloring clusters by the same colors as the traces plot 
col_list = ["lightblue", "lightcoral", "orange", "springgreen", "deepskyblue", "mediumpurple","gold", "cyan", "crimson", "deeppink", "lawngreen", "darkviolet", "orchid", "limegreen", "seagreen", "chocolate", "blue", "navy"]

num_row = 3
num_col = 3
fig1, ax1 = plt.subplots(num_row, num_col, figsize=(8, 8))

for i in range(num_row*num_row):
    r = i // num_row
    c = np.mod(i, num_row)
    
    try:
        path = files_clol[fish_inds[i].astype(int)]
        coords = fl.load(path / "data_from_suite2p_unfiltered.h5", "/coords")
        selected = fl.load(path / "selected.h5")
        col_hrc = labels[ind1[i]:ind2[i]]
        
        ax1[r, c].scatter(coords[:, 1], coords[:, 2], c='gray', s=3)
        sel_coords = coords[selected]
        col_hrc
        for j in range(n_clust):
            ind = np.where(col_hrc == j)
            ax1[r, c].scatter(sel_coords[ind, 1], sel_coords[ind, 2], c=col_list[j], cmap='rainbow', s=3)
        ax1[r, c].axis('off')
    except:
        ax1[r, c].axis('off')

    
plt.show()
file_name = 'clusters_hrc_rois_220531_k' + str(n_clust) + '_clustcolors.jpg'
fig1.savefig(str(master/file_name), dpi=300)