In [None]:
import matplotlib
%matplotlib widget

import numpy as np
from split_dataset import SplitDataset
from pathlib import Path
import flammkuchen as fl
from tifffile import imread
import matplotlib.pyplot as plt 
from fimpylab.core.lightsheet_experiment import LightsheetExperiment
from bouterin.plots.stimulus_log_plot import get_paint_function

from bouter.utilities import reliability 
from skimage.filters import threshold_otsu
import xarray as xr
from scipy.signal import detrend 

from motions.utilities import stim_vel_dir_dataframe, quantize_directions

from scipy.signal import find_peaks

In [None]:
from scipy.cluster.hierarchy import dendrogram, cut_tree, set_link_color_palette
from matplotlib import cm
from bouterin.plots.stimulus_log_plot import get_paint_function

In [None]:
master_path =  Path(r"Z:\Hagar\E0040\ablations\post\v31")
fish_list = list(master_path.glob("*f*"))
path = fish_list[0]
print(path)

In [None]:
traces = fl.load(path / "filtered_traces.h5", "/detr")
coords = fl.load(path / "data_from_suite2p_cells.h5", "/coords")
#coords = suite2p_data['coords']

exp = LightsheetExperiment(path)
fs = int(exp.fn)

In [None]:
suite2p_brain = fl.load(path / "data_from_suite2p_cells_brain.h5")
in_brain_idx = suite2p_brain['coords_idx']

traces = traces[:, in_brain_idx]
coords = coords[in_brain_idx]

In [None]:
t = np.arange(np.shape(traces)[1]) / fs
len_rec, num_traces = np.shape(traces)
print("num_traces: ", num_traces)
print("len_rec: ", len_rec)
print("sampling rate: ", fs)

In [None]:
regs = fl.load(path / "sensory_regressors.h5", "/regressors")#[0]
right = np.asarray(regs.iloc[:, 0])
left = np.asarray(regs.iloc[:, 4])

num_traces = np.shape(traces)[1]

right_corr = np.zeros((num_traces))
left_corr = np.zeros((num_traces))
for i in range(num_traces):
    right_corr[i] = np.corrcoef(right, traces[:, i])[0,1]
    left_corr[i] = np.corrcoef(left, traces[:, i])[0,1]

In [None]:
thresh = 0.15
right_tuned = np.where(np.abs(right_corr) > thresh)[0]
print(np.shape(right_tuned))
n_right_tuned = np.shape(right_tuned)[0]

left_tuned = np.where(np.abs(left_corr) > thresh)[0]
print(np.shape(left_tuned))
n_left_tuned = np.shape(left_tuned)[0]

In [None]:
left_traces = traces[:, left_tuned].T
right_traces = traces[:, right_tuned].T
print(np.shape(left_traces))

In [None]:
#### Getting a list of stimuli order: 0=right, 7=right-up
pause_duration = 10 * fs
stim_duration = 10 * fs

left_diff = np.diff(left)
right_diff = np.diff(right)


left_start = find_peaks(left_diff, height=0.1)[0]  - pause_duration
left_end = find_peaks(left_diff, height=0.1)[0] + stim_duration 
right_start = find_peaks(right_diff, height=0.1)[0] - pause_duration
right_end = find_peaks(right_diff, height=0.1)[0] + stim_duration 

In [None]:
n_dir=8
n_sessions = 4
num_left_trials = np.shape(left_start)[0]
num_right_trials = np.shape(right_start)[0]
len_segment = pause_duration + stim_duration
print(len_segment)

left_trials = np.zeros((n_left_tuned, n_dir, n_sessions, len_segment))
right_trials = np.zeros((n_right_tuned, n_dir, n_sessions, len_segment))

In [None]:
regs_array = np.asarray(regs)
curr_session = np.zeros((n_dir), dtype=int)
for i in range(num_left_trials):
    t1 = left_start[i]  - stim_duration 
    t2 = t1 + stim_duration
    
    curr_seg = np.nanmean(regs_array[t1:t2], axis=0)
    
    try:
        curr_dir = np.where(curr_seg > 0.1)[0][0]

        t1 = left_start[i]
        t2 = t1 + len_segment

        if curr_session[curr_dir] < n_sessions:
            left_trials[:, curr_dir, curr_session[curr_dir], :] = left_traces[:, t1:t2]
            curr_session[curr_dir] += 1
    except:
        print("Stupid trial")
        
print(np.unique(left_trials))
#left_trials[left_trials == 0] = None
print(np.unique(left_trials))

In [None]:
curr_session = np.zeros((n_dir), dtype=int)
for i in range(num_right_trials):
    t1 = right_start[i] - stim_duration 
    t2 = t1 + stim_duration
    try:
        curr_seg = np.nanmean(regs_array[t1:t2], axis=0)
        curr_dir = np.where(curr_seg > 0.1)[0][0]

        t1 = right_start[i]
        t2 = t1 + len_segment
        #print(t1,t2)
    
        if curr_session[curr_dir] < n_sessions:
            right_trials[:, curr_dir, curr_session[curr_dir], :] = right_traces[:, t1:t2]
            curr_session[curr_dir] += 1
    except:
        print("Stupid trial")
        
