In [None]:
%matplotlib widget 

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import flammkuchen as fl

from glob import glob
import pandas as pd
from bouter import Experiment
import colorspacious
import json
from scipy import stats
from quickdisplay import *
from motions.utilities import stim_vel_dir_dataframe, quantize_directions
from skimage.filters import threshold_otsu

from fimpylab.core.twop_experiment import TwoPExperiment

In [None]:
def corr2_coeff(A, B):
    # Rowwise mean of input arrays & subtract from input arrays themeselves
    A_mA = A - A.mean(1)[:, None]
    B_mB = B - B.mean(1)[:, None]

    # Sum of squares across rows
    ssA = (A_mA**2).sum(1)
    ssB = (B_mB**2).sum(1)

    # Finally get corr coeff
    return np.dot(A_mA, B_mB.T) / np.sqrt(np.dot(ssA[:, None],ssB[None]))

In [None]:
# calculate directional tuning from dF/F traces, px-wise
def get_tuning_map(traces, sens_regs, n_dirs=8):

    n_t = sens_regs.shape[0]
    reg = sens_regs.T @ traces[:n_t, :]
    #reg = reg.reshape(reg.shape[0], traces.shape[-1], traces.shape[-1])
    
    # tuning vector
    bin_centers, bins = quantize_directions([0], n_dirs)
    vectors = np.stack([np.cos(bin_centers), np.sin(bin_centers)], 0)
    reg_vectors = vectors @ reg

    angle = np.arctan2(reg_vectors[1], reg_vectors[0])
    amp = np.sqrt(np.sum(reg_vectors ** 2, 0))

    return amp, angle

In [None]:
# make a color map

def JCh_to_RGB255(x):
    output = np.clip(colorspacious.cspace_convert(x, "JCh", "sRGB1"), 0, 1)
    return (output * 255).astype(np.uint8)

def color_stack(
        amp,
        angle,
        hueshift=2.5,
        amp_percentile=80,
        maxsat=50,
        lightness_min=100,
        lightness_delta=-40,
    ):
    output_lch = np.zeros((amp.shape[0], 3))
    print(np.shape(output_lch))
    output_lch[:,0]
    maxamp = np.percentile(amp, amp_percentile)

    output_lch[:, 0] = (
            lightness_min + (np.clip(amp / maxamp, 0, 1)) * lightness_delta
    )
    output_lch[:, 1] = (np.clip(amp / maxamp, 0, 1)) * maxsat
    output_lch[:, 2] = (angle + hueshift) * 180 / np.pi

    return JCh_to_RGB255(output_lch)

In [None]:
master = Path(r"Z:\Hagar and Ot\e0075\ipn h2b")

fish_list = list(master.glob("*_f*"))
num_fish = len(fish_list)

num_positions = 8

In [None]:
for fish in range(0, num_fish):
    fish_path = fish_list[fish]
    exp = TwoPExperiment(fish_path)
    fish_id = exp.fish_id
    print(fish_id)
    
    
    ### concatenate transformed coords for all fish
    if fish is 1:
        coords = fl.load(fish_path / "registration" / "ref_roi_coords_mapped.h5")
    else:
        coords = fl.load(fish_path / "registration" / "mov_coords_transformed.h5")
        
    if fish is 0:
        all_coords = coords
    else:
        all_coords = np.append(all_coords, coords, axis=0)
    
    #### concatenating plane wise data
    suite2p_path = fish_path / 'suite2p'
    path_list = list(suite2p_path.glob("*00*"))
    count = 0
    
    for path in path_list:
        ### concatenate reliability index for all fish
        rel_index = fl.load(path / 'reliability_index_arr.h5')['reliability_arr_combined']
        
        ### concatenate all regression values for all cells for all fish
        traces = fl.load(path / "filtered_traces.h5", "/detr")
        
        # make a list of sensory regressors 
        try:
            sensory_regressors = fl.load(path / "sensory_regressors.h5", "/regressors_conv")
        except:
            sensory_regressors = fl.load(path / "sensory_regressors_cells.h5", "/regressors_conv")
        reg_list = [sensory_regressors]
        n_t = sensory_regressors.shape[0]

        # calculate tuning
        amp, angle = get_tuning_map(traces, sensory_regressors.T)

        if fish is 0 and count is 0:
            all_reliability = rel_index
            all_amp = amp
            all_angle = angle
        else:
            all_reliability = np.append(all_reliability, rel_index)
            all_amp = np.append(all_amp, amp)
            all_angle = np.append(all_angle, angle)
        count += 1

In [None]:
colors = color_stack(all_amp, all_angle)

In [None]:
thresh = threshold_otsu(all_reliability)
print("Reliability threshold: ", thresh)

colors = color_stack(all_amp, all_angle)
amp_thresh = np.copy(all_amp)
amp_thresh[np.where(all_amp < thresh)[0]] *= 0

colors_thresh = np.copy(colors)
colors_thresh[np.where(all_amp < thresh)[0]] *= 0
colors_thresh[np.where(all_amp < thresh)[0]] += 220

selected_vis = np.where(all_reliability > thresh)[0]
coords_vis = all_coords[selected_vis]
colors_vis = all_coords[selected_vis]
colors_thresh = colors_thresh[selected_vis]
amp_vis = all_amp[selected_vis]


mp_ind = np.argsort(amp_vis)

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(10, 3))

mp_ind_all = np.argsort(all_amp)
ax[0].scatter(all_coords[mp_ind_all,0], all_coords[mp_ind_all,1], c=colors[mp_ind_all]/255, s=2, alpha=0.8)
ax[0].set_title("All ROIs")

thresh = 0.2
thresh = threshold_otsu(all_reliability)

colors = color_stack(all_amp, all_angle)
amp_thresh = np.copy(all_amp)
amp_thresh[np.where(all_amp < thresh)[0]] *= 0

colors_thresh = np.copy(colors)
colors_thresh[np.where(all_amp < thresh)[0]] *= 0
colors_thresh[np.where(all_amp < thresh)[0]] += 220

selected_vis = np.where(all_reliability > thresh)[0]
coords_vis = all_coords[selected_vis]
colors_vis = all_coords[selected_vis]
colors_thresh = colors_thresh[selected_vis]
amp_vis = all_amp[selected_vis]


mp_ind = np.argsort(amp_vis)

ax[1].scatter(coords_vis[mp_ind,0], coords_vis[mp_ind,1], c=colors_thresh[mp_ind]/255, s=2, alpha=0.8)
ax[1].set_title("Thresh " + str(thresh))

thresh = 0.15
amp_thresh = np.copy(all_amp)
amp_thresh[np.where(all_amp < thresh)[0]] *= 0

colors_thresh = np.copy(colors)
colors_thresh[np.where(all_amp < thresh)[0]] *= 0
colors_thresh[np.where(all_amp < thresh)[0]] += 220

selected_vis = np.where(all_reliability > thresh)[0]
coords_vis = all_coords[selected_vis]
colors_vis = all_coords[selected_vis]
colors_thresh = colors_thresh[selected_vis]
amp_vis = all_amp[selected_vis]


mp_ind = np.argsort(amp_vis)
ax[2].scatter(coords_vis[mp_ind,0], coords_vis[mp_ind,1], c=colors_thresh[mp_ind]/255, s=2, alpha=0.8)
ax[2].set_title("Thresh " + str(thresh))