# Minian reduced

## Load packages

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#%%capture  ## "%%capture var" redirects stdout and stderr to the variable var that can be used latter on. If no var is provided, it will just suppress the output
import itertools as itt
import os
import sys

import holoviews as hv
import numpy as np
import xarray as xr
from dask.distributed import Client, LocalCluster
from holoviews.operation.datashader import datashade, regrid
from holoviews.util import Dynamic
from IPython.display import display
from ipyfilechooser import FileChooser


In [None]:
cd "C:/Users/Manip2/SCRIPTS/Code python audrey/code python aurelie/interfaceJupyter/minian"
interactive = True

In [None]:
#%%capture
minian_path = os.path.join(os.path.abspath('..'),'minian')
print("The folder used for minian procedures is : {}".format(minian_path))

sys.path.append(minian_path)
from minian.cnmf import (
    compute_AtC,
    compute_trace,
    get_noise_fft,
    smooth_sig,
    unit_merge,
    update_spatial,
    update_temporal,
    update_background,
)
from minian.initialization import (
    gmm_refine,
    initA,
    initC,
    intensity_refine,
    ks_refine,
    pnr_refine,
    seeds_init,
    seeds_merge,
)
from minian.motion_correction import apply_transform, estimate_motion
from minian.preprocessing import denoise, remove_background
from minian.utilities import (
    TaskAnnotation,
    get_optimal_chk,
    load_videos,
    open_minian,
    save_minian,
)
from minian.visualization import (
    CNMFViewer,
    VArrayViewer,
    generate_videos,
    visualize_gmm_fit,
    visualize_motion,
    visualize_preprocess,
    visualize_seeds,
    visualize_spatial_update,
    visualize_temporal_update,
    write_video,
)

## Configuration

### Select folder

In [None]:
dpath = "C:/Users/Manip2/DATA/MINISCOPE/"
try:
    %store -r dpath
except:
    print("data not in strore")
    #dpath = "/Users/mb/Documents/Syntuitio/AudreyHay/PlanB/ExampleRedLines/2022_08_06/13_30_01/My_V4_Miniscope/"
    dpath = "C:/Users/Manip2/DATA/MINISCOPE/ThreeColoredDots/Baseline_recording/2022_05_20/09_33_30/My_V4_Miniscope/"

# Set up Initial Basic Parameters#
minian_path = "."

fc1 = FileChooser(dpath,select_default=True, show_only_dirs = True, title = "<b>Folder with videos</b>")
display(fc1)

# Sample callback function
def update_my_folder(chooser):
    global dpath
    dpath = chooser.selected
    %store dpath
    return 

# Register callback function
fc1.register_callback(update_my_folder)



In [None]:
minian_ds_path = os.path.join(dpath, "minianAB")
intpath = os.path.join(dpath, "minian_intermediateAB")
minian_ds_path

### Initial parameters

In [None]:
subset = dict(frame=slice(0, None))
subset_mc = None
output_size = 100
n_workers = int(os.getenv("MINIAN_NWORKERS", 4))
param_save_minian = {
    "dpath": minian_ds_path,
    "meta_dict": dict(session=-1, animal=-2),
    "overwrite": True,
}

# Pre-processing Parameters#
param_load_videos = {
    "pattern": "[0-9]+\.avi$",
    "dtype": np.uint8,
    "downsample": dict(frame=1, height=1, width=1),
    "downsample_strategy": "subset",
}
param_denoise ={"method": "median", "ksize": 7} #{"method": "median", "ksize": 5} #Default minian = {"method": "median", "ksize": 7}
param_background_removal = {"method": "tophat", "wnd": 15}

# Motion Correction Parameters#
subset_mc = None
param_estimate_motion = {"dim": "frame"}

# Initialization Parameters#
param_seeds_init = {
    "wnd_size": 1000, # 100, #Default minian = 1000
    "method": "rolling",
    "stp_size": 500, #50, #Default minian =500
    "max_wnd": 15, #20,#generally 10 updated here to 20 to account for L1 wide dendritic trees #Default minian =15
    "diff_thres": 3,
}
param_pnr_refine = {"noise_freq": 0.06, "thres": 1}
param_ks_refine = {"sig": 0.05}
param_seeds_merge = {"thres_dist": 10, "thres_corr": 0.8, "noise_freq": 0.06}
param_initialize = {"thres_corr": 0.8, "wnd": 10, "noise_freq": 0.06} 
param_init_merge = {"thres_corr": 0.8}

