In [None]:
%matplotlib widget

In [None]:
from glob import glob
import numpy as np
import pandas as pd
import flammkuchen as fl
from split_dataset import SplitDataset
from bouter import Experiment
from scipy import stats
from motions.utilities import stim_vel_dir_dataframe, quantize_directions
from scipy.interpolate import interp1d 
from scipy.signal import convolve2d
import colorspacious
import matplotlib.pyplot as plt

from fimpylab.core.lightsheet_experiment import LightsheetExperiment

import json
from pathlib import Path

from vision_and_navigation.imaging.general import corr2_coeff

In [None]:
# make sensory regressors. requires old bouter stimulus_param_log.
def make_sensory_regressors(exp, n_dirs=8, upsampling=5, sampling=1/3):
    stim = stim_vel_dir_dataframe(exp)
    bin_centres, dir_bins = quantize_directions(stim.theta)
    ind_regs = np.zeros((n_dirs, len(stim)))
    for i_dir in range(n_dirs):
        ind_regs[i_dir, :] = (np.abs(dir_bins - i_dir) < 0.1) & (stim.vel > 0.1)  

    dt_upsampled = sampling / upsampling
    t_imaging_up = np.arange(0, stim.t.values[-1], dt_upsampled)
    reg_up = interp1d(stim.t.values, ind_regs, axis=1, fill_value="extrapolate")(
        t_imaging_up
    )
    
    # 6s kernel
    u_steps = t_imaging_up.shape[0]
    u_time = np.arange(u_steps) * dt_upsampled
    decay = np.exp(-u_time / (1.5 / np.log(2)))
    kernel = decay / np.sum(decay)
    
    convolved = convolve2d(reg_up, kernel[None, :])[:, 0:u_steps]
    reg_sensory = convolved[:, ::upsampling]

    return pd.DataFrame(reg_sensory.T, columns=[f"motion_{i}" for i in range(n_dirs)])

In [None]:
# calculate directional tuning from zscored traces for each roi
def get_tuning_map(traces, sens_regs, n_dirs=8):

    n_t = sens_regs.shape[0]
    reg = sens_regs.values.T @ traces[:n_t, :]
#     print(np.shape(reg))
    #reg = reg.reshape(reg.shape[0], traces.shape[-1], traces.shape[-1])
    
    # tuning vector
    bin_centers, bins = quantize_directions([0], n_dirs)
    vectors = np.stack([np.cos(bin_centers), np.sin(bin_centers)], 0)
#     print(np.shape(vectors))
    reg_vectors = vectors @ reg
#     print(np.shape(reg_vectors))
    
    angle = np.arctan2(reg_vectors[1], reg_vectors[0])
    amp = np.sqrt(np.sum(reg_vectors ** 2, 0))

    return amp, angle

In [None]:
# make a color map
def JCh_to_RGB255(x):
    output = np.clip(colorspacious.cspace_convert(x, "JCh", "sRGB1"), 0, 1)
    return (output * 255).astype(np.uint8)

def color_stack(
        amp,
        angle,
        hueshift=2.5,
        amp_percentile=80,
        maxsat=50,
        lightness_min=100,
        lightness_delta=-40,
    ):
    output_lch = np.zeros((amp.shape[0], 3))
#     print(np.shape(output_lch))
    output_lch[:,0]
    maxamp = np.percentile(amp, amp_percentile)

    output_lch[:, 0] = (
            lightness_min + (np.clip(amp / maxamp, 0, 1)) * lightness_delta
    )
    output_lch[:, 1] = (np.clip(amp / maxamp, 0, 1)) * maxsat
    output_lch[:, 2] = (-angle + hueshift) * 180 / np.pi

    return JCh_to_RGB255(output_lch)

# Set fish path

In [None]:
# master = Path(r"Z:\Hagar and Ot\E0040\v10\LS")
treatment = 'ntr' #'control' #'ntr'
# session = 'pre' #'post' #'pre'

master = Path(r'\\portulab.synology.me\data\Hagar and Ot\E0040\v10\LS ablation\{}'.format(treatment))

fig_path = master # master # None

fish_list = list(master.glob("*_f*_{}*".format(treatment)))
fish_list

# Load morphed coordinates

In [None]:
try:
    coords_pooled = fl.load(master / 'coords_pooled_{}.h5'.format(treatment))
    in_brain_arr_pooled = fl.load(master / 'in_brain_arr_pooled_{}.h5'.format(treatment))
    
