In [None]:
%matplotlib widget

In [None]:
from pathlib import Path
import numpy as np
import flammkuchen as fl
import json

from matplotlib import  pyplot as plt
import ipywidgets as widgets

from bouter.utilities import reliability 
from skimage.filters import threshold_otsu
import xarray as xr

from lavian_et_al_2025.imaging.imaging_classes import LightsheetExperiment

In [None]:
n_blocks = 10
n_dirs = 8

In [None]:
master =  Path(r"\\portulab.synology.me\data\Hagar and Ot\E0040\v10\LS")
fish_list = list(master.glob("*_f*"))
path = fish_list[0]
path

In [None]:
suite2p_brain = fl.load(path / "data_from_suite2p_cells_brain.h5")
in_brain_idx = suite2p_brain['coords_idx']

In [None]:
traces = fl.load(path / "filtered_traces.h5", "/detr")[:, in_brain_idx]

suite2p_data = fl.load(path / "data_from_suite2p_cells.h5")
coords = suite2p_data['coords'][in_brain_idx]
anatomy = suite2p_data['anatomy_stack']

exp = LightsheetExperiment(path)
fs = int(exp.fn)
beh_df = exp.behavior_log


In [None]:
sens_regs = fl.load(path / 'sensory_regressors.h5', '/individual_theta_interp')
sens_regs = np.round(sens_regs)
np.unique(sens_regs)

In [None]:
######################### Part 1 - getting trial timing for each direction
trial_times = np.zeros((n_dirs, n_blocks))

for direction in range(n_dirs):
    tmp_reg = sens_regs[direction]
    trial_times[direction] = np.where(np.diff(tmp_reg) > 0)[0]


In [None]:
len_rec, num_neurons = np.shape(traces)
traces = traces.T

In [None]:
######################### Part 2 - chunking the traces into trials
new_len_rec = len_rec // (n_dirs * n_blocks)

trial_traces = np.zeros((n_dirs, num_neurons, n_blocks, new_len_rec))
print(np.shape(trial_traces))




In [None]:
count = np.zeros((8))
for i in range(n_blocks*n_dirs):
    t1 =  i * new_len_rec
    t2 = t1 + new_len_rec

    ### find direction type for current trial
    curr_direction = np.where(np.nanmean(sens_regs[:, t1:t2], axis=1) > 0.25)[0]
    trial_traces[curr_direction, :, int(count[curr_direction])] = traces[:, t1:t2]
    count[curr_direction] += 1

In [None]:
######################### Part 3 - looking for neurons that reliably respond to the visual stimulus
# selectnig reliable neruons 

reliability_arr_all_dirs = np.zeros((n_dirs, num_neurons))

for direction in range(n_dirs):
    
    tmp_trial_traces = trial_traces[direction]
    dt = 1 / fs
    traces_xr = xr.DataArray(
        data=tmp_trial_traces,                               #Adding the data
        dims=['roi', 'block', 't'],                #Defining name of the dimensions
        coords={                                   #Defining values at which each dimension wase valuated
            'roi':np.arange(tmp_trial_traces.shape[0]), 
            'block':np.arange(n_blocks),
            't':np.arange(tmp_trial_traces.shape[2])*dt
            }
        )
    reliability_arr = reliability(np.swapaxes(traces_xr, 0, 2).values)

    reliability_arr_all_dirs[direction] = reliability_arr

In [None]:
######################### Part 5 - chunking the traces into trials in a new way
new_len_rec = len_rec // (n_dirs * n_blocks)


In [None]:
trial_traces_new = np.zeros((num_neurons, n_blocks, new_len_rec*n_dirs))
count = np.zeros((8), dtype=int)
for i in range(n_blocks*n_dirs):
    t1 =  i * new_len_rec
    t2 = t1 + new_len_rec

    ### find direction type for current trial
    curr_direction = np.where(np.nanmean(sens_regs[:, t1:t2], axis=1) > 0.25)[0]
    
    t3 = int(curr_direction * new_len_rec)
    t4 = t3 + new_len_rec
    
    trial_traces_new[:, int(count[curr_direction]), t3:t4] = traces[:, t1:t2]
    count[curr_direction] += 1

In [None]:
######################### Part 6 - looking for neurons that reliably respond to the visual stimulus
# selectnig reliable neruons 

    
dt = 1 / fs
traces_xr = xr.DataArray(
    data=trial_traces_new,                               #Adding the data
    dims=['roi', 'block', 't'],                #Defining name of the dimensions
    coords={                                   #Defining values at which each dimension wase valuated
        'roi':np.arange(trial_traces_new.shape[0]), 
        'block':np.arange(n_blocks),
        't':np.arange(trial_traces_new.shape[2])*dt
        }
    )
reliability_arr_combined = reliability(np.swapaxes(traces_xr, 0, 2).values)

In [None]:
rel_thresh = threshold_otsu(reliability_arr)
print("Reliability threshold: ", rel_thresh)


In [None]:
d = {'reliability_index_per_direction': reliability_arr_all_dirs,
     'reliability_arr_combined': reliability_arr_combined,
    'trial_traces': trial_traces_new}

fl.save(path / 'reliability_index_arr.h5', d)