In [None]:
%matplotlib widget

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import json
import pandas as pd

import flammkuchen as fl
from split_dataset import SplitDataset
from bouterin.plots.stimulus_log_plot import get_paint_function
from scipy.cluster.hierarchy import dendrogram, linkage, cut_tree, to_tree, set_link_color_palette
from sklearn.cluster import AgglomerativeClustering

from skimage import color
from scipy.cluster.hierarchy import dendrogram, cut_tree, set_link_color_palette
#import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import cm

from sklearn.cluster import AgglomerativeClustering
from fimpylab.core.twop_experiment import TwoPExperiment

In [None]:
def cluster_id_search(tree):
    nodes_list = []
    if tree.is_leaf():
        nodes_list.append(tree.get_id())
    else:
        nodes_list += cluster_id_search(tree.get_left())
        nodes_list += cluster_id_search(tree.get_right())
        
    return nodes_list

def find_trunc_dendro_clusters(linkage_mat, dendro):
    tree, branches = to_tree(linkage_mat, rd =True)
    ids = np.empty(linkage_mat.shape[0]+1, dtype=int)
    
    for i, clust in enumerate(dendro["leaves"]):
        ids[cluster_id_search(branches[clust])] = i
        
    return ids

def shade_plot(stim, ax=None, gamma=1/6, shade_range=(0.6, 0.9)):
    if type(stim) == list:  # these would be transitions
        _shade_plot(stim, ax=ax, gamma=gamma, shade_range=shade_range)

    elif type(stim) == Data:  # fish data
        transitions = find_transitions(Data.resampled_stim, Data.time_im_rep)
        _shade_plot(transitions, ax=ax, gamma=gamma, shade_range=shade_range)

    elif type(stim) == np.ndarray:  # stimulus array
        transitions = find_transitions(stim[:,1], stim[:,0])
        _shade_plot(transitions, ax=ax, gamma=gamma, shade_range=shade_range)

    elif type(stim) == tuple:  # time, lum tuple
        transitions = find_transitions(stim[1], stim[0])
        _shade_plot(transitions, ax=ax, gamma=gamma, shade_range=shade_range)


def _shade_plot(lum_transitions, ax=None, gamma=1/6, shade_range=(0.6, 0.9)):

    if ax is None:
        ax = plt.gca()
    shade = lum_transitions[0][1]
    for i in range(len(lum_transitions)-1):
        shade = shade + lum_transitions[i][1]
        new_shade = shade_range[0] + np.power(np.abs(shade), gamma) * (shade_range[1] - shade_range[0])
        ax.axvspan(lum_transitions[i][0], lum_transitions[i+1][0], color=(new_shade, )*3)
        

def _find_thr(linked, n_clust):
    interval = [0, 2000]
    new_height = np.mean(interval)
    clust = 0
    n_clust = n_clust
    while clust != n_clust:
        new_height = np.mean(interval)
        clust = cut_tree(linked, height=new_height).max()
        if clust > n_clust:
            interval[0] = new_height
        elif clust < n_clust:
            interval[1] = new_height


    return new_height


def find_plot_thr(linked, n_clust):
    min_thr = _find_thr(linked, n_clust - 1)
    return min_thr  


In [None]:
def plot_clusters_dendro(traces, stim, linkage_mat, labels, dendrolims=(900, 30),
                         thr=None, f_lim=1.5, gamma=1, fish_id=""):
    fig_clust, ax = plt.subplots(3, 1, figsize=(15, 15))
    hexac = cluster_cols()

    n_clust = labels.max() + 1

    ##################
    ### Dendrogram ###
    # Compute and plot first dendrogram.
    if thr is None:
        thr = find_plot_thr(linkage_mat, n_clust)

    set_link_color_palette(hexac)
    ax_traces = plt.subplot2grid((2, 2), (0, 0))
    ax_clusters = plt.subplot2grid((2, 2), (0, 1))
    ax_dendro = plt.subplot2grid((2, 2), (1, 0), colspan=2)

    #ax_dendro = ax[2]
    #ax_traces = ax[1]
    #ax_clusters = ax[0]
    
    panel_dendro = dendrogram(linkage_mat,
                              color_threshold=thr,
                              #orientation='left',
                              distance_sort='descending',
                              show_leaf_counts=False,
                              no_labels=True,
                              above_threshold_color='#%02x%02x%02x' % (
                              120, 120, 120))
    
    ax_dendro.axhline(thr, linewidth=0.7, color="k")
    ax_dendro.axis("off")

    # Plot traces matrix.
    im = ax_traces.imshow(traces[panel_dendro["leaves"], :],
                         aspect='auto', origin='lower', cmap="gray_r",
                         vmin=-f_lim, vmax=f_lim)
    ax_traces.axes.spines['left'].set_visible(False)
    ax_traces.set_yticks([])

    # Time bar:
    dt = stim[1, 0]
    barlength = 10
    bounds = np.array([traces.shape[1] - barlength / dt,
                       traces.shape[1]])

    ##################
    # Cluster sizes ##
    # Calculate size of each defined cluster to put colored labels on the side.
    # Find indervals spanned by each cluster in the sorted traces matrix.
    # Add percentages spanned by each cluster.
    sizes = np.cumsum(np.array([np.sum(labels == i) for i in range(np.max(labels) + 1)]))
    intervals = np.insert(sizes, 0, 0)

    ##################
    # Cluster means ##

    for i in range(n_clust):
        ax_clusters.plot(np.nanmean(traces[labels == i, :], 0) +
                      i * 5, label=i, color=hexac[i])
    ax_clusters.axes.spines['left'].set_visible(False)
    ax_clusters.set_yticks([])

    barlength = 10
    ax_traces.axis("off")
    ax_clusters.axis("off")

    return fig_clust, ax_clusters

