In [None]:
%matplotlib widget

In [None]:
from glob import glob
import numpy as np
import pandas as pd
import flammkuchen as fl
from split_dataset import SplitDataset
from bouter import Experiment
from motions.utilities import stim_vel_dir_dataframe, quantize_directions
from scipy.interpolate import interp1d 
from scipy.signal import convolve2d
import colorspacious
import matplotlib.pyplot as plt

from fimpylab.core.lightsheet_experiment import LightsheetExperiment

import json
from pathlib import Path

In [None]:
# make sensory regressors. requires old bouter stimulus_param_log.
def make_sensory_regressors(exp, n_dirs=8, upsampling=5, sampling=1/3):
    stim = stim_vel_dir_dataframe(exp)
    bin_centres, dir_bins = quantize_directions(stim.theta)
    ind_regs = np.zeros((n_dirs, len(stim)))
    for i_dir in range(n_dirs):
        ind_regs[i_dir, :] = (np.abs(dir_bins - i_dir) < 0.1) & (stim.vel > 0.1)  

    dt_upsampled = sampling / upsampling
    t_imaging_up = np.arange(0, stim.t.values[-1], dt_upsampled)
    reg_up = interp1d(stim.t.values, ind_regs, axis=1, fill_value="extrapolate")(
        t_imaging_up
    )
    
    # 6s kernel
    u_steps = t_imaging_up.shape[0]
    u_time = np.arange(u_steps) * dt_upsampled
    decay = np.exp(-u_time / (1.5 / np.log(2)))
    kernel = decay / np.sum(decay)
    
    convolved = convolve2d(reg_up, kernel[None, :])[:, 0:u_steps]
    reg_sensory = convolved[:, ::upsampling]

    return pd.DataFrame(reg_sensory.T, columns=[f"motion_{i}" for i in range(n_dirs)])

In [None]:
# calculate directional tuning from zscored traces for each roi
def get_tuning_map(traces, sens_regs, n_dirs=8):

    n_t = sens_regs.shape[0]
    reg = sens_regs.values.T @ traces[:n_t, :]
#     print(np.shape(reg))
    #reg = reg.reshape(reg.shape[0], traces.shape[-1], traces.shape[-1])
    
    # tuning vector
    bin_centers, bins = quantize_directions([0], n_dirs)
    vectors = np.stack([np.cos(bin_centers), np.sin(bin_centers)], 0)
#     print(np.shape(vectors))
    reg_vectors = vectors @ reg
#     print(np.shape(reg_vectors))
    
    angle = np.arctan2(reg_vectors[1], reg_vectors[0])
    amp = np.sqrt(np.sum(reg_vectors ** 2, 0))

    return amp, angle

In [None]:
# make a color map
def JCh_to_RGB255(x):
    output = np.clip(colorspacious.cspace_convert(x, "JCh", "sRGB1"), 0, 1)
    return (output * 255).astype(np.uint8)

def color_stack(
        amp,
        angle,
        hueshift=2.5,
        amp_percentile=80,
        maxsat=50,
        lightness_min=100,
        lightness_delta=-40,
    ):
    output_lch = np.zeros((amp.shape[0], 3))
#     print(np.shape(output_lch))
    output_lch[:,0]
    maxamp = np.percentile(amp, amp_percentile)

    output_lch[:, 0] = (
            lightness_min + (np.clip(amp / maxamp, 0, 1)) * lightness_delta
    )
    output_lch[:, 1] = (np.clip(amp / maxamp, 0, 1)) * maxsat
    output_lch[:, 2] = (-angle + hueshift) * 180 / np.pi

    return JCh_to_RGB255(output_lch)

In [None]:
#master = Path(r"Z:\Hagar and Ot\E0040\v10\LS ablation\ntr")
master = Path(r'\\portulab.synology.me\data\Hagar and Ot\E0040\v10\LS')

fish_list = list(master.glob("*_f*"))
path = fish_list[0]
print(path)


In [None]:
exp = glob(str(path / "*behavior*"))[0]
traces = fl.load(path / "filtered_traces.h5", "/detr")
#traces = fl.load(path / "data_from_suite2p_cells.h5", "/traces")
coords = fl.load(path / "data_from_suite2p_cells.h5", "/coords")

In [None]:
suite2p_brain = fl.load(path / "data_from_suite2p_cells_brain.h5")
in_brain_idx = suite2p_brain['coords_idx'] #np.arange(np.shape(coords)[0])

In [None]:
img_exp = LightsheetExperiment(path)
fs = img_exp.fn

sampling = 1/fs
time = np.linspace(0, traces.shape[0]*sampling, traces.shape[0])
np.max(traces)

len_rec, num_cells = np.shape(traces)

In [None]:
# make a list of sensory regressors 
reg = make_sensory_regressors(Experiment(path), sampling=sampling)
reg_list = [reg]

