In [None]:
%matplotlib widget

In [None]:
import numpy as np
from pathlib import Path
import flammkuchen as fl
import matplotlib.pyplot as plt
import cv2
from bg_atlasapi.bg_atlas import BrainGlobeAtlas
from bg_space import AnatomicalSpace
from quickdisplay import *

from lavian_et_al_2025.imaging.imaging_classes import TwoPExperiment
from lavian_et_al_2025.data_location import master_landmarks

In [None]:
master = master_landmarks / 'ipn 16715'
fish_list = list(master.glob("*_f*"))

In [None]:
#Define morphing space...
morphing_as = AnatomicalSpace('rai', resolution=(0.3,0.3,1))

In [None]:
# load tuning map of ref fish
ref_tuning_map = fl.load(fish_list[0] / "registration" / "ref_mapped.h5")
ref_res = (0.3,0.3,1)
ref_as = AnatomicalSpace('ipl', resolution=ref_res, shape=ref_tuning_map.shape[:3])

In [None]:
#... and transform references to morphing space
ref_mapped = np.zeros((413, 413, 100, 3))

In [None]:
num_planes, n_x, n_y = np.shape(ref_mapped)[:3]
num_fish = len(fish_list)
tuning_maps_all = np.zeros((num_fish, num_planes, n_x, n_y, 3))

In [None]:
count = 0
for fish in fish_list:
    print(fish)
    # load transformation matrix
    path = fish / "registration" 
    transform_mat = fl.load(path / "initial_transform_mapped.h5")
    
    # load tuning map
    mov_stack = fl.load(fish / "tuning_map.h5")

    exp_mov = TwoPExperiment(fish)
    mov_res = exp_mov.resolution

    # morph tuning map
    mov_as = AnatomicalSpace('ipl', resolution=mov_res, shape=mov_stack.shape)

    mov_mapped_tmp = mov_as.map_stack_to(morphing_as, mov_stack[:,:,:,0])

    mov_mapped = np.zeros((np.shape(mov_mapped_tmp)[0], np.shape(mov_mapped_tmp)[1], np.shape(mov_mapped_tmp)[2], 3))

    transformed = np.zeros(ref_mapped.shape)
        
    for i in range(3):
        mov_mapped[:,:,:,i] = mov_as.map_stack_to(morphing_as, mov_stack[:,:,:,i])

        tmp = map_affine(mov_mapped[:,:,:,i], transform_mat, ref_mapped.shape[:3])

        min_val = np.nanmin(tmp)
        tmp -= min_val

        max_val = np.nanmax(tmp)
        tmp /= max_val

        transformed[:,:,:,i] = tmp * 255

    tuning_maps_all[count] = transformed

    d = {'tuning_map_morphed': transformed}
    fl.save(fish / 'tuning_map_morphed.h5', d)

    count += 1
    

In [None]:
tuning_maps_all = tuning_maps_all.astype('int32')

In [None]:
tuning_map_avg = np.nanmean(tuning_maps_all, axis=0).astype('int32')
tuning_map_med = np.nanmedian(tuning_maps_all, axis=0).astype('int32')

In [None]:
n_col = 4
fig, ax = plt.subplots(2, n_col, figsize=(10,6))
planes = [4, 7, 10, 13, 16, 19, 22, 25]
n_planes = len(planes)
contrast = 1
brightness = 50

for i in range(n_planes):
    r = i // n_col
    c = np.mod(i, n_col)
    
    a = np.rot90(tuning_map_med[:,:,planes[i]], 2)
    b = cv2.addWeighted(a, contrast, a, 0, brightness)
    ax[r, c].imshow(b)
    ax[r, c].axis('off')
    
    ax[r,c].set_xlim(40, 390)
    ax[r,c].set_ylim(390, 40)

In [None]:
fig.savefig(master / "average morphed ipns n17 median.jpg", dpi=300)
fig.savefig(master / "average morphed ipns n17 median.pdf", dpi=300)