## load modules

In [None]:
%%capture
%load_ext autoreload
%autoreload 2
import itertools as itt
import os
import sys
from typing import Optional, Union, Dict, List

import numpy as np
import numpy.typing as npt
import xarray as xr
from dask.distributed import Client, LocalCluster
import dask.array as da
from IPython.display import display
from vispy.color import colormap

## set path and parameters

In [None]:
# Set up Initial Basic Parameters#
minian_path = "."
dpath = "./demo_movies/"
minian_ds_path = os.path.join(dpath, "minian")
intpath = "./minian_intermediate"
subset = dict(frame=slice(0, None))
subset_mc = None
interactive = True
output_size = 100
n_workers = int(os.getenv("MINIAN_NWORKERS", 4))
param_save_minian = {
    "dpath": minian_ds_path,
    "meta_dict": dict(session=-1, animal=-2),
    "overwrite": True,
}

# Pre-processing Parameters#
param_load_videos = {
    "pattern": r"msCam[0-9]+\.avi$",
    "dtype": np.uint8,
    "downsample": dict(frame=1, height=1, width=1),
    "downsample_strategy": "subset",
}
param_denoise = {"method": "median", "ksize": 7}
param_background_removal = {"method": "tophat", "wnd": 15}

# Motion Correction Parameters#
subset_mc = None
param_estimate_motion = {"dim": "frame"}

# Initialization Parameters#
param_seeds_init = {
    "wnd_size": 1000,
    "method": "rolling",
    "stp_size": 500,
    "max_wnd": 15,
    "diff_thres": 3,
}
param_pnr_refine = {"noise_freq": 0.06, "thres": 1}
param_ks_refine = {"sig": 0.05}
param_seeds_merge = {"thres_dist": 10, "thres_corr": 0.8, "noise_freq": 0.06}
param_initialize = {"thres_corr": 0.8, "wnd": 10, "noise_freq": 0.06}
param_init_merge = {"thres_corr": 0.8}

# CNMF Parameters#
param_get_noise = {"noise_range": (0.06, 0.5)}
param_first_spatial = {
    "dl_wnd": 10,
    "sparse_penal": 0.01,
    "size_thres": (25, None),
}
param_first_temporal = {
    "noise_freq": 0.06,
    "sparse_penal": 1,
    "p": 1,
    "add_lag": 20,
    "jac_thres": 0.2,
}
param_first_merge = {"thres_corr": 0.8}
param_second_spatial = {
    "dl_wnd": 10,
    "sparse_penal": 0.01,
    "size_thres": (25, None),
}
param_second_temporal = {
    "noise_freq": 0.06,
    "sparse_penal": 1,
    "p": 1,
    "add_lag": 20,
    "jac_thres": 0.4,
}

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MINIAN_INTERMEDIATE"] = intpath

## import minian

In [None]:
%%capture
%load_ext autoreload
%autoreload 2
sys.path.append(minian_path)
from minian.cnmf import (
    compute_AtC,
    compute_trace,
    get_noise_fft,
    smooth_sig,
    unit_merge,
    update_spatial,
    update_temporal,
    update_background,
)
from minian.initialization import (
    gmm_refine,
    initA,
    initC,
    intensity_refine,
    ks_refine,
    pnr_refine,
    seeds_init,
    seeds_merge,
)
from minian.motion_correction import apply_transform, estimate_motion
from minian.preprocessing import denoise, remove_background
from minian.utilities import (
    TaskAnnotation,
    get_optimal_chk,
    load_videos,
    open_minian,
    save_minian,
)
from minian.visualization import (
    visualize_raw_video,
    visualize_before_after,
    visualize_preprocess,
    visualize_motion,
    visualize_seeds,
    visualize_pnr_refine,
    visualize_initialization,
    visualize_spatial_params,
    visualize_spatial_update,
    visualize_spatial_bg,
    visualize_temporal_params,

    write_video,
    generate_videos
)

## start cluster

