# Setting up

## set module paths and data path

In [None]:
minian_path = "."
dpath = "./demo_movies"
meta_dict = {"session_id": -1, "session": -2, "animal": -3}
chunks = {"frame": 1000, "height": 50, "width": 50, "unit_id": 100}
in_memory = False

## load modules

In [None]:
%%capture
%load_ext autoreload
%autoreload 2
import sys
import os
sys.path.append(minian_path)
import gc
import psutil
import numpy as np
import xarray as xr
import holoviews as hv
import paramnb
import matplotlib.pyplot as plt
import bokeh.plotting as bpl
import dask.array as da
import pandas as pd
import dask
import datashader as ds
import lantern as lant
from holoviews.operation.datashader import datashade, regrid, dynspread
from datashader.colors import Sets1to3
from dask.diagnostics import ProgressBar
from IPython.core.display import display, HTML
from dask.distributed import Client, progress, LocalCluster, fire_and_forget
from minian.utilities import load_videos, scale_varr, scale_varr_da, save_variable, open_minian, save_minian
from minian.preprocessing import remove_brightspot, gradient_norm, denoise, remove_background, stripe_correction
from minian.motion_correction import estimate_shift_fft, apply_shifts, interpolate_frame, mask_shifts
from minian.initialization import seeds_init, gmm_refine, pnr_refine, intensity_refine, ks_refine, seeds_merge, initialize
from minian.cnmf import psd_welch, psd_fft, get_noise, update_spatial, update_temporal, unit_merge, smooth_sig
from minian.visualization import VArrayViewer, CNMFViewer, generate_videos, visualize_spatial_update, visualize_temporal_update

## module initialization

In [None]:
dpath = os.path.abspath(dpath)
hv.notebook_extension('bokeh', width=100)
lant.VariableInspector()

# Pre-processing

## loading videos and visualization

In [None]:
%%time
varr = load_videos(dpath, in_memory=in_memory, dtype=np.float32, downsample=dict(frame=2))

In [None]:
if in_memory:
    varr = varr.persist()

In [None]:
%%output size=60
%%opts Image (cmap='Viridis')
vaviewer = VArrayViewer(varr, framerate=5, compute=True)
display(vaviewer.widgets)
vaviewer.show()

## subset part of video

In [None]:
varr_ref = varr
# varr_ref = varr.sel(frame=slice(None, 10000))

In [None]:
varr_ref = varr_ref.chunk(dict(frame=int(chunks['frame']/10), height=-1, width=-1))

## stripe correction

In [None]:
%%time
varr_ref = stripe_correction(varr_ref)
if in_memory:
    with ProgressBar():
        varr_ref = varr_ref.persist()

## glow removal

In [None]:
varr_ref = remove_background(varr_ref, method='uniform', wnd=51)
if in_memory:
    with ProgressBar():
        varr_ref = varr_ref.persist()

## bright spots removal

In [None]:
%%time
varr_ref = remove_brightspot(varr_ref, thres=2)
if in_memory:
    with ProgressBar():
        varr_ref = varr_ref.persist()

## denoise

In [None]:
%%time
varr_ref = denoise(varr_ref, 'gaussian', sigmaX=0, ksize=(3, 3))
if in_memory:
    with ProgressBar():
        varr_ref = varr_ref.persist()

# motion correction

## estimate shifts

In [None]:
%%time
res = estimate_shift_fft(varr_ref, on='first', pct_thres=99.9)
if in_memory:
    with ProgressBar():
        res = res.compute()
shifts = res.sel(variable = ['height', 'width'])
corr = res.sel(variable='corr')

## visualization of shifts

In [None]:
%%output size=100
%%opts Curve [width=500, tools=['hover']]
hv.NdOverlay(dict(width=hv.Curve(shifts.sel(variable='width')), height=hv.Curve(shifts.sel(variable='height'))))

## masking and interpolation

In [None]:
%%time
shifts_ma, mask = mask_shifts(varr_ref, corr, shifts, z_thres=-1.5)

