# Setting Up

## load modules

In [None]:
%%capture
%load_ext autoreload
%autoreload 2
import itertools as itt
import os
import sys

import holoviews as hv
import numpy as np
import xarray as xr
from dask.distributed import Client, LocalCluster
from holoviews.operation.datashader import datashade, regrid
from holoviews.util import Dynamic
from IPython.core.display import display

## set path and parameters

In [None]:
# Set up Initial Basic Parameters#
minian_path = "."
dpath = "../data/m15/2022_08_04/12_34_43/miniscope_top/"
minian_ds_path = os.path.join(dpath, "minian")
intpath = "~/var/2s_validation/minian_intermediate"
intpath = os.path.normpath(os.path.expanduser(intpath))
subset = dict(frame=slice(0, None))
subset_mc = None
interactive = True
output_size = 100
n_workers = int(os.getenv("MINIAN_NWORKERS", 4))
param_save_minian = {
    "dpath": minian_ds_path,
    "overwrite": True,
}

# Pre-processing Parameters#
param_load_videos = {
    "pattern": "[0-9]+\.avi$",
    "dtype": np.uint8,
    "downsample": dict(frame=1, height=1, width=1),
    "downsample_strategy": "subset",
}
param_denoise = {"method": "median", "ksize": 5}
param_background_removal = {"method": "tophat", "wnd": 10}

# Motion Correction Parameters#
subset_mc = None
param_estimate_motion = {"dim": "frame", 'alt_error': None, 'upsample': 10}

# Initialization Parameters#
param_seeds_init = {
    "wnd_size": 10000,
    "method": "rolling",
    "max_wnd": 15,
    "diff_thres": 5,
    'stp_size': 5000
}
param_pnr_refine = {"noise_freq": 0.06, "thres": 1}
param_ks_refine = {"sig": 0.05}
param_seeds_merge = {"thres_dist": 10, "thres_corr": 0.8, "noise_freq": 0.06}
param_initialize = {"thres_corr": 0.8, "wnd": 10, "noise_freq": 0.06}
param_init_merge = {"thres_corr": 0.8}

# CNMF Parameters#
param_get_noise = {"noise_range": (0.06, 0.5)}
param_first_spatial = {
    "dl_wnd": 10,
    "sparse_penal": 0.01,
    "size_thres": (25, None),
}
param_first_temporal = {
    "noise_freq": 0.06,
    "sparse_penal": 1,
    "p": 1,
    "add_lag": 20,
    "jac_thres": 0.2,
}
param_first_merge = {"thres_corr": 0.8}
param_second_spatial = {
    "dl_wnd": 10,
    "sparse_penal": 0.01,
    "size_thres": (25, None),
}
param_second_temporal = {
    "noise_freq": 0.06,
    "sparse_penal": 1,
    "p": 1,
    "add_lag": 20,
    "jac_thres": 0.4,
}

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MINIAN_INTERMEDIATE"] = intpath

## import minian

In [None]:
%%capture
sys.path.append(minian_path)
from minian.cnmf import (
    compute_AtC,
    compute_trace,
    get_noise_fft,
    smooth_sig,
    unit_merge,
    update_spatial,
    update_temporal,
    update_background,
)
from minian.initialization import (
    gmm_refine,
    initA,
    initC,
    intensity_refine,
    ks_refine,
    pnr_refine,
    seeds_init,
    seeds_merge,
)
from minian.motion_correction import apply_transform, estimate_motion
from minian.preprocessing import denoise, remove_background
from minian.utilities import (
    TaskAnnotation,
    get_optimal_chk,
    load_videos,
    open_minian,
    save_minian,
)
from minian.visualization import (
    CNMFViewer,
    VArrayViewer,
    generate_videos,
    visualize_gmm_fit,
    visualize_motion,
    visualize_preprocess,
    visualize_seeds,
    visualize_spatial_update,
    visualize_temporal_update,
    write_video,
)

## module initialization

In [None]:
dpath = os.path.abspath(dpath)
hv.notebook_extension("bokeh", width=100)

