# Pipeline

## Setting up

### set module paths and data path

In [None]:
minian_path = "."
dpath = "./demo_movies"
meta_dict={'session_id': -1, 'session': -2, 'animal': -3}
chunks = {'frame': 1000, 'height': 200, 'width': 200, 'unit_id':20}
in_memory = True

### load modules

In [None]:
%%capture
%load_ext autoreload
%autoreload 2
import sys
import os
sys.path.append(minian_path)
import gc
import psutil
import numpy as np
import xarray as xr
import holoviews as hv
import paramnb
import matplotlib.pyplot as plt
import bokeh.plotting as bpl
import dask.array as da
import pandas as pd
import dask
import datashader as ds
from holoviews.operation.datashader import datashade, regrid, dynspread
from datashader.colors import Sets1to3
from dask.diagnostics import ProgressBar
from IPython.core.display import display, HTML
from dask.distributed import Client, progress, LocalCluster, fire_and_forget
from minian.utilities import load_videos, load_images, video_to_tiffs, varray_to_tif, save_cnmf, save_movies, scale_varr, scale_varr_da, save_variable
from minian.preprocessing import remove_brightspot, gradient_norm, denoise, remove_background, stripe_correction
from minian.motion_correction import estimate_shift_fft, apply_shifts, interpolate_frame, mask_shifts
from minian.initialization import seeds_init, gmm_refine, pnr_refine, intensity_refine, ks_refine, seeds_merge, initialize
from minian.cnmf import psd_welch, get_noise, update_spatial, update_temporal, unit_merge
from minian.visualization import VArrayViewer, MCViewer, CNMFViewer, generate_videos, visualize_temporal_update, normalize

### module initialization

In [None]:
dpath = os.path.abspath(dpath)
hv.notebook_extension('bokeh', width=100)

## Pre-processing
### loading videos and visualization

In [None]:
%%time
varr = load_videos(dpath, in_memory=in_memory, dtype=np.float32, resample=dict(frame=2))

In [None]:
%%output size=100
vaviewer = VArrayViewer([varr], framerate=5)
display(vaviewer.widgets)
vaviewer.show()

### subset part of video

In [None]:
# varr_ref = varr
varr_ref = varr.sel(frame=slice(None, 10232))

In [None]:
varr_ref = varr_ref.chunk(dict(frame='auto', height=-1, width=-1))

### stripe correction

In [None]:
%%time
varr_ref = stripe_correction(varr_ref)
if in_memory:
    with ProgressBar():
        varr_ref = varr_ref.persist()

### glow removal

In [None]:
varr_ref = remove_background(varr_ref, method='uniform', wnd=51)
if in_memory:
    with ProgressBar():
        varr_ref = varr_ref.persist()

### bright spots removal

In [None]:
%%time
varr_ref = remove_brightspot(varr_ref, thres=2)
if in_memory:
    with ProgressBar():
        varr_ref = varr_ref.persist()

### denoise

In [None]:
%%time
varr_ref = denoise(varr_ref, 'gaussian', sigmaX=0, ksize=(3, 3))
if in_memory:
    with ProgressBar():
        varr_ref = varr_ref.persist()

### save processed movie

In [None]:
%%output size=100
vaviewer = VArrayViewer([varr_ref], framerate=5)
display(vaviewer.widgets)
vaviewer.show()

In [None]:
%%time
with ProgressBar():
    save_variable(varr_ref.rename("org"), dpath, 'minian', meta_dict=meta_dict)

### background removal

In [None]:
%%time
Y = remove_background(varr_ref, method='tophat', wnd=10)
if in_memory:
    with ProgressBar(), dask.config.set(scheduler='processes'):
        Y = Y.persist()

### normalization

In [None]:
%%time
Y = scale_varr(Y)
if in_memory:
    with ProgressBar():
        Y = Y.persist()

### visualization of pre-processing

In [None]:
%%output size=70
vaviewer = VArrayViewer([Y.rename('Y')], framerate=5)
display(vaviewer.widgets)
vaviewer.show()

In [None]:
%%time
with ProgressBar():
    save_variable(Y.rename('Y'), dpath, 'minian', meta_dict=meta_dict)

## motion correction

In [None]:
varr_mc = varr_ref

### estimate shifts

In [None]:
%%time
res = estimate_shift_fft(varr_mc, on='mean', pct_thres=99.9)
if in_memory:
    with ProgressBar():
        res = res.compute()
shifts = res.sel(variable = ['height', 'width'])
corr = res.sel(variable='corr')

### masking and interpolation