In [None]:
%%time
varr_ref = interpolate_frame(varr_ref.compute().rename('varr_mc'), mask)

## determine shifts

### take cumulative sum if `on='perframe'` when estimating shifts

In [None]:
%%time
shifts_final = shifts.cumsum('frame')
shifts_final = xr.apply_ufunc(da.around, shifts_final.fillna(0), dask='allowed').astype(int)

### use raw shifts otherwise

In [None]:
shifts_final = xr.apply_ufunc(da.around, shifts.fillna(0), dask='allowed').astype(int)

## visualization of final shifts

In [None]:
%%output size=100
%%opts Curve [width=500, tools=['hover']]
hv.NdOverlay(dict(width=hv.Curve(shifts_final.sel(variable='width')), height=hv.Curve(shifts_final.sel(variable='height'))))

## apply shifts

In [None]:
varr_mc = apply_shifts(varr_ref, shifts_final)
varr_mc = varr_mc.ffill('height').bfill('height').ffill('width').bfill('width')
if in_memory:
    with ProgressBar():
        varr_mc = varr_mc.persist()

## visualization of motion-correction

In [None]:
%%output size=60 fps=5
%%opts Image (cmap='Viridis')
vaviewer = VArrayViewer(varr_mc.rename('varr_mc'), framerate=5)
display(vaviewer.widgets)
vaviewer.show()

## save result as DataSet

In [None]:
%%time
with ProgressBar():
    save_minian(varr_mc.rename('org'), dpath, 'minian', meta_dict=meta_dict, backend='zarr')

# background removal

## load in from disk

In [None]:
varr_mc = open_minian(dpath, 'minian', backend='zarr')['org']

## background removal

In [None]:
%%time
Y = remove_background(varr_mc, method='tophat', wnd=10)
if in_memory:
    with ProgressBar(), dask.config.set(scheduler='processes'):
        Y = Y.persist()

## normalization

In [None]:
%%time
Y = scale_varr(Y)
if in_memory:
    with ProgressBar():
        Y = Y.persist()

## visualization of background removal

In [None]:
%%output size=60
%%opts Image (cmap='Viridis')
vaviewer = VArrayViewer(Y.rename('Y'), framerate=5)
display(vaviewer.widgets)
vaviewer.show()

In [None]:
%%time
with ProgressBar(), dask.config.set(scheduler='processes'):
    save_minian(Y.rename('Y'), dpath, 'minian', meta_dict=meta_dict, backend='zarr')

# initialization

In [None]:
%%time
minian = open_minian(dpath, 'minian', backend='zarr')

In [None]:
Y = minian['Y']

## generating over-complete set of seeds

In [None]:
%%time
seeds = seeds_init(Y, method='rolling')

In [None]:
%%time
with ProgressBar():
    max_proj = Y.max('frame').compute()

In [None]:
Y_flt = Y.stack(spatial=['height', 'width'])

In [None]:
opts_im = dict(plot=dict(height=480, width=752), style=dict(cmap='Viridis'))
opts_pts = dict(
    plot=dict(height=480, width=752, size_index='seeds', tools=['hover']),
    style=dict(fill_alpha=0.6, line_alpha=0, fill_color='white'))
(regrid(hv.Image(max_proj, kdims=['width', 'height'])).opts(**opts_im)
 * hv.Points(seeds, kdims=['width', 'height'], vdims=['index', 'seeds']).opts(**opts_pts))

## gaussian-mixture-model refine

In [None]:
%%time
seeds_gmm = gmm_refine(Y_flt, seeds)

In [None]:
opts_im = dict(plot=dict(height=480, width=752), style=dict(cmap='Viridis'))
opts_pts = dict(
    plot=dict(height=480, width=752, size_index='seeds', color_index='mask_gmm', tools=['hover']),
    style=dict(fill_alpha=0.8, line_alpha=0, cmap='Set1'))
