In [1]:
import matplotlib
%matplotlib widget
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import skimage as sk
from skimage import io as skio
import json
import pandas as pd

import flammkuchen as fl
from split_dataset import SplitDataset
from bouterin.plots.stimulus_log_plot import get_paint_function
from scipy.cluster.hierarchy import dendrogram, linkage, cut_tree, to_tree, set_link_color_palette
from sklearn.cluster import AgglomerativeClustering

In [25]:
master = Path(r"\\Funes\Shared\experiments\E0040_motions_cardinal\v13_cw_ccw\2p\rf\fixed")
all_fish = list(master.glob("*f[0-9]*"))
fish_dir = all_fish[1]
print(fish_dir)

\\Funes\Shared\experiments\E0040_motions_cardinal\v13_cw_ccw\2p\rf\fixed\210602_f1


## Load ROIs:

In [26]:
num_planes = len(list(master.glob("*f[0-9]*")))
rois = fl.load(fish_dir / 'merged_rois.h5')
roi_map = rois['stack']
num_planes = np.shape(roi_map)[0]
print("num planes:", num_planes)

traces = fl.load(fish_dir / "traces.h5")['traces']
num_rois, len_rec = np.shape(traces)
print("num ROIs:", num_rois)

num planes: 10
num ROIs: 508


In [27]:
fig0, ax0 = plt.subplots(3, 4, figsize=(12, 12))
num_rois = 0
for i in range(num_planes):
    r = i // 4
    c = np.mod(i, 4)
    
    roi_layer = roi_map[i]
    roi_layer = np.ma.masked_where(roi_layer < 1, roi_layer)
    roi_layer = np.rot90(roi_layer, k=1, axes=(1, 0))

    ax0[r, c].imshow(roi_layer, cmap="rainbow")
    num_rois += np.shape(np.unique(roi_layer))[0] - 1
    print(num_rois)
    
plt.show()
#fig0.savefig(str(master/'all_rois.jpg'))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

74
147
220
291
371
414
450
474
493
508


## Cluster all traces:

In [28]:
norm_traces = np.copy(traces)
norm_traces=norm_traces.T
sd=np.nanstd(norm_traces, 0)
mean=np.nanmean(norm_traces, 0)
norm_traces=norm_traces-mean 
norm_traces=norm_traces/sd
norm_traces=norm_traces.T
print(np.shape(norm_traces))

(508, 2297)


In [29]:
fig_xx, ax_xx = plt.subplots(1, 2, figsize=(10,7))
ax_xx[0].imshow(norm_traces)
ax_xx[1].imshow(traces)
plt.show()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [30]:
##### Optional - average across repetitions before clustering:

trial_len = 720 # 480
norm_traces_avg = np.zeros((3, num_rois, trial_len))

for i in range(3):
    t1 = i * trial_len
    t2 = t1 + trial_len
    norm_traces_avg[i] = norm_traces[:, t1:t2]
norm_traces_avg = np.nanmean(norm_traces_avg, 0)
norm_traces = norm_traces_avg

In [31]:
fig1 = plt.figure(figsize=(10,7))
linked = linkage(norm_traces, method='ward')
dend = dendrogram(linked)
plt.show()
# fig1.savefig(str(master / "dendrogrm_210104.jpg"))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [32]:
cluster = AgglomerativeClustering(n_clusters=12, affinity='euclidean', linkage='ward')
her_clustering = cluster.fit_predict(norm_traces)

### From Ot & Luigi

In [33]:
def cluster_id_search(tree):
    nodes_list = []
    if tree.is_leaf():
        nodes_list.append(tree.get_id())
    else:
        nodes_list += cluster_id_search(tree.get_left())
        nodes_list += cluster_id_search(tree.get_right())
        
    return nodes_list

def find_trunc_dendro_clusters(linkage_mat, dendro):
    tree, branches = to_tree(linkage_mat, rd =True)
    ids = np.empty(linkage_mat.shape[0]+1, dtype=int)
    
    for i, clust in enumerate(dendro["leaves"]):
        ids[cluster_id_search(branches[clust])] = i
        
    return ids

In [34]:
# make truncated tree to get clusters ids. 
# Ugly but necessary to get the same sequence of leaves as the cut.
plt.figure(figsize=(0.1, 0.1))  
n_clust = 12
dendro = dendrogram(linked, n_clust, truncate_mode ="lastp")
plt.close()
cluster_ids = dendro["leaves"]
labels = find_trunc_dendro_clusters(linked, dendro)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [35]:
meanresps = norm_traces
base_sub_mean = (meanresps.T - np.nanmean(meanresps[:,:8], 1)).T
X = base_sub_mean

#smooth_mean_resps = pd.DataFrame(meanresps.T).rolling(4, center=True).mean().as_matrix().T

num_traces = np.shape(norm_traces)[0]
traces_fixed = np.copy(norm_traces)
traces_fixed[np.where(traces_fixed == 0)[0]] = None
for i in range(num_traces):
    tmp_cluster = traces_fixed[i]
    tmp_cluster[np.where(tmp_cluster <= (np.min(tmp_cluster)+0.1))[0]] = None
    #tmp_cluster = tmp_cluster + (i*5)
    