def cluster_cols():
    # color_list = ["lightblue", "lightcoral", "orange", "springgreen", "deepskyblue", "mediumpurple","gold", "cyan", "crimson", "deeppink", "lawngreen", "darkviolet", "orchid", "limegreen", "seagreen", "chocolate", "blue", "navy"]
    color_list = ["#cc566a", "#cd6c39", "#a39440", "#64ac48", "#4aac8d", "#688bcd", "#8562cc", "#c361aa"]
    # color_list = ["#ff5c67", "#af0006", "#ffa468", "#8c5f00", "#e4a400", "#d5c86f", "#939400", "#a7d380", "#138b00", "#42e087", "#00a86d", "#81c7a8", "#019a82", "#1eaaff", "#0268bb", "#5951d7", "#6b4570", "#ad20aa", "#ffa1e2", "#ff4a94"]
    return color_list

In [None]:
master = Path(r"\\Funes\Shared\experiments\E0040_motions_cardinal\v13_cw_ccw\2p\rf\fixed")
all_fish = list(master.glob("*f[0-9]*"))
fish_dir = all_fish[0]
path = fish_dir / 'suite2p' / '0001'
n_clust = 8

try:
    with open(next(fish_dir.glob("*metadata.json"))) as i:
        metadata = json.load(i)
    fish_id = metadata['general']['fish_id']
except:
    fish_id = ""
print(fish_dir)
print(fish_id)

In [None]:
suite2p_data = fl.load(path / "data_from_suite2p_unfiltered.h5")
traces_all = suite2p_data["traces"]
print(np.shape(traces_all))
    
# normalizing traces:
traces_all = traces_all.T
traces_all = ((traces_all - traces_all.mean(0)) / traces_all.std(0))
norm_traces = traces_all.T
num_rois = np.shape(norm_traces)[0]

In [None]:
fig1 = plt.figure(figsize=(2,2))
linked = linkage(norm_traces, method='ward')
dend = dendrogram(linked)
plt.show()

In [None]:
cluster = AgglomerativeClustering(n_clusters=n_clust, affinity='euclidean', linkage='ward')
her_clustering = cluster.fit_predict(norm_traces)

In [None]:
plt.figure(figsize=(0.1, 0.1))  
dendro = dendrogram(linked, n_clust, truncate_mode ="lastp")
plt.close()
cluster_ids = dendro["leaves"]
labels = find_trunc_dendro_clusters(linked, dendro)

In [None]:
meanresps = norm_traces
base_sub_mean = (meanresps.T - np.nanmean(meanresps[:,:8], 1)).T
X = base_sub_mean

In [None]:
stim = np.asarray([[1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7]])
fig_clust, ax_clust = plot_clusters_dendro(norm_traces, stim, linked, labels, fish_id=fish_id)#, dendrolims=(940, 0))

f = fish_dir
print(f)
#exp = TwoPExperiment(path=f)
fs = 3
#stimulus_log = exp.stimulus_log(0)
#stim_value, t_values = get_paint_function(stimulus_log, 'E0040_motions_cardinal')
#stim_value = stim_value / 255
#num_stim = np.shape(stim_value)[0]
'''
t_values *= fs
for i in range(num_stim):
    ax_clust.axvspan(
        t_values[i, 0],
        t_values[i, 1],
        facecolor=[
            stim_value[i, 0],
            stim_value[i, 1],
            stim_value[i, 2],
        ],
        alpha=0.3,
    )
'''    
plt.suptitle(fish_id)
plt.show()
file_name = 'individual_hrc_k' + str(n_clust) + '_' + fish_id + '_full.jpg'
fig_clust.savefig(str(path
                      /file_name))