# Pipeline
## Setting up
### set module paths and data path

In [None]:
minian_path = "."
dpath = "./demo_movies/"
meta_dict={'session_id': -1, 'session': -2, 'animal': -3}
chunks = {'frame': 1000, 'height': 200, 'width': 200, 'unit_id':20}
in_memory = True

### load modules

In [None]:
%%capture
%load_ext autoreload
%autoreload 2
import sys
import os
sys.path.append(minian_path)
import gc
import psutil
import numpy as np
import xarray as xr
import holoviews as hv
import paramnb
import matplotlib.pyplot as plt
import bokeh.plotting as bpl
import dask.array as da
import pandas as pd
import dask
import datashader as ds
from holoviews.operation.datashader import datashade, regrid, dynspread
from datashader.colors import Sets1to3
from dask.diagnostics import ProgressBar
from IPython.core.display import display, HTML
from dask.distributed import Client, progress
from minian.utilities import load_videos, varray_to_tif, save_cnmf, save_movies, scale_varr, save_variable
from minian.preprocessing import remove_brightspot, gradient_norm, denoise, remove_background
from minian.motion_correction import estimate_shift_fft, apply_shifts, interpolate_frame, mask_shifts
from minian.initialization import seeds_init, gmm_refine, pnr_refine, intensity_refine, ks_refine, seeds_merge, initialize
from minian.cnmf import get_noise_fft, get_noise_welch, update_spatial, update_temporal, unit_merge
from minian.visualization import VArrayViewer, MCViewer, CNMFViewer, generate_videos, visualize_temporal_update, normalize

### module initialization

In [None]:
dpath = os.path.abspath(dpath)
hv.notebook_extension('bokeh', width=100)

## Pre-processing
### loading videos and visualization

In [None]:
%%time
varr = load_videos(dpath, in_memory=in_memory, dtype=np.float32, resample=dict(frame=1))

In [None]:
%%time
varr_ref = scale_varr(varr.chunk(dict(height='auto', width='auto', frame=200)))
if in_memory:
    with ProgressBar():
        varr_ref = varr_ref.persist()

In [None]:
%%output size=100
vaviewer = VArrayViewer([varr_ref], framerate=5)
display(vaviewer.widgets)
vaviewer.show()

In [None]:
varr_ref = varr_ref.isel(height=slice(None, -2))

In [None]:
%%time
save_variable(varr_ref.rename("org"), dpath, 'minian', meta_dict=meta_dict)

### bright spots removal

In [None]:
%%time
varr_ref = remove_brightspot(varr_ref.chunk(dict(height=-1, width=-1, frame='auto')), thres=2)
if in_memory:
    with ProgressBar():
        varr_ref = varr_ref.persist()

### estimate gradient

In [None]:
%%time
with ProgressBar():
    varr_gradient = gradient_norm(varr_ref.isel(frame=0)).compute()

In [None]:
kappa = varr_gradient.quantile(0.9).values

### anisotropic diffusion

In [None]:
%%time
varr_ref = denoise(varr_ref, 'anisotropic', niter=10, kappa=kappa, gamma=0.25, option=2)
if in_memory:
    with ProgressBar(), dask.config.set(scheduler='processes'):
        varr_ref = varr_ref.persist()

### background removal

In [None]:
%%time
varr_ref = remove_background(varr_ref, method='tophat', wnd=10)
if in_memory:
    with ProgressBar(), dask.config.set(scheduler='processes'):
        varr_ref = varr_ref.persist()

### normalization

In [None]:
%%time
varr_ref = scale_varr(varr_ref.chunk(dict(height='auto', width='auto', frame=200)))
if in_memory:
    with ProgressBar():
        varr_ref = varr_ref.persist()

### visualization of pre-processing

In [None]:
%%output size=70
vaviewer = VArrayViewer([varr, varr_ref.rename('varr_ref')], framerate=5, ds_rate=2)
display(vaviewer.widgets)
vaviewer.show()

## motion correction
### estimate shifts

In [None]:
%%time
varr_fft, res = estimate_shift_fft(varr_ref, on='perframe')
if in_memory:
    with ProgressBar():
        res = res.compute()