#meanresps = traces_fixed


In [36]:
from skimage import color
import numpy as np
from matplotlib import pyplot as plt
# from luminance_analysis import Data
#from luminance_analysis.utilities import find_transitions, nanzscore
from scipy.cluster.hierarchy import dendrogram, cut_tree, set_link_color_palette
#from luminance_analysis.clustering import find_trunc_dendro_clusters
#import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import cm

In [37]:
def shade_plot(stim, ax=None, gamma=1/6, shade_range=(0.6, 0.9)):
    if type(stim) == list:  # these would be transitions
        _shade_plot(stim, ax=ax, gamma=gamma, shade_range=shade_range)

    elif type(stim) == Data:  # fish data
        transitions = find_transitions(Data.resampled_stim, Data.time_im_rep)
        _shade_plot(transitions, ax=ax, gamma=gamma, shade_range=shade_range)

    elif type(stim) == np.ndarray:  # stimulus array
        transitions = find_transitions(stim[:,1], stim[:,0])
        _shade_plot(transitions, ax=ax, gamma=gamma, shade_range=shade_range)

    elif type(stim) == tuple:  # time, lum tuple
        transitions = find_transitions(stim[1], stim[0])
        _shade_plot(transitions, ax=ax, gamma=gamma, shade_range=shade_range)


def _shade_plot(lum_transitions, ax=None, gamma=1/6, shade_range=(0.6, 0.9)):

    if ax is None:
        ax = plt.gca()
    shade = lum_transitions[0][1]
    for i in range(len(lum_transitions)-1):
        shade = shade + lum_transitions[i][1]
        new_shade = shade_range[0] + np.power(np.abs(shade), gamma) * (shade_range[1] - shade_range[0])
        ax.axvspan(lum_transitions[i][0], lum_transitions[i+1][0], color=(new_shade, )*3)
        

def _find_thr(linked, n_clust):
    interval = [0, 2000]
    new_height = np.mean(interval)
    clust = 0
    n_clust = n_clust
    while clust != n_clust:
        new_height = np.mean(interval)
        clust = cut_tree(linked, height=new_height).max()
        if clust > n_clust:
            interval[0] = new_height
        elif clust < n_clust:
            interval[1] = new_height


    return new_height


def find_plot_thr(linked, n_clust):
    min_thr = _find_thr(linked, n_clust - 1)
    return min_thr  


In [38]:
def plot_clusters_dendro(traces, stim, linkage_mat, labels, dendrolims=(900, 30),
                         thr=None, f_lim=1.5, gamma=1):
    fig_clust, ax = plt.subplots(3, 1, figsize=(15, 15))
    hexac = cluster_cols()

    n_clust = labels.max() + 1

    ##################
    ### Dendrogram ###
    # Compute and plot first dendrogram.
    if thr is None:
        thr = find_plot_thr(linkage_mat, n_clust)

    set_link_color_palette(hexac)
    
    ax_dendro = ax[2]
    ax_traces = ax[1]
    ax_clusters = ax[0]
    
    panel_dendro = dendrogram(linkage_mat,
                              color_threshold=thr,
                              #orientation='left',
                              distance_sort='descending',
                              show_leaf_counts=False,
                              no_labels=True,
                              above_threshold_color='#%02x%02x%02x' % (
                              120, 120, 120))
    
    ax_dendro.axhline(thr, linewidth=0.7, color="k")
    ax_dendro.axis("off")

    # Plot traces matrix.
    im = ax_traces.imshow(traces[panel_dendro["leaves"], :],
                         aspect='auto', origin='lower', cmap="gray_r",
                         vmin=-f_lim, vmax=f_lim)
    ax_traces.axes.spines['left'].set_visible(False)
    ax_traces.set_yticks([])

    # Time bar:
    dt = stim[1, 0]
    barlength = 10
    bounds = np.array([traces.shape[1] - barlength / dt,
                       traces.shape[1]])

    ##################
    # Cluster sizes ##
    # Calculate size of each defined cluster to put colored labels on the side.
    # Find indervals spanned by each cluster in the sorted traces matrix.
    # Add percentages spanned by each cluster.
    sizes = np.cumsum(np.array([np.sum(labels == i) for i in range(np.max(labels) + 1)]))
    intervals = np.insert(sizes, 0, 0)

    ##################
    # Cluster means ##

    for i in range(n_clust):
        ax_clusters.plot(np.nanmean(traces[labels == i, :], 0) +
                      i * 2, label=i, color=hexac[i])
    ax_clusters.axes.spines['left'].set_visible(False)
    ax_clusters.set_yticks([])

    barlength = 10
    ax_traces.axis("off")
    ax_clusters.axis("off")

    return fig_clust

def cluster_cols():
    color_list = ["lightblue", "lightcoral", "orange", "springgreen", "deepskyblue", "mediumpurple","gold", "cyan", "crimson", "deeppink", "lawngreen", "darkviolet"]
    return color_list