In [None]:
%%time
shifts_ma, mask = mask_shifts(varr_mc, corr, shifts, z_thres=-1.5)

In [None]:
%%time
varr_mc = interpolate_frame(varr_mc.compute().rename('varr_mc'), mask)

### determine shifts

#### take cumulative sum if `on='perframe'` when estimating shifts

In [None]:
%%time
shifts_final = shifts.cumsum('frame')
shifts_final = np.around(shifts_final.fillna(0)).astype(int)

#### use raw shifts otherwise

In [None]:
shifts_final = np.around(shifts.fillna(0)).astype(int)

### visualization of shifts

In [None]:
%%output size=100
%%opts Curve [width=500, tools=['hover']]
hv.NdOverlay(dict(width=hv.Curve(shifts.sel(variable='width')), height=hv.Curve(shifts.sel(variable='height'))))\
+ hv.NdOverlay(dict(width=hv.Curve(shifts_final.sel(variable='width')), height=hv.Curve(shifts_final.sel(variable='height'))))

### apply shifts

In [None]:
Y_mc = apply_shifts(Y, shifts_final)
if in_memory:
    with ProgressBar():
        Y_mc = Y_mc.persist()

### visualization of motion-correction

In [None]:
%%output size=100 fps=5
%%opts Image (cmap='Viridis')
vaviewer = VArrayViewer([Y_mc.rename('Y_mc')], framerate=5)
display(vaviewer.widgets)
vaviewer.show().redim.range(Y=(0,0.2), Y_mc=(0, 0.2))

### save result as DataSet

In [None]:
%%time
with ProgressBar():
    save_variable(Y_mc.rename('Y'), dpath, 'minian', meta_dict=meta_dict)

## initialization

In [None]:
%%time
minian = xr.open_dataset(os.path.join(dpath, 'minian.nc'))
Y = minian['Y'].load()

In [None]:
%%time
seeds = seeds_init(Y, method='rolling')

In [None]:
max_proj = Y.max('frame')
Y_flt = Y.stack(spatial=['height', 'width'])

In [None]:
opts_im = dict(plot=dict(height=480, width=752), style=dict(cmap='Viridis'))
opts_pts = dict(
    plot=dict(height=480, width=752, size_index='seeds', tools=['hover']),
    style=dict(fill_alpha=0.6, line_alpha=0, fill_color='white'))
(regrid(hv.Image(max_proj, kdims=['width', 'height'])).opts(**opts_im)
 * hv.Points(seeds, kdims=['width', 'height'], vdims=['index', 'seeds']).opts(**opts_pts))

In [None]:
%%time
seeds_gmm = gmm_refine(Y_flt, seeds)

In [None]:
opts_im = dict(plot=dict(height=480, width=752), style=dict(cmap='Viridis'))
opts_pts = dict(
    plot=dict(height=480, width=752, size_index='seeds', color_index='mask_gmm', tools=['hover']),
    style=dict(fill_alpha=0.8, line_alpha=0, cmap='Set1'))
(regrid(hv.Image(max_proj, kdims=['width', 'height'])).opts(**opts_im)
 * hv.Points(seeds_gmm, kdims=['width', 'height'], vdims=['index', 'seeds', 'mask_gmm']).opts(**opts_pts))

In [None]:
%%time
seeds_pnr = pnr_refine(Y_flt, seeds_gmm[seeds_gmm['mask_gmm']])

In [None]:
opts_im = dict(plot=dict(height=480, width=752), style=dict(cmap='Viridis'))
opts_pts = dict(
    plot=dict(height=480, width=752, size_index='seeds', color_index='mask_pnr', tools=['hover']),
    style=dict(fill_alpha=0.8, line_alpha=0, cmap='Set1'))
(regrid(hv.Image(max_proj, kdims=['width', 'height'])).opts(**opts_im)
 * hv.Points(seeds_pnr, kdims=['width', 'height'], vdims=['index', 'seeds', 'mask_pnr']).opts(**opts_pts))

In [None]:
%%time
seeds_int = intensity_refine(max_proj, seeds_pnr[seeds_pnr['mask_pnr']])

In [None]:
opts_im = dict(plot=dict(height=480, width=752), style=dict(cmap='Viridis'))
opts_pts = dict(
    plot=dict(height=480, width=752, size_index='seeds', color_index='mask_int', tools=['hover']),
    style=dict(fill_alpha=0.8, line_alpha=0, cmap='Set1'))
(regrid(hv.Image(max_proj, kdims=['width', 'height'])).opts(**opts_im)
 * hv.Points(seeds_int, kdims=['width', 'height'], vdims=['index', 'seeds', 'mask_int']).opts(**opts_pts))

