In [None]:
%matplotlib widget

In [None]:
from glob import glob
import numpy as np
import pandas as pd
import flammkuchen as fl
from split_dataset import SplitDataset
from fimpylab.core.twop_experiment import TwoPExperiment
from bouter import Experiment
from fimpy.pipeline.general import calc_f0, dff
from motions.utilities import stim_vel_dir_dataframe, quantize_directions
from scipy.interpolate import interp1d 
from scipy.signal import convolve2d
import colorspacious
import napari
import matplotlib.pyplot as plt
from pathlib import Path

In [None]:
# make sensory regressors. requires old bouter stimulus_param_log.
def make_sensory_regressors(exp, n_dirs=8, upsampling=5, sampling=1/2):
    stim = stim_vel_dir_dataframe(exp)
    bin_centres, dir_bins = quantize_directions(stim.theta)
    ind_regs = np.zeros((n_dirs, len(stim)))
    for i_dir in range(n_dirs):
        ind_regs[i_dir, :] = (np.abs(dir_bins - i_dir) < 0.1) & (stim.vel > 0.1)  

    dt_upsampled = sampling / upsampling
    t_imaging_up = np.arange(0, stim.t.values[-1], dt_upsampled)
    reg_up = interp1d(stim.t.values, ind_regs, axis=1, fill_value="extrapolate")(
        t_imaging_up
    )
    
    # 6s kernel
    u_steps = t_imaging_up.shape[0]
    u_time = np.arange(u_steps) * dt_upsampled
    decay = np.exp(-u_time / (1.5 / np.log(2)))
    kernel = decay / np.sum(decay)
    
    convolved = convolve2d(reg_up, kernel[None, :])[:, 0:u_steps]
    reg_sensory = convolved[:, ::upsampling]

    return pd.DataFrame(reg_sensory.T, columns=[f"motion_{i}" for i in range(n_dirs)])

In [None]:
# find the frames to calculate the baseline.
def no_regressor_frames(regressors, threshold=0.01):
    return np.where(np.all(regressors.values < threshold, axis=1))[0]

# calculate the baseline, plane-wise
def calc_f0(stack, frames):
    fr_mean = None
    for i_frame in frames:
        sf = stack[int(i_frame), :, :]
        if fr_mean is None:
            fr_mean = sf
        else:
            fr_mean += sf
    return fr_mean / len(frames)

In [None]:
# calculate directional tuning from dF/F traces, px-wise
def get_tuning_map(img, sens_regs, n_dirs=8):
    traces = img.reshape(img.shape[0], -1)

    n_t = sens_regs.shape[0]
    reg = sens_regs.values.T @ traces[:n_t, :]
    reg = reg.reshape(reg.shape[0], img.shape[-1], img.shape[-1])
    
    # tuning vector
    bin_centers, bins = quantize_directions([0], n_dirs)
    vectors = np.stack([np.cos(bin_centers), np.sin(bin_centers)], 0)
    reg_vectors = np.reshape(
        vectors @ np.reshape(reg[:, :, :], (n_dirs, -1)),
        (2,) + reg.shape[1:],
    )
    angle = np.arctan2(reg_vectors[1], reg_vectors[0])
    amp = np.sqrt(np.sum(reg_vectors ** 2, 0))

    return amp, angle

In [None]:
# make a color map

def JCh_to_RGB255(x):
    output = np.clip(colorspacious.cspace_convert(x, "JCh", "sRGB1"), 0, 1)
    return (output * 255).astype(np.uint8)

def color_stack(
        amp,
        angle,
        hueshift=2.5,
        amp_percentile=80,
        maxsat=50,
        lightness_min=100,
        lightness_delta=-40,
        max_amp=None
    ):
    output_lch = np.empty(amp.shape + (3,))
    
    if max_amp is None:
        maxamp = np.percentile(amp, amp_percentile)
    else:
        maxamp = max_amp

    output_lch[:, :, 0] = (
            lightness_min + (np.clip(amp / maxamp, 0, 1)) * lightness_delta
    )
    output_lch[:, :, 1] = (np.clip(amp / maxamp, 0, 1)) * maxsat
    output_lch[:, :, 2] = (angle + hueshift) * 180 / np.pi

    return JCh_to_RGB255(output_lch)

In [None]:
master =  Path(r"\\Funes\Shared\experiments\E0040_motions_cardinal\v13_cw_ccw\2p\ipn\itpr1b - fixed")
fishes = list(master.glob("*_f*"))
fish = fishes[0]

aligned = SplitDataset(fish / "aligned")
behavior_path = fish / "behavior"
exp_list = behavior_path.glob("*.json")
#exp_list = behavior_path.glob("*.json")[:aligned.shape[1]]
                                           
