In [None]:
%matplotlib widget

In [None]:
from glob import glob
import numpy as np
import pandas as pd
import flammkuchen as fl
from split_dataset import SplitDataset
from bouter import Experiment
from motions.utilities import stim_vel_dir_dataframe, quantize_directions
from scipy.interpolate import interp1d 
from scipy.signal import convolve2d
import colorspacious
import napari
import matplotlib.pyplot as plt

from fimpylab.core.twop_experiment import TwoPExperiment
from fimpy.pipeline.general import calc_f0, dff

from pathlib import Path

from vision_and_navigation.imaging.visual_motion import color_stack_3d, make_sensory_regressors

In [None]:
# find the frames to calculate the baseline.
def no_regressor_frames(regressors, threshold=0.01):
    return np.where(np.all(regressors.values < threshold, axis=1))[0]

# calculate the baseline, plane-wise
def calc_f0(stack, frames):
    fr_mean = None
    for i_frame in frames:
        sf = stack[int(i_frame), :, :]
        if fr_mean is None:
            fr_mean = sf
        else:
            fr_mean += sf
    return fr_mean / len(frames)

In [None]:
# calculate directional tuning from dF/F traces, px-wise
def get_tuning_map(img, sens_regs, n_dirs=8):
    traces = img.reshape(img.shape[0], -1)

    n_t = sens_regs.shape[0]
    reg = sens_regs.values.T @ traces[:n_t, :]
    reg = reg.reshape(reg.shape[0], img.shape[-2], img.shape[-1])
    
    print(np.shape(img))
    print(np.shape(traces))
    # tuning vector
    bin_centers, bins = quantize_directions([0], n_dirs)
    vectors = np.stack([np.cos(bin_centers), np.sin(bin_centers)], 0)
    reg_vectors = np.reshape(
        vectors @ np.reshape(reg[:, :, :], (n_dirs, -1)),
        (2,) + reg.shape[1:],
    )
    
    print(np.shape(reg_vectors))
    
    angle = np.arctan2(reg_vectors[1], reg_vectors[0])
    amp = np.sqrt(np.sum(reg_vectors ** 2, 0))

    return amp, angle

In [None]:
master = Path(r"Z:\Hagar\s11\e0040")
master = Path(r"\\portulab.synology.me\data\Hagar\e0040\v10\gad1b")


fish_list = list(master.glob("*_f*"))
fish_list

In [None]:
#### dff all fish

for f in fish_list[-1:]:
    print(f)
    
    exp_list = glob(str(f / "behavior/*.json"))
    reg_list = [make_sensory_regressors(Experiment(exp), sampling=sampling) for exp in exp_list]
    frame_list = [no_regressor_frames(reg) for reg in reg_list]
    try:
        if not (f / "dff2").exists():
            aligned = SplitDataset(f / "aligned")
            len_rec, num_planes, x_pix, y_pix = np.shape(aligned)

            f0_stack = np.empty((len_rec, x_pix, y_pix))
            for i, frames in enumerate(frame_list):
                print(frames)
                f0 = calc_f0(aligned[:,i,:,:], frames)
                f0_stack[i,:,:] = f0

            # will created a dff split-dataset folder
            stack = dff(aligned, f0_stack)
        else:
            print("dfffized") 
    except:
        print("aaaaaaaaaaaaaaaaaaaaaahiugoua")

In [None]:
fish = fish_list[-2]
print(fish)

aligned = SplitDataset(fish / "aligned")
exp_list = glob(str(fish / "behavior/*.json"))

sampling = 1/3
time = np.linspace(0, aligned.shape[0]*sampling, aligned.shape[0])

In [None]:
len_rec, num_planes, x_pix, y_pix = np.shape(aligned)

sampling = 1/3
time = np.linspace(0, aligned.shape[0]*sampling, aligned.shape[0])

In [None]:
# make a list of sensory regressors for each plane
reg_list = [make_sensory_regressors(Experiment(exp), sampling=sampling) for exp in exp_list]

In [None]:
len(reg_list)

In [None]:
'''

# calculate the baseline image for each plane

frame_list = [no_regressor_frames(reg) for reg in reg_list]

f0_stack = np.empty((len_rec, x_pix, y_pix))
for i, frames in enumerate(frame_list):
    print(frames)
    f0 = calc_f0(aligned[:,i,:,:], frames)
    f0_stack[i,:,:] = f0


#f0_stack = np.empty((aligned.shape[0], x_pix, y_pix))
#f0 = calc_f0(aligned[:,:,:,:], frames)
#f0_stack = f0

    
# will created a dff split-dataset folder
stack = dff(aligned, f0_stack)

'''

In [None]:
aligned = 0
stack = SplitDataset(fish / "dff")

In [None]:
# calculate tuning

amps = []
angles = []
for i in range(stack.shape[1]):
    print(i)
    img = stack[:,i,:,:]
    amp, angle = get_tuning_map(img, reg_list[i])
    amps.append(amp)
    angles.append(angle)

df = pd.DataFrame(list(zip(amps, angles)), columns=["amp", "angle"])

In [None]:
# make a color map from the amplitude/angle

pctl = 90

color_maps = []
for i in range(stack.shape[1]):
    amp = df.loc[i, "amp"]
    angle = df.loc[i, "angle"]
    color_map = color_stack_3d(np.nan_to_num(amp), np.nan_to_num(angle), amp_percentile=pctl) #default percentile was 80
    color_maps.append(color_map)
    
color_maps = np.array(color_maps)

In [None]:
fl.save(fish / "tuning_map_90_fixed_2024.h5", color_maps)

In [None]:
color_maps = fl.load(fish / "tuning_map_90_fixed_2024.h5")

In [None]:
with napari.gui_qt():
    v = napari.view_image(color_maps)

In [None]:
num_planes = len(reg_list)

In [None]:
num_row = 3
num_col = 3
fig, axs = plt.subplots(num_row, num_col, figsize=(10, 10), sharey=True, sharex=True)

count = 0 

for i in range(0, num_planes):
    r = i // num_col
    c = np.mod(i, num_col)
    
    if count > 0:
        axs[r,c].axis('off')
    else:
        count += 1
        
    tmp_plane = np.rot90(color_maps[i], 2)
    #tmp_plane = np.ma.masked_where(tmp_anatomy < 1, tmp_plane)
    axs[r,c].imshow(tmp_plane)

In [None]:
fig.savefig(fish / "tuning_curve2.pdf", dpi=300)
fig.savefig(fish / "tuning_curve2.jpg", dpi=300)