In [None]:
import matplotlib
%matplotlib widget
import matplotlib.pyplot as plt
import h5py
import numpy as np
from pathlib import Path
import skimage as sk
from skimage import io as skio
import json
import pandas as pd

import flammkuchen as fl
from split_dataset import SplitDataset
from bouterin.plots.stimulus_log_plot import get_paint_function
from sklearn.cluster import KMeans
from fimpylab.core.twop_experiment import TwoPExperiment


In [None]:
master = Path(r"J:\_Shared\experiments\E0040_motions_cardinal\v14_cw_ccw")
fish_list = ['200826_f1', '200827_f0', '200917_f0', '200918_f0']

## load all traces:

In [None]:
# loading all traces
len_segment = 2160
traces_all = 0
for i in range(4):
    f = master / fish_list[i]
    print(f)
    dir_traces = f / "traces.h5"
    traces_tmp = fl.load(dir_traces)['traces']
    traces_tmp = traces_tmp[:, 0:len_segment]
    if traces_all is 0:
        traces_all = traces_tmp
    else:
        traces_all = np.concatenate((traces_all, traces_tmp), axis=0)
    print(np.shape(traces_all))
    

In [None]:
# normalizing traces:
traces_all = traces_all.T
traces_all = ((traces_all - traces_all.mean(0)) / traces_all.std(0))
traces_all = traces_all.T

In [None]:
print(np.shape(traces_all))
#avg_traces = np.reshape(traces_all, (np.shape(traces_all)[0], np.shape(traces_all)[1]//3, 3))
#avg_traces = np.nanmean(avg_traces,2)
avg_traces = np.zeros((np.shape(traces_all)[0], np.shape(traces_all)[1]//3))
for i in range(3):
    t1 = i * np.shape(traces_all)[1]//3
    t2 = t1 + np.shape(traces_all)[1]//3
    avg_traces += traces_all[:, t1:t2]
avg_traces /= 3
print(np.shape(avg_traces))


In [None]:
np.shape(avg_traces)
fig0, ax0 = plt.subplots(1, 1, figsize=(15, 8))
ax0.imshow(traces_all[:, 0:720])

In [None]:
np.shape(avg_traces)
fig0, ax0 = plt.subplots(1, 1, figsize=(15, 8))
ax0.imshow(avg_traces)

In [None]:
k = 20
kmeans = KMeans(k)
clusters = kmeans.fit_transform(avg_traces)
clustered_traces = np.zeros_like(avg_traces)
labels_k = kmeans.predict(avg_traces)
labelsinds = labels_k.argsort()
clustered_traces = avg_traces[labelsinds[::-1]]

In [None]:
fig1, ax1 = plt.subplots(1, 1, figsize=(15, 8))
ax1.imshow(clustered_traces, extent=[0,  1000, 0, 500], vmin=-2, vmax=10)
ax1.set_xlabel("Time (sec)")
ax1.set_title("Clustered traces")
plt.show()
file_name = 'clusters_k_of_' + str(k) + '.jpg'
fig1.savefig(str(master/ file_name))

In [None]:
sum_of_sqr_d = []
k_opt = range(1, 15)
for k in k_opt:
    km = KMeans(n_clusters=k)
    km = km.fit(avg_traces)
    sum_of_sqr_d.append(km.inertia_)


In [None]:
fig3, ax3 = plt.subplots(1, 1, figsize=(12, 5))
plt.scatter(k_opt, sum_of_sqr_d)
plt.ylabel("Sum of squred distances")
plt.xlabel("k")
plt.show()
fig3.savefig(str(master/'kmeans_error.jpg'))

In [None]:
fig4, ax4 = plt.subplots(1, 1, figsize=(8, 8))
fs = 3
clusters_centers = kmeans.cluster_centers_
exp = TwoPExperiment(path=f)
clusters_centers_fixed = np.copy(clusters_centers)
#clusters_centers_fixed[np.where(clusters_centers_fixed == 0)[0]] = None
color_list = plt.cm.tab10(np.linspace(0, 1, k))

t = np.arange(0, np.shape(avg_traces)[1]) / fs
for i in range(k):
    tmp_cluster = clusters_centers_fixed[i]
    print(np.where(tmp_cluster <= (np.min(tmp_cluster)))[0])
    #tmp_cluster[np.where(tmp_cluster <= (np.min(tmp_cluster)+0.1))[0]] = None
    ax4.plot(t, tmp_cluster + (i * 7), color=color_list[i])
    
    num_traces_in_cluster = np.shape(np.where(labels_k == i)[0])[0]
    plt.text(-10,(i * 7),str(num_traces_in_cluster))


stimulus_log = exp.load_session_log(log_name='stimulus_log', session_idx=0)
stim_value, t_values = get_paint_function(stimulus_log, 'E0040_motions_cardinal')
stim_value = stim_value / 255
num_stim = np.shape(stim_value)[0] // 3

for i in range(num_stim):
    ax4.axvspan(
        t_values[i, 0],
        t_values[i, 1],
        facecolor=[
            stim_value[i, 0],
            stim_value[i, 1],
            stim_value[i, 2],
        ],
        alpha=0.4,
    )

ax4.axis('off')
plt.show()

In [None]:
file_name = 'kmeans'  + str(k) + '_avg_210309_combined.jpg'
fig4.savefig(str(master/ file_name))

In [None]:
### getting the indices for each fish:
num_rois = np.zeros(4)
for i in range(4):
    f = master / fish_list[i]
    dir_traces = f / "traces.h5"
    traces = fl.load(dir_traces)['traces']
    num_traces = np.shape(traces)[0]
    num_rois[i] = num_traces // 1
    print(num_rois)
    

In [None]:
ind2 = np.cumsum(num_rois).astype(int)
ind_fish

In [None]:
num_rois= num_rois.astype(int)

In [None]:
ind1 = np.asarray([0, ind2[0], ind2[1], ind2[2]])
ind1

#### Choosing one fish and coloring ROIs by clusters

In [None]:
current_fish = 3
fish_dir = master / fish_list[current_fish]
# Rois
rois = fl.load(fish_dir / "merged_rois.h5")["stack"][:,:,:]
labels_fish = labels_k[ind1[current_fish]:ind2[current_fish]]

In [None]:
roi_map = np.copy(rois)

roi_map_clustered = np.zeros_like(roi_map)
for i in range(0, num_rois[current_fish]):
    roi_map_clustered[np.where(roi_map == (i + 1))] = labels_fish[i] + 1


In [None]:
fig1, ax1 = plt.subplots(4, 4, figsize=(12, 12))

for i in range(16):
    r = i // 4
    c = np.mod(i, 4)
    
    roi_layer = roi_map_clustered[i]
    roi_layer = np.ma.masked_where(roi_layer < 1, roi_layer)
    roi_layer = np.rot90(roi_layer, k=1, axes=(1, 0))

    ax1[r, c].imshow(roi_layer, cmap="tab10")#rainbow")
    ax1[r, c].axis('off')
    ax1[r, c].set_title('z' + str(i))
    #print(roi_layer)

    
plt.show()
fig1.savefig(str(fish_dir/'clusters_rois_210305.jpg'))