In [None]:
cluster = LocalCluster(
    n_workers=n_workers,
    memory_limit="20GB",
    resources={"MEM": 1},
    threads_per_worker=2,
    dashboard_address=":8787",
)
annt_plugin = TaskAnnotation()
cluster.scheduler.add_plugin(annt_plugin)
client = Client(cluster)

# Pre-processing

## loading videos and visualization

In [None]:
varr = load_videos(dpath, **param_load_videos)
chk, _ = get_optimal_chk(varr, dtype=float)

In [None]:
%%time
varr = save_minian(
    varr.rename("varr"),
    dpath=intpath,
    overwrite=True,
    chunks={"frame": chk["frame"], "height": -1, "width": -1}
)

## visualize raw data and optionally set roi for motion correction

In [None]:
%gui qt
visualize_raw_video(varr, title='Original video')

In [None]:
## ** DEFINE SUBSET_MC HERE ** ##

## subset part of video

In [None]:
varr_ref = varr.sel(subset)

## glow removal and visualization

In [None]:
%%time
varr_min = varr_ref.min("frame").compute()
varr_ref = varr_ref - varr_min

In [None]:
%gui qt
visualize_before_after(
    before=varr,
    after=varr_ref,
    title='Glow Removal'
)

## denoise

In [None]:
%gui qt
visualize_preprocess(
    frame=varr_ref.isel(frame=0).compute(),
    func=denoise,
    title='Denoise',
    method=['median'],
    ksize=[5, 7, 9]
)

In [None]:
varr_ref = denoise(varr_ref, **param_denoise)

## background removal

In [None]:
%gui qt
visualize_preprocess(
    frame=varr_ref.isel(frame=0).compute(),
    func=remove_background,
    title='Background subtraction',
    method=["tophat"],
    wnd=[10, 15, 20],
)

In [None]:
varr_ref = remove_background(varr_ref, **param_background_removal)

## save result

In [None]:
%%time
varr_ref = save_minian(varr_ref.rename("varr_ref"), dpath=intpath, overwrite=True)

In [None]:
%gui qt
visualize_before_after(
    before=varr.max('frame'),
    after=varr_ref.max('frame'),
    title='Preprocessing before vs after'
)

# Motion Correction

## estimate motion

In [None]:
%%time
motion = estimate_motion(varr_ref.sel(subset_mc), **param_estimate_motion)

## save motion

In [None]:
%%time
motion = save_minian(
    motion.rename("motion"),
    **param_save_minian
)

## visualization of motion

In [None]:
%gui qt
visualize_motion(motion, magnify=False)

## apply transform

In [None]:
Y = apply_transform(varr_ref, motion, fill=0)

## save result

In [None]:
%%time
Y_fm_chk = save_minian(Y.astype(float).rename("Y_fm_chk"), intpath, overwrite=True)
Y_hw_chk = save_minian(
    Y_fm_chk.rename("Y_hw_chk"),
    intpath,
    overwrite=True,
    chunks={"frame": -1, "height": chk["height"], "width": chk["width"]},
)

## visualization of motion-correction

In [None]:
%gui qt
# video
visualize_before_after(
    before=varr_ref,
    after=Y_fm_chk.astype(np.float32),
    title='Before vs After motion correction (Video)'
)

In [None]:
%gui qt
# maximum intensity projections
visualize_before_after(
    before=varr_ref.max('frame'),
    after=Y_fm_chk.max('frame').astype(np.float32),
    title='Before vs After motion correction (MaxIP)'
)

## generate video for motion-correction

In [None]:
%%time
vid_arr = xr.concat([varr_ref, Y_fm_chk], "width").chunk({"width": -1})
write_video(vid_arr, "minian_mc.mp4", dpath)

# Initialization

## compute max projection

In [None]:
max_proj = save_minian(
    Y_fm_chk.max("frame").rename("max_proj"), **param_save_minian
).compute()

## generating over-complete set of seeds

In [None]:
%%time
seeds = seeds_init(Y_fm_chk, **param_seeds_init)

In [None]:
%gui qt
visualize_seeds(max_proj.astype(np.float32), seeds)

## peak-noise-ratio refine

