In [None]:
minian_path = "."
dpath = "./demo_movies/"

In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import warnings
warnings.simplefilter('ignore')
sys.path.append(minian_path)
import itertools as itt
import numpy as np
import xarray as xr
import holoviews as hv
import pandas as pd
from holoviews.operation.datashader import datashade, regrid
from minian.cross_registration import load_cnm_dataset, get_minian_list, estimate_shifts, apply_shifts, calculate_centroids, calculate_centroid_distance, calculate_mapping, group_by_session, resolve_mapping, fill_mapping
from minian.utilities import resave_varr, update_meta
from minian.visualization import AlignViewer
from IPython.core.debugger import set_trace
hv.notebook_extension('bokeh', width=100)

In [None]:
regi_list = []
for anm_path in next(os.walk(dpath))[1]:
    print("processing: {}".format(anm_path))
    anm_path = os.path.join(dpath, anm_path)
    flist = get_minian_list(anm_path)
    if not flist:
        continue
    shifts, corrs, temps = estimate_shifts(flist, ['first']*len(flist))
    temps_sh = apply_shifts(temps, shifts)
    temps = temps.astype(float)
    temps_sh = temps_sh.astype(float)
    cross_regi = xr.merge([shifts, corrs, temps, temps_sh])
    regi_list.append(cross_regi)
shiftds = xr.concat(regi_list, dim='animal')

In [None]:
%%output size=70
%%opts Image [height=480, width=752]
%%opts Layout [shared_datasource=True]
alignviewer = AlignViewer(shiftds, sampling=5)
alignviewer.show()

In [None]:
alignviewer.shiftds.to_netcdf(os.path.join(dpath, "shiftds.nc"))

In [None]:
def sub_ds(ds):
    return ds[['A', 'b']]

for anm_path in next(os.walk(dpath))[1]:
    try:
        shifts = shiftds.sel(animal=anm_path)['shifts']
    except KeyError:
        print("no shift presented for animal {}".format(anm_path))
        continue
    print("processing: {}".format(anm_path))
    anm_path = os.path.join(dpath, anm_path)
    flist = get_minian_list(anm_path)
    with xr.open_mfdataset(flist, concat_dim='session', preprocess=sub_ds) as cnmds:
        print("loading spatial matrix")
        cnmds['A'].load()
        cnmds['b'].load()
        print("applying shift to spatial matrix")
        A_sh = apply_shifts(cnmds['A'], shifts)
        b_sh = apply_shifts(cnmds['b'], shifts)
        cnmds_sh = xr.merge([A_sh, b_sh])
        print("saving results")
        cnmds_sh.to_netcdf(anm_path + os.sep + "minian_anm_sh.nc")

In [None]:
cnmds = xr.open_mfdataset(get_minian_list(dpath, pattern=r'^minian_anm_sh.nc$'), concat_dim='animal')
shiftds = xr.open_dataset(os.path.join(dpath, "shiftds.nc"))

In [None]:
window_list = []
for anm, temp_anm in shiftds['temps_shifted'].groupby('animal'):
    cur_wnd = temp_anm.dropna('session', how='all').isnull().sum('session')
    window_list.append(cur_wnd)
windowds = xr.concat(window_list, dim='animal')

In [None]:
%%output size=70
%%opts Image [height=480, width=752] {+axiswise +framewise}
hv_wnd = hv.Dataset(windowds, kdims=['animal', 'height', 'width'])
hv_temps = hv.Dataset(shiftds['temps_shifted'], kdims=['animal', 'session', 'height', 'width'])
regrid(hv_wnd.to(hv.Image, ['width', 'height'])) + regrid(hv_temps.to(hv.Image, ['width', 'height']))

In [None]:
%%time
cnmds['A_shifted'].load()

In [None]:
# dist = pd.read_pickle("/home/phild/Documents/sync/project/miniscope/data_temp/dist.pkl")
# dist_shifted_crop = pd.read_pickle("/home/phild/Documents/sync/project/miniscope/data_temp/dist_shifted_crop.pkl")
# dist_shifted_crop_dig = pd.read_pickle("/home/phild/Documents/sync/project/miniscope/data_temp/dist_shifted_crop_dig.pkl")

In [None]:
%%time
cents = calculate_centroids(cnmds, windowds==0)
try:
    cents = cents.drop('unit_labels', axis='columns')
except KeyError:
    pass
cents.to_pickle(os.path.join(dpath, "centroid.pkl"))

In [None]:
cents = pd.read_pickle(os.path.join(dpath, "centroid.pkl"))

In [None]:
%%output size=70
%%opts Points [height=480, width=752] {+axiswise +framewise}
cents_hv = hv.Dataset(cents, kdims=['height', 'width', 'unit_id', 'animal', 'session'])
cents_hv.to(hv.Points, kdims=['width', 'height']).overlay('unit_id')

