In [None]:
%matplotlib widget

In [None]:
import numpy as np
import flammkuchen as fl
import napari
from pathlib import Path

from split_dataset import SplitDataset
import json

import flammkuchen as fl 
import matplotlib.pyplot as plt

In [None]:
def exp_decay_kernel(upsample=1, tau, dt, len_rec):
    t = np.arange(len_rec * upsample) * dt / upsample
    
    decay = np.exp(-t / tau)
    decay /= np.sum(decay)
    return decay

In [None]:
master_path =  Path(r"\\funes\Shared\Elena\2022_11_17_ls_trb_red_lp650nm_3\07")
#master_path =  Path(r"\\funes\Shared\Elena\2022_11_22_ls_trb_red_lp650nm_4\07")
fish_list = list(master_path.glob("*f*"))
path = fish_list[7]
print(path)


In [None]:
suite2p_data = fl.load(path / "data_from_suite2p_cells_brain.h5")
traces = suite2p_data['traces']
coords = suite2p_data['coords']
np.shape(traces)


In [None]:
# normalize traces 
norm_traces = np.copy(traces)
norm_traces=norm_traces.T
sd=np.nanstd(norm_traces, 0)
mean=np.nanmean(norm_traces, 0)
norm_traces=norm_traces-mean 
norm_traces=norm_traces/sd
traces=norm_traces.T

In [None]:
fs = 2
num_traces, len_rec = np.shape(traces)
stim_right = fl.load(path / "stimulus_right.h5")[0]
stim_left = fl.load(path / "stimulus_left.h5")[0]

ca_kernel = exp_decay_kernel(tau=1.8, dt=1/fs, len_rec=len_rec)
ca_kernel = ca_kernel[0:500]

In [None]:
right_conv = np.convolve(stim_right, ca_kernel, mode='full')
kernel_size = np.shape(ca_kernel)[0] - 1
right_conv = right_conv[0:-kernel_size]

In [None]:
left_conv = np.convolve(stim_left, ca_kernel, mode='full')
kernel_size = np.shape(ca_kernel)[0] - 1
left_conv = left_conv[0:-kernel_size]

In [None]:
fig = plt.figure()
plt.plot(left_conv)
plt.plot(right_conv)

In [None]:
right_traces = np.dot(traces, right_conv) - num_traces * np.mean(traces, 1) * np.mean(right_conv)
right_traces /= (traces.shape[1] - 1) * np.std(traces, 1) * np.std(right_conv)

left_traces = np.dot(traces, left_conv) - num_traces * np.mean(traces, 1) * np.mean(left_conv)
left_traces /= (traces.shape[1] - 1) * np.std(traces, 1) * np.std(left_conv)

In [None]:
directionality_i = right_traces - left_traces 

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(6, 6), gridspec_kw={'width_ratios': [1, 2], 'height_ratios': [1, 2]})

ax[0,0].axis('off')
ax[0,1].spines['right'].set_visible(False)
ax[0,1].spines['top'].set_visible(False)

ax[1,0].spines['right'].set_visible(False)
ax[1,0].spines['top'].set_visible(False)

ax[1,1].spines['right'].set_visible(False)
ax[1,1].spines['top'].set_visible(False)

ax[0, 1].scatter(coords[:, 2], coords[:, 0], c=directionality_i, cmap='PiYG', alpha=1, vmin=-0.8, vmax=0.8, s=2)
ax[1, 1].scatter(coords[:, 2], coords[:, 1], c=directionality_i, cmap='PiYG', alpha=1, vmin=-0.8, vmax=0.8, s=2)
ax[1, 0].scatter(coords[:, 0], coords[:, 1], c=directionality_i, cmap='PiYG', alpha=1, vmin=-0.8, vmax=0.8, s=2)


In [None]:
file_name = "directionality index.jpg"
fig.savefig(path / file_name, dpi=300)

In [None]:
fig_r, ax_r = plt.subplots(2, 2, figsize=(6, 6), gridspec_kw={'width_ratios': [1, 2], 'height_ratios': [1, 2]})

ax_r[0,0].axis('off')
ax_r[0,1].spines['right'].set_visible(False)
ax_r[0,1].spines['top'].set_visible(False)

ax_r[1,0].spines['right'].set_visible(False)
ax_r[1,0].spines['top'].set_visible(False)

ax_r[1,1].spines['right'].set_visible(False)
ax_r[1,1].spines['top'].set_visible(False)

ax_r[0, 1].scatter(coords[:, 2], coords[:, 0], c=left_traces, cmap='coolwarm', alpha=0.7, s=2, vmin=-0.8, vmax=0.8)
ax_r[1, 1].scatter(coords[:, 2], coords[:, 1], c=left_traces, cmap='coolwarm', alpha=0.7, s=2, vmin=-0.8, vmax=0.8)
ax_r[1, 0].scatter(coords[:, 0], coords[:, 1], c=left_traces, cmap='coolwarm', alpha=0.7, s=2, vmin=-0.8, vmax=0.8)

In [None]:
file_name = "corr with left regressor.jpg"
fig_r.savefig(path / file_name, dpi=300)

In [None]:
dir_path = path / 'beh'
with open(next(Path(dir_path).glob("*metadata.json")), "r") as f:
    metadata = json.load(f)
lsconfig = metadata["imaging"]["microscope_config"]['lightsheet']['scanning']
z_tot_span = lsconfig["z"]["piezo_max"] - lsconfig["z"]["piezo_min"]
n_planes = lsconfig["triggering"]["n_planes"]
z_res = z_tot_span / n_planes

In [None]:
mp_ind = np.argsort(np.abs(directionality_i))
mp_ind

In [None]:
fig_mp, ax_mp = plt.subplots(2, 2, figsize=(6, 6), gridspec_kw={'width_ratios': [1, 2], 'height_ratios': [1, 2]})

ax_mp[0,0].axis('off')
ax_mp[0,1].spines['right'].set_visible(False)
ax_mp[0,1].spines['top'].set_visible(False)

ax_mp[1,0].spines['right'].set_visible(False)
ax_mp[1,0].spines['top'].set_visible(False)

ax_mp[1,1].spines['right'].set_visible(False)
ax_mp[1,1].spines['top'].set_visible(False)

ax_mp[0, 1].scatter(coords[mp_ind, 2] * 0.6, coords[mp_ind, 0] * z_res, c=directionality_i[mp_ind], cmap='PiYG', alpha=1, vmin=-2, vmax=2, s=2)
ax_mp[1, 1].scatter(coords[mp_ind, 2] * 0.6, coords[mp_ind, 1] * 0.6, c=directionality_i[mp_ind], cmap='PiYG', alpha=1, vmin=-2, vmax=2, s=2)
ax_mp[1, 0].scatter(coords[mp_ind, 0] * z_res, coords[mp_ind, 1] * 0.6, c=directionality_i[mp_ind], cmap='PiYG', alpha=1, vmin=-2, vmax=2, s=2)


In [None]:
file_name = "directionality index max proj.png"
fig_mp.savefig(path / file_name, dpi=300)

In [None]:
ax_mp[0,1].axis('off')
ax_mp[1,0].axis('off')
ax_mp[1,1].axis('off')

In [None]:
file_name = "directionality index max proj no axes.png"
fig_mp.savefig(path / file_name, dpi=300)