right_trials[right_trials == 0] = None

In [None]:
####### Concatenate average responses and cluster

left_trials_avg = np.nanmean(left_trials, axis=2)
right_trials_avg = np.nanmean(right_trials, axis=2)
print(np.shape(left_trials_avg))
    
    
dt = 1 / fs
traces_xr = xr.DataArray(
    data=left_trials_avg,                               #Adding the data
    dims=['roi', 'block', 't'],                #Defining name of the dimensions
    coords={                                   #Defining values at which each dimension wase valuated
        'roi':np.arange(left_trials_avg.shape[0]), 
        'block':np.arange(n_dir),
        't':np.arange(left_trials_avg.shape[2])*dt
        }
    )
reliability_arr_left = reliability(np.swapaxes(traces_xr, 0, 2).values)


traces_xr = xr.DataArray(
    data=right_trials_avg,                               #Adding the data
    dims=['roi', 'block', 't'],                #Defining name of the dimensions
    coords={                                   #Defining values at which each dimension wase valuated
        'roi':np.arange(right_trials_avg.shape[0]), 
        'block':np.arange(n_dir),
        't':np.arange(right_trials_avg.shape[2])*dt
        }
    )
reliability_arr_right = reliability(np.swapaxes(traces_xr, 0, 2).values)

In [None]:
coords_right = coords[right_tuned]
coords_left = coords[left_tuned]

In [None]:
z_res = 10
fig, axs = plt.subplots(2, 4, figsize=(12, 6), gridspec_kw={'width_ratios': [3, 2, 3, 2], 'height_ratios': [1, 2]})
mp_ind_l = np.argsort(reliability_arr_left)
mp_ind_l = np.arange(0, n_left_tuned)
axs[1,0].scatter(coords[:,2]*0.6, coords[:,1]*.6, c='lightgray', s=2, alpha=0.8)
axs[1,1].scatter(coords[:,0]*z_res, coords[:,1]*0.6, c='lightgray', alpha=0.8)
axs[0,0].scatter(coords[:,2]*0.6, coords[:,0]*z_res, c='lightgray', alpha=0.8)

axs[1,0].scatter(coords_left[mp_ind_l,2]*0.6, coords_left[mp_ind_l,1]*.6, c=reliability_arr_left[mp_ind_l], s=2, alpha=0.8, cmap='Reds', vmin=0, vmax=1)
axs[1,1].scatter(coords_left[mp_ind_l,0]*z_res, coords_left[mp_ind_l,1]*0.6, c=reliability_arr_left[mp_ind_l], alpha=0.8, cmap='Reds', vmin=0, vmax=1)
axs[0,0].scatter(coords_left[mp_ind_l,2]*0.6, coords_left[mp_ind_l,0]*z_res, c=reliability_arr_left[mp_ind_l], alpha=0.8, cmap='Reds', vmin=0, vmax=1)

axs[0,0].spines['right'].set_visible(False)
axs[0,0].spines['top'].set_visible(False)

axs[1,1].spines['right'].set_visible(False)
axs[1,1].spines['top'].set_visible(False)

axs[1,0].spines['right'].set_visible(False)
axs[1,0].spines['top'].set_visible(False)

axs[0,1].axis('off')
axs[0,3].axis('off')

mp_ind_r = np.argsort(reliability_arr_right)
mp_ind_r = np.arange(0, n_right_tuned)
axs[1,2].scatter(coords[:,2]*0.6, coords[:,1]*.6, c='lightgray', s=2, alpha=0.8)
axs[1,3].scatter(coords[:,0]*z_res, coords[:,1]*0.6, c='lightgray', alpha=0.8)
axs[0,2].scatter(coords[:,2]*0.6, coords[:,0]*z_res, c='lightgray', alpha=0.8)

axs[1,2].scatter(coords_right[mp_ind_r,2]*0.6, coords_right[mp_ind_r,1]*.6, c=reliability_arr_right[mp_ind_r], s=2, alpha=0.8, cmap='Reds', vmin=0, vmax=1)
axs[1,3].scatter(coords_right[mp_ind_r,0]*z_res, coords_right[mp_ind_r,1]*0.6, c=reliability_arr_right[mp_ind_r], alpha=0.8, cmap='Reds', vmin=0, vmax=1)
axs[0,2].scatter(coords_right[mp_ind_r,2]*0.6, coords_right[mp_ind_r,0]*z_res, c=reliability_arr_right[mp_ind_r], alpha=0.8, cmap='Reds', vmin=0, vmax=1)

axs[0,2].spines['right'].set_visible(False)
axs[0,2].spines['top'].set_visible(False)

axs[1,3].spines['right'].set_visible(False)
axs[1,3].spines['top'].set_visible(False)

axs[1,2].spines['right'].set_visible(False)
axs[1,2].spines['top'].set_visible(False)

In [None]:
file_name = 'color reliability between different history stimuli (thresh ' + str(thresh) + ').jpg'
fig.savefig(str(path / file_name), dpi=300)