In [13]:
%matplotlib inline

In [14]:
from glob import glob
import numpy as np
import pandas as pd
import flammkuchen as fl
from split_dataset import SplitDataset
from bouter import Experiment
import matplotlib.pyplot as plt

from lavian_et_al_2025.imaging.imaging_classes import LightsheetExperiment
from lavian_et_al_2025.visual_motion.stimulus_functions import stim_vel_dir_dataframe, quantize_directions, get_tuning_map_rois, make_sensory_regressors
from lavian_et_al_2025.visual_motion.colors import JCh_to_RGB255
from lavian_et_al_2025.data_location import master_motion

import json
from pathlib import Path

In [15]:
def color_stack(
        amp,
        angle,
        hueshift=2.5,
        amp_percentile=80,
        maxsat=50,
        lightness_min=100,
        lightness_delta=-40,
    ):
    output_lch = np.zeros((amp.shape[0], 3))

    output_lch[:,0]
    maxamp = np.percentile(amp, amp_percentile)

    output_lch[:, 0] = (
            lightness_min + (np.clip(amp / maxamp, 0, 1)) * lightness_delta
    )
    output_lch[:, 1] = (np.clip(amp / maxamp, 0, 1)) * maxsat
    output_lch[:, 2] = (-angle + hueshift) * 180 / np.pi

    return JCh_to_RGB255(output_lch)

In [17]:
#Enter path to dataset and generate a list of all fish
master = master_motion / "LS" / "whole brain"
fish_list = list(master.glob("*_f*"))

# Load also morphed coordinates

In [19]:
morphed_coords = {}
in_brain_arr = {}

for fish in fish_list:
    print(fish)
    
    #Load morphed coords
    morphed_coords[fish.name] = fl.load(fish / 'registration' / 'to_h2b_baier_ref' / 'antspy' / 'mov_coords_transformed.h5')
    
    #indices of ROIs inside brain
    suite2p_brain = fl.load(fish / "data_from_suite2p_cells_brain.h5")
    
    in_brain_arr[fish.name] = np.full(morphed_coords[fish.name].shape[0], False)
    in_brain_arr[fish.name][suite2p_brain['coords_idx']] = True
    
print('Done')

Done


# Load example fish data

In [22]:
#Choose a fish for example plot
example_fish = 7
path = fish_list[example_fish]

IndexError: list index out of range

In [None]:
#Load experiment details and neural traces
exp = glob(str(path / "*behavior*"))[0]
traces = fl.load(path / "filtered_traces.h5", "/detr")

In [None]:
coords = morphed_coords[path.name]
suite2p_brain = fl.load(path / "data_from_suite2p_cells_brain.h5")
in_brain_idx = suite2p_brain['coords_idx']

In [None]:
# Data resolution: pixel size of x, y, z
with open(next(Path(path).glob("*metadata.json")), "r") as f:
        metadata = json.load(f)
lsconfig = metadata["imaging"]["microscope_config"]['lightsheet']['scanning']
z_tot_span = lsconfig["z"]["piezo_max"] - lsconfig["z"]["piezo_min"]
n_planes = lsconfig["triggering"]["n_planes"]
z_res = z_tot_span / n_planes
res = [0.6, 0.6, z_res]

# Sensory regressors - single fish

In [None]:
sensory_regressors = fl.load(path / "sensory_regressors.h5", "/regressors")
titles = ['right', 'backward right', 'backward', 'backward left', 'left', 'forward left', 'forward', 'forward right', ]

In [None]:
#Choose the direction of visual motion to plot
plot_dir = 0

current_dir = np.asarray(sensory_regressors.iloc[:, plot_dir])
num_traces = np.shape(traces)[1]

reg_corr = np.zeros((num_traces))
for i in range(num_traces):
    reg_corr[i] = np.corrcoef(current_dir, traces[:, i])[0,1]

coords_ib = coords[in_brain_idx]
reg_corr_ib = reg_corr[in_brain_idx]

mp_ind_regressor = np.argsort(np.abs(reg_corr_ib))

# Direction tuning - single fish

In [None]:
img_exp = LightsheetExperiment(path)
fs = img_exp.fn

sampling = 1/fs
time = np.linspace(0, traces.shape[0]*sampling, traces.shape[0])

len_rec, num_cells = np.shape(traces)

In [None]:
# make a list of sensory regressors 
reg = make_sensory_regressors(Experiment(path), sampling=sampling)
reg_list = [reg]

In [None]:
n_t = reg.shape[0]
regi = reg.values.T @ traces[:, :]

