# setup

In [1]:
%%capture
%load_ext autoreload
%autoreload 2
import sys
import os
import gc
import psutil
import numpy as np
import xarray as xr
import holoviews as hv
import matplotlib.pyplot as plt
import bokeh.plotting as bpl
import dask.array as da
import pandas as pd
import dask
import datashader as ds
import itertools as itt
import papermill as pm
import ast
import functools as fct
import SimpleITK as sitk
import cv2
import itertools as itt
import numba as nb
from scipy.stats import zscore
from holoviews.operation.datashader import datashade, regrid, dynspread
from datashader.colors import Sets1to3
from dask.diagnostics import ProgressBar, Profiler
from IPython.core.display import display, HTML
from dask.distributed import Client, progress, LocalCluster, fire_and_forget
minian_path = "."
sys.path.append(minian_path)
from minian.utilities import load_params, load_videos, scale_varr, scale_varr_da, save_variable, open_minian, save_minian, handle_crash
from minian.preprocessing import remove_brightspot, gradient_norm, denoise, remove_background, stripe_correction
from minian.motion_correction import estimate_shift_fft, apply_shifts, interpolate_frame, mask_shifts, mser, kaze, match_dsc, desc_vec, estimate_homo
from minian.initialization import seeds_init, gmm_refine, pnr_refine, intensity_refine, ks_refine, seeds_merge, initialize
from minian.cnmf import psd_welch, psd_fft, get_noise, update_spatial, update_temporal, unit_merge, smooth_sig
from minian.visualization import VArrayViewer, CNMFViewer, generate_videos, visualize_seeds, visualize_gmm_fit, visualize_spatial_update, visualize_temporal_update, roi_draw
from IPython.core.debugger import set_trace

In [2]:
hv.notebook_extension('bokeh', width=100)
pbar = ProgressBar(minimum=2)
pbar.register()

# load data

In [3]:
dpath = "/media/share/csstorage/Tristan/Epilepsy Revision Experiments/PFD/OrganizedandAnalyzed/PFD2/PFD2ByAnimal/TS45-4/S13"
fname = "minian"

In [4]:
minian = open_minian(
    dpath=dpath,
    fname=fname,
    backend='zarr', chunks=dict(frame='auto', height='auto', width='auto'))

In [5]:
mov = minian['Y']

In [None]:
%%time
mov = mov.compute()

In [None]:
mov_int = scale_varr(mov, (0, 255)).astype(np.uint8)

In [None]:
%%time
mov_int = mov_int.persist()

# demons

In [None]:
from minian.motion_correction import demon, apply_displacement, mser, hist_match, demon_reg, mser_vec

In [None]:
def plot_dis(dis, cvt_grd=True):
    if cvt_grd:
        gx = -dis[:, :, 0]
        gy = dis[:, :, 1]
        mag = np.sqrt(gx ** 2 + gy ** 2)
        ang = np.arctan2(gy, gx)
        dis = np.stack((ang, mag), axis=-1)
    xs, ys = np.meshgrid(np.arange(dis.shape[1]), np.arange(dis.shape[0]))
    return hv.VectorField((np.flip(xs, axis=0), np.flip(ys, axis=0), dis[:, :, 0], dis[:, :, 1])).opts(magnitude='Magnitude')


import itertools as itt
def get_field(trans, shape):
    xs, ys = (np.arange(shape[0]).astype(np.float),
              np.arange(shape[1]).astype(np.float))
    fd = np.zeros(shape + (2,))
    for x, y in itt.product(xs, ys):
        yp, xp = trans.TransformPoint((y, x))
        fd[int(x), int(y), 1] = xp - x
        fd[int(x), int(y), 0] = yp - y
    return fd

# testing on individual frame

In [None]:
subh = slice(150, 250)
subw = slice(300, 450)
bounds = (0, 0, 150, 100)

