In [None]:
%matplotlib widget

In [None]:
import numpy as np
from pathlib import Path
import flammkuchen as fl
import napari
import matplotlib.pyplot as plt
import tifffile as tiff
import cv2
from bg_atlasapi.bg_atlas import BrainGlobeAtlas
from bg_space import AnatomicalSpace
from quickdisplay import *

from lavian_et_al_2025.imaging.imaging_classes import TwoPExperiment

In [None]:
master = Path(r"Z:\Hagar and Ot\E0040\v10\2p\s1186t")
fish_list = list(master.glob("*_f*"))
ref_fish = Path(r"Z:\Hagar and Ot\E0040\v10\2p\s1186t\240725_4_anatomy")
mov_fish = fish_list[0]

In [None]:
#Define morphing space...
morphing_as = AnatomicalSpace('rai', resolution=(0.3,0.3,1))

In [None]:
# load tuning map of ref fish
ref_tuning_map = fl.load(ref_fish / "registration" / "ref_mapped.h5")

#exp_ref = TwoPExperiment(ref_fish)
#ref_res = exp_ref.resolution
ref_res = (0.3,0.3,1)
ref_as = AnatomicalSpace('ipl', resolution=ref_res, shape=ref_tuning_map.shape[:3])

In [None]:
n_dir = 8

In [None]:
np.shape(ref_tuning_map[:,:,:])

In [None]:
ref_mapped = np.zeros((413, 413, 78, n_dir))

In [None]:
num_planes, n_x, n_y = np.shape(ref_mapped)[:3]
num_fish = len(fish_list)
tuning_maps_all = np.zeros((num_fish, num_planes, n_x, n_y, n_dir))

In [None]:
count = 0
for fish in fish_list:
    print(fish)
    
    path = fish / "registration" 


    # load tuning map
    transform_mat = fl.load(path / "initial_transform_mapped.h5")

    mov_stack_org = fl.load(fish / "plane_corrmap_corrvalues.h5")['plane_corr']
    mov_stack = np.transpose(mov_stack_org, (0,2,3,1))
    
    exp_mov = TwoPExperiment(fish)
    mov_res = exp_mov.resolution

    # morph tuning map
    mov_as = AnatomicalSpace('ipl', resolution=mov_res, shape=mov_stack.shape)

    mov_mapped_tmp = mov_as.map_stack_to(morphing_as, mov_stack[:,:,:,0])

    mov_mapped = np.zeros((np.shape(mov_mapped_tmp)[0], np.shape(mov_mapped_tmp)[1], np.shape(mov_mapped_tmp)[2], n_dir))

    transformed = np.zeros(ref_mapped.shape)
        
    for i in range(n_dir):
        mov_mapped[:,:,:,i] = mov_as.map_stack_to(morphing_as, mov_stack[:,:,:,i])

        tmp = map_affine(mov_mapped[:,:,:,i], transform_mat, ref_mapped.shape[:3])
        transformed[:,:,:,i] = tmp
        
    tuning_maps_all[count] = transformed

    d = {'corr_map_morphed': transformed}
    fl.save(fish / 'corr_map_morphed.h5', d)

    count += 1
    

In [None]:
np.shape(tuning_map_all)

In [None]:
num_fish, n_x, n_y, num_planes = np.shape(tuning_map_all)[:4]

In [None]:
np.shape(np.unique(tuning_maps_all[-1]))


In [None]:
fig, ax = plt.subplots(1,2)

In [None]:
ax[0].imshow(tuning_maps_all[-1,80:320,60:200,35,0])
ax[1].imshow(mov_stack[2,:,:,0])

In [None]:
np.shape(tuning_maps_all)

In [None]:
y_start = 80
y_end = 320
n_col = y_end - y_start
n_row = 200-60
n_fish = len(fish_list)

tuning_maps_cropped = tuning_maps_all[:,y_start:y_end,60:200,15:35,:]
tuning_maps_cropped_avg1 = np.nanmean(tuning_maps_cropped, axis=3)

In [None]:
corr_values = np.zeros((n_fish, n_dir, n_col))

for fish in range(n_fish):
    for i_dir in range(n_dir):
        corr_values[fish, i_dir] = np.nanmean(tuning_maps_cropped_avg1[fish,:,:,i_dir], axis=1)

In [None]:
fig1, ax1 = plt.subplots(1, 2, figsize=(10,4))
ax1[0].imshow(np.nanmean(corr_values,axis=0), cmap='coolwarm', aspect='auto', vmin=-0.2, vmax=0.2, interpolation='none')
ax1[0].set_title('Tg(s1168t:Gal4; UAS:GCaMP6s)')
ax1[0].set_xlabel('Left <-----------------------> Right')

arrow_symbols = ['→', '↘', '↓', '↙', '←', '↖', '↑', '↗']
# Set the y-ticks and labels
ax1[0].set_yticks(np.arange(8))
ax1[0].set_yticklabels(arrow_symbols)

In [None]:

corr_values_x = np.zeros((n_fish, n_dir, n_row))

for fish in range(n_fish):
    for i_dir in range(n_dir):
        corr_values_x[fish, i_dir] = np.nanmean(tuning_maps_cropped_avg1[fish,:,:,i_dir], axis=0)

In [None]:
ax1[1].imshow(np.nanmean(corr_values_x,axis=0).T, cmap='coolwarm', aspect='auto', vmin=-0.2, vmax=0.2, interpolation='none')
ax1[1].set_title('Tg(s1168t:Gal4; UAS:GCaMP6s)')
ax1[1].set_ylabel('Bottom <-----------------------> Top')

arrow_symbols = ['↓', '↙', '←', '↖', '↑', '↗', '→', '↘']
arrow_symbols = ['→', '↘', '↓', '↙', '←', '↖', '↑', '↗']


# Set the y-ticks and labels
ax1[1].set_xticks(np.arange(8))
ax1[1].set_xticklabels(arrow_symbols)
ax1[1].set_yticks(np.arange(0))
ax1[0].set_xticks(np.arange(0))

In [None]:
fig1.savefig(master / 's1168t corr matrix - both.pdf', dpi=300)
fig1.savefig(master / 's1168t corr matrix - both.png', dpi=300)