(regrid(hv.Image(max_proj, kdims=['width', 'height'])).opts(**opts_im)
 * hv.Points(seeds_gmm, kdims=['width', 'height'], vdims=['index', 'seeds', 'mask_gmm']).opts(**opts_pts))

## peak-noise-ratio refine

In [None]:
%%time
noise_freq_list = [0.005, 0.01, 0.02, 0.06, 0.2, 0.3, 0.45]
example_seeds = seeds_gmm[seeds_gmm['mask_gmm']].sample(12, axis='rows')
example_trace = (Y_flt
                 .sel(spatial=[tuple(hw) for hw in example_seeds[['height', 'width']].values])
                 .assign_coords(spatial=np.arange(12)))
smooth_dict = dict()
for freq in noise_freq_list:
    trace_smth = smooth_sig(example_trace, freq)
    with ProgressBar():
        trace_smth = trace_smth.compute()
    hv_trace = (hv.Dataset(trace_smth, kdims=['spatial', 'frame'])
                .to(hv.Curve, kdims=['frame']).layout('spatial'))
    smooth_dict[freq] = hv_trace

In [None]:
%%output size=80
hv.HoloMap(smooth_dict, kdims=['noise_freq']).collate()

In [None]:
seeds_pnr, pnr = pnr_refine(Y_flt, seeds_gmm[seeds_gmm['mask_gmm']], noise_freq=0.06, thres=.9)

In [None]:
opts_im = dict(plot=dict(height=480, width=752), style=dict(cmap='Viridis'))
opts_pts = dict(
    plot=dict(height=480, width=752, size_index='seeds', color_index='mask_pnr', tools=['hover']),
    style=dict(fill_alpha=0.8, line_alpha=0, cmap='Set1'))
(regrid(hv.Image(max_proj, kdims=['width', 'height'])).opts(**opts_im)
 * hv.Points(seeds_pnr, kdims=['width', 'height'], vdims=['index', 'seeds', 'mask_pnr']).opts(**opts_pts))

## intensity refine

In [None]:
%%time
seeds_int = intensity_refine(max_proj, seeds_pnr[seeds_pnr['mask_pnr']])

In [None]:
opts_im = dict(plot=dict(height=480, width=752), style=dict(cmap='Viridis'))
opts_pts = dict(
    plot=dict(height=480, width=752, size_index='seeds', color_index='mask_int', tools=['hover']),
    style=dict(fill_alpha=0.8, line_alpha=0, cmap='Set1'))
(regrid(hv.Image(max_proj, kdims=['width', 'height'])).opts(**opts_im)
 * hv.Points(seeds_int, kdims=['width', 'height'], vdims=['index', 'seeds', 'mask_int']).opts(**opts_pts))

## ks refine

In [None]:
%%time
seeds_ks = ks_refine(Y_flt, seeds_int[seeds_int['mask_int']], sig=0.1)

In [None]:
opts_im = dict(plot=dict(height=480, width=752), style=dict(cmap='Viridis'))
opts_pts = dict(
    plot=dict(height=480, width=752, size_index='seeds', color_index='mask_ks', tools=['hover']),
    style=dict(fill_alpha=0.8, line_alpha=0, cmap='Set1'))
(regrid(hv.Image(max_proj, kdims=['width', 'height'])).opts(**opts_im)
 * hv.Points(seeds_ks, kdims=['width', 'height'], vdims=['index', 'seeds', 'mask_ks']).opts(**opts_pts))

## merge seeds

In [None]:
%%time
seeds_final = seeds_gmm[seeds_gmm['mask_gmm']].reset_index(drop=True)
seeds_mrg = seeds_merge(Y_flt, seeds_final, thres_dist=5, thres_corr=0.7)

In [None]:
opts_im = dict(plot=dict(height=480, width=752), style=dict(cmap='Viridis'))
opts_pts = dict(
    plot=dict(height=480, width=752, size_index='seeds', color_index='mask_mrg', tools=['hover']),
    style=dict(fill_alpha=0.8, line_alpha=0, cmap='Set1'))