# CNMF Parameters# 0.025 for threecolordots
param_get_noise = {"noise_range": (0.06, 0.5)}
param_first_spatial = {
    "dl_wnd": 10, #15, #Default minian = 10
    "sparse_penal": 0.01, #0.012, #Default minian =0.01
    "size_thres": (25, None),
}
param_first_temporal = {
    "noise_freq": 0.06,
    "sparse_penal": 1,
    "p": 1,
    "add_lag": 20,
    "jac_thres": 0.2,
}
param_first_merge = {"thres_corr": 0.8}
param_second_spatial = {
    "dl_wnd": 10,
    "sparse_penal": 0.01, #0.005, #Default minian =0.01
    "size_thres": (25, None),
}
param_second_temporal = {
    "noise_freq": 0.06,
    "sparse_penal": 1,
    "p": 1,
    "add_lag": 20,
    "jac_thres": 0.4,
}

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MINIAN_INTERMEDIATE"] = intpath

## start cluster

In [None]:
hv.notebook_extension("style=dict(cmap=Viridis256),")
hv.notebook_extension("bokeh", width=70)

cluster = LocalCluster(
        n_workers=n_workers,
        memory_limit="8GB", #4
        resources={"MEM": 1},
        threads_per_worker=2,
        dashboard_address=":8780" #port 8787 already used by jupyter
    )

In [None]:
annt_plugin = TaskAnnotation()
cluster.scheduler.add_plugin(annt_plugin)
client = Client(cluster)
print(cluster)
print(client)

## Pre-processing

In [None]:
%%time
#%%capture

varr = load_videos(dpath, **param_load_videos)
chk, _ = get_optimal_chk(varr, dtype=np.float64)

varr = save_minian(
    varr.chunk({"frame": chk["frame"], "height": -1, "width": -1}).rename("varr"),
    intpath,
    overwrite=True,
)

In [None]:
#possibility to crop data
varr_ref = varr.sel(height=slice(0, 600), width=slice(0, 600))

In [None]:
if interactive:
    hv.output(size=output_size)
    vaviewer = VArrayViewer(varr, framerate=5, summary=["mean", "max"])
    display(vaviewer.show())

## Clean up

### Glow removal

In [None]:
%%time
varr_min = varr_ref.min("frame").compute()
varr_ref = varr_ref - varr_min

In [None]:
if interactive:
    hv.output(size=int(output_size * 0.7))
    vaviewer = VArrayViewer(
        [varr.rename("original"), varr_ref.rename("glow_removed")],
        framerate=5,
        summary=None,
        layout=True,
    )
    display(vaviewer.show())