except OSError:
    coords_pooled = {}
    in_brain_arr_pooled = {}

    for session in ['pre', 'post']:
        
        session_list = list(master.glob("*{}".format(session)))
        morphed_coords = {}
        in_brain_arr = {}
        
        for fish in session_list:
            print(fish)

            #Load morphed coords
            morphed_coords[fish.name] = fl.load(fish / 'registration' / 'to_h2b_baier_ref' / 'antspy' / 'mov_coords_transformed.h5')

            #I guess this is an index to keep track of ROIs inside brain? Will make it into a boolean because makes more sense
            suite2p_brain = fl.load(fish / "data_from_suite2p_cells_brain.h5")

            in_brain_arr[fish.name] = np.full(morphed_coords[fish.name].shape[0], False)
            in_brain_arr[fish.name][suite2p_brain['coords_idx']] = True

        #And pool in a single array
        coords_pooled[session] = np.concatenate([morphed_coords[fish.name] for fish in session_list], 0)
        in_brain_arr_pooled[session] = np.concatenate([in_brain_arr[fish.name] for fish in session_list])

    fl.save(master / 'coords_pooled_{}.h5'.format(treatment), coords_pooled)
    fl.save(master / 'in_brain_arr_pooled_{}.h5'.format(treatment), in_brain_arr_pooled)

In [None]:
session_list = list(master.glob("*pre"))

for fish in session_list:
    print(fish)
    suite2p_brain = fl.load(fish / "data_from_suite2p_cells_brain.h5")
    print(len(suite2p_brain['coords_idx']))
    print('')

# Load tuning data

In [None]:
amp_pooled = {}
angle_pooled = {}

for session in ['pre', 'post']:
    tuning_arrs = fl.load(master / 'tuning_arrs_{}_{}.h5'.format(treatment, session))
    amp_pooled[session], angle_pooled[session] = tuning_arrs['amp_pooled'], tuning_arrs['angle_pooled']
    
    for dicti in [amp_pooled, angle_pooled]:
        dicti[session] = np.concatenate([dicti[session][k] for k in list(dicti[session].keys())])

In [None]:
rel_arr_pooled = {}

for session in ['pre', 'post']:

    session_list = list(master.glob("*{}".format(session)))
    rel_arr_pooled[session] = np.concatenate([fl.load(fish / "reliability_index_arr.h5", "/reliability_arr_combined") for fish in session_list])

# Load correlation data

In [None]:
n_dirs = 8

In [None]:
for session in ['pre', 'post']:
    try:
        corrcoefs_all = fl.load(master / 'reg_corrcoefs_pooled_{}_{}_all.h5'.format(treatment, session))

    except OSError:
        session_list = list(master.glob("*{}".format(session)))

        corrcoefs_all = {direction:[] for direction in range(n_dirs)}
        
        for path in session_list:
            traces = fl.load(path / "filtered_traces.h5", "/detr")
            sensory_regressors = fl.load(path / "sensory_regressors.h5", "/regressors")
            
            for direction in range(n_dirs):
                current_dir = np.asarray(sensory_regressors.iloc[:, direction])        
                corrcoefs_all[direction].append(corr2_coeff(traces.T, current_dir.reshape(1, -1)).ravel())
                
        for direction in range(n_dirs):
            corrcoefs_all[direction] = np.concatenate(corrcoefs_all[direction])

        fl.save(master / 'reg_corrcoefs_pooled_{}_{}_all.h5'.format(treatment, session), corrcoefs_all)

In [None]:
session_corrcoefs = {}

for session in ['pre', 'post']:
    session_corrcoefs[session] = fl.load(master / 'reg_corrcoefs_pooled_{}_{}_all.h5'.format(treatment, session))
    
    full_mat = np.stack([session_corrcoefs[session][direction] for direction in range(n_dirs)])
    session_corrcoefs[session] = np.array([full_mat[i, j] for i,j in zip(np.abs(full_mat).argmax(0), np.arange(full_mat.shape[1]))])

## Voxelization

In [None]:
from numba import jit
import numba
import napari

In [None]:
#Import reference anatomy [from first fish]
ref_anatomy = fl.load(fish_list[0] / 'registration' / 'to_h2b_baier_ref' / 'antspy' / 'ref_mapped.h5')

In [None]:
#Safety check
viewer = napari.Viewer()
viewer.add_image(ref_anatomy)
viewer.add_points(coords_pooled['pre'][in_brain_arr_pooled['pre']], face_color='red')
viewer.add_points(coords_pooled['post'][in_brain_arr_pooled['post']], face_color='blue')

In [None]:
print(ref_anatomy.shape)
print(coords_pooled['pre'].max(0))
print(coords_pooled['post'].max(0))

In [None]:
#Define voxel size and define shape of each new anatomical axis
vox_size = 5
xvx, yvx, zvx = [np.arange(0, ref_anatomy.shape[i], vox_size) for i in range(3)]

In [None]:
@jit(nopython=True)
def get_voxel_centroids(xvx, yvx, zvx, vox_size):
    vx_centroids = np.zeros((xvx.shape[0], yvx.shape[0], zvx.shape[0], 3))

    for ix, x in enumerate(xvx):
        for iy, y in enumerate(yvx):
            for iz, z in enumerate(zvx):
                vx_centroids[ix, iy, iz, :] = np.array((x+(vox_size/2),  y+(vox_size/2),  z+(vox_size/2)))
    return(vx_centroids)

