In [None]:
import matplotlib
%matplotlib widget
import matplotlib.pyplot as plt
import h5py
import numpy as nps
from pathlib import Path
import skimage as sk
from skimage import io as skio
import json
import pandas as pd

import flammkuchen as fl
from split_dataset import SplitDataset
from bouterin.plots.stimulus_log_plot import get_paint_function
from fimpylab.core.twop_experiment import TwoPExperiment


In [None]:
master = Path(r"\\Funes\Shared\experiments\E0040_motions_cardinal\v13_cw_ccw\2p\rf")
fish_list = list(master.glob("*f[0-9]*"))
num_fish = len(fish_list)
fish_dir = fish_list[2]
print(fish_dir)

In [None]:
# load traces:
traces = fl.load(fish_dir / "traces.h5")['traces'][:, 0:-2]
fs = 3
t = np.arange(np.shape(traces)[1]) / fs
num_traces, len_rec = np.shape(traces)
print("num_traces: ", num_traces)
print("len_rec: ", len_rec)
new_len_rec = int(len_rec/3)

In [None]:
norm_traces = np.copy(traces)
norm_traces=norm_traces.T# need to transpose it since the functions work like that 
sd=np.nanstd(norm_traces)
mean=np.nanmean(norm_traces)
norm_traces=norm_traces-mean #numerator in the formula for z-score 
norm_traces=norm_traces/sd
norm_traces=norm_traces.T

In [None]:
trial_traces = np.zeros((3, num_traces, new_len_rec))
for i in range(3):
    t1 = i * new_len_rec
    t2 = t1 + new_len_rec
    trial_traces[i] = traces[:, t1:t2]
avg_traces = np.nanmean(trial_traces, 0)

In [None]:
exp = TwoPExperiment(path=fish_dir)
num_subtrials = 16
num_rep = 3
num_trials = num_rep * num_subtrials

stimulus_log = exp.load_session_log(log_name='stimulus_log', session_idx=0)
stim_value, t_values = get_paint_function(stimulus_log, 'E0040_motions_cardinal')
stim_value = stim_value / 255
num_stim = np.shape(stim_value)[0]

In [None]:
# Rois
rois = fl.load(fish_dir / "merged_rois.h5")["stack"][:,:,:]

In [None]:
from skimage import color
import numpy as np
from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram, cut_tree, set_link_color_palette
#import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import cm

### Create a figure of all ROIs colored by cluster:

In [None]:
### getting the indices for each fish:
print(num_fish)
num_rois = np.zeros(num_fish)
for i in range(num_fish):
    f = master / fish_list[i]
    dir_traces = f / "traces.h5"
    traces = fl.load(dir_traces)['traces']
    num_traces = np.shape(traces)[0]
    num_rois[i] = num_traces // 1
    print(num_rois)

In [None]:
num_rois = num_rois.astype(int)

In [None]:
roi_map = np.copy(rois)

roi_map_clustered = np.zeros_like(roi_map)
for i in range(0, num_rois[current_fish]):
    roi_map_clustered[np.where(roi_map == (i + 1))] = labels_fish[i] + 1


In [None]:
fig1, ax1 = plt.subplots(3, 4, figsize=(12, 12))

color_list = ["lightblue", "lightcoral", "orange", "springgreen", "deepskyblue", "mediumpurple","gold", "cyan", "crimson", "deeppink", "lawngreen", "darkviolet"][0:n_clust]
color_list = ["#ff5c67", "#af0006", "#ffa468", "#8c5f00", "#e4a400", "#d5c86f", "#939400", "#a7d380", "#138b00", "#42e087", "#00a86d", "#81c7a8", "#019a82", "#1eaaff", "#0268bb", "#5951d7", "#6b4570", "#ad20aa", "#ffa1e2", "#ff4a94"][0:n_clust]
color_list = ["#cc566a", "#cd6c39", "#a39440", "#64ac48", "#4aac8d", "#688bcd", "#8562cc", "#c361aa"]
cm_roi = LinearSegmentedColormap.from_list("my_list", color_list, N=n_clust)

num_planes= np.shape(roi_map)[0]
for i in range(num_planes):
    r = i // 4
    c = np.mod(i, 4)
    
    roi_layer = roi_map_clustered[i]
    roi_layer[0, 1:(n_clust+1)] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20][0:n_clust]
    roi_layer = np.ma.masked_where(roi_layer < 1, roi_layer)
    roi_layer = np.rot90(roi_layer, k=1, axes=(1, 0))

    ax1[r, c].imshow(roi_layer, cmap=cm_roi)#rainbow")
    ax1[r, c].axis('off')
    ax1[r, c].set_title('z' + str(i))
    #print(roi_layer)

    
plt.show()
file_name = 'clusters_hrc_rois_210611_k' + str(n_clust) + '_2.jpg'
fig1.savefig(str(fish_dir/file_name))