In [None]:
%matplotlib widget

In [None]:
from glob import glob
import numpy as np
import pandas as pd
import flammkuchen as fl
from split_dataset import SplitDataset
from bouter import Experiment
from scipy.interpolate import interp1d 
from scipy.signal import convolve2d
import colorspacious
import napari
import matplotlib.pyplot as plt

from lavian_et_al_2025.imaging.imaging_classes import TwoPExperiment
from lavian_et_al_2025.visual_motion.stimulus_functions import stim_vel_dir_dataframe, quantize_directions

from pathlib import Path

In [None]:
# calculate directional tuning from dF/F traces, px-wise
def get_tuning_map(img, sens_regs, n_dirs=8):
    traces = img.reshape(img.shape[0], -1)

    n_t = sens_regs.shape[0]
    reg = sens_regs.T @ traces[:, :]
    reg = reg.reshape(reg.shape[0], img.shape[-2], img.shape[-1])
    
    # tuning vector
    bin_centers, bins = quantize_directions([0], n_dirs)
    vectors = np.stack([np.cos(bin_centers), np.sin(bin_centers)], 0)
    reg_vectors = np.reshape(
        vectors @ np.reshape(reg[:, :, :], (n_dirs, -1)),
        (2,) + reg.shape[1:],
    )
        
    angle = np.arctan2(reg_vectors[1], reg_vectors[0])
    amp = np.sqrt(np.sum(reg_vectors ** 2, 0))

    return amp, angle

In [None]:
# make a color map

def JCh_to_RGB255(x):
    output = np.clip(colorspacious.cspace_convert(x, "JCh", "sRGB1"), 0, 1)
    return (output * 255).astype(np.uint8)

def color_stack(
        amp,
        angle,
        hueshift=2.5,
        amp_percentile=80,
        maxsat=50,
        lightness_min=100,
        lightness_delta=-40,
    ):
    output_lch = np.empty(amp.shape + (3,))
    maxamp = np.percentile(amp, amp_percentile)

    output_lch[:, :, 0] = (
            lightness_min + (np.clip(amp / maxamp, 0, 1)) * lightness_delta
    )
    output_lch[:, :, 1] = (np.clip(amp / maxamp, 0, 1)) * maxsat
    output_lch[:, :, 2] = (-angle + hueshift) * 180 / np.pi

    return JCh_to_RGB255(output_lch)

In [None]:
master = Path(r"\\portulab.synology.me\data\Hagar\e0075\v06\new ipn")
fish_list = list(master.glob("*_f*"))

In [None]:
fish = fish_list[-1]
print(fish)

aligned = SplitDataset(fish / "aligned")
exp_list = glob(str(fish / "behavior/*.json"))

sampling = 1/3
time = np.linspace(0, aligned.shape[0]*sampling, aligned.shape[0])

In [None]:
len_rec, num_planes, x_pix, y_pix = np.shape(aligned)

In [None]:
# make a list of sensory regressors for each plane
plane_list = list(fish.glob("suite2p\*00*"))
reg_list = [None] * num_planes

for i in range(len(plane_list)):
    new_reg = fl.load(plane_list[i] / 'sensory_regressors.h5')['regressors_conv']
    reg_list[i] = new_reg

In [None]:
# make a list of sensory regressors for each plane
plane_list = list(fish.glob("suite2p\*00*"))
reg_list = [None] * num_planes

for i in range(len(plane_list)):
    new_reg = fl.load(plane_list[i] / 'sensory_regressors.h5')['regressors_conv']
    reg_list[i] = new_reg

In [None]:
stack = SplitDataset(fish / "dff")

In [None]:
# calculate tuning

amps = []
angles = []
for i in range(stack.shape[1]):
    print(i)
    img = stack[:,i,:,:]
    amp, angle = get_tuning_map(img, reg_list[i].T, n_dirs=8)
    amps.append(amp)
    angles.append(angle)

df = pd.DataFrame(list(zip(amps, angles)), columns=["amp", "angle"])

In [None]:
# make a color map from the amplitude/angle
pctl = 90

color_maps = []
for i in range(stack.shape[1]):
    amp = df.loc[i, "amp"]
    angle = df.loc[i, "angle"]
    color_map = color_stack(np.nan_to_num(amp), np.nan_to_num(angle), amp_percentile=pctl)
    color_maps.append(color_map)
    
color_maps = np.array(color_maps)

In [None]:
num_row = 3
num_col = 3
fig, axs = plt.subplots(num_row, num_col, figsize=(10, 10), sharey=True, sharex=True)

count = 0 

for i in range(0, num_planes):
    r = i // num_col
    c = np.mod(i, num_col)
    
    if count > -1:
        axs[r,c].axis('off')
    else:
        count += 1
        
    tmp_plane = np.rot90(color_maps[i], 3)
    axs[r,c].imshow(tmp_plane)

In [None]:
fig.savefig(fish / "tuning_curve2.pdf", dpi=300)
fig.savefig(fish / "tuning_curve2.jpg", dpi=300)