In [None]:
%matplotlib widget

In [4]:
from glob import glob
import numpy as np
import pandas as pd
import flammkuchen as fl
from split_dataset import SplitDataset
from bouter import Experiment
from scipy.interpolate import interp1d 
from scipy.signal import convolve2d
import colorspacious
import matplotlib.pyplot as plt
import json

from lavian_et_al_2025.imaging.imaging_classes import TwoPExperiment
from lavian_et_al_2025.imaging.imaging_functions import normalize_traces
from lavian_et_al_2025.landmarks.stimulus_functions import get_tuning_map_rois, color_stack
from lavian_et_al_2025.data_location import master_landmarks

from skimage.filters import threshold_otsu

from pathlib import Path

In [None]:
master = master_landmarks / '2p' / 'ipn h2b'
fish_list = list(master.glob("*_f*"))
fish = fish_list[0]

In [None]:
anatomy = fl.load(fish / 'registration' / 'ref_mapped.h5').T
fish_path = fish / 'suite2p'
paths = list(fish_path.glob("*00*"))

In [None]:
img_exp = TwoPExperiment(fish)
fs = img_exp['imaging']['microscope_config']['scanning']['framerate']
sampling = 1/fs
res = img_exp.resolution
z_res, x_res, y_res = img_exp.resolution

In [None]:
ind_count = 0
coords = fl.load(fish / 'registration' / 'ref_roi_coords_mapped.h5')

count=0
for path in paths:
    ### concatenate reliability index for all fish
    rel_index = fl.load(path / 'reliability_index_arr.h5')['reliability_arr_combined']

    ### concatenate all regression values for all cells for all fish
    traces = fl.load(path / "filtered_traces.h5", "/detr")

    # make a list of sensory regressors 
    try:
        sensory_regressors = fl.load(path / "sensory_regressors.h5", "/regressors_conv")
    except:
        sensory_regressors = fl.load(path / "sensory_regressors_cells.h5", "/regressors_conv")
    reg_list = [sensory_regressors]
    n_t = sensory_regressors.shape[0]

    # calculate tuning
    amp, angle = get_tuning_map_rois(traces, sensory_regressors.T)

    if count is 0:
        all_reliability = rel_index
        all_amp = amp
        all_angle = angle
    else:
        all_reliability = np.append(all_reliability, rel_index)
        all_amp = np.append(all_amp, amp)
        all_angle = np.append(all_angle, angle)
    count += 1


thresh = threshold_otsu(all_reliability)
print("Reliability threshold: ", thresh)
thresh = 0.15

colors = color_stack(all_amp, all_angle)
amp_thresh = np.copy(all_amp)
amp_thresh[np.where(all_amp < thresh)[0]] *= 0

colors_thresh = np.copy(colors)
colors_thresh[np.where(all_amp < thresh)[0]] *= 0
colors_thresh[np.where(all_amp < thresh)[0]] += 220

selected_vis = np.where(all_reliability > thresh)[0]
coords_vis = coords[selected_vis]
colors_vis = coords[selected_vis]
colors_thresh = colors_thresh[selected_vis]
amp_vis = all_amp[selected_vis]
mp_ind = np.argsort(amp_vis)

ind_count += 1    

In [None]:
fig2, axs2 = plt.subplots(1, 1, figsize=(3, 3))
axs2.spines['right'].set_visible(False)
axs2.spines['top'].set_visible(False)
axs2.imshow(np.rot90(np.nanmean(anatomy, axis=0), 3), cmap='gray_r')
axs2.scatter(coords_vis[mp_ind,1], coords_vis[mp_ind,0], c=colors_thresh[mp_ind]/255, s=10)
axs2.invert_yaxis()