shifts = res.sel(variable = ['height', 'width'])
corr = res.sel(variable='corr')

### masking and interpolation

In [None]:
%%time
shifts_ma, mask = mask_shifts(varr_fft, corr, shifts, z_thres=-1.5)

In [None]:
%%time
varr_mc = interpolate_frame(varr_mc.compute().rename('varr_mc'), mask)

### apply shifts

In [None]:
%%time
shifts_cum = shifts.cumsum('frame')
shifts_cum = np.around(shifts_cum.fillna(0)).astype(int)

In [None]:
varr_ref = varr_ref.chunk(dict(height=-1, width=-1, frame='auto'))
varr_mc = apply_shifts(varr_ref, shifts_cum)
if in_memory:
    with ProgressBar():
        varr_mc = varr_mc.persist()

### visualization of motion-correction

In [None]:
%%output size=100 fps=5
vaviewer = VArrayViewer([varr_ref.rename('varr_ref'), varr_mc.rename('mc')], framerate=5)
display(vaviewer.widgets)
vaviewer.show()

### visualization of shifts

In [None]:
%%output size=100
%%opts Curve [width=500, tools=['hover']]
hv.NdOverlay(dict(width=hv.Curve(shifts.sel(variable='width')), height=hv.Curve(shifts.sel(variable='height'))))\
+ hv.NdOverlay(dict(width=hv.Curve(shifts_cum.sel(variable='width')), height=hv.Curve(shifts_cum.sel(variable='height'))))

### save result as DataSet

In [None]:
%%time
save_variable(varr_mc.rename('Y'), dpath, 'minian', meta_dict=meta_dict)

## initialization

In [None]:
%%time
minian = xr.open_dataset(os.path.join(dpath, 'minian.nc'), autoclose=True)
Y = minian['Y'].load()

In [None]:
%%time
seeds = seeds_init(Y, method='rolling')

In [None]:
%%time
seeds_gmm = gmm_refine(Y, seeds)

In [None]:
%%time
seeds_pnr = pnr_refine(Y, seeds_gmm)

In [None]:
%%time
seeds_int = intensity_refine(Y, seeds_pnr)

In [None]:
%%time
seeds_ks = ks_refine(Y, seeds_int)

In [None]:
%%time
seeds_mrg = seeds_merge(Y, seeds_ks)

In [None]:
%%time
A, C, b, f = initialize(Y, seeds_mrg, chk=dict(height=200, width=200, frame=1000))

In [None]:
opts = dict(plot=dict(height=300, width=300))
regrid(hv.Image(A.sum('unit_id'), kdims=['width', 'height'])).opts(**opts) + regrid(hv.Image(C, kdims=['frame', 'unit_id'])).opts(**opts)

In [None]:
%%time
save_variable(A.rename('A_init').rename(unit_id='unit_id_init'), dpath, 'minian', meta_dict=meta_dict)
save_variable(C.rename('C_init').rename(unit_id='unit_id_init'), dpath, 'minian', meta_dict=meta_dict)
save_variable(b.rename('b_init'), dpath, 'minian', meta_dict=meta_dict)
save_variable(f.rename('f_init'), dpath, 'minian', meta_dict=meta_dict)

## CNMF
### loading data

In [None]:
%%time
chk = chunks.copy()
chk['unit_id_init'] = chk.pop('unit_id')
minian = xr.open_dataset(os.path.join(dpath, 'minian.nc'), chunks=chk)
Y = minian['Y']
A_init = minian['A_init'].rename(unit_id_init='unit_id')
C_init = minian['C_init'].rename(unit_id_init='unit_id')
b_init = minian['b_init']
f_init = minian['f_init']

### estimate spatial noise

In [None]:
%%time
sn_spatial, psd = get_noise_fft(Y)

In [None]:
%%opts Image [height=300, width=800] (cmap='Viridis')
psd_flt = psd.stack(spatial=['height', 'width'])
hv_psd = hv.Image(psd_flt.assign_coords(spatial=range(psd_flt.sizes['spatial'])), kdims=['spatial', 'freq'])
regrid(hv_psd)

