In [None]:
%matplotlib widget

from pathlib import Path
import numpy as np
import flammkuchen as fl
import pandas as pd

from fimpylab import LightsheetExperiment

from matplotlib import  pyplot as plt
import seaborn as sns
from tqdm import tqdm
sns.set(style="ticks", palette="deep")
cols = sns.color_palette()
import ipywidgets as widgets

from lotr.utils import zscore
from lotr.pca import pca_and_phase, get_fictive_heading, fictive_heading_and_fit, \
        fit_phase_neurons,qap_sorting_and_phase
from circle_fit import hyper_fit
from lotr import LotrExperiment, A_FISH

In [None]:
plt.close("all")

master_path =  Path(r"\\funes\Shared\experiments\E0040_motions_cardinal\v13_cw_ccw\ls_fixed\spont_plus_v13\huc")
fish_list = list(master_path.glob("*_f*"))
path = fish_list[10]
traces = fl.load(path / "filtered_traces.h5", "/detr")

reg_df = fl.load(path / "motor_regressors.h5")
cc_motor = reg_df["all_bias_abs"].values
cc_motor_integr = reg_df["all_bias_abs_dfdt"].values
coords = fl.load(path / "data_from_suite2p_cells.h5", "/coords")
anat = fl.load(path / "data_from_suite2p_cells.h5", "/anatomy_stack")
traces[np.isnan(traces)] = 0

exp = LotrExperiment(path)
fn = int(exp.fn)
beh_df = exp.behavior_log