(regrid(hv.Image(max_proj, kdims=['width', 'height'])).opts(**opts_im)
 * hv.Points(seeds_mrg, kdims=['width', 'height'], vdims=['index', 'seeds', 'mask_mrg']).opts(**opts_pts))

## initialize spatial and temporal matrices from seeds

In [None]:
%%time
Y = Y.chunk(dict(frame=-1, height=200, width=200))
A, C, b, f = initialize(Y, seeds_mrg[seeds_mrg['mask_mrg']])

In [None]:
opts = dict(plot=dict(height=300, width=300), style=dict(cmap='Viridis'))
(regrid(hv.Image(A.sum('unit_id'), kdims=['width', 'height'])).opts(**opts)
 + regrid(hv.Image(C, kdims=['frame', 'unit_id'])).opts(**opts))

## save results

In [None]:
%%time
minian.close()
save_minian(A.rename('A_init').rename(unit_id='unit_id_init'), dpath, 'minian', meta_dict=meta_dict, backend='zarr')
save_minian(C.rename('C_init').rename(unit_id='unit_id_init'), dpath, 'minian', meta_dict=meta_dict, backend='zarr')
save_minian(b.rename('b_init'), dpath, 'minian', meta_dict=meta_dict, backend='zarr')
save_minian(f.rename('f_init'), dpath, 'minian', meta_dict=meta_dict, backend='zarr')

# CNMF

## loading data

In [None]:
%%time
chk = chunks.copy()
chk['unit_id_init'] = chk.pop('unit_id')
minian = open_minian(dpath, 'minian', backend='zarr', chunks=chk)
Y = minian['Y']
A_init = minian['A_init'].rename(unit_id_init='unit_id')
C_init = minian['C_init'].rename(unit_id_init='unit_id')
b_init = minian['b_init']
f_init = minian['f_init']

## estimate spatial noise

In [None]:
%%time
psd = psd_welch(Y)
with ProgressBar():
    psd = psd.persist()

In [None]:
%%opts Image [height=300, width=800, colorbar=True, logz=True] (cmap='Viridis')
psd_flt = psd.stack(spatial=['height', 'width'])
hv_psd = hv.Image(
    psd_flt.assign_coords(spatial=range(psd_flt.sizes['spatial'])).rename('psd'),
    kdims=['spatial', 'freq'])
regrid(hv_psd).redim.range(psd=(0, 5e-3))

In [None]:
sn_spatial = get_noise(psd, noise_range=(0.06, 0.5))

## test parameters for spatial update

In [None]:
units = np.random.choice(A_init.coords['unit_id'], 10)
units.sort()

In [None]:
%%time
sprs_ls = [0.1, 0.5, 1]
A_dict = dict()
C_dict = dict()
for cur_sprs in sprs_ls:
    cur_A, cur_b, cur_C, cur_f = update_spatial(
        Y, A_init.sel(unit_id=units),
        b_init, C_init.sel(unit_id=units), f_init, sn_spatial, dl_wnd=20, sparse_penal=cur_sprs)
    if cur_A.sizes['unit_id']:
        A_dict[cur_sprs] = cur_A.compute()
        C_dict[cur_sprs] = cur_C.compute()
hv_res = visualize_spatial_update(A_dict, C_dict, kdims=['sparse penalty'])

In [None]:
%%output size=80
hv_res

## first spatial update

In [None]:
%%time
A_spatial, b_spatial, C_spatial, f_spatial = update_spatial(
    Y, A_init, b_init, C_init, f_init, sn_spatial, sparse_penal=0.1, post_scal=True)
A_spatial = A_spatial.chunk(dict(unit_id = chunks['unit_id']))