In [None]:
@jit(nopython=True)
def assign_to_voxels(coords, vx_centroids, vox_size):
    coord_vox = np.full_like(coords, np.nan, dtype=numba.int32)

    for roi in range(coords.shape[0]):
    
        a = np.nonzero(np.sum(np.abs(vx_centroids - coords[roi, :]) < (vox_size/2), axis=-1) == 3)
        
        for i, coord in enumerate(a):
            coord_vox[roi, i] = coord[0]
        
    return(coord_vox)

In [None]:
try:
    voxeled_rois = fl.load(master / 'voxeled_rois_{}_{}voxsize.h5'.format(treatment, vox_size))

except OSError:
     
    voxeled_rois = {}
    
    for session in ['pre', 'post']:
        vx_centroids = get_voxel_centroids(xvx, yvx, zvx, vox_size)
        voxeled_rois[session] = assign_to_voxels(coords_pooled[session], vx_centroids, vox_size)
        
    fl.save(master / 'voxeled_rois_{}_{}voxsize.h5'.format(treatment, vox_size), voxeled_rois)

In [None]:
occ_map = {}

for session in ['pre', 'post']:
    occ_map[session] = np.full((xvx.shape[0], yvx.shape[0], zvx.shape[0]), np.nan)
    coords = voxeled_rois[session][in_brain_arr_pooled[session]]

    unique_coords, counts = np.unique(coords, axis=0, return_counts=True)
    for coord, count in zip(unique_coords, counts):
        occ_map[session][tuple(coord)] = count

In [None]:
viewer = napari.Viewer()
viewer.add_image(occ_map['pre'], colormap='Reds')
viewer.add_image(occ_map['post'], colormap='Blues')

# Alright let's make plots

In [None]:
# @jit(nopython=True)
# def make_map_from_values(vxv, yvx, zvx, coords, metric_arr):
#     map_arr = np.full((xvx.shape[0], yvx.shape[0], zvx.shape[0]), np.nan)
#     unique_coords = np.unique(coords, axis=0)
    
#     for coord in unique_coords:
#         vox_rois = np.argwhere(((coords == coord).all(axis=1)))
#         rel_map[tuple(coord)] = np.mean(metric_arr[vox_rois])

In [None]:
#Reliability map
try:
    rel_map = fl.load(master / 'rel_map_{}_{}voxsize.h5'.format(treatment, vox_size))
    
except OSError:    
    rel_map = {}

    for session in ['pre', 'post']:
        rel_map[session] = np.full((xvx.shape[0], yvx.shape[0], zvx.shape[0]), np.nan)
        coords = voxeled_rois[session][in_brain_arr_pooled[session]]

        unique_coords = np.unique(coords, axis=0)

        for coord in unique_coords:
            vox_rois = np.argwhere(((coords == coord).all(axis=1)))
            rel_map[session][tuple(coord)] = np.mean(rel_arr_pooled[session][vox_rois])

    fl.save(master / 'rel_map_{}_{}voxsize.h5'.format(treatment, vox_size), rel_map)

In [None]:
#Amplitude map
try:
    amp_map = fl.load(master / 'amp_map_{}_{}voxsize.h5'.format(treatment, vox_size))
    
except OSError:    
    amp_map = {}

    for session in ['pre', 'post']:
        amp_map[session] = np.full((xvx.shape[0], yvx.shape[0], zvx.shape[0]), np.nan)
        coords = voxeled_rois[session][in_brain_arr_pooled[session]]

        unique_coords = np.unique(coords, axis=0)

        for coord in unique_coords:
            vox_rois = np.argwhere(((coords == coord).all(axis=1)))
            amp_map[session][tuple(coord)] = np.mean(amp_pooled[session][in_brain_arr_pooled[session]][vox_rois])

    fl.save(master / 'amp_map_{}_{}voxsize.h5'.format(treatment, vox_size), amp_map)

In [None]:
#Correlation map
try:
    corr_map = fl.load(master / 'corr_map_{}_{}voxsize.h5'.format(treatment, vox_size))

except OSError:    
    corr_map = {}

    for session in ['pre', 'post']:
        corr_map[session] = np.full((xvx.shape[0], yvx.shape[0], zvx.shape[0]), np.nan)
        coords = voxeled_rois[session][in_brain_arr_pooled[session]]

        unique_coords = np.unique(coords, axis=0)

        for coord in unique_coords:
            vox_rois = np.argwhere(((coords == coord).all(axis=1)))
            corr_map[session][tuple(coord)] = np.nanmean(session_corrcoefs[session][in_brain_arr_pooled[session]][vox_rois])
            
    fl.save(master / 'corr_map_{}_{}voxsize.h5'.format(treatment, vox_size), corr_map)

In [None]:
def slice_stack(z_size, n_planes):
    step_size = z_size//n_planes
    z_levels = np.arange(0, z_size, step_size)
    z_levels = np.append(z_levels, z_size)
    
    return z_levels

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib as mpl