In [None]:
n_t = reg.shape[0]
regi = reg.values.T @ traces[:, :]
np.min(regi)

In [None]:
# calculate tuning
amp, angle = get_tuning_map(traces, reg)

df = pd.DataFrame(list(zip(amp, angle)), columns=["amp", "angle"])

In [None]:
nan_filt = ~np.isnan(amp)

In [None]:
colors = color_stack(amp[nan_filt], angle[nan_filt])
print(np.unique(colors))

In [None]:
coords_ib = coords[in_brain_idx]
colors_ib = colors[in_brain_idx]
amp_ib = amp[in_brain_idx]
amp_norm = amp_ib / np.nanmax(amp_ib)
amp_norm.shape

In [None]:
with open(next(Path(path).glob("*metadata.json")), "r") as f:
        metadata = json.load(f)
lsconfig = metadata["imaging"]["microscope_config"]['lightsheet']['scanning']
z_tot_span = lsconfig["z"]["piezo_max"] - lsconfig["z"]["piezo_min"]
n_planes = lsconfig["triggering"]["n_planes"]
z_res = z_tot_span / n_planes

In [None]:
thresh = 700

amp_thresh = np.copy(amp_ib)
amp_thresh[np.where(amp_ib < thresh)[0]] *= 0

colors_thresh = np.copy(colors_ib)
colors_thresh[np.where(amp_ib < thresh)[0]] *= 0
colors_thresh[np.where(amp_ib < thresh)[0]] += 220

In [None]:
selected_vis = np.where(amp_ib > thresh)[0]
coords_vis = coords_ib[selected_vis]
colors_vis = colors_ib[selected_vis]
colors_thresh = colors_thresh[selected_vis]
amp_vis = amp_ib[selected_vis]

In [None]:
fig3, axs3 = plt.subplots(2, 4, figsize=(12, 6), gridspec_kw={'width_ratios': [5, 2, 5, 2], 'height_ratios': [1, 3]})
mp_ind = np.argsort(amp_vis)

axs3[1,0].scatter(coords_ib[:,2]*0.6, coords_ib[:,1]*.6, c='linen', s=2, alpha=0.8)
axs3[1,0].scatter(coords_vis[mp_ind,2]*0.6, coords_vis[mp_ind,1]*.6, c=colors_thresh[mp_ind]/255, s=2, alpha=0.8)

axs3[1,1].scatter(coords_ib[:,0]*z_res, coords_ib[:,1]*0.6, c='linen', s=2, alpha=0.8)
axs3[1,1].scatter(coords_vis[mp_ind,0]*z_res, coords_vis[mp_ind,1]*0.6, c=colors_thresh[mp_ind]/255, s=2, alpha=0.8)

axs3[0,0].scatter(coords_ib[:,2]*0.6, coords_ib[:,0]*z_res, c='linen', s=2, alpha=0.8)
axs3[0,0].scatter(coords_vis[mp_ind,2]*0.6, coords_vis[mp_ind,0]*z_res, c=colors_thresh[mp_ind]/255, s=2, alpha=0.8)

axs3[0,0].spines['right'].set_visible(False)
axs3[0,0].spines['top'].set_visible(False)

axs3[1,1].spines['right'].set_visible(False)
axs3[1,1].spines['top'].set_visible(False)

axs3[1,0].spines['right'].set_visible(False)
axs3[1,0].spines['top'].set_visible(False)

axs3[0,1].axis('off')
axs3[0,3].axis('off')

axs3[1,2].scatter(coords_ib[:,2]*0.6, coords_ib[:,1]*.6, c=colors_ib[:]/255, s=2, alpha=0.8)
axs3[1,3].scatter(coords_ib[:,0]*z_res, coords_ib[:,1]*0.6, c=colors_ib[:]/255, s=2, alpha=0.8)
axs3[0,2].scatter(coords_ib[:,2]*0.6, coords_ib[:,0]*z_res, c=colors_ib[:]/255, s=2, alpha=0.8)

axs3[0,2].spines['right'].set_visible(False)
axs3[0,2].spines['top'].set_visible(False)

axs3[1,3].spines['right'].set_visible(False)
axs3[1,3].spines['top'].set_visible(False)

axs3[1,2].spines['right'].set_visible(False)

axs3[1,2].spines['top'].set_visible(False)

In [None]:
file_name = "tuning_thresh" + str(thresh) + "_240716.jpg"
fig3.savefig(path / file_name, dpi=900)

file_name = "tuning_thresh" + str(thresh) + "_240716.pdf"
fig3.savefig(path / file_name, dpi=900)

In [None]:
##### only reliable traces 

In [None]:
reliable_arr = fl.load(path / "reliable_rois.h5", "/reliability_arr")
rel_thresh = 0.5
selected_vis = np.where(reliable_arr > rel_thresh)[0]