sampling = 1/2
time = np.linspace(0, aligned.shape[0]*sampling, aligned.shape[0])

In [None]:
# make a list of sensory regressors for each plane

reg_list = [make_sensory_regressors(Experiment(exp)) for exp in exp_list]

In [None]:
print(len(reg_list))
print(np.shape(reg_list[0]))
aligned.shape[1]

In [None]:
# calculate the baseline image for each plane

frame_list = [no_regressor_frames(reg) for reg in reg_list]
#meta_data = json.load(open(fish.glob("*metadata.json*")))
#num_planes = metadata["shape_full"][1]
#print(num_planes)
f0_stack = np.empty((aligned.shape[1], aligned.shape[-1], aligned.shape[-1]))
for i, frames in enumerate(frame_list):
    #print(i)
    try:
        f0 = calc_f0(aligned[:,i,:,:], frames)
        f0_stack[i,:,:] = f0
    except:
        print("S")   
    
# will created a dff split-dataset folder
stack = dff(aligned, f0_stack)

In [None]:
print(np.shape(aligned))

In [None]:
# calculate tuning
all_amp_percentile = np.zeros((np.shape(aligned)[1]))
amps = []
angles = []
for i in range(aligned.shape[1]):
    img = stack[:,i,:,:]
    amp, angle = get_tuning_map(img, reg_list[i])
    amps.append(amp)
    angles.append(angle)
    all_amp_percentile[i] = np.percentile(amp, 80)

df = pd.DataFrame(list(zip(amps, angles)), columns=["amp", "angle"])
max_amp = np.max(all_amp_percentile)

In [None]:
# fl.save(fish + "/tuning.h5", df)
max_amp = np.percentile(amps, 80)
max_amp

In [None]:
# make a color map from the amplitude/angle

pctl = 90

color_maps = []
for i in range(stack.shape[1]):
    amp = df.loc[i, "amp"]
    angle = df.loc[i, "angle"]
    color_map = color_stack(np.nan_to_num(amp), np.nan_to_num(angle), amp_percentile=pctl) #default percentile was 80
    color_maps.append(color_map)
    
color_maps = np.array(color_maps)

In [None]:
fl.save(fish / "tuning_map_{}.h5".format(pctl), color_maps)

In [None]:
#with napari.gui_qt():
#    v = napari.view_image(color_maps)

In [None]:
n_row = 7
n_col = 7
fig, ax = plt.subplots(n_row, n_col, figsize=(12,12))
for i in range(n_row*n_col):
    r = i // n_row
    c = np.mod(i, n_col)
    tmp_plane = color_maps[i]
    #print(np.min(tmp_plane), np.max(tmp_plane))
    tmp_plane = np.rot90(tmp_plane, k=1, axes=(1, 0))
    ax[r, c].imshow(tmp_plane,  vmin=0, vmax=255)
    ax[r, c].axis('off')
plt.show()
file_name = "tuning_plot_all_planes_210901.jpg"
fig.savefig(str(fish/file_name), dpi=300)

### check the regressors made by make_sensory_regressors

In [None]:
exp = Experiment(exp_list[0])
n_dirs=8
upsampling=5
sampling=1/3

'''def'''
stim = stim_vel_dir_dataframe(exp)
bin_centres, dir_bins = quantize_directions(stim.theta)
ind_regs = np.zeros((n_dirs, len(stim)))
for i_dir in range(n_dirs):
    ind_regs[i_dir, :] = (np.abs(dir_bins - i_dir) < 0.1) & (stim.vel > 0.1)  

dt_upsampled = sampling / upsampling
t_imaging_up = np.arange(0, stim.t.values[-1], dt_upsampled)
reg_up = interp1d(stim.t.values, ind_regs, axis=1, fill_value="extrapolate")(
    t_imaging_up
)

u_steps = t_imaging_up.shape[0]
u_time = np.arange(u_steps) * dt_upsampled
decay = np.exp(-u_time / (1.5 / np.log(2)))
kernel = decay / np.sum(decay)
convolved = convolve2d(reg_up, kernel[None, :])[:, 0:u_steps]
reg_sensory = convolved[:, ::upsampling]
'''return'''
sens_regs = pd.DataFrame(reg_sensory.T, columns=[f"motion_{i}" for i in range(n_dirs)])


fig, ax = plt.subplots(3,1, figsize=(8,4), constrained_layout=True)
ax[0].plot(stim["t"], stim["theta"])
for i in range(n_dirs):
    ax[1].plot(ind_regs[i,:])
    ax[2].plot(sens_regs.values.T[i,:])