In [None]:
n_planes = 3

rel_cmap = 'Reds'
rel_vlims = [0,1]
diff_cmap = 'RdBu_r'
diff_vlims = [-.5, .5]


z_levels = slice_stack(rel_map['pre'].shape[2], n_planes)

fig, axes = plt.subplots(n_planes, 3, figsize=(6, 8))

for i_session, session in enumerate(['pre', 'post']):
    session_map = rel_map[session]

    for plane in range(n_planes):
        map_slice = session_map[:, :, z_levels[plane]:z_levels[plane+1]]
        axes[n_planes-1-plane, i_session].imshow(np.nanmean(map_slice, 2).T, cmap=rel_cmap, vmin=rel_vlims[0], vmax=rel_vlims[1])

    axes[0, i_session].set_title(session)
    
diff_map = rel_map['post']-rel_map['pre']
for plane in range(n_planes):
    map_slice = diff_map[:, :, z_levels[plane]:z_levels[plane+1]]
    axes[n_planes-1-plane, 2].imshow(np.nanmean(map_slice, 2).T, cmap=diff_cmap, vmin=diff_vlims[0], vmax=diff_vlims[1])
axes[0, 2].set_title('post - pre')
    
for ax in axes.ravel():
    ax.axis('off')

divider = make_axes_locatable(axes[-1, 0])
cax = divider.append_axes('right', size='5%', pad=0.05)
cmap = mpl.cm.cool
norm = mpl.colors.Normalize(vmin=rel_vlims[0], vmax=rel_vlims[1])
fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=rel_cmap), cax=cax, fraction=.5)
    
divider = make_axes_locatable(axes[-1, 2])
cax = divider.append_axes('right', size='5%', pad=0.05)
cmap = mpl.cm.cool
norm = mpl.colors.Normalize(vmin=diff_vlims[0], vmax=diff_vlims[1])
fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=diff_cmap), cax=cax, fraction=.5)
    
plt.suptitle('Reliability ({})'.format(treatment))
plt.tight_layout()

In [None]:
if fig_path is not None:
    fig.savefig(fig_path / 'reliability_voxelwise_{}_{}voxsize.pdf'.format(treatment, vox_size), dpi=350)

In [None]:
n_planes = 3

rel_cmap = 'Reds'
rel_vlims = [0, 1300]
diff_cmap = 'RdBu_r'
diff_vlims = [-1000, 1000]


z_levels = slice_stack(rel_map['pre'].shape[2], n_planes)

fig, axes = plt.subplots(n_planes, 3, figsize=(6, 8))

for i_session, session in enumerate(['pre', 'post']):
    session_map = amp_map[session]

    for plane in range(n_planes):
        map_slice = session_map[:, :, z_levels[plane]:z_levels[plane+1]]
        axes[n_planes-1-plane, i_session].imshow(np.nanmean(map_slice, 2).T, cmap=rel_cmap, vmin=rel_vlims[0], vmax=rel_vlims[1])

    axes[0, i_session].set_title(session)
    
diff_map = amp_map['post']-amp_map['pre']
for plane in range(n_planes):
    map_slice = diff_map[:, :, z_levels[plane]:z_levels[plane+1]]
    axes[n_planes-1-plane, 2].imshow(np.nanmean(map_slice, 2).T, cmap=diff_cmap, vmin=diff_vlims[0], vmax=diff_vlims[1])
axes[0, 2].set_title('post - pre')
    
for ax in axes.ravel():
    ax.axis('off')
    

divider = make_axes_locatable(axes[-1, 0])
cax = divider.append_axes('right', size='5%', pad=0.05)
cmap = mpl.cm.cool
norm = mpl.colors.Normalize(vmin=rel_vlims[0], vmax=rel_vlims[1])
fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=rel_cmap), cax=cax, fraction=.5)
    
divider = make_axes_locatable(axes[-1, 2])
cax = divider.append_axes('right', size='5%', pad=0.05)
cmap = mpl.cm.cool
norm = mpl.colors.Normalize(vmin=diff_vlims[0], vmax=diff_vlims[1])
fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=diff_cmap), cax=cax, fraction=.5)
    
plt.suptitle('Amplitude ({})'.format(treatment))
plt.tight_layout()

In [None]:
if fig_path is not None:
    fig.savefig(fig_path / 'amplitude_voxelwise_{}_{}voxsize.pdf'.format(treatment, vox_size), dpi=350)

In [None]:
n_planes = 3

corr_cmap = 'RdBu_r'
corr_vlims = [-.5, .5]
diff_cmap = 'RdBu_r'
diff_vlims = [-.5, .5]

z_levels = slice_stack(corr_map['pre'].shape[2], n_planes)

fig, axes = plt.subplots(n_planes, 3, figsize=(6, 8))