In [None]:
%gui qt
# recommended not to use magnify and link_views together
visualize_pnr_refine(
    Y_hw_chk,
    example_seeds = seeds.sample(6, axis='rows'),
    noise_freq_list = [0.005, 0.01, 0.02, 0.06, 0.1, 0.2, 0.3, 0.45, 0.6, 0.8],
    magnify=False,
    link_views=False
)

In [None]:
%%time
seeds, pnr, gmm = pnr_refine(Y_hw_chk, seeds, **param_pnr_refine)

In [None]:
## **THIS FUNCTION NEEDS TO BE WRITTEN** ##
if gmm:
    display(visualize_gmm_fit(pnr, gmm, 100))
else:
    print("nothing to show")

In [None]:
%gui qt
visualize_seeds(max_proj.astype(np.float32), seeds, "mask_pnr")

## ks refine

In [None]:
%%time
seeds = ks_refine(Y_hw_chk, seeds, **param_ks_refine)

In [None]:
%gui qt
visualize_seeds(max_proj.astype(np.float32), seeds, "mask_ks")

## merge seeds

In [None]:
%%time
seeds_final = seeds[seeds["mask_ks"] & seeds["mask_pnr"]].reset_index(drop=True)
seeds_final = seeds_merge(Y_hw_chk, max_proj, seeds_final, **param_seeds_merge)

In [None]:
%gui qt
visualize_seeds(max_proj.astype(np.float32), seeds_final, "mask_mrg")

## initialize spatial matrix

In [None]:
%%time
A_init = initA(Y_hw_chk, seeds_final[seeds_final["mask_mrg"]], **param_initialize)
A_init = save_minian(A_init.rename("A_init"), intpath, overwrite=True)

## initialize temporal matrix

In [None]:
%%time
C_init = initC(Y_fm_chk, A_init)
C_init = save_minian(
    C_init.rename("C_init"), intpath, overwrite=True, chunks={"unit_id": 1, "frame": -1}
)

## merge units

In [None]:
%%time
A, C = unit_merge(A_init, C_init, **param_init_merge)
A = save_minian(A.rename("A"), intpath, overwrite=True)
C = save_minian(C.rename("C"), intpath, overwrite=True)
C_chk = save_minian(
    C.rename("C_chk"),
    intpath,
    overwrite=True,
    chunks={"unit_id": -1, "frame": chk["frame"]},
)

## initialize background terms

In [None]:
%%time
b, f = update_background(Y_fm_chk, A, C_chk)
f = save_minian(f.rename("f"), intpath, overwrite=True)
b = save_minian(b.rename("b"), intpath, overwrite=True)

## visualization of initialization

In [None]:
%gui qt
visualize_initialization(
    A, C, b, f
)

# CNMF

## estimate spatial noise

In [None]:
%%time
sn_spatial = get_noise_fft(Y_hw_chk, **param_get_noise)
sn_spatial = save_minian(sn_spatial.rename("sn_spatial"), intpath, overwrite=True)

## first spatial update

### parameter exploration

In [None]:
%%time
units = np.random.choice(A.coords["unit_id"], 10, replace=False)
units.sort()
A_sub = A.sel(unit_id=units).persist()
C_sub = C.sel(unit_id=units).persist()

sprs_ls = [0.005, 0.01, 0.05]
A_dict = dict()
C_dict = dict()
for cur_sprs in sprs_ls:
    cur_A, cur_mask, cur_norm = update_spatial(
        Y_hw_chk,
        A_sub,
        C_sub,
        sn_spatial,
        in_memory=True,
        dl_wnd=param_first_spatial["dl_wnd"],
        sparse_penal=cur_sprs,
    )
    if cur_A.sizes["unit_id"]:
        A_dict[cur_sprs] = cur_A.compute()
        C_dict[cur_sprs] = C_sub.sel(unit_id=cur_mask).compute()

In [None]:
%gui qt
visualize_spatial_params(
    units,
    A_dict,
    C_dict
)

### spatial update