In [None]:
%%time
seeds_ks = ks_refine(Y_flt, seeds_int[seeds_int['mask_int']])

In [None]:
opts_im = dict(plot=dict(height=480, width=752), style=dict(cmap='Viridis'))
opts_pts = dict(
    plot=dict(height=480, width=752, size_index='seeds', color_index='mask_ks', tools=['hover']),
    style=dict(fill_alpha=0.8, line_alpha=0, cmap='Set1'))
(regrid(hv.Image(max_proj, kdims=['width', 'height'])).opts(**opts_im)
 * hv.Points(seeds_ks, kdims=['width', 'height'], vdims=['index', 'seeds', 'mask_ks']).opts(**opts_pts))

In [None]:
%%time
seeds_fm = (seeds_gmm[seeds_gmm['mask_gmm']]
            .set_index(['height', 'width'])['seeds']
            .to_xarray().reindex_like(max_proj).fillna(0))
seeds_mrg = seeds_merge(Y, seeds_fm)

In [None]:
%%time
A, C, b, f = initialize(Y, seeds_mrg, chk=dict(height=200, width=200, frame=1000))

In [None]:
opts = dict(plot=dict(height=300, width=300))
regrid(hv.Image(A.sum('unit_id'), kdims=['width', 'height'])).opts(**opts) + regrid(hv.Image(C, kdims=['frame', 'unit_id'])).opts(**opts)

In [None]:
%%time
minian.close()
save_variable(A.rename('A_init').rename(unit_id='unit_id_init'), dpath, 'minian', meta_dict=meta_dict)
save_variable(C.rename('C_init').rename(unit_id='unit_id_init'), dpath, 'minian', meta_dict=meta_dict)
save_variable(b.rename('b_init'), dpath, 'minian', meta_dict=meta_dict)
save_variable(f.rename('f_init'), dpath, 'minian', meta_dict=meta_dict)

## CNMF

### loading data

In [None]:
%%time
chk = chunks.copy()
chk['unit_id_init'] = chk.pop('unit_id')
minian = xr.open_dataset(os.path.join(dpath, 'minian.nc'), chunks=chk)
Y = minian['Y']
A_init = minian['A_init'].rename(unit_id_init='unit_id')
C_init = minian['C_init'].rename(unit_id_init='unit_id')
b_init = minian['b_init']
f_init = minian['f_init']

### estimate spatial noise

In [None]:
%%time
psd = psd_welch(Y)
with ProgressBar():
    psd = psd.persist()

In [None]:
%%opts Image [height=300, width=800, colorbar=True, logz=True] (cmap='Viridis')
psd_flt = psd.stack(spatial=['height', 'width'])
hv_psd = hv.Image(psd_flt.assign_coords(spatial=range(psd_flt.sizes['spatial'])).rename('psd'), kdims=['spatial', 'freq'])
regrid(hv_psd).redim.range(psd=(0, 5e-3))

In [None]:
sn_spatial = get_noise(psd, noise_range=(0.02, 0.5))

### randomly select units for parameter exploring

In [None]:
units = np.random.choice(A_init.coords['unit_id'], 10)

### test parameters for spatial update

In [None]:
opts_A = dict(plot=dict(height=480, width=752), style=dict(cmap='Viridis'))
opts_C = dict(plot=dict(height=480, width=1600), style=dict(cmap='Viridis'))
sprs_ls = [0.1, 0.3, 0.5, 0.7]
A_dict = dict()
for cur_sprs in sprs_ls:
    cur_A, cur_b, cur_C, cur_f = update_spatial(
        Y, A_init.sel(unit_id=units),
        b_init, C_init.sel(unit_id=units), f_init, sn_spatial, dl_wnd=20, sparse_penal=cur_sprs)
    try:
        hv_cur_A = hv.Image(cur_A.sum('unit_id'), kdims=['width', 'height']).opts(**opts_A)
        hv_cur_A_sps = hv.Image((cur_A > 0).sum('unit_id'), kdims=['width', 'height']).opts(**opts_A)
        hv_cur_C = hv.Image(cur_C, kdims=['frame', 'unit_id']).opts(**opts_C)
    except ValueError:
        print("unable to find units with sparse penalty {}".format(cur_sprs))
        continue
    A_dict[cur_sprs] = (hv_cur_A + hv_cur_A_sps + hv_cur_C + hv.Div('')).cols(2)
hv_res = hv.HoloMap(A_dict, kdims=['sparse_penalty'])

In [None]:
%%output size=60
%%opts Image [colorbar=True] {+axiswise}
hv_res.collate()

