In [None]:
import matplotlib
%matplotlib widget
import matplotlib.pyplot as plt
import h5py
import numpy as np
from pathlib import Path
import skimage as sk
from skimage import io as skio
import json
import pandas as pd

import flammkuchen as fl
from split_dataset import SplitDataset
from bouterin.plots.stimulus_log_plot import get_paint_function
from scipy.cluster.hierarchy import dendrogram, linkage, cut_tree, to_tree, set_link_color_palette
from sklearn.cluster import AgglomerativeClustering
from fimpylab.core.lightsheet_experiment import LightsheetExperiment


In [None]:
master = Path(r"//Funes/Shared/experiments/E0040_motions_cardinal/v13_cw_ccw/ls_fixed")
fish_list = list(master.glob("*f[0-9[]*"))
fish_dir = fish_list[0]

try:
    with open(next(fish_dir.glob("*metadata.json"))) as i:
        metadata = json.load(i)
    fish_id = metadata['general']['fish_id']
except:
    fish_id = ""
print(fish_dir)
print(fish_id)

## Cluster all traces:

In [None]:
# loading all traces

#dir_traces = fish_dir / "suite2p/combined"
#cell_ind = np.load(dir_traces / 'iscell.npy')
#all_traces = np.load(dir_traces / 'F.npy')
#traces_all = all_traces[cell_ind[:,0]==1]

suite2p_data = fl.load(fish_dir / "data_from_suite2p_cells.h5")
traces_all = suite2p_data["traces"]
print(np.shape(traces_all))
    
# normalizing traces:
traces_all = traces_all.T
traces_all = ((traces_all - traces_all.mean(0)) / traces_all.std(0))
norm_traces = traces_all.T
num_rois = np.shape(norm_traces)[0]

In [None]:
np.shape(np.where(np.isnan(norm_traces)))
#np.where(np.isnan(norm_traces))
print(np.shape(norm_traces))

In [None]:
########## remove bad traces from lasr fish
norm_traces[np.where(np.isnan(norm_traces))] = 0

In [None]:
print(np.shape(np.where(np.isnan(norm_traces))))
print(np.where(np.isnan(norm_traces)))

In [None]:
fig1 = plt.figure(figsize=(10,7))
linked = linkage(norm_traces, method='ward')
dend = dendrogram(linked)
plt.show()
#fig1.savefig(str(master / "dendrogrm_210104.jpg"))

In [None]:
cluster = AgglomerativeClustering(n_clusters=12, affinity='euclidean', linkage='ward')
her_clustering = cluster.fit_predict(norm_traces)

### From Ot & Luigi

In [None]:
def cluster_id_search(tree):
    nodes_list = []
    if tree.is_leaf():
        nodes_list.append(tree.get_id())
    else:
        nodes_list += cluster_id_search(tree.get_left())
        nodes_list += cluster_id_search(tree.get_right())
        
    return nodes_list

def find_trunc_dendro_clusters(linkage_mat, dendro):
    tree, branches = to_tree(linkage_mat, rd =True)
    ids = np.empty(linkage_mat.shape[0]+1, dtype=int)
    
    for i, clust in enumerate(dendro["leaves"]):
        ids[cluster_id_search(branches[clust])] = i
        
    return ids

In [None]:
# make truncated tree to get clusters ids. 
# Ugly but necessary to get the same sequence of leaves as the cut.
plt.figure(figsize=(0.1, 0.1))  
n_clust = 8
dendro = dendrogram(linked, n_clust, truncate_mode ="lastp")
plt.close()
cluster_ids = dendro["leaves"]
labels = find_trunc_dendro_clusters(linked, dendro)

In [None]:
meanresps = norm_traces
base_sub_mean = (meanresps.T - np.nanmean(meanresps[:,:8], 1)).T
X = base_sub_mean
#smooth_mean_resps = pd.DataFrame(meanresps.T).rolling(4, center=True).mean().as_matrix().T