In [39]:
stim = np.asarray([[1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7]])
fig_clust = plot_clusters_dendro(meanresps, stim, linked, labels)#, dendrolims=(940, 0))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [40]:
plt.show()
fig_clust.savefig(str(fish_dir / "hierarchical_clustering_221219_avg.jpg"), dpi=300)

In [21]:
print(labels)
print(np.unique(labels))
#fig, ax = plt.subplots(1,1)
#ax.plot(labels)
#plt.show()
print(num_rois)
print(np.max(roi_map))

[ 3 11 11  7 10  7 11  7  8  7  5  6 11  5 11  3 10 11 11 10  3  3  8  8
  5 11  1  1  1  9  8  5  3  5  7  5  7  5  1  1  5  3  1  1  1  5  9  1
 11  1  2  2  2  9  1 11  9  9  2 11  5 10  5  2 10 11  8  1  7  7  0  1
  9  1 11  5  9  7  0  0 11 11  9  4  4  1  1  4  2  1  3  1  1  1  0  5
  9  8  8  9  2  8  8  3 11  2  4  3  4 10  3  3  3  1  2  9 11  5  5  9
 11 10 11  9  1 10  5  1 10  3  3  1  3  7  7  9 11  4  1  2  9  2  5 11
  9  4  2  7  9  9 11 11  3  4  5  2  6  6  1 11  5  6  5  8  5  8  6  6
  5 11  6  7  9 11  6  6  0  1 10  7  2  8  7  7 11  3 11  6  6  6  1  6
  0  6  8  6  6  7  5  9  8  8  1  6  6  9  9  6  6  3  2  6  8  9  9  9
  2  6  6  6  6  6]
[ 0  1  2  3  4  5  6  7  8  9 10 11]
222
222.0


### Create a figure of all ROIs colored by cluster:

In [22]:
roi_map_clustered = np.zeros_like(roi_map)
for i in range(0, num_rois):
    roi_map_clustered[np.where(roi_map == (i + 1))] = labels[i] + 1
print(np.unique(roi_map_clustered))
print(roi_map_clustered[np.where(roi_map_clustered == 2)])

[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12.]
[2. 2. 2. ... 2. 2. 2.]


In [23]:
fig2, ax2 = plt.subplots(3, 4, figsize=(12, 12))
#title_list = ['201007_f1', '201007_f2','201007_f3','201021_f1','201021_f2','201021_f3','201021_f2','201022_f2','201022_f3','201022_f4','201023_f0','201023_f3']
color_list = ["lightblue", "lightcoral", "orange", "springgreen", "deepskyblue", "mediumpurple","gold", "cyan", "crimson", "deeppink", "lawngreen", "darkviolet"]
cm_roi = LinearSegmentedColormap.from_list("my_list", color_list, N=12)

#cm_roi='rainbow'
for i in range(num_planes):
    r = i // 4
    c = np.mod(i, 4)
    
    roi_layer = roi_map_clustered[i]
    roi_layer[0, 1:13] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
    roi_layer = np.ma.masked_where(roi_layer < 1, roi_layer)
    roi_layer = np.rot90(roi_layer, k=1, axes=(1, 0))

    im = ax2[r, c].imshow(roi_layer, cmap=cm_roi)
    ax2[r, c].axis('off')
    #ax2[r, c].set_title(title_list[i])
    
fig2.colorbar(im, ax=ax2[2,3])
plt.show()
fig2.savefig(str(fish_dir/'clusters_rois_hrc_221219_avg.jpg'), dpi=300)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [24]:
fig2.savefig(str(fish_dir/'clusters_rois_hrc_221219_avg.jpg'), dpi=300)

In [None]:

### Adapt to hrc
'''
clusters_centers = kmeans.cluster_centers_
fig3, ax3 = plt.subplots(1, 1, figsize=(12, 12))
fs = 3
trial_len = 30
num_stim = 18

color_list = plt.cm.tab20(np.linspace(0, 1, k))
clusters_centers_fixed = np.copy(clusters_centers)
clusters_centers_fixed[np.where(clusters_centers_fixed == 0)[0]] = None
for i in range(k):
    tmp_cluster = clusters_centers_fixed[i]
    print(np.where(tmp_cluster <= (np.min(tmp_cluster)))[0])
    tmp_cluster[np.where(tmp_cluster <= (np.min(tmp_cluster)+0.1))[0]] = None
    ax3.plot(tmp_cluster + (i * 7), c=color_list[i])
    
    num_traces_in_cluster = np.shape(np.where(labels_k == i)[0])[0]
    plt.text(-500,(i * 7),str(num_traces_in_cluster))

    
ax3.axvspan(0, (9 * trial_len * fs), facecolor=[0, 0.7, 0.9], alpha=0.2)

for i in range(num_stim):
    t1 = (9 * 3 * 30) + (i * 30 * 10 * 3)
    t2 = t1 + (30 * 3)
    ax3.axvspan(t1, t2, facecolor=[0, 0.7, 0.9], alpha=0.2)

ax3.axvspan((10 * 30 * 3 * 9), (10 * 30 * 3 * 10) - 90, facecolor=[0, 0.7, 0.9], alpha=0.2)

plt.show()
#fig.savefig(str(master/'mean_traces_norm_for_rois.jpg'))
'''