In [None]:
%%time
A_new, mask, norm_fac = update_spatial(
    Y_hw_chk, A, C, sn_spatial, **param_first_spatial
)
C_new = save_minian(
    (C.sel(unit_id=mask) * norm_fac).rename("C_new"), intpath, overwrite=True
)
C_chk_new = save_minian(
    (C_chk.sel(unit_id=mask) * norm_fac).rename("C_chk_new"), intpath, overwrite=True
)

In [None]:
%%time
b_new, f_new = update_background(Y_fm_chk, A_new, C_chk_new)

### visualization of spatial footprints

In [None]:
%gui qt
visualize_spatial_update(
    A,
    A_new
)

### visualization of background

In [None]:
%gui qt
visualize_spatial_bg(
    b,
    f,
    b_new,
    f_new
)

### save results

In [None]:
%%time
A = save_minian(
    A_new.rename("A"),
    intpath,
    overwrite=True,
    chunks={"unit_id": 1, "height": -1, "width": -1},
)
b = save_minian(b_new.rename("b"), intpath, overwrite=True)
f = save_minian(
    f_new.chunk({"frame": chk["frame"]}).rename("f"), intpath, overwrite=True
)
C = save_minian(C_new.rename("C"), intpath, overwrite=True)
C_chk = save_minian(C_chk_new.rename("C_chk"), intpath, overwrite=True)

## first temporal update

### parameter exploration

In [None]:
units = np.random.choice(A.coords["unit_id"], 10, replace=False)
units.sort()
A_sub = A.sel(unit_id=units).persist()
C_sub = C_chk.sel(unit_id=units).persist()

params = dict(
    p_ls = [1],
    sprs_ls = [0.1, 0.5, 1, 2],
    add_ls = [20],
    noise_ls = [0.06]
    )

In [None]:
%%time
YA_dict, C_dict, S_dict, g_dict, sig_dict, A_dict = [dict() for _ in range(6)]
YrA = (
    compute_trace(Y_fm_chk, A_sub, b, C_sub, f)
    .persist()
    .chunk({"unit_id": 1, "frame": -1})
)
for cur_p, cur_sprs, cur_add, cur_noise in itt.product(
    params['p_ls'], params['sprs_ls'], params['add_ls'], params['noise_ls']
):
    ks = (cur_p, cur_sprs, cur_add, cur_noise)
    print(
        "p:{}, sparse penalty:{}, additional lag:{}, noise frequency:{}".format(
            cur_p, cur_sprs, cur_add, cur_noise
        )
    )
    cur_C, cur_S, cur_b0, cur_c0, cur_g, cur_mask = update_temporal(
        A_sub,
        C_sub,
        YrA=YrA,
        sparse_penal=cur_sprs,
        p=cur_p,
        use_smooth=True,
        add_lag=cur_add,
        noise_freq=cur_noise,
    )
    YA_dict[ks], C_dict[ks], S_dict[ks], g_dict[ks], sig_dict[ks], A_dict[ks] = (
        YrA.compute().astype(np.float32),
        cur_C.compute().astype(np.float32),
        cur_S.compute().astype(np.float32),
        cur_g.compute().astype(np.float32),
        (cur_C + cur_b0 + cur_c0).compute().astype(np.float32),
        A_sub.compute().astype(np.float32),
    )

In [None]:
%gui qt
visualize_temporal_params(
    units,
    params,
    YA_dict,
    C_dict,
    S_dict,
    g_dict,
    sig_dict,
    A_dict,
    magnify=True
)

### temporal update

In [None]:
%%time
YrA = save_minian(
    compute_trace(Y_fm_chk, A, b, C_chk, f).rename("YrA"),
    intpath,
    overwrite=True,
    chunks={"unit_id": 1, "frame": -1},
)

In [None]:
%%time
C_new, S_new, b0_new, c0_new, g, mask = update_temporal(
    A, C, YrA=YrA, **param_first_temporal
)

### visualization of temporal components

In [None]:
%gui qt
visualize_temporal_components(
    C=C,
    S=None,
    C_new=C_new,
    S_new=S_new,
    title='First temporal update'
)

### visualization of dropped units