In [None]:
%%output size=60
opts = dict(plot=dict(height=A_init.sizes['height'], width=A_init.sizes['width'], colorbar=True), style=dict(cmap='Viridis'))
(regrid(hv.Image(A_init.sum('unit_id').compute().rename('A'), kdims=['width', 'height'])).opts(**opts).relabel("Spatial Footprints Initial")
+ regrid(hv.Image((A_init.fillna(0) > 0).sum('unit_id').compute().rename('A'), kdims=['width', 'height']), aggregator='max').opts(**opts).relabel("Binary Spatial Footprints Initial")
+ regrid(hv.Image(A_spatial.sum('unit_id').compute().rename('A'), kdims=['width', 'height'])).opts(**opts).relabel("Spatial Footprints First Update")
+ regrid(hv.Image((A_spatial > 0).sum('unit_id').compute().rename('A'), kdims=['width', 'height']), aggregator='max').opts(**opts).relabel("Binary Spatial Footprints First Update")).cols(2)

## test parameters for temporal update

In [None]:
units = np.random.choice(A_spatial.coords['unit_id'], 10)
units.sort()

In [None]:
%%time
import itertools as itt
p_ls = [2]
sprs_ls = [1, 3, 5]
add_ls = [20]
noise_ls = [0.06]
YA_dict, C_dict, S_dict, g_dict, sig_dict, A_dict = [dict() for _ in range(6)]
for cur_p, cur_sprs, cur_add, cur_noise in itt.product(p_ls, sprs_ls, add_ls, noise_ls):
    ks = (cur_p, cur_sprs, cur_add, cur_noise)
    print("p:{}, sparse penalty:{}, additional lag:{}, noise frequency:{}"
          .format(cur_p, cur_sprs, cur_add, cur_noise))
    YrA, cur_C, cur_S, cur_B, cur_C0, cur_sig, cur_g, cur_scal = update_temporal(
        Y, A_spatial.sel(unit_id=units), b_spatial, C_spatial.sel(unit_id=units),
        f_spatial, sn_spatial, sparse_penal=cur_sprs, p=cur_p, use_spatial=False, use_smooth=True,
        add_lag = cur_add, noise_freq=cur_noise, cvx_sched="processes", chk=chunks)
    cur_A = A_spatial.sel(unit_id = cur_C.coords['unit_id'])
    YA_dict[ks], C_dict[ks], S_dict[ks], g_dict[ks], sig_dict[ks], A_dict[ks] = (
        YrA, cur_C, cur_S, cur_g, cur_sig, cur_A)
hv_res = visualize_temporal_update(
    YA_dict, C_dict, S_dict, g_dict, sig_dict, A_dict,
    kdims=['p', 'sparse penalty', 'additional lag', 'noise frequency'])

In [None]:
%%output size=60
hv_res

## first temporal update

In [None]:
%%time
YrA, C_temporal, S_temporal, B_temporal, C0_temporal, sig_temporal, g_temporal, scale = update_temporal(
    Y, A_spatial,
    b_spatial, C_spatial, f_spatial, sn_spatial, jac_thres=0.1,
    noise_freq=0.06, sparse_penal=1, p=2, add_lag=20, use_spatial=False,
    chk=chunks, cvx_sched='processes')
A_temporal = (A_spatial.sel(unit_id = C_temporal.coords['unit_id'])
              .chunk(chunks['unit_id']))

In [None]:
%%output size=60
opts_im = dict(plot=dict(height=500, width=1000, colorbar=True), style=dict(cmap='Viridis'))
ranges = dict(c=(0, 1.5), s=(0, 0.04))
(regrid(hv.Image(C_init.rename('c'), kdims=['frame', 'unit_id'])).opts(**opts_im).relabel("Temporal Trace Initial").redim.range(**ranges)
 + hv.Div('')
 + regrid(hv.Image(C_temporal.rename('c'), kdims=['frame', 'unit_id'])).opts(**opts_im).relabel("Temporal Trace First Update").redim.range(**ranges)
 + regrid(hv.Image(S_temporal.rename('s'), kdims=['frame', 'unit_id'])).opts(**opts_im).relabel("Spikes First Update").redim.range(**ranges)).cols(2)

