In [None]:
minian_path = "."
dpath = "./demo_movies"

In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import warnings
sys.path.append(minian_path)
import itertools as itt
import numpy as np
import xarray as xr
import holoviews as hv
import pandas as pd
from holoviews.operation.datashader import datashade, regrid
from minian.cross_registration import (estimate_shifts, calculate_centroids,
                                       calculate_centroid_distance, calculate_mapping,
                                       group_by_session, resolve_mapping, fill_mapping)
from minian.motion_correction import apply_shifts
from minian.utilities import open_minian, open_minian_mf
from minian.visualization import AlignViewer
hv.notebook_extension('bokeh', width=100)

In [None]:
minian_df = open_minian_mf(
    dpath, ['animal', 'session'], result_format='pandas',
    backend='zarr', pattern=r'minian\.[0-9]+$', chunks=dict(frame='auto', height='auto', width='auto', unit_id='auto'))

In [None]:
%%time
shiftds = estimate_shifts(minian_df, template='mean', pct_thres=99.99)

In [None]:
%%output size=70
alignviewer = AlignViewer(shiftds, pct_thres=99.99)
alignviewer.show()

In [None]:
shiftds = alignviewer.shiftds

In [None]:
minian = open_minian_mf(
    dpath, ['animal', 'session'], result_format='xarray', backend='zarr',
    chunks=dict(height='auto', width='auto', frame='auto', unit_id='auto'))

In [None]:
A_shifted = apply_shifts(minian['A'].chunk(dict(height=-1, width=-1, unit_id='auto')), shiftds['shifts'])

In [None]:
window = shiftds['temps_shifted'].isnull().sum('session')

In [None]:
%%output size=50
%%opts Image [height=480, width=752, colorbar=True] (cmap='Viridis')
window, temps_sh = xr.broadcast(window, shiftds['temps_shifted'])
hv_wnd = hv.Dataset(window, kdims=list(window.dims)).to(hv.Image, ['width', 'height'])
hv_temps = hv.Dataset(temps_sh, kdims=list(temps_sh.dims)).to(hv.Image, ['width', 'height'])
regrid(hv_wnd, aggregator='max').relabel("Window") + regrid(hv_temps).relabel("Shifted Templates")

In [None]:
def set_window(wnd):
    return wnd == wnd.min()
window = xr.apply_ufunc(
    set_window,
    window,
    input_core_dims=[['height', 'width']],
    output_core_dims=[['height', 'width']],
    vectorize=True)

In [None]:
%%time
cents = calculate_centroids(A_shifted, window)

In [None]:
%%output size=50
%%opts Points [height=480, width=752] {+axiswise +framewise}
cents_hv = hv.Dataset(cents, kdims=['height', 'width', 'unit_id', 'animal', 'session'])
cents_hv.to(hv.Points, kdims=['width', 'height']).overlay('unit_id')

In [None]:
%%time
dist = calculate_centroid_distance(cents)

In [None]:
dist_ft = dist[dist['variable', 'distance'] < 5]
dist_ft = group_by_session(dist_ft)

In [None]:
%%time
mappings = calculate_mapping(dist_ft)
mappings_meta = resolve_mapping(mappings)
mappings_meta_fill = fill_mapping(mappings_meta, cents)

In [None]:
mappings_meta_fill.to_pickle(os.path.join(dpath, "mappings_meta_fill.pkl"))

In [None]:
def subset_by_session(sessions, mappings):
    mappings_ma = mappings[mappings['session'][sessions].notnull().all(axis='columns')]
    mappings_non = mappings[
        mappings['meta', 'group'].isnull()
        & mappings['session'][sess].notnull().any(axis='columns')]
    return mappings_ma, mappings_non

In [None]:
sess = ['4', '5']
group_dim = ['animal']
mappings_match, mappings_nonmatch = subset_by_session(sess, mappings_meta_fill)
A_dict = dict()
cent_dict = dict()
for cur_ss in sess:
    cur_uid_ma = mappings_match[[('meta', d) for d in group_dim] + [('session', cur_ss)]]
    cur_uid_nm = mappings_nonmatch[[('meta', d) for d in group_dim] + [('session', cur_ss)]].dropna()
    cur_uid_ma.columns = cur_uid_ma.columns.droplevel()
    cur_uid_nm.columns = cur_uid_nm.columns.droplevel()
    cur_uid_ma = cur_uid_ma.rename(columns={cur_ss:'unit_id'})
    cur_uid_nm = cur_uid_nm.rename(columns={cur_ss:'unit_id'})
    cur_uid_ma['session'] = cur_ss
    cur_uid_nm['session'] = cur_ss
    cur_cents_ma = cur_uid_ma.merge(cents)
    cur_cents_nm = cur_uid_nm.merge(cents)
    A_ma_dict = dict()
    A_nm_dict = dict()
    for igrp, grp_ma in cur_uid_ma.groupby(group_dim + ['session']):
        cur_keys = {d: k for d, k in zip(group_dim + ['session'], igrp)}
        A_sub = A_shifted.sel(**cur_keys)
        A_ma = A_sub.sel(unit_id=grp_ma['unit_id'].values)
        grp_nm = cur_uid_nm.query(" and ".join(["==".join((d, "'{}'".format(k))) for d, k in cur_keys.items()]))
        A_nm = A_sub.sel(unit_id=grp_nm['unit_id'].values)
        A_ma_dict[igrp] = hv.Image(A_ma.sum('unit_id').compute(), kdims=['width', 'height'])
        A_nm_dict[igrp] = hv.Image(A_nm.sum('unit_id').compute(), kdims=['width', 'height'])
    hv_A_ma = hv.HoloMap(A_ma_dict, kdims=group_dim + ['session'])
    hv_A_nm = hv.HoloMap(A_nm_dict, kdims=group_dim + ['session'])
    hv_cent_ma = hv.Dataset(cur_cents_ma).to(hv.Points, kdims=['width', 'height'])
    hv_cent_nm = hv.Dataset(cur_cents_nm).to(hv.Points, kdims=['width', 'height'])
    hv_A = hv.HoloMap(dict(match=hv_A_ma, nonmatch=hv_A_nm), kdims=['matching']).collate()
    hv_cent = hv.HoloMap(dict(match=hv_cent_ma, nonmatch=hv_cent_nm), kdims=['matching']).collate()
    A_dict[cur_ss] = hv_A
    cent_dict[cur_ss] = hv_cent
hv_A = hv.HoloMap(A_dict, kdims=['session']).collate()
hv_cent = hv.HoloMap(cent_dict, kdims=['session']).collate()

In [None]:
%%output size=60
from holoviews.operation.datashader import regrid
(regrid(hv_A).opts(plot=dict(width=752, height=480), style=dict(cmap='Viridis'))).layout(['animal', 'matching']).cols(2)