In [None]:
coords_vis = coords_ib[selected_vis]
colors_vis = colors_ib[selected_vis]
amp_vis = amp_ib[selected_vis]

In [None]:
fig2, axs2 = plt.subplots(2, 2, figsize=(4, 4), gridspec_kw={'width_ratios': [8, 2], 'height_ratios': [1, 4]})
mp_ind = np.argsort(amp_vis)
axs2[1,0].scatter(coords_ib[:,2]*0.6, coords_ib[:,1]*.6, c='linen', s=2, alpha=0.8)
axs2[1,0].scatter(coords_vis[mp_ind,2]*0.6, coords_vis[mp_ind,1]*.6, c=colors_vis[mp_ind]/255, s=2, alpha=0.8)

axs2[1,1].scatter(coords_ib[:,0]*10, coords_ib[:,1]*0.6, c='linen', s=2, alpha=0.8)
axs2[1,1].scatter(coords_vis[mp_ind,0]*10, coords_vis[mp_ind,1]*0.6, s=2, c=colors_vis[mp_ind]/255, alpha=0.8)

axs2[0,0].scatter(coords_ib[:,2]*0.6, coords_ib[:,0]*10, c='linen', s=2, alpha=0.8)
axs2[0,0].scatter(coords_vis[mp_ind,2]*0.6, coords_vis[mp_ind,0]*10, s=2, c=colors_vis[mp_ind]/255, alpha=0.8)

axs2[0,0].spines['right'].set_visible(False)
axs2[0,0].spines['top'].set_visible(False)
axs2[0,0].invert_xaxis()

axs2[1,1].spines['right'].set_visible(False)
axs2[1,1].spines['top'].set_visible(False)

axs2[1,0].spines['right'].set_visible(False)
axs2[1,0].spines['top'].set_visible(False)
axs2[1,0].invert_xaxis()

axs2[0,0].axis('off')
axs2[0,1].axis('off')
axs2[1,1].axis('off')
axs2[0,1].axis('off')
axs2[1,0].axis('off')
axs2[1,0].set_xlim(500, 0)
axs2[0,0].set_xlim(500, 0)

axs2[1,0].set_ylim(0, 550)
axs2[1,1].set_ylim(0, 550)

In [None]:
plt.subplots_adjust(wspace=0.1, hspace=0.1)

In [None]:
file_name = "tuning_reliable_b" + str(rel_thresh) + " 2_240419.pdf"
fig2.savefig(path / file_name, dpi=900)
file_name = "tuning_reliable_b" + str(rel_thresh) + " 2_240419.jpg"
fig2.savefig(path / file_name, dpi=900)


# Morphed datasets

In [None]:
master = Path(r'\\portulab.synology.me\data\Hagar and Ot\E0040\v10\LS')
fish_list = list(master.glob("*_f*"))
fish_list

In [None]:
try:
    tuning_arrs = fl.load(master / 'tuning_arrs.h5')
    amp_pooled, angle_pooled = tuning_arrs['amp_pooled'], tuning_arrs['angle_pooled']

except OSError:
    amp_pooled = {}
    angle_pooled = {}

    for fish in fish_list:
        print(fish)

        #Load traces and stimulus metadata
        exp = glob(str(fish / "*behavior*"))[0]
        traces = fl.load(fish / "filtered_traces.h5", "/detr")

        img_exp = LightsheetExperiment(fish)
        fs = img_exp.fn
        sampling = 1/fs
        time = np.linspace(0, traces.shape[0]*sampling, traces.shape[0])

        #Make list of sensory regressors 
        reg = make_sensory_regressors(Experiment(fish), sampling=sampling)
        reg_list = [reg]

        #Compute tuning
        amp_pooled[fish.name], angle_pooled[fish.name] = get_tuning_map(traces, reg)
        
    tuning_arrs = {'angle_pooled':angle_pooled, 'amp_pooled':amp_pooled}
    fl.save(master / 'tuning_arrs.h5', tuning_arrs)

    print('Done')

In [None]:
morphed_coords = {}
# in_brain_idx = {}
in_brain_arr = {}

for fish in fish_list:
    print(fish)
    
    #Load morphed coords
    morphed_coords[fish.name] = fl.load(fish / 'registration' / 'to_h2b_baier_ref' / 'antspy' / 'mov_coords_transformed.h5')
    
    #I guess this is an index to keep track of ROIs inside brain? Will make it into a boolean because makes more sense
    suite2p_brain = fl.load(fish / "data_from_suite2p_cells_brain.h5")
#     in_brain_idx[fish.name] = suite2p_brain['coords_idx'] 
    
    in_brain_arr[fish.name] = np.full(morphed_coords[fish.name].shape[0], False)
    in_brain_arr[fish.name][suite2p_brain['coords_idx']] = True
    