In [None]:
fm0 = mov.isel(frame=0, height=subh, width=subw)
fm1 = mov.isel(frame=30, height=subh, width=subw)
fm2 = mov.isel(frame=60, height=subh, width=subw)
fm0_int = mov_int.isel(frame=0, height=subh, width=subw)
fm1_int = mov_int.isel(frame=30, height=subh, width=subw)
fm2_int = mov_int.isel(frame=60, height=subh, width=subw)

In [None]:
regs_fm0 = mser(fm0_int.values)
regs_fm1 = mser(fm1_int.values)
regs_fm2 = mser(fm2_int.values)

In [None]:
trans01 = demon_reg(fm1.values, fm0.values, regs_fm0, regs_fm1)
trans12 = demon_reg(fm2.values, fm1.values, regs_fm2, regs_fm1)
trans02 = sitk.Transform(trans01)
trans02.AddTransform(trans12)

In [None]:
trans01_im = sitk.GetArrayFromImage(trans01.GetDisplacementField())
trans12_im = sitk.GetArrayFromImage(trans12.GetDisplacementField())
# trans01_im = get_field(trans01, fm0.shape)
# trans12_im = get_field(trans12, fm1.shape)
trans02_im = get_field(trans02, fm2.shape)

In [None]:
fm01_trans = apply_displacement(trans01, fm1)
fm12_trans = apply_displacement(trans12, fm2)
fm02_trans = apply_displacement(trans02, fm2)

In [None]:
from holoviews.operation import contours
(contours(hv.Image(fm0.values, bounds=bounds))
 + contours(hv.Image(fm1.values, bounds=bounds))
 + hv.Image(mser_mask(fm0_int.values), bounds=bounds)
 + plot_dis(trans01_im)
 + contours(hv.Image(fm01_trans, bounds=bounds))
 + contours(hv.Image(fm1.values, bounds=bounds))
 + contours(hv.Image(fm2.values, bounds=bounds))
 + hv.Image(mser_mask(fm1_int.values), bounds=bounds)
 + plot_dis(trans12_im)
 + contours(hv.Image(fm12_trans, bounds=bounds))
 + contours(hv.Image(fm0.values, bounds=bounds))
 + contours(hv.Image(fm2.values, bounds=bounds))
 + hv.Image(mser_mask(fm2_int.values), bounds=bounds)
 + plot_dis(trans02_im)
 + contours(hv.Image(fm02_trans, bounds=bounds))
).cols(5)

# testing movie

In [None]:
mov = mov.chunk(dict(height=-1, width=-1, frame='auto'))
mov_int = mov_int.chunk(dict(height=-1, width=-1, frame='auto'))

In [None]:
mser_res = [dask.delayed(mser_vec)(np.asscalar(p)) for p in mov_int.data.to_delayed()]

In [None]:
%%time
mser_res = dask.compute(mser_res)[0]

In [None]:
mser_res = sum(mser_res, [])

In [None]:
mser_xr = xr.DataArray(
    mser_res,
    dims=['frame'],
    coords=dict(frame=mov.coords['frame']))

In [None]:
trans = xr.apply_ufunc(
    demon_reg,
    mov,
    mov.shift(frame=1),
    mser_xr,
    mser_xr.shift(frame=1),
    input_core_dims=[['height', 'width'], ['height', 'width'], [], []],
    output_core_dims=[[]],
    vectorize=True,
    output_dtypes=[np.object],
    dask='parallelized'
)

In [None]:
%%time
trans = trans.compute()

In [None]:
%%time
dec_fac = 0.9
trans_last = None
mov_trans = []
for tr, fm in zip(trans.values, mov.values):
    if tr is np.nan:
        mov_trans.append(fm)
        continue
    if trans_last is None:
        trans_cur = sitk.Transform(tr)
    else:
        trans_cur = sitk.Transform(trans_last)
        trans_cur.AddTransform(tr)
        trans_cur.FlattenTransform()
    fm_trans = apply_displacement(trans_cur, fm)
    mov_trans.append(fm_trans)
    trans_last = trans_cur

