In [None]:
minian_path = "/home/phild/Documents/sync/project/miniscope/MiniAn/"
caiman_path = "/home/phild/Documents/sync/project/miniscope/CaImAn/"
dpath = "/media/share/hdda/Denise/Wired_Valence/Wired_Valence_Organized_Data/"

In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import warnings
warnings.simplefilter('ignore')
sys.path.append(minian_path)
sys.path.append(caiman_path)
import numpy as np
import xarray as xr
import holoviews as hv
import pandas as pd
from minian.cross_registration import load_cnm_dataset, get_cnm_list, estimate_shifts, apply_shifts, calculate_centroids, calculate_centroid_distance, calculate_mapping, group_by_session
from minian.utilities import resave_varr, update_meta
from IPython.core.debugger import set_trace
hv.notebook_extension('bokeh', width=100)

In [None]:
%%time
def subsetting_unit(ds):
    ds = ds.sel(unit_id=ds.attrs['unit_mask'])
    return ds

for anm_path in next(os.walk(dpath))[1]:
    print("processing: {}".format(anm_path))
    anm_path = os.path.join(dpath, anm_path)
    flist = get_cnm_list(anm_path)
    if not flist:
        continue
    shifts, corrs, temps = estimate_shifts(flist, ['mean']*len(flist))
    temps_sh = apply_shifts(temps, shifts)
    temps = temps.astype(float)
    temps_sh = temps_sh.astype(float)
    cross_regi = xr.merge([shifts, corrs, temps, temps_sh])
    cross_regi.to_netcdf(anm_path + os.sep + "cross_regi.nc")
    with xr.open_mfdataset(flist, concat_dim='session', preprocess=subsetting_unit) as cnmds:
        print("\nloading spatial matrix")
        cnmds['A'].load()
        cnmds['b'].load()
        print("applying shift to spatial matrix")
        A_sh = apply_shifts(cnmds['A'], shifts)
        b_sh = apply_shifts(cnmds['b'], shifts)
        cnmds_sh = xr.merge([A_sh, b_sh])
        print("saving results")
        cnmds_sh.to_netcdf(anm_path + os.sep + "cnm_anm_sh.nc")

In [None]:
cnmds = xr.open_mfdataset(get_cnm_list(dpath, pattern=r'^cnm_anm_sh.nc$'), concat_dim='animal')
shiftds = xr.open_mfdataset(get_cnm_list(dpath, pattern=r'^cross_regi.nc$'), concat_dim='animal')

In [None]:
cents = calculate_centroids(cnmds)
cents.height = cents.height.astype(float)
cents.width = cents.width.astype(float)
cents.unit_id = cents.unit_id.astype(int)
cents.animal = cents.animal.astype(str)
cents.session = cents.session.astype(str)
cents.session_id = cents.session_id.astype(str)

In [None]:
dist = calculate_centroid_distance(cents)

In [None]:
dist_ft = dist[dist['variable', 'distance'] < 5]
dist_ft = group_by_session(dist_ft)

In [None]:
mappings = calculate_mapping(dist_ft)

In [None]:
overlap_list = []
for (cur_anm, cur_map), cur_grp in mappings.groupby([mappings['meta', 'animal'], mappings['meta', 'group']]):
    novlp = len(cur_grp)
    nunit = [len(cents[(cents['animal'] == cur_anm) & (cents['session'] == ss)]) for ss in cur_map]
    nA = nunit[0]
    nB = nunit[1]
    nSum = np.sum(nunit) - novlp
    cur_ovlp = pd.Series([cur_anm, cur_map, novlp/nSum, novlp/nA, novlp/nB], index=['animal', 'session', 'overlap', 'overlap-onA', 'overlap-onB'])
    overlap_list.append(cur_ovlp)
overlaps = pd.concat(overlap_list, axis=1, ignore_index=True).T
group_dict = dict(MS101='negative', MS104='negative', NS20='negative', NS22='negative', MS102='neutral', MS103='neutral', NS24='neutral')
overlaps['group'] = overlaps['animal'].apply(lambda anm: group_dict[anm])

In [None]:
%%opts BoxWhisker [width=1200, height=600, xrotation=90]
overlap_ds = hv.Dataset(overlaps, kdims=['animal', 'session', 'group'], vdims=['overlap', 'overlap-onA', 'overlap-onB'])
hv.HoloMap({val.name: hv.BoxWhisker(overlap_ds, kdims=['session', 'group'], vdims=[val]) for val in overlap_ds.vdims})

In [None]:
%%opts Curve [width=1000, height=400, xrotation=90, tools=['hover']]
hv.HoloMap({
    val.name: hv.NdLayout({
        grp: hv.NdOverlay({
            anm: hv.Curve(overlap_ds.select(group=grp, animal=anm), kdims=['session'], vdims=[val])
            for anm, anm_df in grp_df.groupby('animal')})
        for grp, grp_df in overlaps.groupby('group')}, kdims=['group']).cols(1)
    for val in overlap_ds.vdims}).collate()

In [None]:
%%opts Image [height=480, width=752]
temps_ds = hv.Dataset(shiftds['temps'])
temps_ds.to(hv.Image, ['width', 'height'])

In [None]:
%%opts Image [height=480, width=752]
temps_ds = hv.Dataset(shiftds['temps_shifted'])
temps_ds.to(hv.Image, ['width', 'height'])