In [None]:
%%output size=60
h, w = A_spatial.sizes['height'], A_spatial.sizes['width']
im_opts = dict(plot=dict(height=h, width=w), style=dict(cmap='Viridis'))
cr_opts = dict(plot=dict(height=h, width=2*w))
bad_units = list(set(A_spatial.coords['unit_id'].values) - set(A_temporal.coords['unit_id'].values))
bad_units.sort()
(datashade(hv.Dataset(YrA.sel(unit_id=bad_units).rename('raw')).to(hv.Curve, kdims=['frame'])).opts(**cr_opts).relabel("Temporal Trace")
 + hv.Div('')
 + regrid(hv.Dataset(A_spatial.sel(unit_id=bad_units).rename('A')).to(hv.Image, kdims=['width', 'height'])).opts(**im_opts).relabel("Spatial Footprint")
 + regrid(hv.Image(A_temporal.sum('unit_id').compute().rename('A'), kdims=['width', 'height'])).opts(**im_opts).relabel("Spatial Footprints of Accepted Units")).cols(2)

In [None]:
%%output size=60
visualize_temporal_update(YrA, C_temporal, S_temporal, g_temporal, sig_temporal, A_temporal)

## merge units

In [None]:
%%time
A_mrg, sig_mrg = unit_merge(A_temporal, sig_temporal, thres_corr=0.9)

In [None]:
%%output size=70
opts_im = dict(plot=dict(height=500, width=1000, colorbar=True), style=dict(cmap='Viridis'))
ranges = dict(c=(0, 4), s=(0, 0.04))
(regrid(hv.Image(sig_temporal.rename('c'), kdims=['frame', 'unit_id'])).relabel("Temporal Signals Before Merge").opts(**opts_im).redim.range(**ranges) +
regrid(hv.Image(sig_mrg.rename('c'), kdims=['frame', 'unit_id'])).relabel("Temporal Signals After Merge").opts(**opts_im).redim.range(**ranges))

## test parameters for spatial update

In [None]:
A_mrg, sig_mrg = (A_mrg.chunk({c: chunks[c] for c in ['height', 'width', 'unit_id']}),
                  sig_mrg.chunk({c: chunks[c] for c in ['frame', 'unit_id']}))

In [None]:
units = np.random.choice(A_mrg.coords['unit_id'], 10)
units.sort()

In [None]:
%%time
sprs_ls = [0.05, 0.1, 0.5]
A_dict = dict()
C_dict = dict()
for cur_sprs in sprs_ls:
    cur_A, cur_b, cur_C, cur_f = update_spatial(
        Y, A_mrg.sel(unit_id=units),
        b_init, sig_mrg.sel(unit_id=units), f_init, sn_spatial, dl_wnd=20, sparse_penal=cur_sprs)
    if cur_A.sizes['unit_id']:
        A_dict[cur_sprs] = cur_A
        C_dict[cur_sprs] = cur_C
hv_res = visualize_spatial_update(A_dict, C_dict, kdims=['sparse penalty'])

In [None]:
%%output size=80
hv_res

## second spatial update

In [None]:
%%time
A_spatial_it2, b_spatial_it2, C_spatial_it2, f_spatial_it2 = update_spatial(
    Y, A_mrg, b_spatial, sig_mrg, f_spatial, sn_spatial, sparse_penal=0.05, dl_wnd=5)