## start cluster

In [None]:
cluster = LocalCluster(
    n_workers=n_workers,
    memory_limit="4GB",
    resources={"MEM": 1},
    threads_per_worker=2,
    dashboard_address=":23456",
)
annt_plugin = TaskAnnotation()
cluster.scheduler.add_plugin(annt_plugin)
client = Client(cluster)

# Pre-processing

## loading videos and visualization

In [None]:
varr = load_videos(dpath, **param_load_videos)
chk, _ = get_optimal_chk(varr, dtype=float)

In [None]:
%%time
varr = save_minian(
    varr.chunk({"frame": chk["frame"], "height": -1, "width": -1}).rename("varr"),
    intpath,
    overwrite=True,
)

## visualize raw data and optionally set roi for motion correction

In [None]:
hv.output(size=output_size)
if interactive:
    vaviewer = VArrayViewer(varr, framerate=5, summary=["mean", "max"])
    display(vaviewer.show())

In [None]:
if interactive:
    try:
        subset_mc = list(vaviewer.mask.values())[0]
    except IndexError:
        pass

## subset part of video

In [None]:
varr_ref = varr.sel(subset)

## glow removal and visualization

In [None]:
%%time
varr_min = varr_ref.min("frame").compute()
varr_ref = varr_ref - varr_min

In [None]:
hv.output(size=int(output_size * 0.7))
if interactive:
    vaviewer = VArrayViewer(
        [varr.rename("original"), varr_ref.rename("glow_removed")],
        framerate=5,
        summary=None,
        layout=True,
    )
    display(vaviewer.show())

## denoise

In [None]:
hv.output(size=int(output_size * 0.6))
if interactive:
    display(
        visualize_preprocess(
            varr_ref.isel(frame=0).compute(),
            denoise,
            method=["median"],
            ksize=[5, 7, 9],
        )
    )