for i_session, session in enumerate(['pre', 'post']):
    session_map = corr_map[session]

    for plane in range(n_planes):
        map_slice = session_map[:, :, z_levels[plane]:z_levels[plane+1]]
        axes[n_planes-1-plane, i_session].imshow(np.nanmean(map_slice, 2).T, cmap=corr_cmap, vmin=corr_vlims[0], vmax=corr_vlims[1])

    axes[0, i_session].set_title(session)
    
diff_map = corr_map['post']-corr_map['pre']
for plane in range(n_planes):
    map_slice = diff_map[:, :, z_levels[plane]:z_levels[plane+1]]
    axes[n_planes-1-plane, 2].imshow(np.nanmean(map_slice, 2).T, cmap=diff_cmap, vmin=diff_vlims[0], vmax=diff_vlims[1])
axes[0, 2].set_title('post - pre')
    
for ax in axes.ravel():
    ax.axis('off')
    

divider = make_axes_locatable(axes[-1, 0])
cax = divider.append_axes('right', size='5%', pad=0.05)
cmap = mpl.cm.cool
norm = mpl.colors.Normalize(vmin=corr_vlims[0], vmax=corr_vlims[1])
fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=corr_cmap), cax=cax, fraction=.5)
    
divider = make_axes_locatable(axes[-1, 2])
cax = divider.append_axes('right', size='5%', pad=0.05)
cmap = mpl.cm.cool
norm = mpl.colors.Normalize(vmin=diff_vlims[0], vmax=diff_vlims[1])
fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=diff_cmap), cax=cax, fraction=.5)
    
# plt.suptitle('Amplitude ({})'.format(treatment))
plt.tight_layout()

In [None]:
if fig_path is not None:
    fig.savefig(fig_path / 'corrcoef_voxelwise_{}_{}voxsize.pdf'.format(treatment, vox_size), dpi=350)

In [None]:
amp_map['pre'].shape

In [None]:
72*5

In [None]:
z_levels

In [None]:
24*5

In [None]:
fig, axes = plt.subplots(1,3, figsize=(10,5))

axes[0].imshow(np.nanmean(rel_map['pre'], 2).T, cmap='RdBu_r', vmin=-1, vmax=1)
axes[0].set_title('pre')

axes[1].imshow(np.nanmean(rel_map['post'], 2).T, cmap='RdBu_r', vmin=-1, vmax=1)
axes[1].set_title('post')

axes[2].imshow(np.nanmean(rel_map['pre']-rel_map['post'], 2).T, cmap='RdBu_r', vmin=-1, vmax=1)
axes[2].set_title('diff')

for ax in axes:
    ax.axis('off')
    
fig.suptitle('Ntr', fontsize=15)

In [None]:
unique_coords, counts = np.unique(coords, axis=0, return_counts=True)


In [None]:
unique_coords, counts

In [None]:
((coords == coord).all(axis=1)).sum()

In [None]:
coord

In [None]:
occ_map = {}

for session in ['pre', 'post']:
    occ_map[session] = np.full((xvx.shape[0], yvx.shape[0], zvx.shape[0]), np.nan)
    coords = voxeled_rois[session][in_brain_arr_pooled[session]]

    unique_coords, counts = np.unique(coords, axis=0, return_counts=True)
    print(unique_coords.shape)
    
    for coord, count in zip(unique_coords, counts):
        occ_map[session][tuple(coord)] = count

In [None]:
mastera

In [None]:
voxeled_rois['pre'][np.nonzero(voxeled_rois['pre']<0)] = 0

In [None]:
voxeled_rois['pre'].max(0)

In [None]:
xvx.shape[0], yvx.shape[0], zvx.shape[0]

In [None]:
occ_map = {}

for session in ['pre', 'post']:
    occ_map[session] = np.full((xvx.shape[0], yvx.shape[0], zvx.shape[0]), np.nan)
    coords = voxeled_rois[session][in_brain_arr_pooled[session]]

    unique_coords, counts = np.unique(coords, axis=0, return_counts=True)
    for coord, count in zip(unique_coords, counts):
        occ_map[session][tuple(coord)] = count
        
fl.save(master / 'occupancy_map_ntr_5voxsize.h5', occ_map)

In [None]:
viewer = napari.Viewer()
viewer.add_image(occ_map['pre'])

In [None]:
unique_coords

In [None]:
plt.figure()
plt.scatter(coords_pooled['pre'][in_brain_arr_pooled['pre']][:, 0], coords_pooled['pre'][in_brain_arr_pooled['pre']][:, 1])

In [None]:
import napari

In [None]:
viewer = napari.Viewer()
viewer.add_image(ref_anatomy)
a = coords_pooled['pre'][in_brain_arr_pooled['pre']][:100000, :]
viewer.add_points(a)

In [None]:
a = voxeled_rois['pre'][in_brain_arr_pooled['pre']][-1000:, :]

In [None]:
plt.figure()
plt.imshow(np.nanmean(ref_anatomy, 2), cmap='gray')

for roi in a:
    plt.scatter(roi[0], roi[1])

In [None]:
voxeled_rois['pre'][:1000, :]