In [None]:
# calculate tuning
amp, angle = get_tuning_map_rois(traces, reg)

df = pd.DataFrame(list(zip(amp, angle)), columns=["amp", "angle"])

In [None]:
nan_filt = ~np.isnan(amp)

In [None]:
colors = color_stack(amp[nan_filt], angle[nan_filt])

In [None]:
coords_ib = coords[in_brain_idx]
colors_ib = colors[in_brain_idx]
amp_ib = amp[in_brain_idx]
amp_norm = amp_ib / np.nanmax(amp_ib)
amp_norm.shape

In [None]:
reliable_arr = fl.load(path / "reliability_index_arr.h5", "/reliability_arr_combined")
rel_thresh = 0.5
rel_thresh = np.percentile(reliable_arr, 95) # 0.5
selected_vis = np.where(reliable_arr > rel_thresh)[0]

In [None]:
coords_vis = coords_ib[selected_vis]
colors_vis = colors_ib[selected_vis]
amp_vis = amp_ib[selected_vis]

mp_ind_tuning = np.argsort(amp_vis)

# Morphed datasets

In [None]:
try:
    tuning_arrs = fl.load(master / 'tuning_arrs.h5')
    amp_pooled, angle_pooled = tuning_arrs['amp_pooled'], tuning_arrs['angle_pooled']

except OSError:
    amp_pooled = {}
    angle_pooled = {}

    for fish in fish_list:
        print(fish)

        #Load traces and stimulus metadata
        exp = glob(str(fish / "*behavior*"))[0]
        traces = fl.load(fish / "filtered_traces.h5", "/detr")

        img_exp = LightsheetExperiment(fish)
        fs = img_exp.fn
        sampling = 1/fs
        time = np.linspace(0, traces.shape[0]*sampling, traces.shape[0])

        #Make list of sensory regressors 
        reg = make_sensory_regressors(Experiment(fish), sampling=sampling)
        reg_list = [reg]

        #Compute tuning
        amp_pooled[fish.name], angle_pooled[fish.name] = get_tuning_map_rois(traces, reg)
        
    tuning_arrs = {'angle_pooled':angle_pooled, 'amp_pooled':amp_pooled}
    fl.save(master / 'tuning_arrs.h5', tuning_arrs)

    print('Done')

In [None]:
#Pool amp and angle arrays
amp_pooled_arr = np.concatenate([amp_pooled[fish.name] for fish in fish_list])
angle_pooled_arr = np.concatenate([angle_pooled[fish.name] for fish in fish_list])

#also boolean array to keep track of ROIs in the brain
in_brain_arr_pooled = np.concatenate([in_brain_arr[fish.name] for fish in fish_list])

#and pooled coordinates
coords_pooled = np.concatenate([morphed_coords[fish.name] for fish in fish_list], 0)

In [None]:
#Make filter to exclude NaNs from amp array (we assume NaNs co-occur in amp and angle array, it seems to be the case)
nan_filt = np.isnan(amp_pooled_arr)
print(nan_filt.sum(), ' ROIs excluded')

#Combine into a filtering array
valid_rois = np.logical_and(in_brain_arr_pooled, ~nan_filt)

In [None]:
#Color stack
colors_ib = color_stack(amp_pooled_arr[valid_rois], angle_pooled_arr[valid_rois])
# colors = color_stack(amp_pooled_arr, angle_pooled_arr)

In [None]:
#Filter ROIs and normalize vector amplitude
coords_ib_pooled = coords_pooled[valid_rois]
amp_pooled_ib = amp_pooled_arr[valid_rois]
amp_norm = amp_pooled_ib / np.nanmax(amp_pooled_ib)

In [None]:
#Load reliability arrays and filter ROIs
rel_thresh = 0.5

rel_arr_pooled = np.concatenate([fl.load(fish / "reliability_index_arr.h5", "/reliability_arr_combined") for fish in fish_list])

rel_thresh = np.nanpercentile(rel_arr_pooled, 95) # 0.5
print(rel_thresh)
selected_vis = np.where(rel_arr_pooled[~nan_filt[in_brain_arr_pooled]] > rel_thresh)[0]

In [None]:
coords_vis_pooled = coords_ib_pooled[selected_vis]
colors_vis_pooled = colors_ib[selected_vis]
amp_vis = amp_pooled_ib[selected_vis]

mp_ind_pooled = np.argsort(amp_vis)

# Make figure

In [None]:
x_lim = [500, 0]
t_lim = [0, 550]
dot_s = 2

