In [None]:
%matplotlib widget

In [None]:
from pathlib import Path
import numpy as np
import flammkuchen as fl
import matplotlib.pyplot as plt
import json
from lavian_et_al_2025.imaging.imaging_classes import TwoPExperiment

In [None]:
def exp_decay_kernel(tau, dt, len_rec):
    upsample = 10
    t = np.arange(len_rec * upsample) * dt / upsample
    
    decay = np.exp(-t / tau)
    decay /= np.sum(decay)
    return decay

In [None]:
def scale_square(square_size, radius, proj_height):
    circumf = 2*np.pi*radius
    scale_factor = circumf / proj_height
    scaled_size = square_size * scale_factor
    return scaled_size

In [None]:
# set the path to the dataset
master = Path(r"")
files = list(master.glob("*_f*"))

# choose a fish
fish = files[0]

# choose a plane 
path = fish / 'suite2p' / '0001'

In [None]:
# loading suite2p data:
suite2p_data = fl.load(path / 'data_from_suite2p_unfiltered.h5')
coords = suite2p_data["coords"]
anat = suite2p_data["anatomy_stack"]
traces = suite2p_data["traces"]

In [None]:
# normalizing traces:
traces = traces.T
traces = ((traces - traces.mean(0)) / traces.std(0))
traces = traces.T

In [None]:
# Getting stimulus information
metadata_file = list(path.glob("*_metadata.json"))[0]

with open(str(metadata_file), "r") as f:
     metadata = json.load(f)
        
stim = metadata["stimulus"]["log"]

In [None]:
# mapping stimulus position to the way it is saved in the metadata file
num_rows = 8
n_regs = num_rows
s_size = 1 / num_rows
s_size = 0.5 / num_rows
radius = 0.3
proj_height = [2, 1.75, 1.5, 1.25, 1.25, 1.5, 1.75, 2]
        
choices = []
for x_pos in range(num_rows):
    scaled_size = scale_square(s_size, radius, proj_height[x_pos])
    curr_choice = [0, x_pos / num_rows, 1, scaled_size]
    choices.append(curr_choice)

In [None]:
# getting tail information:
beh_file = list(path.glob("*_behavior*"))[0]
beh_log = fl.load(beh_file)['data']

In [None]:
# getting stimulus timing parameters 
fs = int(metadata['imaging']['microscope_config']['scanning']['framerate'])
pause_duration = stim[0]['duration'] * fs
stim_duration = stim[1]['duration'] * fs

In [None]:
# finding trial start and end times
n_options = 8
n_rep = [metadata['stimulus']['protocol']['receptive_fields']['v06_front_cols_flashes_1x8']['n_trials']][0]
n_trials = (n_options * n_rep) 
position_list = np.zeros((n_trials, 4))
for_regs = np.zeros((n_options, n_trials * 2 + 1))

len_rec = np.shape(traces)[1]
regs = np.zeros((n_options, len_rec))
t1 = pause_duration

