In [None]:
%matplotlib widget

In [None]:
from glob import glob
import numpy as np
import pandas as pd
import flammkuchen as fl
from split_dataset import SplitDataset
from scipy import stats
import colorspacious
import matplotlib.pyplot as plt

from pathlib import Path

from vision_and_navigation.imaging.general import corr2_coeff

# Set fish path

In [None]:
treatment = 'ntr' #'control' #'ntr'

master = Path(r'\\portulab.synology.me\data\Hagar and Ot\E0040\v10\LS ablation\{}'.format(treatment))
fish_list = list(master.glob("*_f*_{}*".format(treatment)))

# Load morphed coordinates

In [None]:
try:
    coords_pooled = fl.load(master / 'coords_pooled_{}.h5'.format(treatment))
    in_brain_arr_pooled = fl.load(master / 'in_brain_arr_pooled_{}.h5'.format(treatment))
    
except OSError:
    coords_pooled = {}
    in_brain_arr_pooled = {}

    for session in ['pre', 'post']:
        
        session_list = list(master.glob("*{}".format(session)))
        morphed_coords = {}
        in_brain_arr = {}
        
        for fish in session_list:
            print(fish)

            #Load morphed coords
            morphed_coords[fish.name] = fl.load(fish / 'registration' / 'to_h2b_baier_ref' / 'antspy' / 'mov_coords_transformed.h5')

            #I guess this is an index to keep track of ROIs inside brain? Will make it into a boolean because makes more sense
            suite2p_brain = fl.load(fish / "data_from_suite2p_cells_brain.h5")

            in_brain_arr[fish.name] = np.full(morphed_coords[fish.name].shape[0], False)
            in_brain_arr[fish.name][suite2p_brain['coords_idx']] = True

        #And pool in a single array
        coords_pooled[session] = np.concatenate([morphed_coords[fish.name] for fish in session_list], 0)
        in_brain_arr_pooled[session] = np.concatenate([in_brain_arr[fish.name] for fish in session_list])

    fl.save(master / 'coords_pooled_{}.h5'.format(treatment), coords_pooled)
    fl.save(master / 'in_brain_arr_pooled_{}.h5'.format(treatment), in_brain_arr_pooled)

In [None]:
session_list = list(master.glob("*pre"))

for fish in session_list:
    print(fish)
    suite2p_brain = fl.load(fish / "data_from_suite2p_cells_brain.h5")
    print(len(suite2p_brain['coords_idx']))
    print('')

# Load tuning data

In [None]:
amp_pooled = {}
angle_pooled = {}

for session in ['pre', 'post']:
    tuning_arrs = fl.load(master / 'tuning_arrs_{}_{}.h5'.format(treatment, session))
    amp_pooled[session], angle_pooled[session] = tuning_arrs['amp_pooled'], tuning_arrs['angle_pooled']
    
    for dicti in [amp_pooled, angle_pooled]:
        dicti[session] = np.concatenate([dicti[session][k] for k in list(dicti[session].keys())])

In [None]:
rel_arr_pooled = {}

for session in ['pre', 'post']:

    session_list = list(master.glob("*{}".format(session)))
    rel_arr_pooled[session] = np.concatenate([fl.load(fish / "reliability_index_arr.h5", "/reliability_arr_combined") for fish in session_list])

# Load correlation data

In [None]:
n_dirs = 8

In [None]:
for session in ['pre', 'post']:
    try:
        corrcoefs_all = fl.load(master / 'reg_corrcoefs_pooled_{}_{}_all.h5'.format(treatment, session))

    except OSError:
        session_list = list(master.glob("*{}".format(session)))

        corrcoefs_all = {direction:[] for direction in range(n_dirs)}
        
        for path in session_list:
            traces = fl.load(path / "filtered_traces.h5", "/detr")
            sensory_regressors = fl.load(path / "sensory_regressors.h5", "/regressors")
            
            for direction in range(n_dirs):
                current_dir = np.asarray(sensory_regressors.iloc[:, direction])        
                corrcoefs_all[direction].append(corr2_coeff(traces.T, current_dir.reshape(1, -1)).ravel())
                
        for direction in range(n_dirs):
            corrcoefs_all[direction] = np.concatenate(corrcoefs_all[direction])

        fl.save(master / 'reg_corrcoefs_pooled_{}_{}_all.h5'.format(treatment, session), corrcoefs_all)

In [None]:
session_corrcoefs = {}