In [None]:
%gui qt
params = {param:[param_first_temporal[param]] for param in ['noise_freq', 'sparse_penal', 'p', 'add_lag']}
param_key = tuple([param[0] for param in params.values()])
bad_units = mask.where(mask == False, drop=True).coords["unit_id"].values
if len(bad_units) > 0:
    visualize_temporal_params(
        units=YrA.sel(unit_id=mask).unit_id.values,
        params=params,
        YA_dict={param_key:YrA.sel(unit_id=mask)},
        C_dict={param_key:C_new},
        S_dict={param_key:S_new},
        g_dict={param_key:g},
        sig_dict={param_key:g},
        A_dict={param_key:A.sel(unit_id=mask)},
        magnify=False
    )
else:
    print("No rejected units to display")

### visualization of accepted units

In [None]:
%gui qt
params = {param:[param_first_temporal[param]] for param in ['noise_freq', 'sparse_penal', 'p', 'add_lag']}
param_key = tuple([param[0] for param in params.values()])
visualize_temporal_params(
    units=YrA.sel(unit_id=mask).unit_id.values,
    params=params,
    YA_dict={param_key:YrA.sel(unit_id=mask)},
    C_dict={param_key:C_new},
    S_dict={param_key:S_new},
    g_dict={param_key:g},
    sig_dict={param_key:g},
    A_dict={param_key:A.sel(unit_id=mask)},
    magnify=False
)

### save results

In [None]:
%%time
C = save_minian(
    C_new.rename("C").chunk({"unit_id": 1, "frame": -1}), intpath, overwrite=True
)
C_chk = save_minian(
    C.rename("C_chk"),
    intpath,
    overwrite=True,
    chunks={"unit_id": -1, "frame": chk["frame"]},
)
S = save_minian(
    S_new.rename("S").chunk({"unit_id": 1, "frame": -1}), intpath, overwrite=True
)
b0 = save_minian(
    b0_new.rename("b0").chunk({"unit_id": 1, "frame": -1}), intpath, overwrite=True
)
c0 = save_minian(
    c0_new.rename("c0").chunk({"unit_id": 1, "frame": -1}), intpath, overwrite=True
)
A = A.sel(unit_id=C.coords["unit_id"].values)

## merge units

In [None]:
%%time
A_mrg, C_mrg, [sig_mrg] = unit_merge(A, C, [C + b0 + c0], **param_first_merge)

In [None]:
%gui qt
visualize_temporal_components(
    C,
    C_mrg
)

In [None]:
%%time
A = save_minian(A_mrg.rename("A_mrg"), intpath, overwrite=True)
C = save_minian(C_mrg.rename("C_mrg"), intpath, overwrite=True)
C_chk = save_minian(
    C.rename("C_mrg_chk"),
    intpath,
    overwrite=True,
    chunks={"unit_id": -1, "frame": chk["frame"]},
)
sig = save_minian(sig_mrg.rename("sig_mrg"), intpath, overwrite=True)

## second spatial update

### parameter exploration

In [None]:
%%time
units = np.random.choice(A.coords["unit_id"], 10, replace=False)
units.sort()
A_sub = A.sel(unit_id=units).persist()
C_sub = C.sel(unit_id=units).persist()

sprs_ls = [0.005, 0.01, 0.05]
A_dict = dict()
C_dict = dict()
for cur_sprs in sprs_ls:
    cur_A, cur_mask, cur_norm = update_spatial(
        Y_hw_chk,
        A_sub,
        C_sub,
        sn_spatial,
        in_memory=True,
        dl_wnd=param_second_spatial["dl_wnd"],
        sparse_penal=cur_sprs,
    )
    if cur_A.sizes["unit_id"]:
        A_dict[cur_sprs] = cur_A.compute()
        C_dict[cur_sprs] = C_sub.sel(unit_id=cur_mask).compute()

In [None]:
%gui qt
visualize_spatial_params(
    units,
    A_dict,
    C_dict
)

### spatial update

In [None]:
%%time
A_new, mask, norm_fac = update_spatial(
    Y_hw_chk, A, C, sn_spatial, **param_second_spatial
)
C_new = save_minian(
    (C.sel(unit_id=mask) * norm_fac).rename("C_new"), intpath, overwrite=True
)
C_chk_new = save_minian(
    (C_chk.sel(unit_id=mask) * norm_fac).rename("C_chk_new"), intpath, overwrite=True
)

