# Pipeline
## Setting up
### set module paths and data path

In [None]:
minian_path = "."
dpath = "./demo_movies/"
meta_dict={'session_id': -1, 'session': -2, 'animal': -3}
chunks = {'frame': 50, 'height': 80, 'width': 80, 'unit_id':50}

### load modules

In [None]:
%%capture
%load_ext autoreload
%autoreload 2
import sys
import os
sys.path.append(minian_path)
import gc
import psutil
import numpy as np
import xarray as xr
import holoviews as hv
import paramnb
import matplotlib.pyplot as plt
import bokeh.plotting as bpl
import dask.array as da
import pandas as pd
import dask
import datashader as ds
from holoviews.operation.datashader import datashade, regrid, dynspread
from datashader.colors import Sets1to3
from dask.diagnostics import ProgressBar
from IPython.core.display import display, HTML
from dask.distributed import Client, progress
from minian.utilities import load_videos, varray_to_tif, save_cnmf, save_movies, scale_varr, save_variable
from minian.preprocessing import remove_brightspot, gradient_norm, denoise, remove_background
from minian.motion_correction import estimate_shift_fft, apply_shifts, interpolate_frame
from minian.initialization import seeds_init, gmm_refine, pnr_refine, intensity_refine, ks_refine, seeds_merge, initialize
from minian.cnmf import get_noise_fft, get_noise_welch, update_spatial, update_temporal, unit_merge
from minian.visualization import VArrayViewer, MCViewer, CNMFViewer

### module initialization

In [None]:
dpath = os.path.abspath(dpath)
hv.notebook_extension('bokeh', width=100)

## Pre-processing
### loading videos and visualization

In [None]:
%%time
varr = load_videos(dpath)
varr_ref = scale_varr(varr.astype(float), (0,1), inplace=False).chunk(dict(frame=chunks['frame'], height=chunks['height'], width=chunks['width']))

### bright spots removal

In [None]:
%%time
with ProgressBar():
    varr_ref = remove_brightspot(varr_ref).persist()

### estimate gradient

In [None]:
%%time
with ProgressBar():
    varr_gradient = gradient_norm(varr_ref.isel(frame=0)).compute()

In [None]:
kappa = varr_gradient.quantile(0.9).values

### anisotropic diffusion

In [None]:
%%time
with dask.config.set(scheduler='single-threaded'):
    varr_ref_anisotropic = denoise(varr_ref, 'anisotropic', niter=10, kappa=kappa, gamma=0.25, option=2)

### background removal

In [None]:
%%time
varr_ref_tophat = remove_background(varr_ref_anisotropic, method='tophat', wnd=10)

### normalization

In [None]:
%%time
varr_ref = scale_varr(varr_ref_tophat, (0, 1)).persist()

### visualization of pre-processing

In [None]:
%%output size=70
vaviewer = VArrayViewer([varr, varr_ref], framerate=5)
display(vaviewer.widgets)
vaviewer.show()

## motion correction
### estimate shifts

In [None]:
%%time
with dask.config.set(scheduler='single-threaded'):
    shifts, corr, mask = estimate_shift_fft(varr_ref, z_thres=None, on='perframe')

### apply shifts

In [None]:
%%time
varr_mc, shifts_final = apply_shifts(varr_ref.load(), shifts, aggregate=True)

### interpolation

In [None]:
%%time
varr_mc_int = interpolate_frame(varr_mc, mask)

### visualization of motion-correction

In [None]:
%%output size=70 fps=5
vaviewer = VArrayViewer([varr_ref, varr_mc_int], framerate=5)
display(vaviewer.widgets)
vaviewer.show()

### visualization of shifts

In [None]:
%%output size=100
%%opts Curve [width=1500, tools=['hover']]
hv.NdOverlay(dict(width=hv.Curve(shifts_final.sel(shift_dim='width')), height=hv.Curve(shifts_final.sel(shift_dim='height'))))

### save result as DataSet

In [None]:
%%time
save_variable(varr_mc_int.rename('Y'), dpath, 'minian', meta_dict=meta_dict)

## initialization

In [None]:
%%time
seeds = seeds_init(varr_mc_int, method='rolling')

