In [None]:
import matplotlib
%matplotlib widget
import matplotlib.pyplot as plt
import h5py
import numpy as np
from pathlib import Path
import skimage as sk
from skimage import io as skio
import json
import pandas as pd

import flammkuchen as fl
from split_dataset import SplitDataset
from bouterin.plots.stimulus_log_plot import get_paint_function
from sklearn.decomposition import PCA
from fimpylab.core.twop_experiment import TwoPExperiment


In [None]:
master = Path(r"J:\_Shared\experiments\E0040_motions_cardinal\v14_cw_ccw")
fish_list = ['200826_f1', '200827_f0', '200917_f0', '200918_f0']

## load all traces and create transformation matrix:

In [None]:
# loading all traces
len_segment = 2160
traces_all = 0
for i in range(4):
    f = master / fish_list[i]
    print(f)
    dir_traces = f / "traces.h5"
    traces_tmp = fl.load(dir_traces)['traces']
    traces_tmp = traces_tmp[:, 0:len_segment]
    if traces_all is 0:
        traces_all = traces_tmp
    else:
        traces_all = np.concatenate((traces_all, traces_tmp), axis=0)
    print(np.shape(traces_all))
    

In [None]:
# normalizing traces:
traces_all = traces_all.T
traces_all = ((traces_all - traces_all.mean(0)) / traces_all.std(0))
traces_all = traces_all.T