In [None]:
%%time
b_new, f_new = update_background(Y_fm_chk, A_new, C_chk_new)

### visualization of spatial footprints

In [None]:
%gui qt
visualize_spatial_update(
    A,
    A_new
)

### visualization of background

In [None]:
%gui qt
visualize_spatial_bg(
    b,
    f,
    b_new,
    f_new
)

### save results

In [None]:
%%time
A = save_minian(
    A_new.rename("A"),
    intpath,
    overwrite=True,
    chunks={"unit_id": 1, "height": -1, "width": -1},
)
b = save_minian(b_new.rename("b"), intpath, overwrite=True)
f = save_minian(
    f_new.chunk({"frame": chk["frame"]}).rename("f"), intpath, overwrite=True
)
C = save_minian(C_new.rename("C"), intpath, overwrite=True)
C_chk = save_minian(C_chk_new.rename("C_chk"), intpath, overwrite=True)

## second temporal update

### parameter exploration

In [None]:
units = np.random.choice(A.coords["unit_id"], 10, replace=False)
units.sort()
A_sub = A.sel(unit_id=units).persist()
C_sub = C_chk.sel(unit_id=units).persist()

params = dict(
    p_ls = [1],
    sprs_ls = [0.1, 0.5, 1, 2],
    add_ls = [20],
    noise_ls = [0.06]
    )

In [None]:
%%time
YA_dict, C_dict, S_dict, g_dict, sig_dict, A_dict = [dict() for _ in range(6)]
YrA = (
    compute_trace(Y_fm_chk, A_sub, b, C_sub, f)
    .persist()
    .chunk({"unit_id": 1, "frame": -1})
)
for cur_p, cur_sprs, cur_add, cur_noise in itt.product(
    params['p_ls'], params['sprs_ls'], params['add_ls'], params['noise_ls']
):
    ks = (cur_p, cur_sprs, cur_add, cur_noise)
    print(
        "p:{}, sparse penalty:{}, additional lag:{}, noise frequency:{}".format(
            cur_p, cur_sprs, cur_add, cur_noise
        )
    )
    cur_C, cur_S, cur_b0, cur_c0, cur_g, cur_mask = update_temporal(
        A_sub,
        C_sub,
        YrA=YrA,
        sparse_penal=cur_sprs,
        p=cur_p,
        use_smooth=True,
        add_lag=cur_add,
        noise_freq=cur_noise,
    )
    YA_dict[ks], C_dict[ks], S_dict[ks], g_dict[ks], sig_dict[ks], A_dict[ks] = (
        YrA.compute().astype(np.float32),
        cur_C.compute().astype(np.float32),
        cur_S.compute().astype(np.float32),
        cur_g.compute().astype(np.float32),
        (cur_C + cur_b0 + cur_c0).compute().astype(np.float32),
        A_sub.compute().astype(np.float32),
    )

In [None]:
%gui qt
visualize_temporal_params(
    units,
    params,
    YA_dict,
    C_dict,
    S_dict,
    g_dict,
    sig_dict,
    A_dict,
    magnify=True
)

### temporal update

In [None]:
%%time
YrA = save_minian(
    compute_trace(Y_fm_chk, A, b, C_chk, f).rename("YrA"),
    intpath,
    overwrite=True,
    chunks={"unit_id": 1, "frame": -1},
)

In [None]:
%%time
C_new, S_new, b0_new, c0_new, g, mask = update_temporal(
    A, C, YrA=YrA, **param_second_temporal
)

### visualization of temporal components

In [None]:
%gui qt
visualize_temporal_components(
    C=C,
    S=S,
    C_new=C_new,
    S_new=S_new,
    title='Second temporal update'
)

### visualization of dropped units