In [None]:
from skimage import color
import numpy as np
from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram, cut_tree, set_link_color_palette
#import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import cm

In [None]:
def shade_plot(stim, ax=None, gamma=1/6, shade_range=(0.6, 0.9)):
    if type(stim) == list:  # these would be transitions
        _shade_plot(stim, ax=ax, gamma=gamma, shade_range=shade_range)

    elif type(stim) == Data:  # fish data
        transitions = find_transitions(Data.resampled_stim, Data.time_im_rep)
        _shade_plot(transitions, ax=ax, gamma=gamma, shade_range=shade_range)

    elif type(stim) == np.ndarray:  # stimulus array
        transitions = find_transitions(stim[:,1], stim[:,0])
        _shade_plot(transitions, ax=ax, gamma=gamma, shade_range=shade_range)

    elif type(stim) == tuple:  # time, lum tuple
        transitions = find_transitions(stim[1], stim[0])
        _shade_plot(transitions, ax=ax, gamma=gamma, shade_range=shade_range)


def _shade_plot(lum_transitions, ax=None, gamma=1/6, shade_range=(0.6, 0.9)):

    if ax is None:
        ax = plt.gca()
    shade = lum_transitions[0][1]
    for i in range(len(lum_transitions)-1):
        shade = shade + lum_transitions[i][1]
        new_shade = shade_range[0] + np.power(np.abs(shade), gamma) * (shade_range[1] - shade_range[0])
        ax.axvspan(lum_transitions[i][0], lum_transitions[i+1][0], color=(new_shade, )*3)
        

def _find_thr(linked, n_clust):
    interval = [0, 2000]
    new_height = np.mean(interval)
    clust = 0
    n_clust = n_clust
    while clust != n_clust:
        new_height = np.mean(interval)
        clust = cut_tree(linked, height=new_height).max()
        if clust > n_clust:
            interval[0] = new_height
        elif clust < n_clust:
            interval[1] = new_height


    return new_height


def find_plot_thr(linked, n_clust):
    min_thr = _find_thr(linked, n_clust - 1)
    return min_thr  


In [None]:
def plot_clusters_dendro(traces, stim, linkage_mat, labels, dendrolims=(900, 30),
                         thr=None, f_lim=1.5, gamma=1, fish_id=""):
    fig_clust, ax = plt.subplots(3, 1, figsize=(15, 15))
    hexac = cluster_cols()

    n_clust = labels.max() + 1

    ##################
    ### Dendrogram ###
    # Compute and plot first dendrogram.
    if thr is None:
        thr = find_plot_thr(linkage_mat, n_clust)

    set_link_color_palette(hexac)
    ax_traces = plt.subplot2grid((2, 2), (0, 0))
    ax_clusters = plt.subplot2grid((2, 2), (0, 1))
    ax_dendro = plt.subplot2grid((2, 2), (1, 0), colspan=2)

    #ax_dendro = ax[2]
    #ax_traces = ax[1]
    #ax_clusters = ax[0]
    
    panel_dendro = dendrogram(linkage_mat,
                              color_threshold=thr,
                              #orientation='left',
                              distance_sort='descending',
                              show_leaf_counts=False,
                              no_labels=True,
                              above_threshold_color='#%02x%02x%02x' % (
                              120, 120, 120))
    
    ax_dendro.axhline(thr, linewidth=0.7, color="k")
    ax_dendro.axis("off")

    # Plot traces matrix.
    im = ax_traces.imshow(traces[panel_dendro["leaves"], :],
                         aspect='auto', origin='lower', cmap=cm.RdBu_r,
                         vmin=-f_lim, vmax=f_lim)
    ax_traces.axes.spines['left'].set_visible(False)
    ax_traces.set_yticks([])

    # Time bar:
    dt = stim[1, 0]
    barlength = 10
    bounds = np.array([traces.shape[1] - barlength / dt,
                       traces.shape[1]])

    ##################
    # Cluster sizes ##
    # Calculate size of each defined cluster to put colored labels on the side.
    # Find indervals spanned by each cluster in the sorted traces matrix.
    # Add percentages spanned by each cluster.
    sizes = np.cumsum(np.array([np.sum(labels == i) for i in range(np.max(labels) + 1)]))
    intervals = np.insert(sizes, 0, 0)

    ##################
    # Cluster means ##

    for i in range(n_clust):
        ax_clusters.plot(np.nanmean(traces[labels == i, :], 0) +
                      i * 5, label=i, color=hexac[i])
    ax_clusters.axes.spines['left'].set_visible(False)
    ax_clusters.set_yticks([])

    barlength = 10
    ax_traces.axis("off")
    ax_clusters.axis("off")

    return fig_clust, ax_clusters