for session in ['pre', 'post']:
    session_corrcoefs[session] = fl.load(master / 'reg_corrcoefs_pooled_{}_{}_all.h5'.format(treatment, session))
    
    full_mat = np.stack([session_corrcoefs[session][direction] for direction in range(n_dirs)])
    session_corrcoefs[session] = np.array([full_mat[i, j] for i,j in zip(np.abs(full_mat).argmax(0), np.arange(full_mat.shape[1]))])

## Voxelization

In [None]:
from numba import jit
import numba
import napari

In [None]:
#Import reference anatomy [from first fish]
ref_anatomy = fl.load(fish_list[0] / 'registration' / 'to_h2b_baier_ref' / 'antspy' / 'ref_mapped.h5')

In [None]:
#Safety check
viewer = napari.Viewer()
viewer.add_image(ref_anatomy)
viewer.add_points(coords_pooled['pre'][in_brain_arr_pooled['pre']], face_color='red')
viewer.add_points(coords_pooled['post'][in_brain_arr_pooled['post']], face_color='blue')

In [None]:
print(ref_anatomy.shape)
print(coords_pooled['pre'].max(0))
print(coords_pooled['post'].max(0))

In [None]:
#Define voxel size and define shape of each new anatomical axis
vox_size = 5
xvx, yvx, zvx = [np.arange(0, ref_anatomy.shape[i], vox_size) for i in range(3)]

In [None]:
@jit(nopython=True)
def get_voxel_centroids(xvx, yvx, zvx, vox_size):
    vx_centroids = np.zeros((xvx.shape[0], yvx.shape[0], zvx.shape[0], 3))

    for ix, x in enumerate(xvx):
        for iy, y in enumerate(yvx):
            for iz, z in enumerate(zvx):
                vx_centroids[ix, iy, iz, :] = np.array((x+(vox_size/2),  y+(vox_size/2),  z+(vox_size/2)))
    return(vx_centroids)

In [None]:
@jit(nopython=True)
def assign_to_voxels(coords, vx_centroids, vox_size):
    coord_vox = np.full_like(coords, np.nan, dtype=numba.int32)

    for roi in range(coords.shape[0]):
    
        a = np.nonzero(np.sum(np.abs(vx_centroids - coords[roi, :]) < (vox_size/2), axis=-1) == 3)
        
        for i, coord in enumerate(a):
            coord_vox[roi, i] = coord[0]
        
    return(coord_vox)

In [None]:
try:
    voxeled_rois = fl.load(master / 'voxeled_rois_{}_{}voxsize.h5'.format(treatment, vox_size))

except OSError:
     
    voxeled_rois = {}
    
    for session in ['pre', 'post']:
        vx_centroids = get_voxel_centroids(xvx, yvx, zvx, vox_size)
        voxeled_rois[session] = assign_to_voxels(coords_pooled[session], vx_centroids, vox_size)
        
    fl.save(master / 'voxeled_rois_{}_{}voxsize.h5'.format(treatment, vox_size), voxeled_rois)

In [None]:
occ_map = {}

for session in ['pre', 'post']:
    occ_map[session] = np.full((xvx.shape[0], yvx.shape[0], zvx.shape[0]), np.nan)
    coords = voxeled_rois[session][in_brain_arr_pooled[session]]

    unique_coords, counts = np.unique(coords, axis=0, return_counts=True)
    for coord, count in zip(unique_coords, counts):
        occ_map[session][tuple(coord)] = count

In [None]:
viewer = napari.Viewer()
viewer.add_image(occ_map['pre'], colormap='Reds')
viewer.add_image(occ_map['post'], colormap='Blues')

# Alright let's make plots

In [None]:
# @jit(nopython=True)
# def make_map_from_values(vxv, yvx, zvx, coords, metric_arr):
#     map_arr = np.full((xvx.shape[0], yvx.shape[0], zvx.shape[0]), np.nan)
#     unique_coords = np.unique(coords, axis=0)
    
#     for coord in unique_coords:
#         vox_rois = np.argwhere(((coords == coord).all(axis=1)))
#         rel_map[tuple(coord)] = np.mean(metric_arr[vox_rois])

In [None]:
#Reliability map
try:
    rel_map = fl.load(master / 'rel_map_{}_{}voxsize.h5'.format(treatment, vox_size))
    
