In [None]:
%matplotlib widget

In [None]:
import numpy as np
from pathlib import Path
import flammkuchen as fl
import matplotlib.pyplot as plt
import cv2
from bg_atlasapi.bg_atlas import BrainGlobeAtlas
from bg_space import AnatomicalSpace
from quickdisplay import *
from scipy import stats

from lavian_et_al_2025.imaging.imaging_classes import TwoPExperiment
from lavian_et_al_2025.data_location import master_motion

In [None]:
master = master_motion / "2p" / "habenula"
fish_list = list(master.glob("*_f*"))
mov_fish = fish_list[0]

In [None]:
#Define morphing space...
morphing_as = AnatomicalSpace('rai', resolution=(0.3,0.3,1))

In [None]:
# load tuning map of ref fish
ref_tuning_map = fl.load(mov_fish / "registration" / "ref_mapped.h5")
ref_res = (0.3,0.3,1)
ref_as = AnatomicalSpace('ipl', resolution=ref_res, shape=ref_tuning_map.shape[:3])

In [None]:
n_dir = 8 #number of directions of visual motion

In [None]:
ref_mapped = np.zeros((413, 413, 66, n_dir))

In [None]:
num_planes, n_x, n_y = np.shape(ref_mapped)[:3]
num_fish = len(fish_list)
tuning_maps_all = np.zeros((num_fish-1, num_planes, n_x, n_y, n_dir))

In [None]:
count = 0
for fish in fish_list[:-1]:
    print(fish)
    
    # Load transformation matrix
    path = fish / "registration" 
    transform_mat = fl.load(path / "initial_transform_mapped.h5")
    
    # load correlation map
    mov_stack_org = fl.load(fish / "plane_corrmap_corrvalues.h5")['plane_corr']
    mov_stack = np.transpose(mov_stack_org, (0,2,3,1))
    
    exp_mov = TwoPExperiment(fish)
    mov_res = exp_mov.resolution

    # morph correlation map
    mov_as = AnatomicalSpace('ipl', resolution=mov_res, shape=mov_stack.shape)

    mov_mapped_tmp = mov_as.map_stack_to(morphing_as, mov_stack[:,:,:,0])

    mov_mapped = np.zeros((np.shape(mov_mapped_tmp)[0], np.shape(mov_mapped_tmp)[1], np.shape(mov_mapped_tmp)[2], n_dir))

    transformed = np.zeros(ref_mapped.shape)
        
    for i in range(n_dir):
        mov_mapped[:,:,:,i] = mov_as.map_stack_to(morphing_as, mov_stack[:,:,:,i])

        tmp = map_affine(mov_mapped[:,:,:,i], transform_mat, ref_mapped.shape[:3])
        transformed[:,:,:,i] = tmp
        
    tuning_maps_all[count] = transformed

    d = {'corr_map_morphed': transformed}
    fl.save(fish / 'corr_map_morphed.h5', d)

    count += 1
    

In [None]:
num_fish, n_x, n_y, num_planes = np.shape(tuning_maps_all)[:4]

In [None]:
#Plot different planes to crop the IPN from the whole field of view
fig, ax = plt.subplots(1,2)

In [None]:
ax[0].imshow(tuning_maps_all[-1,100:340,130:300,5,0])
ax[1].imshow(mov_stack[2,:,:,0])

In [None]:
# Cropping the IPN
y_start = 100
y_end = 340
n_col = 300-130
n_row = 340-100
n_fish = len(fish_list)-1

tuning_maps_cropped = tuning_maps_all[:,y_start:y_end,130:300,10:25,:]
tuning_maps_cropped_avg1 = np.nanmean(tuning_maps_cropped, axis=3)

In [None]:
corr_values = np.zeros((n_fish, n_dir, n_col))

for fish in range(n_fish):
    for i_dir in range(n_dir):
        corr_values[fish, i_dir] = np.nanmean(tuning_maps_cropped_avg1[fish,:,:,i_dir], axis=0)

In [None]:
fig1, ax1 = plt.subplots(1, 2, figsize=(10,4))
ax1[1].imshow(np.nanmean(corr_values,axis=0).T, cmap='coolwarm', aspect='auto', vmin=-0.2, vmax=0.2, interpolation='none')
ax1[1].set_title('Tg(16715:Gal4; UAS:GCaMP6s)')
ax1[0].set_xlabel('Left <-----------------------> Right')

arrow_symbols = ['→', '↘', '↓', '↙', '←', '↖', '↑', '↗']
ax1[1].set_yticks(np.arange(8))
ax1[1].set_yticklabels(arrow_symbols)

In [None]:
corr_values_x = np.zeros((n_fish, n_dir, n_row))

for fish in range(n_fish):
    for i_dir in range(n_dir):
        corr_values_x[fish, i_dir] = np.nanmean(tuning_maps_cropped_avg1[fish,:,:,i_dir], axis=1)

In [None]:
ax1[0].imshow(np.nanmean(corr_values_x,axis=0), cmap='coolwarm', aspect='auto', vmin=-0.2, vmax=0.2, interpolation='none')
ax1[0].set_title('Tg(16715:Gal4; UAS:GCaMP6s)')
ax1[1].set_ylabel('Bottom <-----------------------> Top')

ax1[0].set_xticks(np.arange(8))
ax1[0].set_xticklabels(arrow_symbols)
ax1[1].set_yticks(np.arange(0))
ax1[0].set_xticks(np.arange(0))
ax1[0].invert_xaxis()

In [None]:
fig1.savefig(master / '16715 corr matrix - both.pdf', dpi=300)

In [None]:
n_samples, n_xpixels, n_ypixels, n_planes, n_positions = tuning_maps_cropped.shape

mean_correlations = np.zeros((n_positions, n_xpixels))
sem_correlations = np.zeros((n_positions, n_xpixels))

data_averaged_planes = np.mean(tuning_maps_cropped, axis=3)  # average over planes
data_averaged_y = np.mean(data_averaged_planes, axis=2)  # average over y-pixels

# Calculate mean and SEM for each position and x-pixel
for pos in range(n_positions):
    for x in range(n_xpixels):
        # Extract samples for this position and x-pixel
        samples_at_x = data_averaged_y[:, x, pos]

        # Calculate mean and SEM across samples only
        mean_correlations[pos, x] = np.nanmean(samples_at_x)
        sem_correlations[pos, x] = stats.sem(samples_at_x, nan_policy='omit')

x_positions = (np.arange(n_xpixels) - n_xpixels/2)*-1

fig, axes = plt.subplots(n_positions, 1, figsize=(2.5, 10), sharex=True, sharey=True)


y_label = ['→', '↘', '↓', '↙', '←', '↑', '↖', '↗']

for pos in range(n_positions):
    ax = axes[pos]

    line = ax.plot(x_positions, mean_correlations[pos, :], 
                  linewidth=2, alpha=0.8, label='Mean correlation')
    color = line[0].get_color()

    ax.fill_between(x_positions, 
                   mean_correlations[pos, :] - sem_correlations[pos, :],
                   mean_correlations[pos, :] + sem_correlations[pos, :],
                   alpha=0.3, color=color, label='±SEM')


    ax.set_ylabel(y_label[pos], fontsize=10)

axes[-1].set_xlabel('X Position (pixels)', fontsize=12)
plt.tight_layout()

In [None]:
fig.savefig(master / '16715 corr mean sem - both.pdf', dpi=300)