In [None]:
%gui qt
params = {param:[param_second_temporal[param]] for param in ['noise_freq', 'sparse_penal', 'p', 'add_lag']}
param_key = tuple([param[0] for param in params.values()])
bad_units = mask.where(mask == False, drop=True).coords["unit_id"].values
if len(bad_units) > 0:
    visualize_temporal_params(
        units=YrA.sel(unit_id=mask).unit_id.values,
        params=params,
        YA_dict={param_key:YrA.sel(unit_id=mask)},
        C_dict={param_key:C_new},
        S_dict={param_key:S_new},
        g_dict={param_key:g},
        sig_dict={param_key:g},
        A_dict={param_key:A.sel(unit_id=mask)},
        magnify=False
    )
else:
    print("No rejected units to display")

### visualization of accepted units

In [None]:
%gui qt
params = {param:[param_second_temporal[param]] for param in ['noise_freq', 'sparse_penal', 'p', 'add_lag']}
param_key = tuple([param[0] for param in params.values()])
visualize_temporal_params(
    units=YrA.sel(unit_id=mask).unit_id.values,
    params=params,
    YA_dict={param_key:YrA.sel(unit_id=mask)},
    C_dict={param_key:C_new},
    S_dict={param_key:S_new},
    g_dict={param_key:g},
    sig_dict={param_key:g},
    A_dict={param_key:A.sel(unit_id=mask)},
    magnify=False
)

### save results

In [None]:
%%time
C = save_minian(
    C_new.rename("C").chunk({"unit_id": 1, "frame": -1}), intpath, overwrite=True
)
C_chk = save_minian(
    C.rename("C_chk"),
    intpath,
    overwrite=True,
    chunks={"unit_id": -1, "frame": chk["frame"]},
)
S = save_minian(
    S_new.rename("S").chunk({"unit_id": 1, "frame": -1}), intpath, overwrite=True
)
b0 = save_minian(
    b0_new.rename("b0").chunk({"unit_id": 1, "frame": -1}), intpath, overwrite=True
)
c0 = save_minian(
    c0_new.rename("c0").chunk({"unit_id": 1, "frame": -1}), intpath, overwrite=True
)
A = A.sel(unit_id=C.coords["unit_id"].values)

In [None]:
### SAVE YrA BY DEFAULT??

### SAVE A JSON WITH ALL THE PARAMS BY DEFAULT??

## visualization

In [None]:
%gui qt
jackson_pollock_plot(
    A,
    max_proj,
    title='Final spatial footprints',
    method='matmul',
    threshold=0.2, # trims off small weightings, set to 0 for full footprints
    cm=colormap.get_colormap('Spectral_r'),
    alpha=0.5
    )

In [None]:
%%time
generate_videos(varr.sel(subset), Y_fm_chk, A=A, C=C_chk, vpath=dpath)

In [None]:
# %%time
# if interactive:
#     cnmfviewer = CNMFViewer(A=A, C=C, S=S, org=Y_fm_chk)

In [None]:
# hv.output(size=int(output_size * 0.35))
# if interactive:
#     display(cnmfviewer.show())

## save unit labels

In [None]:
# if interactive:
#     A = A.assign_coords(unit_labels=("unit_id", cnmfviewer.unit_labels))
#     C = C.assign_coords(unit_labels=("unit_id", cnmfviewer.unit_labels))
#     S = S.assign_coords(unit_labels=("unit_id", cnmfviewer.unit_labels))
#     c0 = c0.assign_coords(unit_labels=("unit_id", cnmfviewer.unit_labels))
#     b0 = b0.assign_coords(unit_labels=("unit_id", cnmfviewer.unit_labels))

## save final results

In [None]:
%%time
A = save_minian(A.rename("A"), **param_save_minian)
C = save_minian(C.rename("C"), **param_save_minian)
S = save_minian(S.rename("S"), **param_save_minian)
c0 = save_minian(c0.rename("c0"), **param_save_minian)
b0 = save_minian(b0.rename("b0"), **param_save_minian)
b = save_minian(b.rename("b"), **param_save_minian)
f = save_minian(f.rename("f"), **param_save_minian)

## close cluster

In [None]:
client.close()
cluster.close()

In [None]:
# delete intpath when complete
import shutil
shutil.rmtree(intpath)