In [None]:
print(np.shape(traces_all))
avg_traces = np.zeros((np.shape(traces_all)[0], np.shape(traces_all)[1]//3))
for i in range(3):
    t1 = i * np.shape(traces_all)[1]//3
    t2 = t1 + np.shape(traces_all)[1]//3
    avg_traces += traces_all[:, t1:t2]
avg_traces /= 3
print(np.shape(avg_traces))

In [None]:
np.shape(avg_traces)
fig0, ax0 = plt.subplots(1, 1)
ax0.imshow(avg_traces)

In [None]:
n_comp = 5
time_pca = PCA(n_components=n_comp)
transformed_data2 = time_pca.fit_transform(avg_traces, 0)
pcs2 = time_pca.components_

print(np.shape(pcs2))

In [None]:
exp = TwoPExperiment(path=f)
color_list = plt.cm.rainbow(np.linspace(0, 1, 5))

t = np.arange(0, 720) / 3

fig3, ax3 = plt.subplots(1, 2, figsize=(8 ,8))
ax3[1].bar(np.arange(n_comp), time_pca.explained_variance_ratio_)

for i in range(n_comp):
    ax3[0].plot(t, pcs2[i] + i*0.25)
ax3[0].set_title("PCs")
ax3[1].set_title("explained variance")

stimulus_log = exp.load_session_log(log_name='stimulus_log', session_idx=0)
stim_value, t_values = get_paint_function(stimulus_log, 'E0040_motions_cardinal')
stim_value = stim_value / 255
num_stim = np.shape(stim_value)[0] // 3

for i in range(num_stim):
    ax3[0].axvspan(
        t_values[i, 0],
        t_values[i, 1],
        facecolor=[
            stim_value[i, 0],
            stim_value[i, 1],
            stim_value[i, 2],
        ],
        alpha=0.5,
    )
                
#plt.xlabel("Time (sec)")

In [None]:
fig3.savefig(str(master/'pca_time_avg_210304_combined.jpg'))

#### Choosing one fish and applying the transformation on the data

In [None]:
fish_id = fish_list[3]
fish_dir = master / fish_id
# Normalize traces:
traces = fl.load(fish_dir / "traces.h5")["traces"].T
traces = ((traces - traces.mean(0)) / traces.std(0))
traces=traces.T[:, 0:-2]

print(np.shape(traces))
avg_traces = np.zeros((np.shape(traces)[0], np.shape(traces)[1]//3))
for i in range(3):
    t1 = i * np.shape(traces)[1]//3
    t2 = t1 + np.shape(traces)[1]//3
    avg_traces += traces[:, t1:t2]
avg_traces /= 3
print(np.shape(avg_traces))

# Rois
rois = fl.load(fish_dir / "merged_rois.h5")["stack"]


In [None]:
# transforming the data:
#transformed_sata = time_pca.transform(avg_traces)
transformed_data = np.matmul(pcs2, avg_traces.T)

In [None]:
print(np.shape(avg_traces))
print(np.shape(pcs2))
print(np.shape(transformed_data))
labels = transformed_data.T

In [None]:
roi_map = np.copy(rois)
num_rois = np.shape(traces)[0]
roi_map_pc123 = np.zeros((np.shape(roi_map)[0], np.shape(roi_map)[1], np.shape(roi_map)[2], 3))

for i in range(0, num_rois):
    roi_ind = np.where(roi_map == (i + 1))
    for j in range(3):
        roi_map_pc123[roi_ind[0], roi_ind[1], roi_ind[2], j] = labels[i,j]


In [None]:
fig6, ax6 = plt.subplots(9, 3, figsize=(12, 20))

num_planes = 16 #np.shape(rois)[0]-7
min_max = np.zeros((3, 2))
for i in range(3):
    min_max[i, 0] = np.min(roi_map_pc123[:,:,:,i])
    min_max[i, 1] = np.max(roi_map_pc123[:,:,:,i])
cmap_list = ['Reds', 'Greens', 'Blues']
im_list = [0, 0, 0]

for i in range(9):
    
    for j in range(3):
        roi_layer = roi_map_pc123[i,:,:,j]
        #roi_layer[0, 0:2] = min_max[j,:]
        roi_layer = np.ma.masked_where(roi_layer == 0, roi_layer)
        roi_layer = np.rot90(roi_layer, k=1, axes=(1, 0))
        im_list[j] = ax6[i, j].imshow(roi_layer, cmap='coolwarm')#, alpha=1, vmin=min_max[j,0], vmax=min_max[j,1])
    
        ax6[i, j].axis('off')

fig6.suptitle(fish_id)
#fig6.colorbar(im_list[0], ax=ax6[0,3])
#fig6.colorbar(im_list[1], ax=ax6[1,3])
#fig6.colorbar(im_list[2], ax=ax6[2,3])
plt.show()
file_name = 'rois_group_pc123_' + fish_id + '.jpg'
fig6.savefig(str(fish_dir/file_name), dpi=300)

In [None]:
fig7, ax7 = plt.subplots(1, 3, figsize=(12, 5))

num_planes = 16 #np.shape(rois)[0]-7
min_max = np.zeros((3, 2))
for i in range(3):
    min_max[i, 0] = -np.max(roi_map_pc123[:,:,:,i])
    min_max[i, 1] = np.max(roi_map_pc123[:,:,:,i])
cmap_list = ['Reds', 'Greens', 'Blues']
im_list = [0, 0, 0]
title_list = ['PC1', 'PC2', 'PC3']
for i in range(num_planes):

    for j in range(3):
        roi_layer = roi_map_pc123[i,:,:,j]
        #roi_layer[0, 0:2] = min_max[j,:]
        roi_layer = np.ma.masked_where(roi_layer == 0, roi_layer)
        roi_layer = np.rot90(roi_layer, k=1, axes=(1, 0))
        im_list[j] = ax7[j].imshow(roi_layer, cmap='coolwarm', alpha=1, vmin=min_max[j,0], vmax=min_max[j,1])
        
        ax7[j].set_title(title_list[j])
        ax7[j].axis('off')

fig7.suptitle(fish_id)
#fig7.colorbar(im_list[0], ax=ax7[0])
#fig7.colorbar(im_list[1], ax=ax7[1])
#fig7.colorbar(im_list[2], ax=ax7[2])
plt.show()
file_name = 'rois_overlay_pc123_' + fish_id + '.jpg'
fig7.savefig(str(fish_dir/file_name))