In [None]:
%%output size=60
opts = dict(plot=dict(height=A_init.sizes['height'], width=A_init.sizes['width'], colorbar=True), style=dict(cmap='Viridis'))
(regrid(hv.Image(A_spatial.sum('unit_id').compute().rename('A'), kdims=['width', 'height'])).opts(**opts).relabel("Spatial Footprints First Update")
+ regrid(hv.Image((A_spatial.fillna(0) > 0).sum('unit_id').compute().rename('A'), kdims=['width', 'height']), aggregator='max').opts(**opts).relabel("Binary Spatial Footprints First Update")
+ regrid(hv.Image(A_spatial_it2.sum('unit_id').compute().rename('A'), kdims=['width', 'height'])).opts(**opts).relabel("Spatial Footprints Second Update")
+ regrid(hv.Image((A_spatial_it2 > 0).sum('unit_id').compute().rename('A'), kdims=['width', 'height']), aggregator='max').opts(**opts).relabel("Binary Spatial Footprints Second Update")).cols(2)

## test parameters for temporal update

In [None]:
units = np.random.choice(A_spatial_it2.coords['unit_id'], 10)
units.sort()

In [None]:
%%time
import itertools as itt
p_ls = [2]
sprs_ls = [0.5, 1, 3]
add_ls = [20]
noise_ls = [0.06]
YA_dict, C_dict, S_dict, g_dict, sig_dict, A_dict = [dict() for _ in range(6)]
for cur_p, cur_sprs, cur_add, cur_noise in itt.product(p_ls, sprs_ls, add_ls, noise_ls):
    ks = (cur_p, cur_sprs, cur_add, cur_noise)
    print("p:{}, sparse penalty:{}, additional lag:{}, noise frequency:{}"
          .format(cur_p, cur_sprs, cur_add, cur_noise))
    YrA, cur_C, cur_S, cur_B, cur_C0, cur_sig, cur_g, cur_scal = update_temporal(
        Y, A_spatial_it2.sel(unit_id=units), b_spatial, C_spatial_it2.sel(unit_id=units),
        f_spatial, sn_spatial, sparse_penal=cur_sprs, p=cur_p, use_spatial=False, use_smooth=True,
        add_lag = cur_add, noise_freq=cur_noise, cvx_sched="processes", chk=chunks)
    cur_A = A_spatial.sel(unit_id = cur_C.coords['unit_id'])
    YA_dict[ks], C_dict[ks], S_dict[ks], g_dict[ks], sig_dict[ks], A_dict[ks] = (
        YrA, cur_C, cur_S, cur_g, cur_sig, cur_A)
hv_res = visualize_temporal_update(
    YA_dict, C_dict, S_dict, g_dict, sig_dict, A_dict,
    kdims=['p', 'sparse penalty', 'additional lag', 'noise frequency'])

In [None]:
%%output size=60
hv_res

## second temporal update

In [None]:
%%time
YrA, C_temporal_it2, S_temporal_it2, B_temporal_it2, C0_temporal_it2, sig_temporal_it2, g_temporal_it2, scale_temporal_it2 = update_temporal(
    Y, A_spatial_it2, b_spatial_it2, C_spatial_it2, f_spatial_it2, sn_spatial, jac_thres=0.1,
    noise_freq=0.06, sparse_penal=1, p=2, add_lag=20, max_iters=500, chk=chunks)
A_temporal_it2 = A_spatial_it2.sel(unit_id=C_temporal_it2.coords['unit_id'])

In [None]:
%%output size=60
opts_im = dict(plot=dict(height=500, width=1000, colorbar=True), style=dict(cmap='Viridis'))
ranges = dict(c=(0, 1.5), s=(0, 0.04))
(regrid(hv.Image(C_temporal.rename('c'), kdims=['frame', 'unit_id'])).opts(**opts_im).relabel("Temporal Trace First Update").redim.range(**ranges)
 + regrid(hv.Image(S_temporal.rename('s'), kdims=['frame', 'unit_id'])).opts(**opts_im).relabel("Spikes First Update").redim.range(**ranges)
 + regrid(hv.Image(C_temporal_it2.rename('c').rename(unit_id='unit_id_it2'), kdims=['frame', 'unit_id_it2'])).opts(**opts_im).relabel("Temporal Trace Second Update").redim.range(**ranges)
 + regrid(hv.Image(S_temporal_it2.rename('s').rename(unit_id='unit_id_it2'), kdims=['frame', 'unit_id_it2'])).opts(**opts_im).relabel("Spikes Second Update").redim.range(**ranges)).cols(2)