print('Done')

In [None]:
#Pool amp and angle arrays
amp_pooled_arr = np.concatenate([amp_pooled[fish.name] for fish in fish_list])
angle_pooled_arr = np.concatenate([angle_pooled[fish.name] for fish in fish_list])

#also boolean array to keep track of ROIs in the brain
in_brain_arr_pooled = np.concatenate([in_brain_arr[fish.name] for fish in fish_list])

#and pooled coordinates
coords_pooled = np.concatenate([morphed_coords[fish.name] for fish in fish_list], 0)

In [None]:
#Make filter to exclude NaNs from amp array (we assume NaNs co-occur in amp and angle array, it seems to be the case)
nan_filt = np.isnan(amp_pooled_arr)
print(nan_filt.sum(), ' ROIs excluded')

#Combine into a filtering array
valid_rois = np.logical_and(in_brain_arr_pooled, ~nan_filt)

In [None]:
#Color stack
colors_ib = color_stack(amp_pooled_arr[valid_rois], angle_pooled_arr[valid_rois])
# colors = color_stack(amp_pooled_arr, angle_pooled_arr)
colors_ib.shape

In [None]:
#Filter ROIs and normalize vector amplitude
coords_ib = coords_pooled[valid_rois]
# colors_ib = colors[valid_rois]
amp_ib = amp_pooled_arr[valid_rois]
amp_norm = amp_ib / np.nanmax(amp_ib)

In [None]:
# #Filter ROIs and normalize vector amplitude
# coords_ib = coords_pooled[in_brain_arr_pooled]
# colors_ib = colors[in_brain_arr_pooled]
# amp_ib = amp_pooled_arr[in_brain_arr_pooled]
# amp_norm = amp_ib / np.nanmax(amp_ib)

In [None]:
thresh = 700

amp_thresh = np.copy(amp_ib)
amp_thresh[np.where(amp_ib < thresh)[0]] *= 0

colors_thresh = np.copy(colors_ib)
colors_thresh[np.where(amp_ib < thresh)[0]] *= 0
colors_thresh[np.where(amp_ib < thresh)[0]] += 220

In [None]:
selected_vis = np.where(amp_ib > thresh)[0]
coords_vis = coords_ib[selected_vis]
colors_vis = colors_ib[selected_vis]
colors_thresh = colors_thresh[selected_vis]
amp_vis = amp_ib[selected_vis]

In [None]:
plt.figure()

mp_ind = np.argsort(amp_vis)

plt.scatter(coords_ib[:,0]*0.6, coords_ib[:,1]*.6, c='linen', s=.1, alpha=0.8)
plt.scatter(coords_vis[mp_ind,0]*0.6, coords_vis[mp_ind,1]*.6, c=colors_thresh[mp_ind]/255, s=.1, alpha=0.8)

In [None]:
fig3, axs3 = plt.subplots(2, 4, figsize=(12, 6), gridspec_kw={'width_ratios': [5, 2, 5, 2], 'height_ratios': [1, 3]})
mp_ind = np.argsort(amp_vis)

axs3[1,0].scatter(coords_ib[:,0]*0.6, coords_ib[:,1]*.6, c='linen', s=.1, alpha=0.8)
axs3[1,0].scatter(coords_vis[mp_ind,0]*0.6, coords_vis[mp_ind,1]*.6, c=colors_thresh[mp_ind]/255, s=.1, alpha=0.8)

axs3[1,1].scatter(coords_ib[:,0]*z_res, coords_ib[:,1]*0.6, c='linen', s=2, alpha=0.8)
axs3[1,1].scatter(coords_vis[mp_ind,0]*z_res, coords_vis[mp_ind,1]*0.6, c=colors_thresh[mp_ind]/255, s=2, alpha=0.8)

axs3[0,0].scatter(coords_ib[:,2]*0.6, coords_ib[:,0]*z_res, c='linen', s=2, alpha=0.8)
axs3[0,0].scatter(coords_vis[mp_ind,2]*0.6, coords_vis[mp_ind,0]*z_res, c=colors_thresh[mp_ind]/255, s=2, alpha=0.8)

# axs3[0,0].spines['right'].set_visible(False)
# axs3[0,0].spines['top'].set_visible(False)

# axs3[1,1].spines['right'].set_visible(False)
# axs3[1,1].spines['top'].set_visible(False)

# axs3[1,0].spines['right'].set_visible(False)
# axs3[1,0].spines['top'].set_visible(False)

# axs3[0,1].axis('off')
# axs3[0,3].axis('off')



In [None]:
plt.figure()

for fish in fish_list:
    plt.scatter(morphed_coords[fish.name][:, 0], morphed_coords[fish.name][:, 1], s=.25)