def cluster_cols():
    # color_list = ["lightblue", "lightcoral", "orange", "springgreen", "deepskyblue", "mediumpurple","gold", "cyan", "crimson", "deeppink", "lawngreen", "darkviolet", "orchid", "limegreen", "seagreen", "chocolate", "blue", "navy"]
    color_list = ["#cc566a", "#cd6c39", "#a39440", "#64ac48", "#4aac8d", "#688bcd", "#8562cc", "#c361aa"]
    # color_list = ["#ff5c67", "#af0006", "#ffa468", "#8c5f00", "#e4a400", "#d5c86f", "#939400", "#a7d380", "#138b00", "#42e087", "#00a86d", "#81c7a8", "#019a82", "#1eaaff", "#0268bb", "#5951d7", "#6b4570", "#ad20aa", "#ffa1e2", "#ff4a94"]
    return color_list

In [None]:
stim = np.asarray([[1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7]])
fig_clust, ax_clust = plot_clusters_dendro(norm_traces, stim, linked, labels, fish_id=fish_id)#, dendrolims=(940, 0))

f = fish_dir
print(f)
exp = LightsheetExperiment(path=f)
fs = 3
stimulus_log = exp.stimulus_log
stim_value, t_values = get_paint_function(stimulus_log, 'E0040_motions_cardinal')
stim_value = stim_value / 255
num_stim = np.shape(stim_value)[0]

t_values *= fs
for i in range(num_stim):
    ax_clust.axvspan(
        t_values[i, 0],
        t_values[i, 1],
        facecolor=[
            stim_value[i, 0],
            stim_value[i, 1],
            stim_value[i, 2],
        ],
        alpha=0.3,
    )
    
plt.suptitle(fish_id)
plt.show()
file_name = 'individual_hrc_210716_k_spont_plus_v13' + str(n_clust) + fish_id + '_full.jpg'
fig_clust.savefig(str(fish_dir
                      /file_name))

### Create a figure of all ROIs colored by cluster:

In [None]:
f = fish_dir / "suite2p/combined"

stat = np.load(f / 'stat.npy', allow_pickle=True)
ops = np.load(f / 'ops.npy', allow_pickle=True).item()
iscell = np.load(f / 'iscell.npy')
#print(ops)
with open(str(fish_dir / "original" / "stack_metadata.json")) as f:
    stack_meta = json.load(f)

ny = ops["Ly"]
nx = ops["Lx"]
nz = ops['nplanes']
#ny = stack_meta["shape_full"][2]
#nx = stack_meta["shape_full"][3]
rois = np.zeros((ny, nx))
ncells = len(stat)
pix_value = 1

coords = suite2p_data["coords"]
rois_stack = suite2p_data["rois_stack"]
fig44 = plt.figure()
plt.imshow(rois_stack[7])
'''
print(np.max(rois_stack))
for n in range(0,ncells):
    if iscell[n][0]==1:
        #print(stat[n])
        ypix = stat[n]['ypix']#[~stat[n]['overlap']]
        xpix = stat[n]['xpix']#[~stat[n]['overlap']]
        rois[ypix,xpix] = pix_value
        pix_value += 1
print(pix_value)
'''