In [None]:
occ_map = {}

for session in ['pre', 'post']:
    occ_map[session] = np.full_like(ref_anatomy, np.nan)
    coords = voxeled_rois[session][in_brain_arr_pooled[session]]

    unique_coords, counts = np.unique(coords, axis=0, return_counts=True)
    for coord, count in zip(unique_coords, counts):
        occ_map[session][coord] = count
        
fl.save(master / 'occupancy_map_{}_{}voxsize.h5'.format(treatment, vox_size), occ_map)

In [None]:
coords = voxeled_rois['pre'][in_brain_arr_pooled['pre']]

In [None]:
unique_coords, counts = np.unique(coords, axis=0, return_counts=True)

In [None]:
np.unique(voxeled_rois['pre'], axis=0)

In [None]:
a = np.unique(voxeled_rois['pre'], axis=0)

In [None]:
plt.figure()
plt.imshow(np.nanmean(ref_anatomy, 2), cmap='gray')

for i in np.unique(voxeled_rois['pre'], axis=0):
    plt.scatter(i[0], i[1])

In [None]:
import napari
viewer = napari.Viewer()
viewer.add_image(occ_map['pre'])

In [None]:
occ_map['pre'].shape

In [None]:
plt.figure()
plt.imshow(ref_anatomy[:, :, 100])
# plt.imshow(occ_map['pre'][:, :, 100])

In [None]:
for i, z in zip(np.unique(coords, axis=0, return_counts=True)):
    print(i, z)

In [None]:
metric = amp_pooled

metric_map = {}


In [None]:
coords = voxeled_rois[session][in_brain_arr_pooled[session]]
coords.shape

In [None]:
amp_pooled[]

In [None]:
np.unique(coords, axis=0, return_counts=True)

In [None]:
coords

In [None]:
maps['pre']

In [None]:
from numba import jit
import numba

#Import reference anatomy [from first fish]
ref_anatomy = fl.load(fish_list[0] / 'registration' / 'to_h2b_baier_ref' / 'antspy' / 'ref_mapped.h5')

#Define voxel size and define shape of each new anatomical axis
vox_size = 5
xvx, yvx, zvx = [np.arange(0, ref_anatomy.shape[i], vox_size) for i in range(3)]

@jit(nopython=True)
def get_voxel_centroids(xvx, yvx, zvx, vox_size):
    vx_centroids = np.zeros((xvx.shape[0], yvx.shape[0], zvx.shape[0], 3))

    for ix, x in enumerate(xvx):
        for iy, y in enumerate(yvx):
            for iz, z in enumerate(zvx):
                vx_centroids[ix, iy, iz, :] = np.array((x+(vox_size/2),  y+(vox_size/2),  z+(vox_size/2)))
    return(vx_centroids)

@jit(nopython=True)
def assign_to_voxels(coords, vx_centroids, vox_size):
    coord_vox = np.full_like(coords, np.nan, dtype=numba.int32)

    for roi in range(coords.shape[0]):
    
        a = np.nonzero(np.sum(np.abs(vx_centroids - coords[roi, :]) < (vox_size/2), axis=-1) == 3)
        
        for i, coord in enumerate(a):
            coord_vox[roi, i] = coord[0]
        
    return(coord_vox)

vox_size = 5

try:
    voxeled_rois = fl.load(master / 'voxeled_rois_{}_{}voxsize.h5'.format(treatment, vox_size))

except OSError:
     
    voxeled_rois = {}
    
    for session in ['pre', 'post']:
        vx_centroids = get_voxel_centroids(xvx, yvx, zvx, vox_size)
        voxeled_rois[session] = assign_to_voxels(coords_pooled[session], vx_centroids, vox_size)
        
    fl.save(master / 'voxeled_rois_{}_{}voxsize.h5'.format(treatment, vox_size), voxeled_rois)

# Sensory regressors

In [None]:
titles = ['right', 'backward right', 'backward', 'backward left', 'left', 'forward left', 'forward', 'forward right', ]
plot_dir = 0

In [None]:
try:
    reg_corrcoefs_pooled = fl.load(master / 'reg_corrcoefs_pooled_{}_{}_dir{}.h5'.format(treatment, session, plot_dir))
    
except OSError:
    reg_corrcoefs = []

    for path in fish_list:
        traces = fl.load(path / "filtered_traces.h5", "/detr")
        sensory_regressors = fl.load(path / "sensory_regressors.h5", "/regressors")

        current_dir = np.asarray(sensory_regressors.iloc[:, plot_dir])        
        reg_corrcoefs.append(corr2_coeff(traces.T, current_dir.reshape(1, -1)).ravel())

    reg_corrcoefs_pooled = np.concatenate(reg_corrcoefs)
    fl.save(master / 'reg_corrcoefs_pooled_{}_{}_dir{}.h5'.format(treatment, session, plot_dir), reg_corrcoefs_pooled)