In [None]:
%%time
dist = calculate_centroid_distance(cents, cnmds, windowds==0, shift=False, hamming=False, corr=False)
dist.to_pickle(os.path.join(dpath, "distance.pkl"))

In [None]:
dist_ps = calculate_centroid_distance(cents, cnmds, windowds==0, shift=False, hamming=False, corr=False)

In [None]:
hv_dist = datashade(hv.Points(dist['variable'], kdims=['coeff', 'distance'])).opts(plot={'width':500, 'height':500})
hv_dist

In [None]:
dist = pd.read_pickle(os.path.join(dpath, "distance.pkl"))

In [None]:
dist_ft = dist[dist['variable', 'distance'] < 5]
dist_ft = group_by_session(dist_ft)

In [None]:
mappings = calculate_mapping(dist_ft)
mappings_meta = resolve_mapping(mappings)

In [None]:
mappings_meta_fill = fill_mapping(mappings_meta, cents)

In [None]:
overlap_list = []
map_dict = {'1': 'A1', '2': 'B1', '3': 'BS', '4': 'A2', '5': 'B2', '6': 'C'}
for (cur_anm, cur_map), cur_grp in mappings.groupby([mappings['meta', 'animal'], mappings['meta', 'group']]):
    novlp = len(cur_grp)
    nunit = [len(cents[(cents['animal'] == cur_anm) & (cents['session'] == ss)]) for ss in cur_map]
    nA = nunit[0]
    nB = nunit[1]
    nSum = np.sum(nunit) - novlp
    cur_map = tuple([map_dict[m] for m in cur_map])
    cur_ovlp = pd.Series([cur_anm, cur_map, novlp/nSum, novlp/nA, novlp/nB, novlp/(nA*nB)], index=['animal', 'session', 'overlap', 'overlap-onA', 'overlap-onB', 'overlap-prod'])
    overlap_list.append(cur_ovlp)
overlaps = pd.concat(overlap_list, axis=1, ignore_index=True).T
group_dict = dict(MS101='negative', MS104='negative', NS20='negative', NS22='negative', MS102='neutral', MS103='neutral', NS24='neutral')
overlaps['group'] = overlaps['animal'].apply(lambda anm: group_dict[anm])

In [None]:
overlap_list = []
map_dict = {'1': 'A1', '2': 'B1', '3': 'BS', '4': 'A2', '5': 'B2', '6': 'C'}
for cur_anm, cur_grp in mappings_meta_fill.groupby(mappings_meta_fill['meta', 'animal']):
    cur_ss = cur_grp['session'].dropna(axis='columns', how='all').columns
    T = cur_grp['session'].dropna(axis='rows', how='all').shape[0]
    for cur_map in itt.combinations(cur_ss, 2):
        nint = cur_grp['session'][list(cur_map)].dropna(axis='rows', how='any').shape[0]
        nuni = cur_grp['session'][list(cur_map)].dropna(axis='rows', how='all').shape[0]
        nA = cur_grp['session'][cur_map[0]].dropna().size
        nB = cur_grp['session'][cur_map[1]].dropna().size
        cur_map = tuple([map_dict[m] for m in cur_map])
        cur_ovlp = pd.Series([cur_anm, cur_map, nint/nuni, nint/nA, nint/nB, (nint*T)/(nA*nB)], index=['animal', 'session', 'overlap', 'overlap-onA', 'overlap-onB', 'overlap-prod'])
        overlap_list.append(cur_ovlp)
overlaps = pd.concat(overlap_list, axis=1, ignore_index=True).T
group_dict = dict(MS101='negative', MS104='negative', NS20='negative', NS22='negative', MS102='neutral', MS103='neutral', NS24='neutral')
overlaps['group'] = overlaps['animal'].apply(lambda anm: group_dict[anm])

In [None]:
overlaps = overlaps.melt(id_vars=['animal', 'session', 'group'], var_name='overlap-type', value_name='overlap-value')
overlaps['overlap-value'] = overlaps['overlap-value'].astype(float)

In [None]:
%%opts BoxWhisker [width=1200, height=600, xrotation=90]
overlap_hv = hv.Dataset(overlaps, kdims=['session', 'group', 'overlap-type'], vdims=['overlap-value'])
overlap_hv.to(hv.BoxWhisker, kdims=['session', 'group'])

In [None]:
%%opts Curve [width=1000, height=400, xrotation=90, tools=['hover']]
overlap_hv_anm = hv.Dataset(overlaps, kdims=['animal', 'session', 'group', 'overlap-type'], vdims=['overlap-value'])
overlap_hv_anm.to(hv.Curve, kdims=['session']).overlay('animal').layout('group').cols(1)