In [None]:
import matplotlib
%matplotlib widget
import matplotlib.pyplot as plt
import h5py
import numpy as np
from pathlib import Path
import skimage as sk
from skimage import io as skio
import json
import pandas as pd

import flammkuchen as fl
from split_dataset import SplitDataset
from bouterin.plots.stimulus_log_plot import get_paint_function
from sklearn.decomposition import PCA
from fimpylab.core.twop_experiment import TwoPExperiment


In [None]:
master = Path(r"J:\_Shared\experiments\E0040_motions_cardinal\v14_cw_ccw")
fish_list = ['200826_f1', '200827_f0', '200917_f0', '200918_f0']
fish_dir = master / fish_list[3]

In [None]:
# loading traces and ROIs:
#exp_list = load_folder(path / "behavior")

# Normalize traces:
traces = fl.load(fish_dir / "traces.h5")["traces"].T
traces = ((traces - traces.mean(0)) / traces.std(0))

# Rois
rois = fl.load(fish_dir / "merged_rois.h5")["stack"]
traces=traces.T[:, 0:-2]

In [None]:

print(np.shape(traces))
avg_traces = np.zeros((np.shape(traces)[0], np.shape(traces)[1]//3))
for i in range(3):
    t1 = i * np.shape(traces)[1]//3
    t2 = t1 + np.shape(traces)[1]//3
    avg_traces += traces[:, t1:t2]
avg_traces /= 3
print(np.shape(avg_traces))


In [None]:
np.shape(avg_traces)
fig0, ax0 = plt.subplots(1, 1)
ax0.imshow(avg_traces)

In [None]:
n_comp = 5
pop_pca = PCA(n_components=n_comp)
transformed_data = pop_pca.fit_transform(avg_traces.T, 0)
pcs = pop_pca.components_
print(np.shape(pcs))

In [None]:
fig1, ax1 = plt.subplots(1, 2)
ax1[1].bar(np.arange(n_comp), pop_pca.explained_variance_ratio_)

for i in range(n_comp):
    ax1[0].plot(pcs[i] + i*0.5)

In [None]:
fig1.savefig(str(fish_dir/'pca_pop_avg_210304.jpg'))

In [None]:
n_comp = 5
time_pca = PCA(n_components=n_comp)
transformed_data2 = time_pca.fit_transform(avg_traces, 0)
pcs2 = time_pca.components_

print(np.shape(pcs2))

In [None]:
exp = TwoPExperiment(path=fish_dir)
color_list = plt.cm.rainbow(np.linspace(0, 1, 5))

t = np.arange(0, 720) / 3

fig3, ax3 = plt.subplots(1, 2)
ax3[1].bar(np.arange(n_comp), time_pca.explained_variance_ratio_)

for i in range(n_comp):
    ax3[0].plot(t, pcs2[i] + i*0.25)
ax3[0].set_title("PCs")
ax3[1].set_title("explained variance")

stimulus_log = exp.load_session_log(log_name='stimulus_log', session_idx=0)
stim_value, t_values = get_paint_function(stimulus_log, 'E0040_motions_cardinal')
stim_value = stim_value / 255
num_stim = np.shape(stim_value)[0] // 3

for i in range(num_stim):
    ax3[0].axvspan(
        t_values[i, 0],
        t_values[i, 1],
        facecolor=[
            stim_value[i, 0],
            stim_value[i, 1],
            stim_value[i, 2],
        ],
        alpha=0.5,
    )
                
#plt.xlabel("Time (sec)")

In [None]:
fig3.savefig(str(fish_dir/'pca_time_avg_210304.jpg'))

In [None]:
### Coloring ROIs according to PC1 and PC2:

In [None]:
roi_map = np.copy(rois)
num_rois = np.shape(traces)[0]
roi_map_pc1= np.zeros_like(roi_map)
roi_map_pc2 = np.zeros_like(roi_map)
roi_map_pc3 = np.zeros_like(roi_map)
roi_map_pc123 = np.zeros((np.shape(roi_map)[0], np.shape(roi_map)[1], np.shape(roi_map)[2], 3))

for i in range(0, num_rois):
    roi_ind = np.where(roi_map == (i + 1))
    roi_map_pc1[roi_ind] = transformed_data2[i,0]
    roi_map_pc2[roi_ind] = transformed_data2[i,1]
    roi_map_pc3[roi_ind] = transformed_data2[i,2]
    for j in range(3):
        roi_map_pc123[roi_ind[0], roi_ind[1], roi_ind[2], j] = transformed_data2[i,j]
#print(np.unique(roi_map_pc1))
#print(roi_map_pc1[np.where(roi_map_pc1 == 2)])

In [None]:
fig4, ax4 = plt.subplots(4, 4, figsize=(12, 12))
#title_list = ['201007_f1', '201007_f2','201007_f3','201021_f1','201021_f2','201021_f3','201021_f2','201022_f2','201022_f3','201022_f4','201023_f0','201023_f3']
color_list = ["lightblue", "lightcoral", "orange", "springgreen", "deepskyblue", "mediumpurple","gold", "cyan", "crimson", "deeppink", "lawngreen", "darkviolet"]
#cm_roi = LinearSegmentedColormap.from_list("my_list", color_list, N=12)

num_planes = 16# np.shape(rois)[0]-6
cm_roi='coolwarm'
for i in range(num_planes):
    r = i // 4
    c = np.mod(i, 4)
    
    roi_layer = roi_map_pc1[i]
    #roi_layer[0, 1:13] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
    roi_layer = np.ma.masked_where(roi_layer < 1, roi_layer)
    roi_layer = np.rot90(roi_layer, k=1, axes=(1, 0))

    im = ax4[r, c].imshow(roi_layer, cmap=cm_roi)
    ax4[r, c].axis('off')
    #ax4[r, c].set_title(title_list[i])
    
fig4.colorbar(im, ax=ax4[2,3])
plt.show()
fig4.savefig(str(fish_dir/'rois_pc1_210304.jpg'))

In [None]:
fig5, ax5 = plt.subplots(4, 4, figsize=(12, 12))

num_planes = 16 #np.shape(rois)[0]-7
cm_roi='coolwarm'
for i in range(num_planes):
    r = i // 4
    c = np.mod(i, 4)
    
    roi_layer = roi_map_pc2[i]
    #roi_layer[0, 1:13] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
    roi_layer = np.ma.masked_where(roi_layer < 1, roi_layer)
    roi_layer = np.rot90(roi_layer, k=1, axes=(1, 0))

    im = ax5[r, c].imshow(roi_layer, cmap=cm_roi)
    ax5[r, c].axis('off')
    #ax5[r, c].set_title(title_list[i])
    
fig5.colorbar(im, ax=ax5[2,3])
plt.show()
fig5.savefig(str(fish_dir/'rois_pc2_210304.jpg'))

In [None]:
fig6, ax6 = plt.subplots(4, 4, figsize=(12, 12))

num_planes = 16 #np.shape(rois)[0]-7
min_val1 = np.min(roi_map_pc1)
max_val1 = np.max(roi_map_pc1)
min_val2 = np.min(roi_map_pc2)
max_val2 = np.max(roi_map_pc2)
min_val3 = np.min(roi_map_pc3)
max_val3 = np.max(roi_map_pc3)

for i in range(num_planes):
    r = i // 4
    c = np.mod(i, 4)
    
    roi_layer = roi_map_pc1[i]
    roi_layer[0, 0:2] = [min_val1, max_val1]
    roi_layer = np.ma.masked_where(roi_layer < 1, roi_layer)
    roi_layer = np.rot90(roi_layer, k=1, axes=(1, 0))
    im1 = ax6[r, c].imshow(roi_layer, cmap='Reds', alpha=1)
    
    roi_layer = roi_map_pc2[i]
    roi_layer[0, 0:2] = [min_val2, max_val2]
    roi_layer = np.ma.masked_where(roi_layer < 1, roi_layer)
    roi_layer = np.rot90(roi_layer, k=1, axes=(1, 0))
    im2 = ax6[r, c].imshow(roi_layer, cmap='Greens', alpha=1)
    
    roi_layer = roi_map_pc3[i]
    roi_layer[0, 0:2] = [min_val3, max_val3]
    roi_layer = np.ma.masked_where(roi_layer < 1, roi_layer)
    roi_layer = np.rot90(roi_layer, k=1, axes=(1, 0))
    im3 = ax6[r, c].imshow(roi_layer, cmap='Blues', alpha=1)
    
    ax6[r, c].axis('off')
    
fig6.colorbar(im1, ax=ax6[0,3])
fig6.colorbar(im2, ax=ax6[1,3])
fig6.colorbar(im3, ax=ax6[2,3])
plt.show()
fig6.savefig(str(fish_dir/'rois_pc123_210304.jpg'))