In [None]:
%matplotlib widget

In [None]:
from glob import glob
import numpy as np
import pandas as pd
import flammkuchen as fl
from split_dataset import SplitDataset
from bouter import Experiment
from fimpy.pipeline.general import calc_f0, dff
from motions.utilities import stim_vel_dir_dataframe, quantize_directions
from scipy.interpolate import interp1d 
from scipy.signal import convolve2d
import colorspacious
import napari
import matplotlib.pyplot as plt
import json

from fimpylab.core.twop_experiment import TwoPExperiment
from skimage.filters import threshold_otsu

from pathlib import Path
import tifffile as tiff

from general_analysis.helper_functions_imaging.general_imaging import normalize_traces

In [None]:
# make sensory regressors. requires old bouter stimulus_param_log.
def make_sensory_regressors(exp, n_dirs=8, upsampling=5, sampling=1/3):
    stim = stim_vel_dir_dataframe(exp)
    bin_centres, dir_bins = quantize_directions(stim.theta)
    ind_regs = np.zeros((n_dirs, len(stim)))
    for i_dir in range(n_dirs):
        ind_regs[i_dir, :] = (np.abs(dir_bins - i_dir) < 0.1) & (stim.vel > 0.1)  

    dt_upsampled = sampling / upsampling
    t_imaging_up = np.arange(0, stim.t.values[-1], dt_upsampled)
    reg_up = interp1d(stim.t.values, ind_regs, axis=1, fill_value="extrapolate")(
        t_imaging_up
    )
    
    # 6s kernel
    u_steps = t_imaging_up.shape[0]
    u_time = np.arange(u_steps) * dt_upsampled
    decay = np.exp(-u_time / (1.5 / np.log(2)))
    kernel = decay / np.sum(decay)
    
    convolved = convolve2d(reg_up, kernel[None, :])[:, 0:u_steps]
    reg_sensory = convolved[:, ::upsampling]

    return pd.DataFrame(reg_sensory.T, columns=[f"motion_{i}" for i in range(n_dirs)])

In [None]:
# calculate directional tuning from dF/F traces, px-wise
def get_tuning_map(traces, sens_regs, n_dirs=8):

    n_t = sens_regs.shape[0]
    reg = sens_regs.T @ traces[:n_t, :]
    #print(np.shape(reg))
    #reg = reg.reshape(reg.shape[0], traces.shape[-1], traces.shape[-1])
    
    # tuning vector
    bin_centers, bins = quantize_directions([0], n_dirs)
    vectors = np.stack([np.cos(bin_centers), np.sin(bin_centers)], 0)
    #print(np.shape(vectors))
    reg_vectors = vectors @ reg
    #print(np.shape(reg_vectors))

    angle = np.arctan2(reg_vectors[1], reg_vectors[0])
    amp = np.sqrt(np.sum(reg_vectors ** 2, 0))

    return amp, angle

In [None]:
# make a color map

def JCh_to_RGB255(x):
    output = np.clip(colorspacious.cspace_convert(x, "JCh", "sRGB1"), 0, 1)
    return (output * 255).astype(np.uint8)

def color_stack(
        amp,
        angle,
        hueshift=2.5,
        amp_percentile=80,
        maxsat=50,
        lightness_min=100,
        lightness_delta=-40,
    ):
    output_lch = np.zeros((amp.shape[0], 3))
    print(np.shape(output_lch))
    output_lch[:,0]
    maxamp = np.percentile(amp, amp_percentile)

    output_lch[:, 0] = (
            lightness_min + (np.clip(amp / maxamp, 0, 1)) * lightness_delta
    )
    output_lch[:, 1] = (np.clip(amp / maxamp, 0, 1)) * maxsat
    output_lch[:, 2] = (angle + hueshift) * 180 / np.pi

    return JCh_to_RGB255(output_lch)

In [None]:
master = Path(r"Z:\Hagar and Ot\e0075\ipn h2b")

fish_list = list(master.glob("*_f*"))
fish = fish_list[1]
print(fish)

In [None]:
anatomy = fl.load(fish / 'registration' / 'ref_mapped.h5').T
fish_path = fish / 'suite2p'
paths = list(fish_path.glob("*00*"))

In [None]:
img_exp = TwoPExperiment(fish)
fs = img_exp['imaging']['microscope_config']['scanning']['framerate']
sampling = 1/fs
res = img_exp.resolution
z_res, x_res, y_res = img_exp.resolution

thresh = 0.3

In [None]:
ind_count = 0

coords = fl.load(fish / 'registration' / 'ref_roi_coords_mapped.h5')

count=0
for path in paths:
    ### concatenate reliability index for all fish
    rel_index = fl.load(path / 'reliability_index_arr.h5')['reliability_arr_combined']

    ### concatenate all regression values for all cells for all fish
    traces = fl.load(path / "filtered_traces.h5", "/detr")

    # make a list of sensory regressors 
    try:
        sensory_regressors = fl.load(path / "sensory_regressors.h5", "/regressors_conv")
    except:
        sensory_regressors = fl.load(path / "sensory_regressors_cells.h5", "/regressors_conv")
    reg_list = [sensory_regressors]
    n_t = sensory_regressors.shape[0]

    # calculate tuning
    amp, angle = get_tuning_map(traces, sensory_regressors.T)

    if count is 0:
        all_reliability = rel_index
        all_amp = amp
        all_angle = angle
    else:
        all_reliability = np.append(all_reliability, rel_index)
        all_amp = np.append(all_amp, amp)
        all_angle = np.append(all_angle, angle)
    count += 1



thresh = threshold_otsu(all_reliability)
print("Reliability threshold: ", thresh)
thresh = 0.15

colors = color_stack(all_amp, all_angle)
amp_thresh = np.copy(all_amp)
amp_thresh[np.where(all_amp < thresh)[0]] *= 0

colors_thresh = np.copy(colors)
colors_thresh[np.where(all_amp < thresh)[0]] *= 0
colors_thresh[np.where(all_amp < thresh)[0]] += 220


selected_vis = np.where(all_reliability > thresh)[0]
coords_vis = coords[selected_vis]
colors_vis = coords[selected_vis]
colors_thresh = colors_thresh[selected_vis]
amp_vis = all_amp[selected_vis]


mp_ind = np.argsort(amp_vis)

ind_count += 1    

In [None]:
fig2, axs2 = plt.subplots(1, 1, figsize=(3, 3))
axs2.spines['right'].set_visible(False)
axs2.spines['top'].set_visible(False)
axs2.imshow(np.rot90(np.nanmean(anatomy, axis=0), 3), cmap='gray_r')
axs2.scatter(coords_vis[mp_ind,1], coords_vis[mp_ind,0], c=colors_thresh[mp_ind]/255, s=10)
axs2.invert_yaxis()