### test parameters for spatial update

In [None]:
opts_A = dict(plot=dict(height=480, width=752), style=dict(cmap='Viridis'))
opts_C = dict(plot=dict(height=480, width=1000), style=dict(cmap='Viridis'))
sprs_ls = [5e-6, 5e-3, 0.5, 5]
units = np.random.choice(A_init.coords['unit_id'], 20)
A_dict = dict()
for cur_sprs in sprs_ls:
    cur_A, cur_b, cur_C, cur_f = update_spatial(
        Y, A_init.sel(unit_id=units),
        b_init, C_init.sel(unit_id=units), f_init, sn_spatial, sparse_penal=cur_sprs)
    hv_cur_A = hv.Image(cur_A.sum('unit_id'), kdims=['width', 'height']).opts(**opts_A)
    hv_cur_C = hv.Image(cur_C, kdims=['frame', 'unit_id']).opts(**opts_C)
    A_dict[cur_sprs] = (hv_cur_A + hv_cur_C).cols(1)
hv_res = hv.HoloMap(A_dict, kdims=['sparse_penalty'])

In [None]:
%%opts Image [colorbar=True] {+axiswise}
hv_res.collate()

### first spatial update

In [None]:
%%time
A_spatial, b_spatial, C_spatial, f_spatial = update_spatial(
    Y, A_init, b_init, C_init, f_init, sn_spatial, sparse_penal=0.1)

In [None]:
%%opts Image [colorbar=True] (cmap='Viridis')
regrid(hv.Image(A_init.sum('unit_id'), kdims=['width', 'height'])).opts(plot=dict(height=480, width=752))\
+ regrid(hv.Image(A_spatial.sum('unit_id').rename('A_spatial'), kdims=['width', 'height'])).opts(plot=dict(height=480, width=752)).redim.range(A_spatial=(0, 0.5))

### test parameters for temporal update

In [None]:
%%time
import itertools as itt
p_ls = [1]
sprs_ls = [0.5, 5]
add_ls = [0, 3, 5]
noise_ls = [0.45, 0.48, 0.499]
C_dict = dict()
S_dict = dict()
g_dict = dict()
vis_dict = dict()
for cur_sprs, cur_p, cur_add, cur_noise in itt.product(sprs_ls, p_ls, add_ls, noise_ls):
    print("processing {}".format((cur_p, cur_sprs, cur_add, cur_noise)))
    YrA, cur_C, cur_S, cur_B, cur_C0, cur_sig, cur_g, = update_temporal(
        Y, A_spatial.isel(unit_id=slice(0, 10)), b_spatial, C_spatial.isel(unit_id=slice(0, 10)),
        f_spatial, sn_spatial, sparse_penal=cur_sprs, p=cur_p, use_spatial=False, use_smooth=True,
        add_lag = cur_add, noise_freq=cur_noise, chk=dict(frame=200, unit_id=20),
        cvx_sched="processes")
    C_dict[(cur_p, cur_sprs, cur_add, cur_noise)] = cur_C
    S_dict[(cur_p, cur_sprs, cur_add, cur_noise)] = cur_S
    g_dict[(cur_p, cur_sprs, cur_add, cur_noise)] = cur_g
    vis_dict[(cur_p, cur_sprs, cur_add, cur_noise)] = visualize_temporal_update(
        YrA, cur_C, cur_S, cur_g, cur_sig)

In [None]:
%%opts Curve [width=1200] {+framewise}
hv_res = hv.HoloMap(vis_dict, kdims=['p', 'sparse_penalty', 'add_lag', 'noise_freq']).collate()
hv_res.select(unit_id=slice(5, 10))

### first temporal update

In [None]:
%%time
YrA, C_temporal, S_temporal, B_temporal, C0_temporal, sig_temporal, g_temporal = update_temporal(
        Y, A_spatial,
        b_spatial, C_spatial, f_spatial, sn_spatial, jac_thres=0.2,
        noise_freq=0.45, sparse_penal=0.5, p=1, add_lag = 0, use_spatial=False, chk=dict(frame=2000, unit_id=200))