In [None]:
%%time
seeds_gmm = gmm_refine(varr_mc_int, seeds)

In [None]:
%%time
seeds_pnr = pnr_refine(varr_mc_int, seeds_gmm)

In [None]:
%%time
seeds_int = intensity_refine(varr_mc_int, seeds_pnr)

In [None]:
%%time
seeds_ks = ks_refine(varr_mc_int, seeds_int)

In [None]:
%%time
seeds_mrg = seeds_merge(varr_mc_int, seeds_ks)

In [None]:
%%time
A, C, b, f = initialize(varr_mc_int, seeds_mrg)

In [None]:
%%time
save_variable(A.rename('A_init').rename(unit_id='unit_id_init'), dpath, 'minian', meta_dict=meta_dict)
save_variable(C.rename('C_init').rename(unit_id='unit_id_init'), dpath, 'minian', meta_dict=meta_dict)
save_variable(b.rename('b_init'), dpath, 'minian', meta_dict=meta_dict)
save_variable(f.rename('f_init'), dpath, 'minian', meta_dict=meta_dict)

## CNMF
### loading data

In [None]:
minian = xr.open_dataset(os.path.join(dpath, 'minian.nc'), chunks=dict(height=chunks['height'], width=chunks['width'], unit_id_init=chunks['unit_id'], frame=chunks['frame']), autoclose=True)
Y = minian['Y']
A_init = minian['A_init'].rename(unit_id_init='unit_id')
C_init = minian['C_init'].rename(unit_id_init='unit_id')
b_init = minian['b_init']
f_init = minian['f_init']

### estimate spatial noise

In [None]:
%%time
sn_spatial, psd = get_noise_fft(Y)

### first spatial update

In [None]:
%%time
with dask.config.set(scheduler='processes'):
    A_spatial, b_spatial, C_spatial, f_spatial = update_spatial(Y, A_init, b_init, C_init, f_init, sn_spatial, sparse_penal=0.1)

In [None]:
regrid(hv.Image(A_init.sum('unit_id'), kdims=['width', 'height'])).opts(plot=dict(height=480, width=752))\
+ regrid(hv.Image(A_spatial.sum('unit_id'), kdims=['width', 'height'])).opts(plot=dict(height=480, width=752))

### first temporal update

In [None]:
%%time
with dask.config.set(scheduler='threads'):
    YrA, C_temporal, S_temporal, B_temporal, C0_temporal, g_temporal = update_temporal(Y, A_spatial, b_spatial, C_spatial, f_spatial, sn_spatial, noise_freq=0.05, sparse_penal=1, use_spatial=False)

In [None]:
from scipy import linalg
def construct_G(g, T):
    cur_c, cur_r = np.zeros(T), np.zeros(T)
    cur_c[0] = 1
    cur_r[0] = 1
    cur_c[1:len(g) + 1] = -g
    return linalg.toeplitz(cur_c, cur_r)

def normalize(a): return np.interp(a, (a.min(), a.max()), (0, +1))

def convolve_G(s, g):
    G = construct_G(g, len(s))
    try:
        c = linalg.inv(G).dot(s)
    except LinAlgError:
        c = s.copy()
    return c

def construct_pulse_response(g):
    s = np.zeros(500)
    s[10] = 1
    c = convolve_G(s, g)
    return s, c