scale_bar_len = 100
scale_bar_xpos = 100
scale_bar_ypos1 = 630
scale_bar_ypos2 = 770
fs = 8

In [None]:
x = 0
y = 1
z = 2

In [None]:
#Import reference anatomy [from first fish]
ref_anatomy = fl.load(fish_list[0] / 'registration' / 'to_h2b_baier_ref' / 'antspy' / 'ref_mapped.h5')

In [None]:
ref_anatomy_cropped = ref_anatomy[50:550, 70:690, 50:]

In [None]:
#Plot
y_main_rat = ref_anatomy_cropped.shape[1]/ref_anatomy_cropped.shape[0]
x_rat = ref_anatomy_cropped.shape[2]/ref_anatomy_cropped.shape[0]
y_rat = ref_anatomy_cropped.shape[2]/ref_anatomy_cropped.shape[1]*y_main_rat

fig = plt.figure(figsize=(4,4))
gs = fig.add_gridspec(2, 2, width_ratios=[1, x_rat], height_ratios=[y_rat, y_main_rat])

ax1 = fig.add_subplot(gs[1, 0])
ax2 = fig.add_subplot(gs[0, 0], sharex=ax1)
ax3 = fig.add_subplot(gs[1, 1], sharey=ax1)

mock_anatomy = np.zeros_like(ref_anatomy_cropped)
ax1.imshow(mock_anatomy.mean(2).T, cmap='gray_r')
ax2.imshow(mock_anatomy.mean(1).T, cmap='gray_r')
ax3.imshow(mock_anatomy.mean(0), cmap='gray_r')

ax1.scatter(coords_ib[mp_ind_regressor,0]-50, coords_ib[mp_ind_regressor,1]-50,
                 c=reg_corr_ib[mp_ind_regressor], s=dot_s, alpha=0.8, cmap='coolwarm', vmin=-1, vmax=1)
ax3.scatter(coords_ib[mp_ind_regressor,2]-50, coords_ib[mp_ind_regressor,1]-50, 
                 c=reg_corr_ib[mp_ind_regressor], s=dot_s, alpha=0.8, cmap='coolwarm', vmin=-1, vmax=1)
ax2.scatter(coords_ib[mp_ind_regressor,0]-50, coords_ib[mp_ind_regressor,2]-50, 
                 c=reg_corr_ib[mp_ind_regressor], s=dot_s, alpha=0.8, cmap='coolwarm', vmin=-1, vmax=1)

#Scale bars
scale_bar_ypos1 = 600
ax1.plot((scale_bar_xpos, scale_bar_xpos+scale_bar_len), (scale_bar_ypos1, scale_bar_ypos1), c='black')
ax1.text(scale_bar_xpos+(scale_bar_len/2), scale_bar_ypos1+10, r'{}$\mu$m'.format(scale_bar_len), va='top', ha='center', fontsize=fs)


ax3.invert_yaxis()

plt.subplots_adjust(hspace=0.001, wspace=0.01, left=0.05, right=0.95, top=0.95, bottom=0.05)
ax1.invert_yaxis()
ax2.invert_yaxis()

for ax in [ax1, ax2, ax3]:
     ax.axis('off')

In [None]:
fig_tun = plt.figure(figsize=(4,4))
gs = fig_tun.add_gridspec(2, 2, width_ratios=[1, x_rat], height_ratios=[y_rat, y_main_rat])

ax_tun1 = fig_tun.add_subplot(gs[1, 0])
ax_tun2 = fig_tun.add_subplot(gs[0, 0], sharex=ax_tun1)
ax_tun3 = fig_tun.add_subplot(gs[1, 1], sharey=ax_tun1)

mock_anatomy = np.zeros_like(ref_anatomy_cropped)
ax_tun1.imshow(mock_anatomy.mean(2).T, cmap='gray_r')
ax_tun2.imshow(mock_anatomy.mean(1).T, cmap='gray_r')
ax_tun3.imshow(mock_anatomy.mean(0), cmap='gray_r')

ax_tun1.scatter(coords_ib[:,0]-50, coords_ib[:,1]-50, s=dot_s, alpha=1, c='linen')
ax_tun1.scatter(coords_vis[mp_ind_tuning,0]-50, coords_vis[mp_ind_tuning,1]-50, c=colors_vis[mp_ind_tuning]/255, s=dot_s/1, alpha=1)