The following cell would carry out denoise step.
Be sure to [change the parameters](https://minian.readthedocs.io/page/start_guide/faq.html#i-don-t-know-python-can-i-still-use-the-pipeline) based on visualization results before running the following cell.

In [None]:
varr_ref = denoise(varr_ref, **param_denoise)

## glow removal

In [None]:
varr_ref = remove_background(varr_ref.astype(float), method='uniform', wnd=50)

## background removal

In [None]:
hv.output(size=int(output_size * 0.6))
if interactive:
    display(
        visualize_preprocess(
            varr_ref.isel(frame=0).compute(),
            remove_background,
            method=["tophat"],
            wnd=[10, 15, 20],
        )
    )

The following cell would carry out background removal step.
Be sure to [change the parameters](https://minian.readthedocs.io/page/start_guide/faq.html#i-don-t-know-python-can-i-still-use-the-pipeline) based on visualization results before running the following cell.

In [None]:
varr_ref = remove_background(varr_ref, **param_background_removal)

## save result

In [None]:
%%time
varr_ref = save_minian(varr_ref.rename("varr_ref"), dpath=intpath, overwrite=True)

# Motion Correction

## estimate motion

In [None]:
%%time
motion = estimate_motion(varr_ref.sel(subset_mc), **param_estimate_motion)

## save motion

In [None]:
%%time
motion = save_minian(
    motion.rename("motion").chunk({"frame": chk["frame"]}), **param_save_minian
)

## visualization of motion

In [None]:
hv.output(size=output_size)
visualize_motion(motion)

## apply transform

In [None]:
Y = apply_transform(varr_ref, motion, fill=0)

## save result

In [None]:
%%time
Y_fm_chk = save_minian(Y.astype(float).rename("Y_fm_chk"), intpath, overwrite=True)
Y_hw_chk = save_minian(
    Y_fm_chk.rename("Y_hw_chk"),
    intpath,
    overwrite=True,
    chunks={"frame": -1, "height": chk["height"], "width": chk["width"]},
)

## visualization of motion-correction

In [None]:
im_opts = dict(
    frame_width=500,
    aspect=varr_ref.sizes["width"] / varr_ref.sizes["height"],
    cmap="Viridis",
    colorbar=True,
)
(
    regrid(
        hv.Image(
            varr_ref.max("frame").compute().astype(np.float32),
            ["width", "height"],
            label="before_mc",
        ).opts(**im_opts)
    )
    + regrid(
        hv.Image(
            Y_hw_chk.max("frame").compute().astype(np.float32),
            ["width", "height"],
            label="after_mc",
        ).opts(**im_opts)
    )
)

# Initialization

In [None]:
temp_ds = open_minian(intpath, return_dict=True)
Y_fm_chk = temp_ds['Y_fm_chk']

In [None]:
minian_ds = open_minian(param_save_minian['dpath'], return_dict=True)
motion = minian_ds['motion']

## compute max projection

In [None]:
max_proj = save_minian(
    Y_fm_chk.max("frame").rename("max_proj"), **param_save_minian
).compute()

## generating over-complete set of seeds

In [None]:
from minian.initialization import local_max_roll

def find_seed(
    max_proj: xr.DataArray,
    max_wnd=10,
    diff_thres=2,
):
    loc_max = xr.apply_ufunc(
        local_max_roll,
        max_proj,
        input_core_dims=[["height", "width"]],
        output_core_dims=[["height", "width"]],
        vectorize=True,
        dask="parallelized",
        output_dtypes=[np.uint8],
        kwargs=dict(k0=8, k1=max_wnd+1, diff=diff_thres),
    )
    seeds = (
        loc_max.where(loc_max > 0).rename("seeds").to_dataframe().dropna().reset_index()
    )
    return seeds[["height", "width", "seeds"]]

In [None]:
%%time
seeds = find_seed(max_proj, max_wnd=15, diff_thres=8)

In [None]:
hv.output(size=output_size)
visualize_seeds(max_proj, seeds)

In [None]:
exp_size = 9 * 9 * np.pi
img = cv2.cvtColor((norm(max_proj.values) * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR)
sd_idx = seeds.index + 2
ws_dict = dict()
ws_df = []
for bg_thres in range(10):
    marker = np.zeros_like(max_proj)
    marker[max_proj.values <= bg_thres] = 1
    marker[seeds['height'], seeds['width']] = sd_idx
    marker = marker.astype(np.int32)
    ws = cv2.watershed(img, marker.copy())
    ws_dict[bg_thres] = ws
    for isd in sd_idx:
        
ws = xr.DataArray(ws, dims=['height', 'width'], coords={'height': max_proj.coords['height'], 'width': max_proj.coords['width']})

In [None]:
from skimage.morphology import disk
exp_size = 8 * 8 * np.pi
max_bg_thres = 10
bg_erode = 5
img = cv2.cvtColor((norm(max_proj.values) * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR)
erd_ele = disk(bg_erode)
A_ls = []
for isd, row in seeds.iterrows():
    marker = np.zeros_like(max_proj, dtype=np.int32)
    marker[int(row['height']), int(row['width'])] = 2
    sizes = np.zeros(max_bg_thres)
    wss = np.zeros((max_bg_thres, max_proj.shape[0], max_proj.shape[1]))
    for ibg, bg_thres in enumerate(range(max_bg_thres)):
        mk = marker.copy()
        mk[cv2.erode((max_proj.values <= bg_thres).astype(np.uint8), erd_ele).astype(bool)] = 1
        ws = cv2.watershed(img, mk.copy())
        wss[ibg, :, :] = ws
        sizes[ibg] = np.sum(ws == 2)
    ws = wss[np.argmin(np.abs(sizes - exp_size)),:,:]
    A = np.where(ws == 2, max_proj, np.nan)
    A[np.isnan(A)] = np.nanmin(A)
    A = norm(A)
    A_ls.append(A)
A = xr.DataArray(np.stack(A_ls, axis=0), dims=['unit_id', 'height', 'width'], coords={'unit_id': seeds.index, 'height': max_proj.coords['height'], 'width': max_proj.coords['width']})

### intensity-based method

In [None]:
import itertools as itt
from skimage.measure import label as imlabel
from skimage.morphology import disk
exp_size = 6 * 6 * np.pi
min_size = 6 * 6
max_size = 15 * 15
max_erode = 5
dist_pow = 0.5
dist = np.ones_like(max_proj, dtype=np.uint8)
dist[seeds['height'], seeds['width']] = 0
dist = cv2.distanceTransform(dist, cv2.DIST_L2, cv2.DIST_MASK_PRECISE)
img = (norm(max_proj.values * ((dist + 1) ** -dist_pow)) * 255).astype(np.uint8)
# img = (norm(max_proj.values) * 255).astype(np.uint8)
# grd = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
# marker = np.zeros_like(max_proj, dtype=np.uint8)
# marker[seeds['height'].values, seeds['width'].values] = seeds.index + 2
# im_labs_erd = np.stack([imlabel(cv2.erode((img > thres).astype(np.uint8), disk(e))) for thres, e in itt.product(range(255), np.arange(max_erode))], axis=0)
# im_labs_org = np.stack([imlabel(img > thres) for thres in range(255)], axis=0)
# im_labs = np.concatenate([im_labs_org, im_labs_erd], axis=0)
im_labs = np.stack([imlabel(img > thres) for thres in range(255)], axis=0)
# im_ws = np.stack([cv2.watershed(grd, np.where(img > thres, marker, 1).astype(np.int32)) for thres in range(255)], axis=0)
# im_labs = np.concatenate([im_labs, im_ws], axis=0)
A_ls = []
idx_ls = []
for isd, row in seeds.iterrows():
    h, w = int(row['height']), int(row['width'])
    labs = im_labs[:, h, w]
    im_sd = im_labs == labs[:, np.newaxis, np.newaxis]
    im_sd = im_sd[labs > 0, :, :]
    sizes = im_sd.sum(axis=(1, 2))
    im_sd = im_sd[np.logical_and(sizes > min_size, sizes < max_size), :, :]
    if len(im_sd)>0:
        cvx = np.array([convexity_score(im) for im in im_sd])
        # size_diff = np.abs(sizes - exp_size) / exp_size
        # sidx = np.argmin(np.abs(sizes - exp_size))
        sidx = np.argmax(cvx)
        mask = im_sd[sidx, :, :]
        curA = np.where(mask, max_proj, np.nan)
        curA[np.isnan(curA)] = np.nanmin(curA)
        A_ls.append(norm(curA))
        idx_ls.append(isd)
A = xr.DataArray(np.stack(A_ls, axis=0), dims=['unit_id', 'height', 'width'], coords={'unit_id': idx_ls, 'height': max_proj.coords['height'], 'width': max_proj.coords['width']})

In [None]:
A_ls = []
for isd, row in seeds.iterrows():
    marker = np.zeros_like(max_proj, dtype=np.int32)
    marker[int(row['height']), int(row['width'])] = 2
    sizes = np.zeros(max_bg_thres)
    wss = np.zeros((max_bg_thres, max_proj.shape[0], max_proj.shape[1]))
    for ibg, bg_thres in enumerate(range(max_bg_thres)):
        mk = marker.copy()
        mk[cv2.erode((max_proj.values <= bg_thres).astype(np.uint8), erd_ele).astype(bool)] = 1
        ws = cv2.watershed(img, mk.copy())
        wss[ibg, :, :] = ws
        sizes[ibg] = np.sum(ws == 2)
    ws = wss[np.argmin(np.abs(sizes - exp_size)),:,:]
    A = np.where(ws == 2, max_proj, np.nan)
    A[np.isnan(A)] = np.nanmin(A)
    A = norm(A)
    A_ls.append(A)
A = xr.DataArray(np.stack(A_ls, axis=0), dims=['unit_id', 'height', 'width'], coords={'unit_id': seeds.index, 'height': max_proj.coords['height'], 'width': max_proj.coords['width']})

### gradient-based method

In [None]:
import cv2
from skimage.measure import label as imlabel
from skimage.morphology import disk

def im_floodfill(im):
    im_floodfill = im.astype(np.uint8)
    h, w = im.shape
    mask = np.zeros((h+2, w+2), np.uint8)
    cv2.floodFill(im_floodfill, mask, (0,0), 255);
    return im_floodfill != 255


min_size = 8 * 8
max_size = 25 * 25
grd_thres = np.arange(-150, 1, 1)
# pad = int(np.sqrt(min_size) / 2 - 1)
pad = 2

img = np.array(max_proj)
# dx = cv2.Scharr(img, ddepth=-1, dx=1, dy=0)
# dy = cv2.Scharr(img, ddepth=-1, dx=0, dy=1)
dx = cv2.Sobel(img, ddepth=-1, dx=1, dy=0)
dy = cv2.Sobel(img, ddepth=-1, dx=0, dy=1)
mag = np.sqrt(dx**2 + dy**2)
ang = np.arctan2(dy, dx)
# mag = cv2.medianBlur(mag.astype(np.float32), 5)
# ang = cv2.medianBlur(ang.astype(np.float32), 5)

In [None]:
A_ls = []
idx_ls = []
for isd, row in seeds.iterrows():
    sd_h, sd_w = int(row['height']), int(row['width'])
    gy = np.tile(np.arange(img.shape[0])[:, np.newaxis], (1, img.shape[1])) - sd_h
    gx = np.tile(np.arange(img.shape[1])[np.newaxis, :], (img.shape[0], 1)) - sd_w
    gang = np.arctan2(gy, gx)
    proj = (mag * np.cos(gang - ang))
    im_labs = np.stack([proj < thres for thres in grd_thres], axis=0)
    im_labs[:, max(sd_h - pad, 0):sd_h+pad, max(sd_w-pad, 0):sd_w+pad] = True
    im_labs = np.stack([imlabel(im) for im in im_labs], axis=0)
    labs = im_labs[:, sd_h, sd_w]
    im_sd = im_labs == labs[:, np.newaxis, np.newaxis]
    sizes = im_sd.sum(axis=(1, 2))
    im_sd = im_sd[np.logical_and(sizes > min_size, sizes < max_size), :, :]
    if len(im_sd)>0:
        cvx = np.array([convexity_score(im) for im in im_sd])
        sidx = np.argmax(cvx)
        mask = im_floodfill(im_sd[sidx, :, :])
        curA = np.where(mask, max_proj, np.nan)
        curA[np.isnan(curA)] = np.nanmin(curA)
        A_ls.append(norm(curA))
        idx_ls.append(isd)
A = xr.DataArray(np.stack(A_ls, axis=0), dims=['unit_id', 'height', 'width'], coords={'unit_id': idx_ls, 'height': max_proj.coords['height'], 'width': max_proj.coords['width']})

In [None]:
def norm(a):
    amin, amax = a.min(), a.max()
    return (a - amin) / (amax - amin)

In [None]:
def convexity_score(im):
    cnt = cv2.findContours(im.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[0][0]
    peri = cv2.arcLength(cnt, True)
    hull = cv2.convexHull(cnt)
    peri_hull = cv2.arcLength(hull, True)
    if peri > 0:
        return peri_hull / peri
    else:
        return 0

In [None]:
from scipy.optimize import minimize_scalar

A_ls = []
idx_ls = []
def cvx_opt_cb(thres, proj, min_size, pad, sd_h, sd_w):
    im_labs = proj < thres
    im_labs[max(sd_h - pad, 0):sd_h+pad, max(sd_w-pad, 0):sd_w+pad] = True
    im_labs = imlabel(im_labs)
    labs = im_labs[sd_h, sd_w]
    im_sd = im_labs == labs
    size = im_sd.sum()
    if size > min_size:
        return -convexity_score(im_sd)
    else:
        return 0


for isd, row in seeds.iterrows():
    sd_h, sd_w = int(row['height']), int(row['width'])
    gy = np.tile(np.arange(img.shape[0])[:, np.newaxis], (1, img.shape[1])) - sd_h
    gx = np.tile(np.arange(img.shape[1])[np.newaxis, :], (img.shape[0], 1)) - sd_w
    gang = np.arctan2(gy, gx)
    proj = (mag * np.cos(gang - ang))
    res = minimize_scalar(cvx_opt_cb, bounds=(-200, 5), args=(proj, min_size, pad, sd_h, sd_w), method='Bounded')
    if res.success:
        im_labs = proj < res.x
        im_labs[max(sd_h - pad, 0):sd_h+pad, max(sd_w-pad, 0):sd_w+pad] = True
        im_labs = imlabel(im_labs)
        labs = im_labs[sd_h, sd_w]
        mask = im_floodfill(im_labs == labs)
        curA = np.where(mask, max_proj, np.nan)
        curA[np.isnan(curA)] = np.nanmin(curA)
        A_ls.append(norm(curA))
        idx_ls.append(isd)
A = xr.DataArray(np.stack(A_ls, axis=0), dims=['unit_id', 'height', 'width'], coords={'unit_id': idx_ls, 'height': max_proj.coords['height'], 'width': max_proj.coords['width']})

In [None]:
vis = xr.DataArray(proj.clip(-100, 100), dims=['height', 'width'], coords={'height': max_proj.coords['height'], 'width': max_proj.coords['width']})
visualize_seeds(max_proj, seeds) + hv.Image(vis, ['width', 'height']).opts(cmap='RdBu', frame_width=608, frame_height=608, symmetric=True, colorbar=True)

In [None]:
jac_thres = 1
A_bl = np.array((A > 0).astype(float))
A_inter = np.tensordot(A_bl, np.moveaxis(A_bl, 0, -1))
A_sum = np.tile(A_bl.sum(axis=(1, 2)), (A_bl.shape[0], 1))
A_sum = A_sum + A_sum.T
jac = A_inter / (A_sum - A_inter)
np.fill_diagonal(jac, 0)
lab = label_connected(jac >= jac_thres)

In [None]:
from sklearn.metrics import pairwise_distances
from minian.cnmf import label_connected
cos_thres = 0.5
cos = 1 - pairwise_distances(np.array(A).reshape((A.shape[0], -1)), metric='cosine', n_jobs=-1)
np.fill_diagonal(cos, 0)
lab = label_connected(cos >= cos_thres)
A_merged = (
        A.assign_coords(unit_labels=("unit_id", lab))
        .groupby("unit_labels")
        .mean("unit_id")
        .rename(unit_labels="unit_id")
    )

In [None]:
u, c = np.unique(lab, return_counts=True)
dup = u[c > 1]
duplicated = np.isin(lab, dup)

In [None]:
(
    visualize_seeds(max_proj, seeds)
    + hv.Image(A.max('unit_id'), ['width', 'height']).opts(cmap='viridis', frame_width=608, frame_height=608)
    + hv.Image(A.sel(unit_id=duplicated).max('unit_id'), ['width', 'height']).opts(cmap='viridis', frame_width=608, frame_height=608)
)

In [None]:
im = hv.Image(max_proj, ['width', 'height']).opts(cmap='viridis', frame_width=608, frame_height=608)
for uid in A_merged.coords['unit_id'].values:
    curA = (np.array(A_merged.sel(unit_id=uid)) > 0).astype(np.uint8)
    cnt = cv2.findContours(curA, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[0][0].squeeze()
    if cnt.ndim > 1:
        im = im * hv.Path(cnt.squeeze())

In [None]:
hv.output(size=80)
visualize_seeds(max_proj, seeds) + im + hv.Image(A_merged.max('unit_id'), ['width', 'height']).opts(cmap='viridis', frame_width=608, frame_height=608)

In [None]:
out_ds = xr.merge([motion, max_proj, A_merged.rename('A')])
out_path = "../intermediate/processed/red/"
os.makedirs(out_path, exist_ok=True)
out_ds.to_netcdf(os.path.join(out_path, "m15-rec1.nc"))