### first spatial update

In [None]:
%%time
A_spatial, b_spatial, C_spatial, f_spatial = update_spatial(
    Y, A_init, b_init, C_init, f_init, sn_spatial, sparse_penal=0.5)
A_spatial = xr.apply_ufunc(normalize, A_spatial)

In [None]:
%%output size=60
opts = dict(plot=dict(height=480, width=752, cmap='Viridis'))
(regrid(hv.Image(A_init.sum('unit_id').rename('A'), kdims=['width', 'height'])).opts(**opts)
+ (hv.Image((A_init > 0).sum('unit_id').rename('A'), kdims=['width', 'height'])).opts(**opts)
+ regrid(hv.Image(A_spatial.sum('unit_id').rename('A'), kdims=['width', 'height'])).opts(**opts)
+ (hv.Image((A_spatial > 0).sum('unit_id').rename('A'), kdims=['width', 'height'])).opts(**opts)).cols(2)

### test parameters for temporal update

In [None]:
%%time
import itertools as itt
p_ls = [2]
sprs_ls = [5, 20, 40]
add_ls = [20]
noise_ls = [0.02]
vis_dict = dict()
for cur_sprs, cur_p, cur_add, cur_noise in itt.product(sprs_ls, p_ls, add_ls, noise_ls):
    print("processing {}".format((cur_p, cur_sprs, cur_add, cur_noise)))
    YrA, cur_C, cur_S, cur_B, cur_C0, cur_sig, cur_g, = update_temporal(
        Y, A_spatial.isel(unit_id=slice(10, 20)), b_spatial, C_spatial.isel(unit_id=slice(10, 20)),
        f_spatial, sn_spatial, sparse_penal=cur_sprs, p=cur_p, use_spatial=False, use_smooth=True,
        add_lag = cur_add, noise_freq=cur_noise, chk=dict(frame=200, unit_id=20),
        cvx_sched="processes")
    vis_dict[(cur_p, cur_sprs, cur_add, cur_noise)] = visualize_temporal_update(
        YrA, cur_C, cur_S, cur_g, cur_sig)

In [None]:
%%opts Curve [width=800] {+framewise}
hv_res = hv.HoloMap(vis_dict, kdims=['p', 'sparse_penalty', 'add_lag', 'noise_freq']).collate()
hv_res

### first temporal update

In [None]:
%%time
YrA, C_temporal, S_temporal, B_temporal, C0_temporal, sig_temporal, g_temporal = update_temporal(
        Y, A_spatial,
        b_spatial, C_spatial, f_spatial, sn_spatial, jac_thres=0.1,
        noise_freq=0.02, sparse_penal=40, p=2, add_lag=20, use_spatial=False, chk=dict(frame=2000, unit_id=200))
A_temporal = A_spatial.sel(unit_id = C_temporal.coords['unit_id'])

In [None]:
%%output size=60
%%opts Image [colorbar=True] (cmap='Viridis')
hv_c = regrid(hv.Image(C_temporal.rename('c'), kdims=['frame', 'unit_id'])).opts(plot=dict(height=500, width=1000)).redim.range(c=(0, 1))
hv_s = regrid(hv.Image(S_temporal.rename('s'), kdims=['frame', 'unit_id'])).opts(plot=dict(height=500, width=1000)).redim.range(s=(0, 0.006))
(hv_c + hv_s).cols(1)

In [None]:
%%opts Curve [width=1200] {+framewise}
visualize_temporal_update(YrA, C_temporal, S_temporal, g_temporal, sig_temporal, norm=True).select(unit_id = slice(0, 50))

### merge units

In [None]:
%%time
A_mrg, C_mrg = unit_merge(A_spatial, C_temporal, thres_corr=0.9)

In [None]:
%%opts Image [height=400, width=800]
regrid(hv.Image(C_temporal, kdims=['frame', 'unit_id'])) +\
regrid(hv.Image(C_mrg, kdims=['frame', 'unit_id']))

### randomly select units for parameter exploring

In [None]:
units = np.random.choice(A_mrg.coords['unit_id'], 10)

### test parameters for spatial update