In [None]:
# titles = ['right', 'backward right', 'backward', 'backward left', 'left', 'forward left', 'forward', 'forward right', ]
# plot_dir = 0

# reg_corrcoefs = []

# for path in fish_list:
#     traces = fl.load(path / "filtered_traces.h5", "/detr")
#     sensory_regressors = fl.load(path / "sensory_regressors.h5", "/regressors")

#     current_dir = np.asarray(sensory_regressors.iloc[:, plot_dir])        
#     reg_corrcoefs.append(corr2_coeff(traces.T, current_dir.reshape(1, -1)).ravel())

In [None]:
in_brain_arr_pooled = np.concatenate([in_brain_arr[fish.name] for fish in fish_list])

In [None]:
#Filter ROIs and normalize vector amplitude
coords_ib_pooled = coords_pooled[in_brain_arr_pooled]
reg_corrcoefs_ib_pooled = reg_corrcoefs_pooled[in_brain_arr_pooled]
mp_ind_regressor = np.argsort(np.abs(reg_corrcoefs_ib_pooled))

In [None]:
dot_s = 2

scale_bar_len = 100
scale_bar_xpos = 100
scale_bar_ypos1 = 750
fs = 8

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(7,8), gridspec_kw={'width_ratios': [6, 2], 'height_ratios': [1, 3]}, sharex='col', sharey='row')

fig.subplots_adjust(left=0.05, wspace=0.05, hspace=0.05)

#Regressors
axs[1,0].scatter(coords_ib_pooled[mp_ind_regressor,0], coords_ib_pooled[mp_ind_regressor,1], c=reg_corrcoefs_ib_pooled[mp_ind_regressor], s=dot_s, alpha=0.8, cmap='coolwarm', vmin=-1, vmax=1)
axs[1,1].scatter(coords_ib_pooled[mp_ind_regressor,2], coords_ib_pooled[mp_ind_regressor,1], c=reg_corrcoefs_ib_pooled[mp_ind_regressor], s=dot_s, alpha=0.8, cmap='coolwarm', vmin=-1, vmax=1)
axs[0,0].scatter(coords_ib_pooled[mp_ind_regressor,0], coords_ib_pooled[mp_ind_regressor,2], c=reg_corrcoefs_ib_pooled[mp_ind_regressor], s=dot_s, alpha=0.8, cmap='coolwarm', vmin=-1, vmax=1)
axs[0,0].set_title(titles[plot_dir])

#Scale bars
axs[1,0].plot((scale_bar_xpos, scale_bar_xpos+scale_bar_len), (scale_bar_ypos1, scale_bar_ypos1), c='black')
axs[1,0].text(scale_bar_xpos+(scale_bar_len/2), scale_bar_ypos1+10, r'{}$\mu$m'.format(scale_bar_len), va='top', ha='center', fontsize=fs)

for ax in axs.flatten():
    ax.axis('off')

axs[0, 1].axis('off')
axs[1, 0].invert_yaxis()
# axs[1, 1].invert_yaxis()

In [None]:
if fig_path is not None:
    fig.savefig(fig_path / 'pooled_regressor_corrcoefs_{}_{}.png'.format(treatment, session), dpi=300)

# Morphed datasets

In [None]:
try:
    tuning_arrs = fl.load(master / 'tuning_arrs_{}_{}.h5'.format(treatment, session))
    amp_pooled, angle_pooled = tuning_arrs['amp_pooled'], tuning_arrs['angle_pooled']

except OSError:
    amp_pooled = {}
    angle_pooled = {}

    for fish in fish_list:
        print(fish)

        #Load traces and stimulus metadata
        exp = glob(str(fish / "*behavior*"))[0]
        traces = fl.load(fish / "filtered_traces.h5", "/detr")

        img_exp = LightsheetExperiment(fish)
        fs = img_exp.fn
        sampling = 1/fs
        time = np.linspace(0, traces.shape[0]*sampling, traces.shape[0])

        #Make list of sensory regressors 
        reg = make_sensory_regressors(Experiment(fish), sampling=sampling)
        reg_list = [reg]

        #Compute tuning
        amp_pooled[fish.name], angle_pooled[fish.name] = get_tuning_map(traces, reg)
        
    tuning_arrs = {'angle_pooled':angle_pooled, 'amp_pooled':amp_pooled}
    fl.save(master / 'tuning_arrs_{}_{}.h5'.format(treatment, session), tuning_arrs)

    print('Done')

In [None]:
try:
    fish_source = fl.load(master / 'fish_source_{}_{}.h5'.format(treatment, session))

except OSError:
    
    fish_source_list = []
    
    for i, fish in enumerate(fish_list):
        fish_source_list.append(np.full((fl.load(fish / "filtered_traces.h5", "/detr").shape[1]), i))
        
    fish_source = np.concatenate(fish_source_list)
    fl.save(master / 'fish_source_{}_{}.h5'.format(treatment, session), fish_source)