In [None]:
%%time
from tqdm import tqdm_notebook
mov_trans = [mov.values[0],]
for fm1 in tqdm_notebook(mov.values[1:]):
    fm0 = mov_trans[-1]
    fm0_int = ((fm0 - fm0.min()) / (fm0.max() - fm0.min()) * 255).astype(np.uint8)
    fm1_int = ((fm1 - fm1.min()) / (fm1.max() - fm1.min()) * 255).astype(np.uint8)
    reg0 = mser(fm0_int)
    reg1 = mser(fm1_int)
    trans = demon_reg(fm1, fm0, reg1, reg0)
    fm_trans = apply_displacement(trans, fm1)
    mov_trans.append(fm_trans)

In [None]:
mov_trans = xr.DataArray(
    mov_trans,
    dims=['frame', 'height', 'width'],
    coords={
        'frame': mov.coords['frame'],
        'height': mov.coords['height'],
        'width': mov.coords['width']
    }
)

In [None]:
trans_cum = xr.DataArray(
    np.asarray(trans_ls), dims=['frame'], coords=dict(frame=trans.coords['frame']))

In [None]:
mov_trans = xr.apply_ufunc(
    apply_displacement,
    trans_cum,
    mov.chunk(dict(height=-1, width=-1, frame=40)),
    input_core_dims=[[], ['height', 'width']],
    output_core_dims=[['height', 'width']],
    vectorize=True,
    output_dtypes=[mov.dtype],
    dask='parallelized'
)

In [None]:
%%time
mov_trans = mov_trans.compute()

In [19]:
%%opts Image [width=752, height=480] (cmap='Viridis')
opts_im = dict(plot=dict(width=752, height=480), style=dict(cmap='viridis'))
# mov_max = mov.max('frame').compute()
# mov_trans_max = mov_trans.max('frame').compute()
# mov_mean = mov.mean('frame').compute()
# mov_trans_mean = mov_trans.mean('frame').compute()
sum_fm = (hv.Image(mov_max.rename('mov'), ['width', 'height'], label="Max Before").opts(**opts_im)
 + hv.Image(mov_trans_max.rename('mov_trans'), ['width', 'height'], label="Max After").opts(**opts_im)
 + hv.Image(mov_mean.rename('mov_mean'), ['width', 'height'], label="Mean Before").opts(**opts_im)
 + hv.Image(mov_trans_mean.rename('mov_trans_mean'), ['width', 'height'], label="Mean After").opts(**opts_im)).cols(2)

In [21]:
hv.save(sum_fm, os.path.join(dpath, "mc_sum.html"))

In [11]:
%%opts Image [width=752, height=480] (cmap='Viridis')
mov_mean = mov.mean('frame').compute()
mov_trans_mean = mov_trans.mean('frame').compute()
hv.Image(mov_mean, ['width', 'height']) + hv.Image(mov_trans_mean, ['width', 'height'])

[########################################] | 100% Completed |  2min 15.1s
[########################################] | 100% Completed |  3min 23.3s


In [None]:
mov_trans.to_netcdf(os.path.join(dpath, "mov_trans.nc"))

In [9]:
mov_trans = xr.open_dataarray(os.path.join(dpath, 'mov_trans.nc'))
mov_trans = mov_trans.chunk(dict(height='auto', width='auto', frame='auto'))

In [12]:
%%time
from skvideo.io import vwrite
vid = xr.concat([mov.drop(['animal', 'session']), mov_trans], 'width')
vid = scale_varr(vid, (0, 255)).astype(np.uint8)
with ProgressBar():
    vwrite(os.path.join(dpath, "mc_test_dem_pfm_hist.mp4"), vid.transpose('frame', 'height', 'width'))

[########################################] | 100% Completed | 38.7s
[########################################] | 100% Completed | 38.7s
CPU times: user 2min 8s, sys: 1min 10s, total: 3min 18s
Wall time: 3min 4s
