In [None]:
%matplotlib widget

In [None]:
import numpy as np
from pathlib import Path
import flammkuchen as fl
import napari
import matplotlib.pyplot as plt

from bg_atlasapi.bg_atlas import BrainGlobeAtlas
from bg_space import AnatomicalSpace
from quickdisplay import *

from fimpylab.core.twop_experiment import TwoPExperiment

In [None]:
%gui qt5

In [None]:
import tifffile as tiff
import cv2

In [None]:
master = Path(r"Z:\Hagar and Ot\E0040\v10\2p\gad1b")
fish_list = list(master.glob("*_f*"))
ref_fish = fish_list[1]

#nice_fish = [2, 4, 6, 10, 13, 14, 15, 16, 18, 21, 22, 23, 24]

In [None]:
# load tuning map of ref fish
ref_tuning_map = fl.load(ref_fish / "tuning_map_90_fixed_2024.h5")

exp_ref = TwoPExperiment(ref_fish)
ref_res = exp_ref.resolution
ref_as = AnatomicalSpace('ipl', resolution=ref_res, shape=ref_tuning_map.shape[:3])

In [None]:
np.shape(ref_tuning_map[:,:,:,0])

In [None]:
#Define morphing space...
morphing_as = AnatomicalSpace('rai', resolution=(0.3,0.3,1))

#... and transform references to morphing space
ref_mapped = np.zeros((413, 413, 65, 3))
for i in range(3):
    ref_mapped[:,:,:,i] = ref_as.map_stack_to(morphing_as, ref_tuning_map[:,:,:,i])

In [None]:
num_planes, n_x, n_y = np.shape(ref_mapped)[:3]
num_fish = len(fish_list)

tuning_maps_all = np.zeros((num_fish, num_planes, n_x, n_y, 3))

In [None]:
count = 0
for fish in fish_list:
    print(fish)
    
    if fish is not ref_fish:
        path = fish / "registration" 
        
        
        # load tuning map
        transform_mat = fl.load(path / "initial_transform_mapped.h5")
        mov_stack = fl.load(fish / "tuning_map_90_fixed_2024.h5")
        
        exp_mov = TwoPExperiment(fish)
        mov_res = exp_mov.resolution
        
        # morph tuning map
        mov_as = AnatomicalSpace('ipl', resolution=mov_res, shape=mov_stack.shape)
        
        mov_mapped_tmp = mov_as.map_stack_to(morphing_as, mov_stack[:,:,:,0])
        
        mov_mapped = np.zeros((np.shape(mov_mapped_tmp)[0], np.shape(mov_mapped_tmp)[1], np.shape(mov_mapped_tmp)[2], 3))
        
        transformed = np.zeros(ref_mapped.shape)
        
        for i in range(3):
            mov_mapped[:,:,:,i] = mov_as.map_stack_to(morphing_as, mov_stack[:,:,:,i])
            
            tmp = map_affine(mov_mapped[:,:,:,i], transform_mat, ref_mapped.shape[:3])
            
            min_val = np.nanmin(tmp)
            tmp -= min_val
            
            max_val = np.nanmax(tmp)
            tmp /= max_val
            
            transformed[:,:,:,i] = tmp * 255
        
        tuning_maps_all[count] = transformed
        
        d = {'tuning_map_morphed': transformed}
        fl.save(fish / 'tuning_map_morphed.h5', d)
    else:
        print("loading ref fish")
        tuning_maps_all[count] = ref_mapped
    count += 1
    

In [None]:
ref_mapped.shape

In [None]:
viewer = napari.view_image(tuning_maps_all[5,:,:,:,0], colormap='Reds')
viewer.add_image(tuning_maps_all[5,:,:,:,1], colormap='Greens', blending='additive')
viewer.add_image(tuning_maps_all[5,:,:,:,2], colormap='Blues', blending='additive')

In [None]:
tuning_maps_all = fl.load(master / 'tuning_map_morphed_all.h5')['tuning_map_all']

In [None]:
tuning_map_med = np.nanmedian(tuning_maps_all, axis=0).astype('int32')

In [None]:
tuning_maps_all = tuning_maps_all.astype('int32')

In [None]:
planes = [10, 15, 20, 25, 30, 35, 40, 45]
n_planes = len(planes)
n_col = 4
fig, ax = plt.subplots(2, n_col, figsize=(12,6))

contrast = 1
brightness = 22

for i in range(n_planes):
    r = i // n_col
    c = np.mod(i, n_col)
    
    a = np.rot90(tuning_map_med[:,:,planes[i]], 3)
    b = cv2.addWeighted(a, contrast, a, 0, brightness)
    ax[r, c].imshow(b)
    ax[r, c].axis('off')
    
    ax[r,c].set_xlim(50, 370)
    ax[r,c].set_ylim(400, 80)

In [None]:
fig.savefig(master / "median morphed ipns gad1b n=13 brightness 25.jpg", dpi=300)
fig.savefig(master / "median morphed ipns gad1b n=13 brightness 25.pdf", dpi=300)

In [None]:
tuning_map_avg = np.nanmean(tuning_maps_all, axis=0).astype('int32')

In [None]:
viewer2 = napari.view_image(tuning_map_avg[:,:,:])

In [None]:
viewer2.add_image(tuning_map_avg[:,:,:,1], colormap='Greens', blending='additive')
viewer2.add_image(tuning_map_avg[:,:,:,2], colormap='Blues', blending='additive')

In [None]:
np.shape(tuning_map_avg)

In [None]:
d = {'tuning_map_avg': tuning_map_avg,
     'tuning_map_all': tuning_maps_all
}
fl.save(master / 'tuning_map_morphed_all.h5', d)