ax_tun3.scatter(coords_ib[:,2]-50, coords_ib[:,1]-50, s=dot_s, alpha=1, c='linen')
ax_tun3.scatter(coords_vis[mp_ind_tuning,2]-50, coords_vis[mp_ind_tuning,1]-50, c=colors_vis[mp_ind_tuning]/255, s=dot_s/1, alpha=1)

ax_tun2.scatter(coords_ib[:,0]-50, coords_ib[:,2]-50, s=dot_s, alpha=1, c='linen')
ax_tun2.scatter(coords_vis[mp_ind_tuning,0]-50, coords_vis[mp_ind_tuning,2]-50, c=colors_vis[mp_ind_tuning]/255, s=dot_s/1, alpha=1)

#Scale bars
scale_bar_ypos1 = 600
ax_tun1.plot((scale_bar_xpos, scale_bar_xpos+scale_bar_len), (scale_bar_ypos1, scale_bar_ypos1), c='black')
ax_tun1.text(scale_bar_xpos+(scale_bar_len/2), scale_bar_ypos1+10, r'{}$\mu$m'.format(scale_bar_len), va='top', ha='center', fontsize=fs)


ax_tun2.invert_yaxis()


for ax in [ax_tun1, ax_tun2, ax_tun3]:
     ax.axis('off')
        
plt.subplots_adjust(hspace=0.001, wspace=0.01, left=0.05, right=0.95, top=0.95, bottom=0.05)

In [None]:
coords_vis_pooled = coords_ib_pooled[selected_vis]
colors_vis_pooled = colors_ib[selected_vis]
amp_vis = amp_pooled_ib[selected_vis]

mp_ind_pooled = np.argsort(amp_vis)

In [None]:
#Plot
y_main_rat = ref_anatomy_cropped.shape[1]/ref_anatomy_cropped.shape[0]
x_rat = ref_anatomy_cropped.shape[2]/ref_anatomy_cropped.shape[0]
y_rat = ref_anatomy_cropped.shape[2]/ref_anatomy_cropped.shape[1]*y_main_rat

fig_pool = plt.figure(figsize=(4,4))
gs = fig_pool.add_gridspec(2, 2, width_ratios=[1, x_rat], height_ratios=[y_rat, y_main_rat])

ax_pool1 = fig_pool.add_subplot(gs[1, 0])
ax_pool2 = fig_pool.add_subplot(gs[0, 0], sharex=ax_pool1)
ax_pool3 = fig_pool.add_subplot(gs[1, 1], sharey=ax_pool1)

mock_anatomy = np.zeros_like(ref_anatomy_cropped)
ax_pool1.imshow(mock_anatomy.mean(2).T, cmap='gray_r')
ax_pool2.imshow(mock_anatomy.mean(1).T, cmap='gray_r')
ax_pool3.imshow(mock_anatomy.mean(0), cmap='gray_r')

ax_pool1.scatter(coords_ib_pooled[:,0], coords_ib_pooled[:,1]-50, s=dot_s, alpha=1, c='linen')
ax_pool1.scatter(coords_vis_pooled[mp_ind_pooled,0], coords_vis_pooled[mp_ind_pooled,1]-50, c=colors_vis_pooled[mp_ind_pooled]/255, s=dot_s/10, alpha=0.8)

ax_pool3.scatter(coords_ib_pooled[:,2], coords_ib_pooled[:,1]-50, s=dot_s, alpha=1, c='linen')
ax_pool3.scatter(coords_vis_pooled[mp_ind_pooled,2], coords_vis_pooled[mp_ind_pooled,1]-50, c=colors_vis_pooled[mp_ind_pooled]/255, s=dot_s/10, alpha=0.8)

ax_pool2.scatter(coords_ib_pooled[:,0], coords_ib_pooled[:,2]+30, s=dot_s, alpha=1, c='linen')
ax_pool2.scatter(coords_vis_pooled[mp_ind_pooled,0], coords_vis_pooled[mp_ind_pooled,2]+30, c=colors_vis_pooled[mp_ind_pooled]/255, s=dot_s/10, alpha=0.8)

#Scale bars
scale_bar_ypos1 = 700
ax_pool1.plot((scale_bar_xpos, scale_bar_xpos+scale_bar_len), (scale_bar_ypos1, scale_bar_ypos1), c='black')

ax_pool2.invert_yaxis()


for ax in [ax_pool1, ax_pool2, ax_pool3]:
     ax.axis('off')
        
plt.subplots_adjust(hspace=0.01, wspace=0.00, left=0.05, right=0.95, top=0.95, bottom=0.05)