t_lims = [500, exp.n_pts // 2]
t_slice = slice(*t_lims)

In [None]:
num_planes = int(np.max(coords[:,0]))
print(num_planes)
fig, axs = plt.subplots(3, 5, figsize=(14, 8))
for i in range(num_planes):
    r = np.mod(i, 3)
    c = i // 3
    axs[r, c].imshow(anat[i], vmax=400, vmin=0)

    s1 = 240
    s2 = 540
    axs[r, c].axvline(s1)
    axs[r, c].axvline(s2)

    s3 = 220
    s4 = 400
    axs[r, c].axhline(s3)
    axs[r, c].axhline(s4)
    axs[r, c].set_title(str(i))

In [None]:
### Set plane range 
s5 = 1
s6 = 19

In [None]:
sel_to_nan = (coords[:, 2] < s1) | (coords[:, 2] > s2) | (coords[:, 1] < s3) | (coords[:, 1] > s4) | (coords[:, 0] < s5) | (coords[:, 0] > s6)
ahb_idx = (coords[:, 2] > s1) & (coords[:, 2] < s2) & (coords[:, 1] > s3) & (coords[:, 1] < s4) & (coords[:, 0] > s5) & (coords[:, 0] < s6)
coords_in = coords[ahb_idx]

traces[:, sel_to_nan] = 0
coords[sel_to_nan] = 0
cc_motor[sel_to_nan] = np.nan
cc_motor_integr[sel_to_nan] = np.nan

coords_in = coords[ahb_idx]
axs[r, c].scatter(coords[:, 2], coords[:, 1], c=(0.9,)*3)
axs[r, c].scatter(coords_in[:, 2], coords_in[:, 1], c=(0.2,)*3)

In [None]:
print(np.max(coords_in[:,0]))

In [None]:
coords_crop = np.delete(coords, sel_to_nan, axis=0)
traces_crop = np.delete(traces, sel_to_nan, axis=1)
cc_motor_crop = np.delete(cc_motor, sel_to_nan, axis=0)
cc_motor_integr_crop = np.delete(cc_motor_integr, sel_to_nan, axis=0)

In [None]:
np.shape(traces_crop)

In [None]:
d = {
    'traces': traces_crop,
    'coords': coords_crop,
    'cc_motor': cc_motor_crop,
    'cc_motor_integr': cc_motor_integr_crop,
    'ahb_idx': ahb_idx,
    'x_limits': [s1, s2],
    'y_limits': [s3, s4],
    'z_limits': [s5, s6],
}
fl.save(path / 'ahb_cropped.h5', d)


In [None]:
cc_wnd = 4000
i_array = np.arange(t_slice.start, t_slice.stop, cc_wnd*fn)
cc_mats = np.zeros((traces.shape[1], traces.shape[1], len(i_array)))

for n, i in enumerate(i_array):
    cc_mats[:, :, n] = np.corrcoef(traces[i:i + cc_wnd*fn, :].T)
corr_mat = np.nanmean(cc_mats, 2)

selection_arr = np.zeros(traces.shape[1])

f = plt.figure(figsize=(3, 3))
x = np.arange(-0.2, np.nanmax(cc_motor), 0.05)
s = plt.scatter(cc_motor, cc_motor_integr, s=10, c=selection_arr, vmin=0, vmax=1)

l_plot = plt.plot(x, x*0.2 + 0.15)
l_max = plt.axvline(1)
l_min = plt.axhline(0)

@widgets.interact(c=(0.05, 2, 0.05), o=(-0.5, 1, 0.02), mot_max=(0, 1, 0.05),
                 integr_min=(0, 1, 0.02), max_corr=(-1, 0, 0.05))
def update(o=0.3, c=0.2, mot_max=1, integr_min=0, max_corr=-0.7):
    l_plot[0].set_data(x, x*c + o)
    print(cc_motor_integr.shape, cc_motor.shape, (cc_motor*c + o).shape)
    selection_arr[:] = (cc_motor_integr > cc_motor*c + o) & \
                       (np.abs(cc_motor) < mot_max) & \
                       (np.abs(cc_motor_integr)> integr_min) | \
                       ((np.nanmin(corr_mat, 0) < max_corr) & \
                        (np.abs(cc_motor) < mot_max) & \
                        (np.abs(cc_motor_integr) > integr_min)) 
    l_max.set_xdata(mot_max)
    l_min.set_ydata(integr_min)

    
    s.set_array(selection_arr)

plt.ylim(-0.15, 0.4)
plt.xlim(-0.3, 1.01)
plt.xlabel("cc. traces - motor regressor")
plt.ylabel("cc. d(traces)/dt - regressor")
sns.despine()

In [None]:
selected = np.argwhere(selection_arr)[:, 0]
print(len(selected))

In [None]:
pcaed_t, phase_t, _, _ = pca_and_phase(traces[t_slice, selected].T, traces[t_slice, selected].T)
hf_c = hyper_fit(pcaed_t)
pcaed_t_all, _, _, _ = pca_and_phase(traces[t_slice, selected].T, traces[t_slice, :].T)


plt.figure(figsize=(7, 3))
thr = 35
sel = (pcaed_t[:, 0]**2+pcaed_t[:, 1]**2)**(1/2) > thr
plt.scatter(pcaed_t[:, 0], pcaed_t[:, 1], c=sel)
plt.scatter(pcaed_t_all[:, 0], pcaed_t_all[:, 1], edgecolor="k", facecolor="none", lw=0.2)
plt.axis("equal")

#selected = selected[sel]
# pcaed, phase = pca_and_phase(traces[t_slice, selected], traces[:, selected])
#pcaed_spont, phase_spont = pca_and_phase(traces[t_slice, selected], traces[t_slice, selected])
pcaed, phase, _, _ = pca_and_phase(traces[t_slice, selected], traces[:, selected])

x1 = hf_c[2]*np.cos(np.linspace(0, 2*np.pi, 100)) + hf_c[0]
x2 = hf_c[2]*np.sin(np.linspace(0, 2*np.pi, 100)) + hf_c[1]

plt.plot(x1, x2)

In [None]:
new_selection_arr = (np.abs(np.sqrt((pcaed_t_all[:, 0] - hf_c[0])**2 + (pcaed_t_all[:, 1] - hf_c[1])**2) - hf_c[2]) < 0) | \
    ((np.sqrt((pcaed_t_all[:, 0] - hf_c[0])**2 + (pcaed_t_all[:, 1] - hf_c[1])**2) - hf_c[2]) > 0)
selected = np.argwhere(new_selection_arr)[:, 0]

In [None]:
plt.figure(figsize=(7, 2.5))
plt.plot(traces[:, selected] + 4)
print(len(selected))
plt.plot(beh_df["t"]*fn, beh_df["tail_sum"])
plt.show()

In [None]:
# plt.close("all")
# t_slice = slice(0, )
pcaed, phase, _, _ = pca_and_phase(traces[t_slice, selected], traces[:, selected])
mot_t_slice = slice(traces.shape[0] // 2, traces.shape[0])
f, axs = plt.subplots(1, 3, figsize=(9., 4.), sharex=True, sharey=True)
# plt.subplot(1,2,1)
#phase = np.angle((pcaed[:, 0] - 2) + 1j * (pcaed[:, 1] +5))
for i, s in enumerate([t_slice, mot_t_slice,  t_slice]):
    
    axs[i].plot(pcaed[s, 0], pcaed[s, 1], 
             c=(0.6,)*3, lw=0.5, zorder=-100) # , c=phase, cmap="twilight", lw=3)
    axs[i].scatter(pcaed[s, 0], pcaed[s, 1], 
                     c=phase[s], lw=0.5, s=5, cmap="twilight",) 
# plt.axis("equal")
sns.despine()

In [None]:
%%time
import os
perm, com_phase = qap_sorting_and_phase(traces[:, selected], t_lims=t_lims)

phases_neuron, _ = fit_phase_neurons(traces[t_slice, selected], phase[t_slice])
perm_pca = np.argsort(phases_neuron)
os.system('say "Fit completed"')

In [None]:
l = 2
f, axs = plt.subplots(2,2, figsize=(7, 7), sharey=True)
# plt.subplot(121)
axs[0, 0].imshow(np.corrcoef(traces[t_slice, selected].T)[perm, :][:, perm], 
           vmax=1, vmin=-1, cmap="RdBu_r", aspect="auto")

axs[0, 1].imshow(traces[:, selected[perm]].T, cmap="gray_r", interpolation="none",
              aspect="auto", vmin=-l, vmax=l)

axs[1, 0].imshow(np.corrcoef(traces[t_slice, selected].T)[perm_pca, :][:, perm_pca], 
           vmax=1, vmin=-1, cmap="RdBu_r", aspect="auto")

axs[1,1].imshow(traces[:, selected[perm_pca]].T, cmap="gray_r", interpolation="none",
              aspect="auto", vmin=-l, vmax=l)


In [None]:
# old_selected = selected.copy()
rm_from_selected = np.array([84])
selected[perm[rm_from_selected]] = -1
##for i in s:# [23,  64,  82, 110, 152, 193,  87, 127, 143,  57,  33, 226, 201,  89, 155,  92,  34]:
#    if i is not " " and i is not "":
#        selected[perm == int(i)] = -1
selected = selected[selected > 0]

In [None]:
unwrapped_phase = np.unwrap(phase)
unwrapped_com_phase = np.unwrap(com_phase)

traj, params = fictive_heading_and_fit(unwrapped_phase, df, min_bias=0.1)
print(params)

plt.figure(figsize=(7, 3))
plt.scatter(np.arange(len(traj[:])), zscore(unwrapped_phase), 
            c=phase[:], cmap="twilight", s=2)
plt.scatter(np.arange(len(traj[:])), -zscore(unwrapped_com_phase), 
            c=com_phase[:], cmap="twilight", s=0.2)
plt.plot(-zscore(traj), c=cols[1])

In [None]:
f, axs = plt.subplots(1, 2, figsize=(6, 3))
s = coords[:, 0] > 0
selection = np.full(coords.shape[0], False)
selection[selected] = True
all_phases = np.zeros(coords.shape[0])
all_phases[selected] = phases_neuron

all_perm = -np.ones(coords.shape[0])
all_perm[selected] = perm

axs[0].scatter(coords[s, 1], coords[s, 2], c=(0.5,)*3)
axs[0].scatter(coords[s, :][selection[s], 1], coords[s, :][selection[s], 2],
            c=all_phases[s][selection[s]], cmap="twilight")
axs[0].axis("equal")
axs[0].axis("off")

axs[1].scatter(coords[s, 1], coords[s, 2], c=(0.5,)*3)
axs[1].scatter(coords[s, :][selection[s], 1], coords[s, :][selection[s], 2],
            c=np.linspace(-np.pi, np.pi, sum(all_perm[s] >= 0)+1)[np.argsort(all_perm[s][all_perm[s] >= 0])] , 
               cmap="twilight")
axs[1].axis("equal")
axs[1].axis("off")

In [None]:
fl.save(path / "selected.h5", selected)