for i in range(1, n_trials * 2, 2):
    curr_trial = stim[i]['clip_mask']
    position_list[(i//2) - 1, :] = curr_trial
    
    for j in range(n_options):
        if curr_trial == choices[j]:
            for_regs[j, i-1] = 1
            regs[j, t1:(t1 + stim_duration)] = 1
    
    t1 = t1 + stim_duration + pause_duration 

In [None]:
# Generating regressors 
from lotr.default_vals import REGRESSOR_TAU_S, TURN_BIAS

dt_imaging = 1 / fs
int_fact = 200
t_imaging = np.arange(traces.shape[1])/fs
num_traces, len_rec = np.shape(traces)

tau_fs = REGRESSOR_TAU_S * fs
kernel = np.exp(-np.arange(1000) / tau_fs)
t_imaging_int = np.arange(traces.shape[1]*int_fact)*dt_imaging/int_fact

regs_conv = np.zeros((n_options, len_rec))
regs_vals = np.zeros((n_options, num_traces))

for i in range(n_options):
    regs_conv[i] = np.convolve(regs[i], kernel)[:np.shape(traces)[1]]
    
    tmp_reg_vals = np.dot(traces, regs_conv[i]) - num_traces * np.mean(traces, 1) * np.mean(regs_conv[i])
    tmp_reg_vals /= (traces.shape[1] - 1) * np.std(traces, 1) * np.std(regs_conv[i])
    regs_vals[i] = tmp_reg_vals

In [None]:
# Displaying the different regressors, tail trace and neural traces
fig_regs, ax_regs = plt.subplots(2, 1, figsize=(8, 6))

for i in range(n_options):
    ax_regs[0].plot(regs_conv[i] + (i * 20))

ax_regs[0].plot(beh_log.t * fs, beh_log.tail_sum*10 - 80, c='k')

ax_regs[0].set_xlim(0, len_rec)    
ax_regs[0].axis('off')
ax_regs[1].axis('off')
ax_regs[1].imshow(traces, extent=[0, 1500, 0, 500], cmap='coolwarm', vmin=-2, vmax=3)

In [None]:
# Crop the habenula in order to only analyze habenula traces in following sections
plt.figure(figsize=(3, 3))
plt.imshow(anat.mean(0), vmax=1000, vmin=0, cmap='gray_r')
plt.scatter(coords[:, 2], coords[:, 1], c=(0.9,)*3, s=1)
s1 = 20
s2 = 330
plt.axhline(s1)
plt.axhline(s2)

s3 = 140
s4 = 240
plt.axvline(s3)
plt.axvline(s4)

In [None]:
# removing ROIs not from of the habenula
sel_to_nan = (coords[:, 1] < s1) | (coords[:, 1] > s2) | (coords[:, 2] < s3) | (coords[:, 2] > s4)
traces[sel_to_nan] = 0

regs_vals_new = np.copy(regs_vals)
regs_vals_new[:, sel_to_nan] = 0

new_coords = coords[sel_to_nan]
plt.scatter(new_coords[:, 2], new_coords[:, 1], c=(0.5,)*3, s=1)

In [None]:
num_row = 2
num_col = 4
fig_rf_reg1, ax_rf_reg1 = plt.subplots(num_row, num_col, figsize=(10, 5), sharex=True, sharey=True)

for i in range(n_options):
    r = i // num_col
    c = np.mod(i, num_col)
    ax_rf_reg1[r,c].invert_yaxis()
    
    try:
        ax_rf_reg1[r,c].scatter(coords[:, 1], coords[:, 2], c=regs_vals_new[i], cmap='coolwarm', s=2, vmin=-0.7, vmax=0.7)
        ax_rf_reg1[r,c].set_title(title_list[i], fontsize=15)
    except:
        print("no plane")
    ax_rf_reg1[r,c].axis('off')
    
plt.subplots_adjust(hspace=0.5)

In [None]:
fig_regs2, ax_regs2 = plt.subplots(1, 1, figsize=(8, 5))

for i in range(n_options):
    ax_regs2.plot(regs_conv[i] + (i * 20), c='royalblue')
    
    max_corr = np.nanmax(regs_vals_new[i])
    max_ind = np.argmax(regs_vals_new[i])
    
    ax_regs2.plot(traces[max_ind] + (i * 20), c='skyblue')

ax_regs2.set_xlim(0, len_rec)    
ax_regs2.axis('off')  

In [None]:
num_row = 8
fig_regs4, ax_regs4 = plt.subplots(num_row, 8, figsize=(8, 6), sharex=True, sharey=True)

for rf in range(num_row):
    
    ax_regs4[0, rf].set_title(title_list[rf], fontsize=10)
    max_ind = np.argmax(regs_vals_new[rf])
    trace = traces[max_ind]
    
    for i in range(n_options):
        trial_trace = np.zeros((n_rep, 50))

        reg_dif = np.diff(regs[i])
        t_start = np.where(reg_dif > 0)[0] - 10
        t_end = t_start + 50
        t_vec = (np.arange(50) - 10 )/ fs

        for trial in range(n_rep):
            try:
                trial_trace[trial] = trace[t_start[trial]:t_end[trial]]
            except:
                print("Error")

        trace_avg = np.nanmean(trial_trace, axis=0)
        trace_sem = np.nanstd(trial_trace, axis=0)/np.sqrt(n_rep)
        ax_regs4[rf, i].plot(t_vec,  trace_avg, c='royalblue')
        ax_regs4[rf, i].fill_between(t_vec, trace_avg - trace_sem, trace_avg + trace_sem, color='lightblue')

        if i is not 0 or rf < 7:
            ax_regs4[rf, i].axis('off')
        else:
            ax_regs4[rf, i].spines['right'].set_visible(False)
            ax_regs4[rf, i].spines['top'].set_visible(False)