### Denoise
Make sure to update the [denoise parameters](#Initial-parameters) based on what you see before proceeding

In [None]:
if interactive:
    hv.output(size=int(output_size * 0.6))
    display(
        visualize_preprocess(
            varr_ref.isel(frame=0).compute(),
            denoise,
            method=["median"],
            ksize=[5, 7, 9],
        )
    )

In [None]:
%%time
param_denoise

varr_ref = denoise(varr_ref, **param_denoise)


### Background removal
Make sure to update the [background removal parameters](#Initial-parameters) based on what you see before proceeding

In [None]:
if interactive:
    hv.output(size=int(output_size * 0.6))
    display(
        visualize_preprocess(
            varr_ref.isel(frame=0).compute(),
            remove_background,
            method=["tophat"],
            wnd=[10, 15, 20],
        )
    )

In [None]:
param_background_removal

In [None]:
%%time
varr_ref = remove_background(varr_ref, **param_background_removal)

### Save results

In [None]:
%%time
varr_ref = save_minian(varr_ref.rename("varr_ref"), dpath=intpath, overwrite=True)

## Motion correction

### Estimation motion

In [None]:
%%time
motion = estimate_motion(varr_ref.sel(subset_mc), **param_estimate_motion)

### Save motion

In [None]:
%%time
motion = save_minian(
    motion.rename("motion").chunk({"frame": chk["frame"]}), **param_save_minian
)

### Visualization of motion

In [None]:
if interactive:
    hv.output(size=output_size)
    visualize_motion(motion)

### Apply transform

In [None]:
Y = apply_transform(varr_ref, motion, fill=0)

### Save result

In [None]:
%%time
Y_fm_chk = save_minian(Y.astype(float).rename("Y_fm_chk"), intpath, overwrite=True)
Y_hw_chk = save_minian(
    Y_fm_chk.rename("Y_hw_chk"),  
    intpath,
    overwrite=True,
    chunks={"frame": -1, "height": chk["height"], "width": chk["width"]},
)

### Visualization of motion-correction

In [None]:
if interactive:
    hv.output(size=int(output_size * 0.7))
    vaviewer = VArrayViewer(
        [varr_ref.rename("before_mc"), Y_fm_chk.rename("after_mc")],
        framerate=5,
        summary=None,
        layout=True,
    )
    display(vaviewer.show())

### Generate video for motion-correction

In [None]:
%%time
vid_arr = xr.concat([varr_ref, Y_fm_chk], "width").chunk({"width": -1})
if interactive:
    write_video(Y_fm_chk, "minian_mc.mp4", dpath)

## Initialisation 

In [None]:
if interactive:
    print("plenty of cool visualization that i did not add yet, will see if it can be usefull, cf website")

### Compute maximal projection

In [None]:
max_proj = save_minian(
    Y_fm_chk.max("frame").rename("max_proj"), **param_save_minian
).compute()

### Generating seeds

In [None]:
param_seeds_init

In [None]:
%%time
seeds = seeds_init(Y_fm_chk, **param_seeds_init)

In [None]:
if interactive:
    seeds

In [None]:
if interactive:
    hv.output(size=output_size)
    visualize_seeds(max_proj, seeds)

In [None]:
%%time
if interactive:
    noise_freq_list = [0.005, 0.01, 0.02, 0.06, 0.1, 0.2, 0.3, 0.45, 0.6, 0.8]
    example_seeds = seeds.sample(6, axis="rows")
    example_trace = Y_hw_chk.sel(
        height=example_seeds["height"].to_xarray(),
        width=example_seeds["width"].to_xarray(),
    ).rename(**{"index": "seed"})
    smooth_dict = dict()
    for freq in noise_freq_list:
        trace_smth_low = smooth_sig(example_trace, freq)
        trace_smth_high = smooth_sig(example_trace, freq, btype="high")
        trace_smth_low = trace_smth_low.compute()
        trace_smth_high = trace_smth_high.compute()
        hv_trace = hv.HoloMap(
            {
                "signal": (
                    hv.Dataset(trace_smth_low)
                    .to(hv.Curve, kdims=["frame"])
                    .opts(frame_width=300, aspect=2, ylabel="Signal (A.U.)")
                ),
                "noise": (
                    hv.Dataset(trace_smth_high)
                    .to(hv.Curve, kdims=["frame"])
                    .opts(frame_width=300, aspect=2, ylabel="Signal (A.U.)")
                ),
            },
            kdims="trace",
        ).collate()
        smooth_dict[freq] = hv_trace

In [None]:
if interactive:
    hv.output(size=int(output_size * 0.9))
    hv_res = (
        hv.HoloMap(smooth_dict, kdims=["noise_freq"])
        .collate()
        .opts(aspect=2)
        .overlay("trace")
        .layout("seed")
        .cols(3)
    )
    display(hv_res)

In [None]:
#param_pnr_refine = {"noise_freq": 0.02, "thres": 1}
param_pnr_refine


### Noise refined 
here there is possibility to visualise to refine param_pnr_refine but noise_freq = 0.06 is generally fine

In [None]:
%%time
seeds, pnr, gmm = pnr_refine(Y_hw_chk, seeds, **param_pnr_refine)

In [None]:
if interactive:
    seeds

In [None]:
if interactive:
    if gmm:
        display(visualize_gmm_fit(pnr, gmm, 100))
    else:
        print("nothing to show")



In [None]:
if interactive:
    hv.output(size=output_size)
    visualize_seeds(max_proj, seeds, "mask_pnr")

### Refine using KS test to look at bimodal distribution

In [None]:
param_ks_refine

In [None]:
%%time
seeds = ks_refine(Y_hw_chk, seeds, **param_ks_refine)

In [None]:
if interactive:
    hv.output(size=output_size)
    visualize_seeds(max_proj, seeds, "mask_ks")

### Merge seeds

In [None]:
param_seeds_merge

In [None]:
%%time
seeds_final = seeds[seeds["mask_ks"] & seeds["mask_pnr"]].reset_index(drop=True)
seeds_final = seeds_merge(Y_hw_chk, max_proj, seeds_final, **param_seeds_merge)
print("{} units found".format(seeds_final["mask_mrg"].count()))

In [None]:
hv.output(size=output_size)
visualize_seeds(max_proj, seeds_final, "mask_mrg")

### Initialise spatial matrix

In [None]:
param_initialize

In [None]:
%%time
A_init = initA(Y_hw_chk, seeds_final[seeds_final["mask_mrg"]], **param_initialize)
A_init = save_minian(A_init.rename("A_init"), intpath, overwrite=True)

### Initialise temporal matrix

In [None]:
%%time
C_init = initC(Y_fm_chk, A_init)
C_init = save_minian(
    C_init.rename("C_init"), intpath, overwrite=True, chunks={"unit_id": 1, "frame": -1}
)

### Merge unit

In [None]:
param_init_merge

In [None]:
%%time
A_merged, C_merged = unit_merge(A_init, C_init, **param_init_merge)
A_merged = save_minian(A_merged.rename("A"), intpath, overwrite=True)
C_merged = save_minian(C_merged.rename("C"), intpath, overwrite=True)
C_chk_merged = save_minian(
    C_merged.rename("C_chk"),
    intpath,
    overwrite=True,
    chunks={"unit_id": -1, "frame": chk["frame"]},
)

### Initialise background terms

In [None]:
%%time
b_init, f_init = update_background(Y_fm_chk, A_merged, C_chk_merged)
f_init = save_minian(f_init.rename("f"), intpath, overwrite=True)
b_init = save_minian(b_init.rename("b"), intpath, overwrite=True)

## CNMF

### Estimate spatial noise

In [None]:
param_get_noise

In [None]:
%%time
sn_spatial = get_noise_fft(Y_hw_chk, **param_get_noise)
sn_spatial = save_minian(sn_spatial.rename("sn_spatial"), intpath, overwrite=True)

### First spatial update

#### Randomly select a subset of units for exploration

In [None]:
if interactive:
    unitsSub = np.random.choice(A_merged.coords["unit_id"], 10, replace=False)
    unitsSub.sort()
    A_sub = A_merged.sel(unit_id=unitsSub).persist()
    C_sub = C_merged.sel(unit_id=unitsSub).persist()

#### Parameter exploration
Here is the only interactive mandatory step to adjust sparse penalty [generally set between 0.01 and 0.02]
> **WARNING** 
> **Be very careful here!**: The parameter displayed isn't necessarily the one used

In [None]:
%%time
if interactive:
    sprs_ls = [ 0.01, 0.05, 0.1, 0.5]
    A_dict = dict()
    C_dict = dict()
    for cur_sprs in sprs_ls:
        cur_A, cur_mask, cur_norm = update_spatial(
            Y_hw_chk,
            A_sub,
            C_sub,
            sn_spatial,
            in_memory=True,
            dl_wnd=param_first_spatial["dl_wnd"],
            sparse_penal=cur_sprs,
        )
        if cur_A.sizes["unit_id"]:
            A_dict[cur_sprs] = cur_A.compute()
            C_dict[cur_sprs] = C_sub.sel(unit_id=cur_mask).compute()
    hv_res = visualize_spatial_update(A_dict, C_dict, kdims=["sparse penalty"])
    display(hv_res)

#### Spatial updates

In [None]:
param_first_spatial


In [None]:
%%time
A_firstS, mask_firstS, norm_fac_firstS = update_spatial(
    Y_hw_chk, A_merged, C_merged, sn_spatial, **param_first_spatial
)
C_firstS = save_minian(
    (C_merged.sel(unit_id=mask_firstS) * norm_fac_firstS).rename("C_new"), intpath, overwrite=True
)
C_chk_firstS = save_minian(
    (C_chk_merged.sel(unit_id=mask_firstS) * norm_fac_firstS).rename("C_chk_new"), intpath, overwrite=True
)

#### Background updates

In [None]:
%%time
b_firstS, f_firstS = update_background(Y_fm_chk, A_firstS, C_chk_firstS)

#### visualization of spatial footprints

In [None]:
if interactive:
    hv.output(size=int(output_size * 0.6))
    opts = dict(height=A_merged.sizes["height"], width=A_merged.sizes["width"], colorbar=True, cmap="Viridis"
        )
    (
        regrid(
            hv.Image(
                A_merged.max("unit_id").compute().astype(np.float32).rename("A"),
                kdims=["width", "height"],
            ).opts(**opts)
        ).relabel("Spatial Footprints Initial")
        + regrid(
            hv.Image(
                (A_merged.fillna(0) > 0).sum("unit_id").compute().astype(np.uint8).rename("A"),
                kdims=["width", "height"],
            ).opts(**opts)
        ).relabel("Binary Spatial Footprints Initial")
        + regrid(
            hv.Image(
                A_firstS.max("unit_id").compute().astype(np.float32).rename("A"),
                kdims=["width", "height"],
            ).opts(**opts)
        ).relabel("Spatial Footprints First Update")
        + regrid(
            hv.Image(
                (A_firstS > 0).sum("unit_id").compute().astype(np.uint8).rename("A"),
                kdims=["width", "height"],
            ).opts(**opts)
        ).relabel("Binary Spatial Footprints First Update")
    ).cols(2)

#### visualization of background

In [None]:
if interactive:
    hv.output(size=int(output_size * 0.55))
    opts_im = dict(height=b_init.sizes["height"], width=b_init.sizes["width"], colorbar=True,cmap="Viridis"
    )
    opts_cr = dict(height=b_init.sizes["height"], width=b_init.sizes["height"] * 2)
    (
        regrid(
            hv.Image(b_init.compute().astype(np.float32), kdims=["width", "height"]).opts(
                **opts_im
            )
        ).relabel("Background Spatial Initial")
        + hv.Curve(f_init.compute().rename("f").astype(np.float16), kdims=["frame"])
        .opts(**opts_cr)
        .relabel("Background Temporal Initial")
        + regrid(
            hv.Image(b_firstS.compute().astype(np.float32), kdims=["width", "height"]).opts(
                **opts_im
            )
        ).relabel("Background Spatial First Update")
        + hv.Curve(f_firstS.compute().rename("f").astype(np.float16), kdims=["frame"])
        .opts(**opts_cr)
        .relabel("Background Temporal First Update")
    ).cols(2)

#### Save results first spatial update

In [None]:
%%time
A_firstS = save_minian(
    A_firstS.rename("A"),
    intpath,
    overwrite=True,
    chunks={"unit_id": 1, "height": -1, "width": -1},
)
b_firstS = save_minian(b_firstS.rename("b"), intpath, overwrite=True)
f_firstS = save_minian(
    f_firstS.chunk({"frame": chk["frame"]}).rename("f"), intpath, overwrite=True
)
C_firstS = save_minian(C_firstS.rename("C"), intpath, overwrite=True)
C_chk_firstS = save_minian(C_chk_firstS.rename("C_chk"), intpath, overwrite=True)

### First temporal update

#### Randomly select a subset of units for exploration

In [None]:
if interactive:
    unitsSub = np.random.choice(A_firstS.coords["unit_id"], 10, replace=False)
    unitsSub.sort()
    A_sub = A_firstS.sel(unit_id=unitsSub).persist()
    C_sub = C_firstS.sel(unit_id=unitsSub).persist()

#### Parameter exploration
Here is the only interactive mandatory step to adjust sparse penalty [generally set between 0.01 and 0.02]
> **WARNING** 
> **Be very careful here!**: The parameter displayed isn't necessarily the one used

In [None]:
%%time
%env SPARSE_AUTO_DENSIFY=1

sprs_ls = [ 0.01, 0.05, 0.1, 0.5]
A_dict, C_dict = [dict() for _ in range(2)]

if interactive:
    p_ls = [1]
    add_ls = [20]
    noise_ls = [0.06]
    YA_dict, S_dict, g_dict, sig_dict = [dict() for _ in range(4)]
    YrA = (
        compute_trace(Y_fm_chk, A_sub, b_firstS, C_sub, f_firstS)
        .persist()
        .chunk({"unit_id": 1, "frame": -1})
    )
    for cur_p, cur_sprs, cur_add, cur_noise in itt.product(
        p_ls, sprs_ls, add_ls, noise_ls
    ):
        ks = (cur_p, cur_sprs, cur_add, cur_noise)
        print(
            "p:{}, sparse penalty:{}, additional lag:{}, noise frequency:{}".format(
                cur_p, cur_sprs, cur_add, cur_noise
            )
        )
        cur_C, cur_S, cur_b0, cur_c0, cur_g, cur_mask = update_temporal(
            A_sub,
            C_sub,
            YrA=YrA,
            sparse_penal=cur_sprs,
            p=cur_p,
            use_smooth=True,
            add_lag=cur_add,
            noise_freq=cur_noise,
        )
        YA_dict[ks], C_dict[ks], S_dict[ks], g_dict[ks], sig_dict[ks], A_dict[ks] = (
            YrA.compute(),
            cur_C.compute(),
            cur_S.compute(),
            cur_g.compute(),
            (cur_C + cur_b0 + cur_c0).compute(),
            A_sub.compute(),
        )
    hv_res = visualize_temporal_update(
        YA_dict,
        C_dict,
        S_dict,
        g_dict,
        sig_dict,
        A_dict,
        kdims=["p", "sparse penalty", "additional lag", "noise frequency"],
    )
else:
    
    for cur_sprs in sprs_ls:
        cur_A, cur_mask, cur_norm = update_spatial(
            Y_hw_chk,
            A_sub,
            C_sub,
            sn_spatial,
            in_memory=True,
            dl_wnd=param_first_spatial["dl_wnd"],
            sparse_penal=cur_sprs,
        )
        if cur_A.sizes["unit_id"]:
            A_dict[cur_sprs] = cur_A.compute()
            C_dict[cur_sprs] = C_sub.sel(unit_id=cur_mask).compute()
    hv_res = visualize_spatial_update(A_dict, C_dict, kdims=["sparse penalty"])

display(hv_res)

In [None]:
if interactive:
    hv.output(size=int(output_size * 0.55))
    opts_im = dict(height=b_init.sizes["height"], width=b_init.sizes["width"], colorbar=True,cmap="Viridis"
    )
    opts_cr = dict(height=b_init.sizes["height"], width=b_init.sizes["height"] * 2)
    (
        regrid(
            hv.Image(b_init.compute().astype(np.float32), kdims=["width", "height"]).opts(
                **opts_im
            )
        ).relabel("Background Spatial Initial")
        + hv.Curve(f_init.compute().rename("f").astype(np.float16), kdims=["frame"])
        .opts(**opts_cr)
        .relabel("Background Temporal Initial")
        + regrid(
            hv.Image(b_firstS.compute().astype(np.float32), kdims=["width", "height"]).opts(
                **opts_im
            )
        ).relabel("Background Spatial First Update")
        + hv.Curve(f_firstS.compute().rename("f").astype(np.float16), kdims=["frame"])
        .opts(**opts_cr)
        .relabel("Background Temporal First Update")
    ).cols(2)

#### Temporal update

In [None]:
%%time
YrA_firstT = save_minian(
    compute_trace(Y_fm_chk, A_firstS, b_firstS, C_chk_firstS, f_firstS).rename("YrA"),
    intpath,
    overwrite=True,
    chunks={"unit_id": 1, "frame": -1},
)

In [None]:
param_first_temporal

In [None]:
%%time
C_firstT, S_firstT, b0_firstT, c0_firstT, g_firstT, mask_firstT = update_temporal(
    A_firstS, C_firstS, YrA=YrA_firstT, **param_first_temporal
)

#### Visualization of temporal components

In [None]:
if interactive:
    hv.output(size=int(output_size * 0.6))
    opts_im = dict(frame_width=500, aspect=2, colorbar=True, cmap="Viridis")
    (
        regrid(
            hv.Image(
                C_firstS.compute().astype(np.float32).rename("ci"), kdims=["frame", "unit_id"]
            ).opts(**opts_im)
        ).relabel("Temporal Trace Initial")
        + hv.Div("")
        + regrid(
            hv.Image(
                C_firstT.compute().astype(np.float32).rename("c1"), kdims=["frame", "unit_id"]
            ).opts(**opts_im)
        ).relabel("Temporal Trace First Update")
        + regrid(
            hv.Image(
                S_firstT.compute().astype(np.float32).rename("s1"), kdims=["frame", "unit_id"]
            ).opts(**opts_im)
        ).relabel("Spikes First Update")
    ).cols(2)

#### Visualization of dropped units

In [None]:
if interactive:
    hv.output(size=int(output_size * 0.6))
    h, w = A_firstS.sizes["height"], A_firstS.sizes["width"]
    im_opts = dict(aspect=w / h, frame_width=500, cmap="Viridis")
    cr_opts = dict(aspect=3, frame_width=1000)
    bad_units = mask_firstT.where(mask_firstT == False, drop=True).coords["unit_id"].values
    print(str(len(bad_units)), "dropped units")
    if len(bad_units) > 0:
        hv_res = (
            hv.NdLayout(
                {
                    "Spatial Footprint": Dynamic(
                        hv.Dataset(A_firstS.sel(unit_id=bad_units).compute().rename("A"))
                        .to(hv.Image, kdims=["width", "height"])
                        .opts(**im_opts)
                    ),
                    "Spatial Footprints of dropped Units": Dynamic(
                        hv.Image(
                            A_firstS.sel(unit_id=bad_units).sum("unit_id").compute().rename("A"),
                            kdims=["width", "height"],
                        ).opts(**im_opts)
                    ),
                }
            )
            + datashade(
                hv.Dataset(YrA_firstT.sel(unit_id=bad_units).rename("raw")).to(
                    hv.Curve, kdims=["frame"]
                )
            )
            .opts(**cr_opts)
            .relabel("Temporal Trace")
        ).cols(1)
        display(hv_res)
    else:
        print("No rejected units to display")

#### Visualization of accepted units

In [None]:
if interactive:
    hv.output(size=int(output_size * 0.6))
    print(str(len(A_firstS)), "accepted units")
    sig = C_firstT + b0_firstT + c0_firstT
    hv_res = visualize_temporal_update(
            YrA_firstT.sel(unit_id=mask_firstT),
            C_firstT,
            S_firstT,
            g_firstT,
            sig,
            A_firstS.sel(unit_id=mask_firstT),
        )
    display(hv_res)
    


#### Save results

In [None]:
%%time
C_firstT = save_minian(
    C_firstT.rename("C").chunk({"unit_id": 1, "frame": -1}), intpath, overwrite=True
)
C_chk_firstT = save_minian(
    C_firstT.rename("C_chk"),
    intpath,
    overwrite=True,
    chunks={"unit_id": -1, "frame": chk["frame"]},
)
S_firstT = save_minian(
    S_firstT.rename("S").chunk({"unit_id": 1, "frame": -1}), intpath, overwrite=True
)
b0_firstT = save_minian(
    b0_firstT.rename("b0").chunk({"unit_id": 1, "frame": -1}), intpath, overwrite=True
)
c0_firstT = save_minian(
    c0_firstT.rename("c0").chunk({"unit_id": 1, "frame": -1}), intpath, overwrite=True
)
A_firstT = A_firstS.sel(unit_id=C_firstT.coords["unit_id"].values)

### Merge units

In [None]:
param_first_merge

In [None]:
%%time
A_mrg, C_mrg, [sig_mrg] = unit_merge(A_firstT, C_firstT, [C_firstT + b0_firstT + c0_firstT], **param_first_merge)

#### Save merged units

In [None]:
%%time
A_mrg = save_minian(A_mrg.rename("A_mrg"), intpath, overwrite=True)
C_mrg = save_minian(C_mrg.rename("C_mrg"), intpath, overwrite=True)
C_chk_mrg = save_minian(
    C_mrg.rename("C_mrg_chk"),
    intpath,
    overwrite=True,
    chunks={"unit_id": -1, "frame": chk["frame"]},
)
sig_mrg = save_minian(sig_mrg.rename("sig_mrg"), intpath, overwrite=True)

### Second spatial and temporal updates

#### Spatial update
Generally not much happens at that stage.

In [None]:
param_second_spatial#={'dl_wnd': 10, 'sparse_penal': 0.01, 'size_thres': (10, None)}

In [None]:
%%time
A_secS, mask_secS, norm_fac_secS = update_spatial(
    Y_hw_chk, A_mrg, C_mrg, sn_spatial, **param_second_spatial
)
C_secS = save_minian(
    (C_mrg.sel(unit_id=mask_secS) * norm_fac_secS).rename("C_new"), intpath, overwrite=True
)
C_chk_secS = save_minian(
    (C_chk_mrg.sel(unit_id=mask_secS) * norm_fac_secS).rename("C_chk_new"), intpath, overwrite=True
)

#### Update background

In [None]:
%%time
b_secS, f_secS = update_background(Y_fm_chk, A_secS, C_chk_secS)

#### Save spatial update

In [None]:
%%time
A_secS = save_minian(
    A_secS.rename("A"),
    intpath,
    overwrite=True,
    chunks={"unit_id": 1, "height": -1, "width": -1},
)
b_secS = save_minian(b_secS.rename("b"), intpath, overwrite=True)
f_secS = save_minian(
    f_secS.chunk({"frame": chk["frame"]}).rename("f"), intpath, overwrite=True
)
C_secS = save_minian(C_secS.rename("C"), intpath, overwrite=True)
C_chk_secS = save_minian(C_chk_secS.rename("C_chk"), intpath, overwrite=True)

#### Second temporal update

In [None]:
%%time
YrA_secT = save_minian(
    compute_trace(Y_fm_chk, A_secS, b_secS, C_chk_secS, f_secS).rename("YrA"),
    intpath,
    overwrite=True,
    chunks={"unit_id": 1, "frame": -1},
)

In [None]:
%%time
C_secT, S_secT, b0_secT, c0_secT, g_secT, mask_secT = update_temporal(
    A_secS, C_secS, YrA=YrA_secT, **param_second_temporal
)

### Save all

In [None]:
%%time
C_secT = save_minian(
    C_secT.rename("C").chunk({"unit_id": 1, "frame": -1}), intpath, overwrite=True
)
C_chk_secT = save_minian(
    C_secT.rename("C_chk"),
    intpath,
    overwrite=True,
    chunks={"unit_id": -1, "frame": chk["frame"]},
)
S_secT = save_minian(
    S_secT.rename("S").chunk({"unit_id": 1, "frame": -1}), intpath, overwrite=True
)
b0_secT = save_minian(
    b0_secT.rename("b0").chunk({"unit_id": 1, "frame": -1}), intpath, overwrite=True
)
c0_secT = save_minian(
    c0_secT.rename("c0").chunk({"unit_id": 1, "frame": -1}), intpath, overwrite=True
)
A_secT = A_secS.sel(unit_id=C_secT.coords["unit_id"].values)

## Generate videos and close 

!!!!! Only if cnmfviewer has been used

In [None]:
%%time
if interactive:
    # Generate video
    generate_videos(varr_ref.sel(subset), Y_fm_chk, A=A_secT, C=C_chk_secT, vpath=dpath)

    """
    # Generate labels
    A = A_secT.assign_coords(unit_labels=("unit_id", cnmfviewer.unit_labels))
    C = C_secT.assign_coords(unit_labels=("unit_id", cnmfviewer.unit_labels))
    S = S_secT.assign_coords(unit_labels=("unit_id", cnmfviewer.unit_labels))
    c0 = c0_secT.assign_coords(unit_labels=("unit_id", cnmfviewer.unit_labels))
    b0 = b0_secT.assign_coords(unit_labels=("unit_id", cnmfviewer.unit_labels))"""

In [None]:
%%time
# Save final
A = save_minian(A_secT.rename("A"), **param_save_minian)
C = save_minian(C_secT.rename("C"), **param_save_minian)
S = save_minian(S_secT.rename("S"), **param_save_minian)
c0 = save_minian(c0_secT.rename("c0"), **param_save_minian)
b0 = save_minian(b0_secT.rename("b0"), **param_save_minian)
b = save_minian(b_init.rename("b"), **param_save_minian)
f = save_minian(f_init.rename("f"), **param_save_minian)

# Close cluster
client.close()
cluster.close()

In [None]:
if interactive:
    minian_ds = open_minian(minian_ds_path)
    C = minian_ds['C']
    import matplotlib.pyplot as plt

    print (C)
    plt.plot(C[2,:])
    plt.plot(C[0,:])
    plt.plot(C[1,:])
    plt.plot(C[3,:])