In [None]:
%%opts Curve [width=1200] {+framewise}
visualize_temporal_update(YrA, C_temporal, S_temporal, g_temporal, sig_temporal, normalize=False).select(unit_id = slice(0, 50))

In [None]:
%%opts Image [colorbar=True] (cmap='Viridis')
hv_c = regrid(hv.Image(C_temporal.rename('c'), kdims=['frame', 'unit_id'])).opts(plot=dict(height=500, width=1000)).redim.range(c=(0, 1))
hv_s = regrid(hv.Image(S_temporal.rename('s'), kdims=['frame', 'unit_id'])).opts(plot=dict(height=500, width=1000)).redim.range(s=(0, 0.006))
(hv_c + hv_s).cols(1)

### merge units

In [None]:
%%time
A_mrg, sig_mrg = unit_merge(A_spatial, sig_temporal, thres_corr=0.8)

### second spatial update

In [None]:
%%time
A_spatial_it2, b_spatial_it2, C_spatial_it2, f_spatial_it2 = update_spatial(
    Y, A_mrg, b_spatial, sig_mrg, f_spatial, sn_spatial, sparse_penal=0.1, dl_wnd=5)

In [None]:
A_spatial_it2_norm = xr.apply_ufunc(normalize, A_spatial_it2, input_core_dims=[['height', 'width']], output_core_dims=[['height', 'width']], vectorize=True)

In [None]:
regrid(hv.Image(A_spatial.sum('unit_id'), kdims=['width', 'height'])).opts(plot=dict(height=480, width=752))\
+ regrid(hv.Image(A_spatial_it2_norm.sum('unit_id'), kdims=['width', 'height'])).opts(plot=dict(height=480, width=752))

### second temporal update

In [None]:
%%time
YrA, C_temporal_it2, S_temporal_it2, B_temporal_it2, C0_temporal_it2, sig_temporal_it2, g_temporal_it2 = update_temporal(
    Y, A_spatial_it2, b_spatial_it2, C_spatial_it2, f_spatial_it2, sn_spatial, jac_thres=0.2,
    noise_freq=0.45, sparse_penal=0.5, p=1, add_lag=0, chk=dict(frame=2000, unit_id=200))

In [None]:
%%opts Curve [width=1200] {+framewise}
visualize_temporal_update(
    YrA, C_temporal_it2, S_temporal_it2, g_temporal_it2, sig_temporal_it2).select(unit_id=slice(0, 50))

In [None]:
%%opts Image [colorbar=True, tools=['hover']] (cmap='Viridis')
hv_c = regrid(hv.Image(C_temporal_it2.rename('c'), kdims=['frame', 'unit_id'])).opts(plot=dict(height=500, width=1000)).redim.range(c=(0, 1))
hv_s = regrid(hv.Image(S_temporal_it2.rename('s'), kdims=['frame', 'unit_id'])).opts(plot=dict(height=500, width=1000)).redim.range(s=(0, 0.006))
(hv_c + hv_s).cols(1)

### save results

In [None]:
%%time
minian.close()
save_variable(A_spatial_it2.rename('A'), dpath, 'minian', meta_dict=meta_dict)
save_variable(C_temporal_it2.rename('C'), dpath, 'minian', meta_dict=meta_dict)
save_variable(S_temporal_it2.rename('S'), dpath, 'minian', meta_dict=meta_dict)
save_variable(g_temporal_it2.rename('g'), dpath, 'minian', meta_dict=meta_dict)
save_variable(b_spatial_it2.rename('b'), dpath, 'minian', meta_dict=meta_dict)
save_variable(f_spatial_it2.rename('f'), dpath, 'minian', meta_dict=meta_dict)

### visualization

In [None]:
minian = xr.open_dataset(os.path.join(dpath, 'minian.nc'))

In [None]:
%%time
generate_videos(minian, os.path.join(dpath, "minian.mp4"), chk=dict(height=100, width=100, frame=1000))

In [None]:
cnmfviewer = CNMFViewer(minian, minian['Y'])

In [None]:
cnmfviewer.show()