In [None]:
print(num_rois)
roi_map = np.copy(rois_stack)
roi_map_clustered = np.zeros_like(roi_map)
for i in range(0, num_rois):
    roi_map_clustered[np.where(roi_map == (i))] = labels[i] + 1


In [None]:
num_rows = 4
num_cols = 5
fig1, ax1 = plt.subplots(num_rows,num_cols,figsize=(12, 12))

color_list = ["lightblue", "lightcoral", "orange", "springgreen", "deepskyblue", "mediumpurple","gold", "cyan", "crimson", "deeppink", "lawngreen", "darkviolet"][0:n_clust]
color_list = ["#ff5c67", "#af0006", "#ffa468", "#8c5f00", "#e4a400", "#d5c86f", "#939400", "#a7d380", "#138b00", "#42e087", "#00a86d", "#81c7a8", "#019a82", "#1eaaff", "#0268bb", "#5951d7", "#6b4570", "#ad20aa", "#ffa1e2", "#ff4a94"][0:n_clust]
color_list = ["#cc566a", "#cd6c39", "#a39440", "#64ac48", "#4aac8d", "#688bcd", "#8562cc", "#c361aa"]
cm_roi = LinearSegmentedColormap.from_list("my_list", color_list, N=n_clust)

for i in range(num_rows*num_cols):
    r = i // num_cols 
    c = np.mod(i, num_cols)
    
    try:
        roi_layer = roi_map_clustered[i]
        roi_layer[1:(n_clust+1), 0] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20][0:n_clust]
        roi_layer = np.ma.masked_where(roi_layer < 1, roi_layer)
        roi_layer = np.rot90(roi_layer, k=-1, axes=(1, 0))

        ax1[r, c].imshow(roi_layer, cmap=cm_roi)#rainbow")
        ax1[r, c].set_title(fish_id)
        ax1[r, c].axis('off')
        ax1[r, c].set_title('z' + str(i))
    except:
        ax1[r, c].axis('off')

fig1.suptitle(fish_id)
plt.show()
file_name = 'individual_clusters_hrc_rois_210716_k' + str(n_clust) + fish_id + '_full_spont_plus_v13.jpg'
fig1.savefig(str(fish_dir/file_name))

In [None]:
num_rows = 1
num_cols = 1
fig1, ax1 = plt.subplots(num_rows,num_cols,figsize=(6, 6))

color_list = ["lightblue", "lightcoral", "orange", "springgreen", "deepskyblue", "mediumpurple","gold", "cyan", "crimson", "deeppink", "lawngreen", "darkviolet"][0:n_clust]
color_list = ["#ff5c67", "#af0006", "#ffa468", "#8c5f00", "#e4a400", "#d5c86f", "#939400", "#a7d380", "#138b00", "#42e087", "#00a86d", "#81c7a8", "#019a82", "#1eaaff", "#0268bb", "#5951d7", "#6b4570", "#ad20aa", "#ffa1e2", "#ff4a94"][0:n_clust]
color_list = ["#cc566a", "#cd6c39", "#a39440", "#64ac48", "#4aac8d", "#688bcd", "#8562cc", "#c361aa"]
cm_roi = LinearSegmentedColormap.from_list("my_list", color_list, N=n_clust)

for i in range(nz):
    roi_layer = roi_map_clustered[i]
    roi_layer[1:(n_clust+1), 0] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20][0:n_clust]
    roi_layer = np.ma.masked_where(roi_layer < 1, roi_layer)
    roi_layer = np.rot90(roi_layer, k=-1, axes=(1, 0))

    ax1.imshow(roi_layer, cmap=cm_roi)
    ax1.set_title(fish_id)
    ax1.axis('off')

plt.show()
file_name = 'individual_clusters_hrc_rois_210716_k' + str(n_clust) + fish_id + '_full_spont_plus_v13_overlay.jpg'
fig1.savefig(str(fish_dir/file_name))