except OSError:    
    rel_map = {}

    for session in ['pre', 'post']:
        rel_map[session] = np.full((xvx.shape[0], yvx.shape[0], zvx.shape[0]), np.nan)
        coords = voxeled_rois[session][in_brain_arr_pooled[session]]

        unique_coords = np.unique(coords, axis=0)

        for coord in unique_coords:
            vox_rois = np.argwhere(((coords == coord).all(axis=1)))
            rel_map[session][tuple(coord)] = np.mean(rel_arr_pooled[session][vox_rois])

    fl.save(master / 'rel_map_{}_{}voxsize.h5'.format(treatment, vox_size), rel_map)

In [None]:
#Amplitude map
try:
    amp_map = fl.load(master / 'amp_map_{}_{}voxsize.h5'.format(treatment, vox_size))
    
except OSError:    
    amp_map = {}

    for session in ['pre', 'post']:
        amp_map[session] = np.full((xvx.shape[0], yvx.shape[0], zvx.shape[0]), np.nan)
        coords = voxeled_rois[session][in_brain_arr_pooled[session]]

        unique_coords = np.unique(coords, axis=0)

        for coord in unique_coords:
            vox_rois = np.argwhere(((coords == coord).all(axis=1)))
            amp_map[session][tuple(coord)] = np.mean(amp_pooled[session][in_brain_arr_pooled[session]][vox_rois])

    fl.save(master / 'amp_map_{}_{}voxsize.h5'.format(treatment, vox_size), amp_map)

In [None]:
#Correlation map
try:
    corr_map = fl.load(master / 'corr_map_{}_{}voxsize.h5'.format(treatment, vox_size))

except OSError:    
    corr_map = {}

    for session in ['pre', 'post']:
        corr_map[session] = np.full((xvx.shape[0], yvx.shape[0], zvx.shape[0]), np.nan)
        coords = voxeled_rois[session][in_brain_arr_pooled[session]]

        unique_coords = np.unique(coords, axis=0)

        for coord in unique_coords:
            vox_rois = np.argwhere(((coords == coord).all(axis=1)))
            corr_map[session][tuple(coord)] = np.nanmean(session_corrcoefs[session][in_brain_arr_pooled[session]][vox_rois])
            
    fl.save(master / 'corr_map_{}_{}voxsize.h5'.format(treatment, vox_size), corr_map)

In [None]:
def slice_stack(z_size, n_planes):
    step_size = z_size//n_planes
    z_levels = np.arange(0, z_size, step_size)
    z_levels = np.append(z_levels, z_size)
    
    return z_levels

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib as mpl

In [None]:
n_planes = 3

rel_cmap = 'Reds'
rel_vlims = [0,1]
diff_cmap = 'RdBu_r'
diff_vlims = [-.5, .5]


z_levels = slice_stack(rel_map['pre'].shape[2], n_planes)

fig, axes = plt.subplots(n_planes, 3, figsize=(6, 8))

for i_session, session in enumerate(['pre', 'post']):
    session_map = rel_map[session]

    for plane in range(n_planes):
        map_slice = session_map[:, :, z_levels[plane]:z_levels[plane+1]]
        axes[n_planes-1-plane, i_session].imshow(np.nanmean(map_slice, 2).T, cmap=rel_cmap, vmin=rel_vlims[0], vmax=rel_vlims[1])

    axes[0, i_session].set_title(session)
    
diff_map = rel_map['post']-rel_map['pre']
for plane in range(n_planes):
    map_slice = diff_map[:, :, z_levels[plane]:z_levels[plane+1]]
    axes[n_planes-1-plane, 2].imshow(np.nanmean(map_slice, 2).T, cmap=diff_cmap, vmin=diff_vlims[0], vmax=diff_vlims[1])
axes[0, 2].set_title('post - pre')
    
for ax in axes.ravel():
    ax.axis('off')

divider = make_axes_locatable(axes[-1, 0])
cax = divider.append_axes('right', size='5%', pad=0.05)
cmap = mpl.cm.cool
norm = mpl.colors.Normalize(vmin=rel_vlims[0], vmax=rel_vlims[1])
fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=rel_cmap), cax=cax, fraction=.5)
    
divider = make_axes_locatable(axes[-1, 2])
cax = divider.append_axes('right', size='5%', pad=0.05)
cmap = mpl.cm.cool
norm = mpl.colors.Normalize(vmin=diff_vlims[0], vmax=diff_vlims[1])
fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=diff_cmap), cax=cax, fraction=.5)
    
plt.suptitle('Reliability ({})'.format(treatment))
plt.tight_layout()

In [None]:
if fig_path is not None:
    fig.savefig(fig_path / 'reliability_voxelwise_{}_{}voxsize.pdf'.format(treatment, vox_size), dpi=350)

