In [None]:
%matplotlib widget

In [None]:
from glob import glob
import numpy as np
import pandas as pd
import flammkuchen as fl
from split_dataset import SplitDataset
from bouter import Experiment
from fimpy.pipeline.general import calc_f0, dff
from motions.utilities import stim_vel_dir_dataframe, quantize_directions
from scipy.interpolate import interp1d 
from scipy.signal import convolve2d
import colorspacious
import napari
import matplotlib.pyplot as plt

from fimpylab.core.twop_experiment import TwoPExperiment

from pathlib import Path


In [None]:
# find the frames to calculate the baseline.
def no_regressor_frames(regressors, threshold=0.05):
    return np.where(np.all(regressors < threshold, axis=0))[0]

# calculate the baseline, plane-wise
def calc_f0(stack, frames):
    fr_mean = None
    for i_frame in frames:
        sf = stack[int(i_frame), :, :]
        if fr_mean is None:
            fr_mean = sf
        else:
            fr_mean += sf
    return fr_mean / len(frames)

In [None]:
# calculate directional tuning from dF/F traces, px-wise
def get_tuning_map(img, sens_regs, n_dirs=8):
    traces = img.reshape(img.shape[0], -1)

    n_t = sens_regs.shape[0]
    print(n_t)
    reg = sens_regs.T @ traces[:, :]
    reg = reg.reshape(reg.shape[0], img.shape[-2], img.shape[-1])
    
    print(np.shape(img))
    print(np.shape(traces))
    # tuning vector
    bin_centers, bins = quantize_directions([0], n_dirs)
    vectors = np.stack([np.cos(bin_centers), np.sin(bin_centers)], 0)
    reg_vectors = np.reshape(
        vectors @ np.reshape(reg[:, :, :], (n_dirs, -1)),
        (2,) + reg.shape[1:],
    )
    
    
    print(np.shape(reg_vectors))
    
    angle = np.arctan2(reg_vectors[1], reg_vectors[0])
    amp = np.sqrt(np.sum(reg_vectors ** 2, 0))

    return amp, angle

In [None]:
# make a color map

def JCh_to_RGB255(x):
    output = np.clip(colorspacious.cspace_convert(x, "JCh", "sRGB1"), 0, 1)
    return (output * 255).astype(np.uint8)

def color_stack(
        amp,
        angle,
        hueshift=2.5,
        amp_percentile=80,
        maxsat=50,
        lightness_min=100,
        lightness_delta=-40,
    ):
    output_lch = np.empty(amp.shape + (3,))
    maxamp = np.percentile(amp, amp_percentile)

    output_lch[:, :, 0] = (
            lightness_min + (np.clip(amp / maxamp, 0, 1)) * lightness_delta
    )
    output_lch[:, :, 1] = (np.clip(amp / maxamp, 0, 1)) * maxsat
    output_lch[:, :, 2] = (-angle + hueshift) * 180 / np.pi

    return JCh_to_RGB255(output_lch)

In [None]:
master = Path(r"\\portulab.synology.me\data\Hagar\vision and navigation - landmarks\e0075\v06\new ipn")

fish_list = list(master.glob("*_f*"))
fish_list


In [None]:
fish = fish_list[-2]
print(fish)

aligned = SplitDataset(fish / "aligned")
exp_list = glob(str(fish / "behavior/*.json"))

sampling = 1/3
time = np.linspace(0, aligned.shape[0]*sampling, aligned.shape[0])

In [None]:
len_rec, num_planes, x_pix, y_pix = np.shape(aligned)
np.shape(aligned)

In [None]:
# make a list of sensory regressors for each plane
plane_list = list(fish.glob("suite2p\*00*"))
reg_list = [None] * num_planes

for i in range(len(plane_list)):
    new_reg = fl.load(plane_list[i] / 'sensory_regressors.h5')['regressors_conv']
    reg_list[i] = new_reg

In [None]:
# make a list of sensory regressors for each plane
plane_list = list(fish.glob("suite2p\*00*"))
reg_list = [None] * num_planes

for i in range(len(plane_list)):
    new_reg = fl.load(plane_list[i] / 'sensory_regressors.h5')['regressors_conv']
    reg_list[i] = new_reg

In [None]:
print(np.shape(aligned))
print(np.shape(reg_list))

In [None]:
stack = SplitDataset(fish / "dff2")

In [None]:
# calculate tuning

amps = []
angles = []
for i in range(stack.shape[1]):
    print(i)
    img = stack[:,i,:,:]
    amp, angle = get_tuning_map(img, reg_list[i].T, n_dirs=8)
    amps.append(amp)
    angles.append(angle)

df = pd.DataFrame(list(zip(amps, angles)), columns=["amp", "angle"])

In [None]:
# make a color map from the amplitude/angle

pctl = 90

color_maps = []
for i in range(stack.shape[1]):
    amp = df.loc[i, "amp"]
    angle = df.loc[i, "angle"]
    color_map = color_stack(np.nan_to_num(amp), np.nan_to_num(angle), amp_percentile=pctl) #default percentile was 80
    color_maps.append(color_map)
    
color_maps = np.array(color_maps)

In [None]:
fl.save(fish / "tuning_map_new_dff_{}.h5".format(pctl), color_maps)

In [None]:
pctl = 90
color_maps = fl.load(fish / "tuning_map_new_dff_{}.h5".format(pctl))

In [None]:
#viewer = napari.Viewer()
#viewer.add_image(color_maps)

In [None]:
num_row = 4
num_col = num_row
fig, axs = plt.subplots(num_row, num_col, figsize=(10, 10), sharey=True, sharex=True)

count = 0 

for i in range(0, num_planes):
    r = i // num_col
    c = np.mod(i, num_col)
    
    if count > -1:
        axs[r,c].axis('off')
    else:
        count += 1
        
    tmp_plane = np.rot90(color_maps[i], 3)
    #tmp_plane = np.ma.masked_where(tmp_anatomy < 1, tmp_plane)
    axs[r,c].imshow(tmp_plane)

In [None]:
fig.savefig(fish / "tuning_curve new dff.pdf", dpi=300)
fig.savefig(fish / "tuning_curve new dff.jpg", dpi=300)