In [None]:
opts_A = dict(plot=dict(height=480, width=752), style=dict(cmap='Viridis'))
opts_C = dict(plot=dict(height=480, width=1600), style=dict(cmap='Viridis'))
sprs_ls = [0.1, 0.3, 0.5, 0.7]
A_dict = dict()
for cur_sprs in sprs_ls:
    cur_A, cur_b, cur_C, cur_f = update_spatial(
        Y, A_mrg.sel(unit_id=units),
        b_init, C_mrg.sel(unit_id=units), f_init, sn_spatial, dl_wnd=20, sparse_penal=cur_sprs)
    try:
        hv_cur_A = hv.Image(cur_A.sum('unit_id'), kdims=['width', 'height']).opts(**opts_A)
        hv_cur_A_sps = hv.Image((cur_A > 0).sum('unit_id'), kdims=['width', 'height']).opts(**opts_A)
        hv_cur_C = hv.Image(cur_C, kdims=['frame', 'unit_id']).opts(**opts_C)
    except ValueError:
        print("unable to find units with sparse penalty {}".format(cur_sprs))
        continue
    A_dict[cur_sprs] = (hv_cur_A + hv_cur_A_sps + hv_cur_C + hv.Div('')).cols(2)
hv_res = hv.HoloMap(A_dict, kdims=['sparse_penalty'])

In [None]:
%%output size=60
%%opts Image [colorbar=True] {+axiswise}
hv_res.collate()

### second spatial update

In [None]:
%%time
A_spatial_it2, b_spatial_it2, C_spatial_it2, f_spatial_it2 = update_spatial(
    Y, A_mrg, b_spatial, C_mrg, f_spatial, sn_spatial, sparse_penal=1, dl_wnd=5)
A_spatial_it2 = xr.apply_ufunc(normalize, A_spatial_it2)

In [None]:
%%output size=60
%%opts Image [colorbar=True] (cmap='Viridis')
(regrid(hv.Image(A_spatial.sum('unit_id').rename('A'), kdims=['width', 'height'])).opts(plot=dict(height=480, width=752))
+ (hv.Image((A_spatial > 0).sum('unit_id').rename('A'), kdims=['width', 'height'])).opts(plot=dict(height=480, width=752))
+ regrid(hv.Image(A_spatial_it2.sum('unit_id').rename('A'), kdims=['width', 'height'])).opts(plot=dict(height=480, width=752))
+ (hv.Image((A_spatial_it2 > 0).sum('unit_id').rename('A'), kdims=['width', 'height'])).opts(plot=dict(height=480, width=752))).cols(2)

### second temporal update

In [None]:
%%time
YrA, C_temporal_it2, S_temporal_it2, B_temporal_it2, C0_temporal_it2, sig_temporal_it2, g_temporal_it2 = update_temporal(
    Y, A_spatial_it2, b_spatial_it2, C_spatial_it2, f_spatial_it2, sn_spatial, jac_thres=0.1,
    noise_freq=0.03, sparse_penal=10, p=2, add_lag=20, chk=dict(frame=2000, unit_id=200))
A_temporal_it2 = A_spatial_it2.sel(unit_id=C_temporal_it2.coords['unit_id'])

In [None]:
%%opts Image [colorbar=True, tools=['hover']] (cmap='Viridis')
hv_c = regrid(hv.Image(C_temporal_it2.rename('c'), kdims=['frame', 'unit_id'])).opts(plot=dict(height=500, width=1000)).redim.range(c=(0, 1))
hv_s = regrid(hv.Image(S_temporal_it2.rename('s'), kdims=['frame', 'unit_id'])).opts(plot=dict(height=500, width=1000)).redim.range(s=(0, 0.006))
(hv_c + hv_s).cols(1)

In [None]:
%%opts Curve [width=1200] {+framewise}
visualize_temporal_update(
    YrA, C_temporal_it2, S_temporal_it2, g_temporal_it2, sig_temporal_it2).select(unit_id=slice(0, 50))

### save results

In [None]:
%%time
minian.close()
save_variable(A_temporal_it2.rename('A'), dpath, 'minian', meta_dict=meta_dict)
save_variable(C_temporal_it2.rename('C'), dpath, 'minian', meta_dict=meta_dict)
save_variable(S_temporal_it2.rename('S'), dpath, 'minian', meta_dict=meta_dict)
save_variable(g_temporal_it2.rename('g'), dpath, 'minian', meta_dict=meta_dict)
save_variable(b_spatial_it2.rename('b'), dpath, 'minian', meta_dict=meta_dict)
save_variable(f_spatial_it2.rename('f'), dpath, 'minian', meta_dict=meta_dict)

### visualization

In [None]:
minian = xr.open_dataset(os.path.join(dpath, 'minian.nc'))

In [None]:
%%time
generate_videos(minian, os.path.join(dpath, "minian.mp4"), chk=dict(height=100, width=100, frame=1000))

In [None]:
cnmfviewer = CNMFViewer(minian, minian['Y'])

In [None]:
cnmfviewer.show()