In [None]:
n_planes = 3

rel_cmap = 'Reds'
rel_vlims = [0, 1300]
diff_cmap = 'RdBu_r'
diff_vlims = [-1000, 1000]


z_levels = slice_stack(rel_map['pre'].shape[2], n_planes)

fig, axes = plt.subplots(n_planes, 3, figsize=(6, 8))

for i_session, session in enumerate(['pre', 'post']):
    session_map = amp_map[session]

    for plane in range(n_planes):
        map_slice = session_map[:, :, z_levels[plane]:z_levels[plane+1]]
        axes[n_planes-1-plane, i_session].imshow(np.nanmean(map_slice, 2).T, cmap=rel_cmap, vmin=rel_vlims[0], vmax=rel_vlims[1])

    axes[0, i_session].set_title(session)
    
diff_map = amp_map['post']-amp_map['pre']
for plane in range(n_planes):
    map_slice = diff_map[:, :, z_levels[plane]:z_levels[plane+1]]
    axes[n_planes-1-plane, 2].imshow(np.nanmean(map_slice, 2).T, cmap=diff_cmap, vmin=diff_vlims[0], vmax=diff_vlims[1])
axes[0, 2].set_title('post - pre')
    
for ax in axes.ravel():
    ax.axis('off')
    

divider = make_axes_locatable(axes[-1, 0])
cax = divider.append_axes('right', size='5%', pad=0.05)
cmap = mpl.cm.cool
norm = mpl.colors.Normalize(vmin=rel_vlims[0], vmax=rel_vlims[1])
fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=rel_cmap), cax=cax, fraction=.5)
    
divider = make_axes_locatable(axes[-1, 2])
cax = divider.append_axes('right', size='5%', pad=0.05)
cmap = mpl.cm.cool
norm = mpl.colors.Normalize(vmin=diff_vlims[0], vmax=diff_vlims[1])
fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=diff_cmap), cax=cax, fraction=.5)
    
plt.suptitle('Amplitude ({})'.format(treatment))
plt.tight_layout()

In [None]:
if fig_path is not None:
    fig.savefig(fig_path / 'amplitude_voxelwise_{}_{}voxsize.pdf'.format(treatment, vox_size), dpi=350)

In [None]:
n_planes = 3

corr_cmap = 'RdBu_r'
corr_vlims = [-.5, .5]
diff_cmap = 'RdBu_r'
diff_vlims = [-.5, .5]

z_levels = slice_stack(corr_map['pre'].shape[2], n_planes)

fig, axes = plt.subplots(n_planes, 3, figsize=(6, 8))

for i_session, session in enumerate(['pre', 'post']):
    session_map = corr_map[session]

    for plane in range(n_planes):
        map_slice = session_map[:, :, z_levels[plane]:z_levels[plane+1]]
        axes[n_planes-1-plane, i_session].imshow(np.nanmean(map_slice, 2).T, cmap=corr_cmap, vmin=corr_vlims[0], vmax=corr_vlims[1])

    axes[0, i_session].set_title(session)
    
diff_map = corr_map['post']-corr_map['pre']
for plane in range(n_planes):
    map_slice = diff_map[:, :, z_levels[plane]:z_levels[plane+1]]
    axes[n_planes-1-plane, 2].imshow(np.nanmean(map_slice, 2).T, cmap=diff_cmap, vmin=diff_vlims[0], vmax=diff_vlims[1])
axes[0, 2].set_title('post - pre')
    
for ax in axes.ravel():
    ax.axis('off')
    

divider = make_axes_locatable(axes[-1, 0])
cax = divider.append_axes('right', size='5%', pad=0.05)
cmap = mpl.cm.cool
norm = mpl.colors.Normalize(vmin=corr_vlims[0], vmax=corr_vlims[1])
fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=corr_cmap), cax=cax, fraction=.5)
    
divider = make_axes_locatable(axes[-1, 2])
cax = divider.append_axes('right', size='5%', pad=0.05)
cmap = mpl.cm.cool
norm = mpl.colors.Normalize(vmin=diff_vlims[0], vmax=diff_vlims[1])
fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=diff_cmap), cax=cax, fraction=.5)
    
# plt.suptitle('Amplitude ({})'.format(treatment))
plt.tight_layout()

In [None]:
if fig_path is not None:
    fig.savefig(fig_path / 'corrcoef_voxelwise_{}_{}voxsize.pdf'.format(treatment, vox_size), dpi=350)