In [None]:
%%output size=60
h, w = A_spatial_it2.sizes['height'], A_spatial_it2.sizes['width']
im_opts = dict(plot=dict(height=h, width=w), style=dict(cmap='Viridis'))
cr_opts = dict(plot=dict(height=h, width=2*w))
bad_units = list(set(A_spatial_it2.coords['unit_id'].values) - set(A_temporal_it2.coords['unit_id'].values))
bad_units.sort()
(datashade(hv.Dataset(YrA.sel(unit_id=bad_units).rename('raw')).to(hv.Curve, kdims=['frame'])).opts(**cr_opts).relabel("Temporal Trace")
 + hv.Div('')
 + regrid(hv.Dataset(A_spatial_it2.sel(unit_id=bad_units).rename('A')).to(hv.Image, kdims=['width', 'height'])).opts(**im_opts).relabel("Spatial Footprint")
 + regrid(hv.Image(A_temporal_it2.sum('unit_id').compute().rename('A'), kdims=['width', 'height'])).opts(**im_opts).relabel("Spatial Footprints of Accepted Units")).cols(2)

In [None]:
%%output size=60
visualize_temporal_update(YrA, C_temporal_it2, S_temporal_it2, g_temporal_it2, sig_temporal_it2, A_temporal_it2)

## merge units

In [None]:
%%time
A_mrg_it2, C_mrg_it2, add_list = unit_merge(A_temporal_it2, C_temporal_it2, [S_temporal_it2, C0_temporal_it2, g_temporal_it2, B_temporal_it2], thres_corr=0.9)

In [None]:
%%output size=70
opts_im = dict(plot=dict(height=500, width=1000, colorbar=True), style=dict(cmap='Viridis'))
ranges = dict(c=(0, 2), s=(0, 0.04))
(regrid(hv.Image(C_temporal_it2.rename('c'), kdims=['frame', 'unit_id'])).relabel("Temporal Signals Before Merge").opts(**opts_im).redim.range(**ranges) +
regrid(hv.Image(C_mrg_it2.rename('c'), kdims=['frame', 'unit_id'])).relabel("Temporal Signals After Merge").opts(**opts_im).redim.range(**ranges))

In [None]:
S_mrg_it2, C0_mrg_it2, g_mrg_it2, B_mrg_it2 = add_list[:]

## save results

In [None]:
%%time
minian.close()
save_minian(A_mrg_it2.rename('A').compute(), dpath, 'minian', meta_dict=meta_dict, backend='zarr')
save_minian(C_mrg_it2.rename('C').compute(), dpath, 'minian', meta_dict=meta_dict, backend='zarr')
save_minian(S_mrg_it2.rename('S').compute(), dpath, 'minian', meta_dict=meta_dict, backend='zarr')
save_minian(g_mrg_it2.rename('g').compute(), dpath, 'minian', meta_dict=meta_dict, backend='zarr')
save_minian(C0_mrg_it2.rename('C0').compute(), dpath, 'minian', meta_dict=meta_dict, backend='zarr')
save_minian(B_mrg_it2.rename('B').compute(), dpath, 'minian', meta_dict=meta_dict, backend='zarr')
save_minian(b_spatial_it2.rename('b').compute(), dpath, 'minian', meta_dict=meta_dict, backend='zarr')
save_minian(f_spatial_it2.rename('f').compute(), dpath, 'minian', meta_dict=meta_dict, backend='zarr')

## visualization

In [None]:
minian = open_minian(dpath, 'minian', backend='zarr')

In [None]:
%%time
generate_videos(minian, os.path.join(dpath, "minian.mp4"), chk=dict(height=100, width=100, frame=1000))

In [None]:
%%time
cnmfviewer = CNMFViewer(minian)

In [None]:
%output size=60
cnmfviewer.show()