def visualize_temporal_update(YA, C, S, g):
    C_norm = xr.apply_ufunc(normalize, C, input_core_dims=[['frame']], output_core_dims=[['frame']], vectorize=True, dask='parallelized', output_dtypes=[C.dtype])
    S_norm = xr.apply_ufunc(normalize, S, input_core_dims=[['frame']], output_core_dims=[['frame']], vectorize=True, dask='parallelized', output_dtypes=[S.dtype])
    YA_norm = xr.apply_ufunc(normalize, YA.compute(), input_core_dims=[['frame']], output_core_dims=[['frame']], vectorize=True, dask='parallelized', output_dtypes=[YA.dtype])
    s_pul, c_pul = xr.apply_ufunc(construct_pulse_response, g.compute(), input_core_dims=[['lag']], output_core_dims=[['frame'], ['frame']], vectorize=True, output_sizes=dict(t=500))
    s_pul = s_pul.assign_coords(frame=np.arange(500))
    c_pul = c_pul.assign_coords(frame=np.arange(500))
    hv_s_pul = hv.Dataset(s_pul.rename('s_pul'), kdims=['unit_id', 'frame'])
    hv_c_pul = hv.Dataset(c_pul.rename('c_pul'), kdims=['unit_id', 'frame'])
    with ProgressBar():
        hv_C = hv.Dataset(C_norm.compute().rename('Calcium trace'), kdims=['unit_id', 'frame'])
        hv_S = hv.Dataset(S_norm.compute().rename('Spike'), kdims=['unit_id', 'frame'])
        hv_YA = hv.Dataset(YA_norm.compute().rename('Raw'), kdims=['unit_id', 'frame'])
    hv_obj = hv_C.to(hv.Curve, kdims=['frame'], label='Calcium trace')\
    * hv_S.to(hv.Curve, kdims=['frame'], label='Spike')\
    * hv_YA.to(hv.Curve, kdims=['frame'], label='YA')\
    + hv_c_pul.to(hv.Curve, kdims=['frame'], label='Simulated Calcium')\
    * hv_s_pul.to(hv.Curve, kdims=['frame'], label='Simultaed Spike')
    return hv_obj.cols(1)

In [None]:
%%opts Curve [width=1200] {+framewise}
visualize_temporal_update(YrA, C_temporal, S_temporal, g_temporal)

### merge units

In [None]:
A_mrg, C_mrg = unit_merge(A_spatial, C_temporal, thres_corr=0.85)

In [None]:
%%output size=80
regrid(hv.Dataset(A_spatial.rename('A_spatial'), kdims=['height', 'width', 'unit_id']).to(hv.Image, kdims=['width', 'height'])).opts(plot=dict(height=480, width=752))\
+ regrid(hv.Dataset(A_mrg.rename('A_merged'), kdims=['height', 'width', 'unit_id']).to(hv.Image, kdims=['width', 'height'])).opts(plot=dict(height=480, width=752))

### second spatial update

In [None]:
%%time
with dask.config.set(scheduler='processes'):
    A_spatial_it2, b_spatial_it2, C_spatial_it2, f_spatial_it2 = update_spatial(Y, A_mrg, b_spatial, C_mrg, f_spatial, sn_spatial, sparse_penal=0.1)

In [None]:
regrid(hv.Image(A_spatial.sum('unit_id'), kdims=['width', 'height'])).opts(plot=dict(height=480, width=752))\
+ regrid(hv.Image(A_spatial_it2.sum('unit_id'), kdims=['width', 'height'])).opts(plot=dict(height=480, width=752))

### second temporal update

In [None]:
%%time
YrA, C_temporal_it2, S_temporal_it2, B_temporal_it2, C0_temporal_it2, g_temporal_it2 = update_temporal(Y, A_spatial_it2, b_spatial_it2, C_spatial_it2, f_spatial_it2, sn_spatial, noise_freq=0.05, sparse_penal=1)

In [None]:
%%opts Curve [width=1200] {+framewise}
visualize_temporal_update(YrA, C_temporal_it2, S_temporal_it2, g_temporal_it2)

### save results

In [None]:
%%time
save_variable(A_spatial_it2.rename('A'), dpath, 'minian', meta_dict=meta_dict)
save_variable(C_temporal_it2.rename('C'), dpath, 'minian', meta_dict=meta_dict)
save_variable(S_temporal_it2.rename('S'), dpath, 'minian', meta_dict=meta_dict)
save_variable(g_temporal_it2.rename('g'), dpath, 'minian', meta_dict=meta_dict)
save_variable(b_spatial_it2.rename('b'), dpath, 'minian', meta_dict=meta_dict)
save_variable(f_spatial_it2.rename('f'), dpath, 'minian', meta_dict=meta_dict)

### visualization

In [None]:
minian = xr.open_dataset(os.path.join(dpath, 'minian.nc'), chunks=dict(height=chunks['height'], width=chunks['width'], unit_id=chunks['unit_id'], frame=chunks['frame']), autoclose=True)

In [None]:
cnmfviewer = CNMFViewer(minian, minian['Y'])

In [None]:
cnmfviewer.show()