In [None]:
#Pool amp and angle arrays
amp_pooled_arr = np.concatenate([amp_pooled[fish.name] for fish in fish_list])
angle_pooled_arr = np.concatenate([angle_pooled[fish.name] for fish in fish_list])

#also boolean array to keep track of ROIs in the brain
in_brain_arr_pooled = np.concatenate([in_brain_arr[fish.name] for fish in fish_list])

# #and pooled coordinates
# coords_pooled = np.concatenate([morphed_coords[fish.name] for fish in fish_list], 0)

In [None]:
#Make filter to exclude NaNs from amp array (we assume NaNs co-occur in amp and angle array, it seems to be the case)
nan_filt = np.isnan(amp_pooled_arr)
print(nan_filt.sum(), ' ROIs excluded')

#Combine into a filtering array
valid_rois = np.logical_and(in_brain_arr_pooled, ~nan_filt)

In [None]:
#Color stack
colors_ib = color_stack(amp_pooled_arr[valid_rois], angle_pooled_arr[valid_rois])
# colors = color_stack(amp_pooled_arr, angle_pooled_arr)
colors_ib.shape

In [None]:
#Filter ROIs and normalize vector amplitude
coords_ib_pooled = coords_pooled[valid_rois]
amp_pooled_ib = amp_pooled_arr[valid_rois]
amp_norm = amp_pooled_ib / np.nanmax(amp_pooled_ib)

In [None]:
#Load reliability arrays and filter ROIs
perct = 95

rel_arr_pooled = np.concatenate([fl.load(fish / "reliability_index_arr.h5", "/reliability_arr_combined") for fish in fish_list])
rel_thresh = np.percentile(rel_arr_pooled, perct) # 0.5
selected_vis = np.where(rel_arr_pooled[~nan_filt[in_brain_arr_pooled]] > rel_thresh)[0]

In [None]:
sel_roi_count = fish_source[valid_rois][selected_vis]
fish_i, roi_count = np.unique(sel_roi_count, return_counts=True)
fish_i, total_count = np.unique(fish_source, return_counts=True)

fig, ax = plt.subplots()
ax.bar(fish_i, total_count, color='white', edgecolor='black')
ax.bar(fish_i, roi_count, color='black')

ax.set_xticks(np.arange(len(fish_list)))
ax.set_xlabel('Fish')
ax.set_ylabel('# of ROIs')

plt.tight_layout()

In [None]:
coords_vis_pooled = coords_ib_pooled[selected_vis]
colors_vis_pooled = colors_ib[selected_vis]
amp_vis = amp_pooled_ib[selected_vis]

mp_ind_pooled = np.argsort(amp_vis)

In [None]:
# x_lim = [500, 0]
# t_lim = [0, 550]
dot_s = 2

scale_bar_len = 100
scale_bar_xpos = 100
scale_bar_ypos1 = 730
fs = 8

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(5,5), gridspec_kw={'width_ratios': [6, 2], 'height_ratios': [1, 3]}, sharex='col', sharey='row')
fig.subplots_adjust(left=0.05, wspace=0.05, hspace=0.05)

#Pooled
axs[1,0].scatter(coords_ib_pooled[:,0], coords_ib_pooled[:,1], c='linen', s=dot_s/10, alpha=0.8)
axs[1,0].scatter(coords_vis_pooled[mp_ind_pooled,0], coords_vis_pooled[mp_ind_pooled,1], c=colors_vis_pooled[mp_ind_pooled]/255, s=dot_s/10, alpha=0.8)

axs[1,1].scatter(coords_ib_pooled[:,2], coords_ib_pooled[:,1], c='linen', s=dot_s/10, alpha=0.8)
axs[1,1].scatter(coords_vis_pooled[mp_ind_pooled,2], coords_vis_pooled[mp_ind_pooled,1], c=colors_vis_pooled[mp_ind_pooled]/255, s=dot_s/10, alpha=0.8)

axs[0,0].scatter(coords_ib_pooled[:,0], coords_ib_pooled[:,2], c='linen', s=dot_s/10, alpha=0.8)
axs[0,0].scatter(coords_vis_pooled[mp_ind_pooled,0], coords_vis_pooled[mp_ind_pooled,2], c=colors_vis_pooled[mp_ind_pooled]/255, s=dot_s/10, alpha=0.8)

#Scale bars
axs[1,0].plot((scale_bar_xpos, scale_bar_xpos+scale_bar_len), (scale_bar_ypos1, scale_bar_ypos1), c='black')
axs[1,0].text(scale_bar_xpos+(scale_bar_len/2), scale_bar_ypos1+10, r'{}$\mu$m'.format(scale_bar_len), va='top', ha='center', fontsize=fs)

for ax in axs.flatten():
    ax.axis('off')

axs[0, 1].axis('off')
axs[1, 0].invert_yaxis()
# axs[1, 1].invert_yaxis()

In [None]:
if fig_path is not None:
    fig.savefig(fig_path / 'pooled_tuning_{}th_{}_{}.png'.format(perct, treatment, session), dpi=300)