In [None]:
# Test setup. Ignore warnings during production runs.

%run ./setup_tests.py

# Specify input data

* `data_dir` (`str`): Where the data is located. (change if data is not in the current directory, normally is)
* `data` (`str`): HDF5 file to use as input data.
* `data_basename` (`str`): Basename to use for intermediate and final result files.
* `dataset` (`str`): HDF5 dataset to use as input data.

In [None]:
data_dir = ""
data = "data.tif"
data_basename = "data"
dataset = "images"


import os

data_ext = os.path.splitext(data)[1].lower()
data_dir = os.path.abspath(data_dir)

subgroup_raw = "raw"
subgroup_trim = "trim"
subgroup_dn = "dn"
subgroup_reg = "reg"
subgroup_reg_images = "reg/images"
subgroup_reg_shifts = "reg/shifts"
subgroup_sub = "sub"
subgroup_norm = "norm"
subgroup_dict = "dict"
subgroup_post = "post"
subgroup_post_mask = "post/mask"
subgroup_rois = "rois"
subgroup_rois_masks = "rois/masks"
subgroup_rois_masks_j = "rois/masks_j"
subgroup_rois_labels = "rois/labels"
subgroup_rois_labels_j = "rois/labels_j"
subgroup_traces = "traces"
subgroup_proj = "proj"
subgroup_proj_hmean = "proj/hmean"
subgroup_proj_max = "proj/max"
subgroup_proj_mean = "proj/mean"
subgroup_proj_std = "proj/std"

postfix_rois = "_rois"
postfix_traces = "_traces"
postfix_html = "_proj"

h5_ext = os.path.extsep + "h5"
tiff_ext = os.path.extsep + "tif"
zarr_ext = os.path.extsep + "zarr"
html_ext = os.path.extsep + "html"

In [None]:
import os
from psutil import cpu_count

cluster_kwargs = {
    "ip": ""
}
client_kwargs = {}
adaptive_kwargs = {
    "minimum": 0,
    "maximum": int(os.environ.get("CORES", cpu_count())) - 1
}

In [None]:
import zarr

from nanshe_workflow.data import DistributedDirectoryStore

zarr_store = zarr.open_group(DistributedDirectoryStore(data_basename + zarr_ext), "a")

# Configure and startup Cluster

In [None]:
from nanshe_workflow.par import startup_distributed
from nanshe_workflow.data import DistributedArrayStore

client = startup_distributed(0, cluster_kwargs, client_kwargs, adaptive_kwargs)

dask_store = DistributedArrayStore(zarr_store, client=client)

client

In [None]:
client.cluster

# Define functions for computation.

In [None]:
%matplotlib notebook

import matplotlib
import matplotlib.cm
import matplotlib.pyplot

import matplotlib as mpl
import matplotlib.pyplot as plt

from mplview.core import MatplotlibViewer as MPLViewer

In [None]:
import logging
import os
import sys

from builtins import range as irange

try:
    from contextlib import suppress
except ImportError:
    from contextlib2 import suppress

import numpy
import scipy
import scipy.ndimage
import h5py

import numpy as np
import scipy as sp
import scipy.ndimage as spim
import h5py as hp

import dask
import dask.array
import dask.array.fft
import dask.distributed

import dask.array as da

import dask_imread
import dask_ndfilters
import dask_ndfourier

import zarr

import nanshe
from nanshe.imp.segment import generate_dictionary

import nanshe_workflow
from nanshe_workflow.data import io_remove, dask_io_remove, dask_load_hdf5, dask_store_zarr, zip_zarr, open_zarr

zarr.blosc.set_nthreads(1)
zarr.blosc.use_threads = False
client.run(zarr.blosc.set_nthreads, 1)
client.run(setattr, zarr.blosc, "use_threads", False)

logging.getLogger("nanshe").setLevel(logging.INFO)

In [None]:
from nanshe_workflow.data import DistributedDirectoryStore
from nanshe_workflow.data import hdf5_to_zarr, zarr_to_hdf5
from nanshe_workflow.data import save_tiff

In [None]:
try:
    import pyfftw.interfaces.numpy_fft as numpy_fft
except ImportError:
    import numpy.fft as numpy_fft

rfftn = da.fft.fft_wrap(numpy_fft.rfftn)
irfftn = da.fft.fft_wrap(numpy_fft.irfftn)

In [None]:
from nanshe_workflow.imp2 import extract_f0, wavelet_transform, renormalized_images, normalize_data

from nanshe_workflow.par import halo_block_generate_dictionary_parallel
from nanshe_workflow.imp import block_postprocess_data_parallel

In [None]:
from nanshe_workflow.proj2 import compute_traces

from nanshe_workflow.proj2 import compute_adj_harmonic_mean_projection

from nanshe_workflow.proj2 import norm_layer

# Begin workflow. Set parameters and run each cell.

### Convert TIFF/HDF5 to Zarr

* `block_chunks` (`tuple` of `int`s): chunk size for each block loaded into memory.

In [None]:
block_chunks = (100, -1, -1)

for k in [subgroup_raw]:
    with suppress(KeyError):
        del dask_store[k]

if data_ext == tiff_ext:
    dask_store[subgroup_raw] = dask_imread.imread(data, nframes=block_chunks[0])
elif data_ext == h5_ext:
    dask_store[subgroup_raw] = dask_load_hdf5(data, dataset, chunks=block_chunks)

dask.distributed.progress(dask_store[subgroup_raw], notebook=False)

### View Input Data

In [None]:
imgs_min, imgs_max = 0, 100

da_imgs = dask_store[subgroup_raw]

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    da_imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

### Trimming

* `front` (`int`): amount to trim off the front
* `back` (`int`): amount to trim off the back

In [None]:
front = 0
back = 0


for k in [subgroup_trim]:
    with suppress(KeyError):
        del dask_store[k]


# Load and prep data for computation.
da_imgs = dask_store[subgroup_raw]

# Trim frames from front and back
da_imgs_trim = da_imgs[front:len(da_imgs)-back]

# Store trimmed data
dask_store[subgroup_trim] = da_imgs_trim

# Check progress of store step
dask.distributed.progress(dask_store[subgroup_trim], notebook=False)
print("")


# View results
imgs_min, imgs_max = 0, 100

da_imgs = dask_store[subgroup_trim]

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    da_imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

### Denoising

* `med_filt_size` (`int`): footprint size for median filter


In [None]:
med_filt_size = 10


for k in [subgroup_dn]:
    with suppress(KeyError):
        del dask_store[k]


# Load and prep data for computation.
da_imgs = dask_store[subgroup_trim]

da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
        da_imgs_flt.dtype.itemsize >= 4):
    da_imgs_flt = da_imgs_flt.astype(np.float32)

# Median filter frames
da_imgs_filt = dask_ndfilters.median_filter(
    da_imgs_flt, (1,) + (da_imgs_flt.ndim - 1) * (med_filt_size,)
)

# Reset minimum to original value.
da_imgs_filt += da_imgs.min() - da_imgs_filt.min()

# Store denoised data
dask.distributed.fire_and_forget(dask.persist(da_imgs.min(), da_imgs_filt, da_imgs_filt.min()))
dask_store[subgroup_dn] = da_imgs_filt

# Check progress of store step
dask.distributed.progress(dask_store[subgroup_dn], notebook=False)
print("")


# View results
imgs_min, imgs_max = 0, 100

da_imgs = dask_store[subgroup_dn]

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    da_imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

### Registration

In [None]:
def fourier_shift_wrap(array, shift):
    result = numpy.empty_like(array)
    for i in irange(len(array)):
        result[i] = spim.fourier_shift(array[i], shift[0][i])
    return result


def find_best_match(matches):
    best_match = numpy.zeros(
        matches.shape[:1],
        dtype=matches.dtype
    )
    if matches.size:
        i = numpy.argmin((matches ** 2).sum(axis=0))
        best_match = matches[:, i]

    return best_match


def compute_offset(match_mask):
    match_mask = match_mask[0][0]

    result = numpy.empty((len(match_mask), match_mask.ndim - 1), dtype=int)
    for i in irange(len(match_mask)):
        match_mask_i = match_mask[i]

        frame_shape = np.array(match_mask_i.shape)
        half_frame_shape = frame_shape // 2

        matches = np.array(match_mask_i.nonzero())
        above = (matches > half_frame_shape[:, None]).astype(matches.dtype)
        matches -= above * frame_shape[:, None]

        result[i] = find_best_match(matches)

    return result


def roll_frames_chunk(frames, shifts):
    # Needed as Dask shares objects and we plan to write to it.
    # Also if there is only one refcount the old object is freed.
    frames = numpy.copy(frames)

    for i in irange(len(frames)):
        frames[i] = numpy.roll(
            frames[i],
            tuple(shifts[i]),
            axis=tuple(irange(frames.ndim - 1))
        )

    return frames


def roll_frames(frames, shifts):
    frames = frames.rechunk({
        k: v for k, v in enumerate(frames.shape[1:], 1)
    })
    shifts = shifts.rechunk({1: shifts.shape[1]})

    rolled_frames = da.atop(
        roll_frames_chunk, tuple(irange(frames.ndim)),
        frames, tuple(irange(frames.ndim)),
        shifts, (0, frames.ndim),
        dtype=da_imgs.dtype,
        concatenate=True
    )

    return rolled_frames

In [None]:
num_reps = 5
tmpl_hist_wght = 0.25
thld_rel_dist = 0.0


for k in [subgroup_reg_images, subgroup_reg_shifts]:
    with suppress(KeyError):
        del dask_store[k]
with suppress(KeyError):
    del zarr_store[subgroup_reg]
zarr_store.require_group(subgroup_reg)


# Load and prep data for computation.
da_imgs = dask_store[subgroup_dn]

da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
        da_imgs_flt.dtype.itemsize >= 4):
    da_imgs_flt = da_imgs_flt.astype(np.float32)

# Create frame shape arrays
frame_shape = np.array(da_imgs_flt.shape[1:], dtype=int)
half_frame_shape = frame_shape // 2
frame_shape = da.asarray(frame_shape)
half_frame_shape = da.asarray(half_frame_shape)

# Find the inverse of each frame
da_imgs_flt_min = da_imgs_flt.min()
da_imgs_inv = dask.array.reciprocal(da_imgs_flt - (da_imgs_flt_min - 1))

# Compute the FFT of inverse frames and template
da_imgs_fft = rfftn(da_imgs_inv, axes=tuple(irange(1, da_imgs_flt.ndim)))
da_imgs_fft_tmplt = da_imgs_fft.mean(axis=0, keepdims=True)

# Initialize
i = 0
avg_rel_dist = 1.0
tmpl_hist_wght = da_imgs_flt.dtype.type(tmpl_hist_wght)
da_shifts = da.zeros(
    (len(da_imgs_flt), da_imgs_flt.ndim - 1),
    dtype=int,
    chunks=(1, da_imgs_flt.ndim - 1)
)

# Persist FFT of frames and template
da_imgs_flt_min, da_imgs_fft, da_imgs_fft_tmplt = client.persist([
    da_imgs_flt_min, da_imgs_fft, da_imgs_fft_tmplt
])
dask.distributed.fire_and_forget(da_imgs_flt_min)
del da_imgs_flt_min

while avg_rel_dist > thld_rel_dist and i < num_reps:
    # Compute the shifted frames
    da_shifted_frames = da.atop(
        fourier_shift_wrap,
        (0,) + tuple(irange(1, da_imgs_fft.ndim)),
        da_imgs_fft,
        (0,) + tuple(irange(1, da_imgs_fft.ndim)),
        da_shifts,
        (0, da_imgs_fft.ndim),
        dtype=da_imgs_fft.dtype
    )

    # Compute the template FFT
    da_imgs_fft_tmplt = (
        tmpl_hist_wght * da_imgs_fft_tmplt +
        (1 - tmpl_hist_wght) * da_shifted_frames.mean(axis=0, keepdims=True)
    )

    # Free connected persisted values
    del da_shifted_frames

    # Find the best overlap with the template.
    da_overlap = irfftn(
        da_imgs_fft * da_imgs_fft_tmplt,
        s=da_imgs_flt.shape[1:],
        axes=tuple(irange(1, da_imgs_flt.ndim))
    )
    da_overlap_max = da_overlap.max(axis=tuple(irange(1, da_imgs_flt.ndim)), keepdims=True)
    da_overlap_max_match = (da_overlap == da_overlap_max)

    # Compute the shift for each frame.
    old_da_shifts = da_shifts
    da_shifts = da.atop(
        compute_offset,
        (0, da_overlap_max_match.ndim),
        da_overlap_max_match.rechunk(dict(enumerate(da_overlap_max_match.shape[1:], 1))),
        tuple(irange(0, da_overlap_max_match.ndim)),
        dtype=int,
        new_axes={da_overlap_max_match.ndim: da_overlap_max_match.ndim - 1}
    )

    # Free connected persisted values
    del da_overlap
    del da_overlap_max
    del da_overlap_max_match

    # Remove any collective frame drift.
    da_drift = da_shifts.mean(axis=0, keepdims=True).round().astype(da_shifts.dtype)
    da_shifts = da_shifts - da_drift

    # Free connected persisted values
    del da_drift

    # Find shift change.
    diff_da_shifts = da_shifts - old_da_shifts
    rel_diff_da_shifts = (
        diff_da_shifts.astype(da_imgs_flt.dtype) / 
        frame_shape.astype(da_imgs_flt.dtype) /
        (da_imgs_flt.dtype.type(len(frame_shape)) ** 0.5)
    )
    rel_dist_da_shifts = (rel_diff_da_shifts ** 2.0).sum(axis=1) ** 0.5
    avg_rel_dist = rel_dist_da_shifts.sum() / da_imgs_flt.dtype.type(len(da_shifts))

    # Free old shifts
    del old_da_shifts

    # Free connected persisted values
    del diff_da_shifts
    del rel_diff_da_shifts
    del rel_dist_da_shifts

    # Persist values needed for the next iteration (and end of this one).
    da_imgs_fft_tmplt, da_shifts, avg_rel_dist = client.persist([da_imgs_fft_tmplt, da_shifts, avg_rel_dist])
    dask.distributed.fire_and_forget(da_shifts)

    # Compute change
    dask.distributed.progress(avg_rel_dist, notebook=False)
    print("")
    avg_rel_dist = avg_rel_dist.compute()
    i += 1

    # Show change
    print((i, avg_rel_dist))

# Drop unneeded items
del frame_shape
del half_frame_shape
del da_imgs_flt
del da_imgs_inv
del da_imgs_fft
del da_imgs_fft_tmplt

# Roll all parts to clip to one side
# Keep origin static
da_imgs_shifted = roll_frames(
    da_imgs,
    da.clip(da_shifts, None, 0)
)

# Truncate all frames to smallest one
da_imgs_trunc_shape = da.asarray(da_imgs.shape[1:]) - abs(da_shifts).max(axis=0)
da_imgs_trunc_shape = da_imgs_trunc_shape.compute()

da_imgs_trunc_cut = tuple(map(
    lambda s: slice(None, s), da_imgs_trunc_shape
))

da_imgs_trunc = da_imgs_shifted[(slice(None),) + da_imgs_trunc_cut]

# Free raw data
del da_imgs

# Store registered data
dask_store.update({
    subgroup_reg_images: da_imgs_trunc,
    subgroup_reg_shifts: da_shifts,
})

dask.distributed.progress(
    dask.distributed.futures_of([
        dask_store[subgroup_reg_images],
        dask_store[subgroup_reg_shifts]
    ]),
    notebook=False
)
print("")

# Free truncated frames and shifts
del da_imgs_trunc
del da_shifts


# View results
imgs_min, imgs_max = 0, 100

da_imgs = dask_store[subgroup_reg_images]
da_shifts = dask_store[subgroup_reg_shifts]

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

fig, axs = plt.subplots(nrows=da_shifts.shape[1], sharex=True)
fig.subplots_adjust(hspace=0.0)
for i in range(da_shifts.shape[1]):
    axs[i].plot(np.asarray(da_shifts[:, i]))
    axs[i].set_ylabel("%s (px)" % chr(ord("X") + da_shifts.shape[1] - i - 1))
    axs[i].yaxis.set_tick_params(width=1.5)
    [v.set_linewidth(2) for v in axs[i].spines.values()]
axs[-1].set_xlabel("Frame (#)")
axs[-1].set_xlim((0, da_shifts.shape[0] - 1))
axs[-1].xaxis.set_tick_params(width=1.5)

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    da_imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

### Projections

In [None]:
for k in [subgroup_proj_hmean, subgroup_proj_max, subgroup_proj_mean, subgroup_proj_std]:
    with suppress(KeyError):
        del dask_store[k]
with suppress(KeyError):
    del zarr_store[subgroup_proj]
zarr_store.require_group(subgroup_proj)


# Load and prep data for computation.
da_imgs = dask_store[subgroup_reg_images]

da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
        da_imgs_flt.dtype.itemsize >= 4):
    da_imgs_flt = da_imgs_flt.astype(np.float32)

da_imgs_proj_hmean = compute_adj_harmonic_mean_projection(da_imgs_flt)

da_imgs_proj_max = da_imgs_flt.max(axis=0)

da_imgs_proj_mean, da_imgs_proj_std = da_imgs_flt.mean(axis=0), da_imgs_flt.std(axis=0)

# Store projections
dask_store.update(dict(zip(
    [subgroup_proj_hmean, subgroup_proj_max, subgroup_proj_mean, subgroup_proj_std],
    [da_imgs_proj_hmean, da_imgs_proj_max, da_imgs_proj_mean, da_imgs_proj_std]
)))

dask.distributed.progress(
    dask.distributed.futures_of([
        dask_store[subgroup_proj_hmean],
        dask_store[subgroup_proj_max],
        dask_store[subgroup_proj_mean],
        dask_store[subgroup_proj_std]
    ]),
    notebook=False
)
print("")

### Subtract Projection

In [None]:
for k in [subgroup_sub]:
    with suppress(KeyError):
        del dask_store[k]


# Load and prep data for computation.
da_imgs = dask_store[subgroup_reg_images]

da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
        da_imgs_flt.dtype.itemsize >= 4):
    da_imgs_flt = da_imgs_flt.astype(np.float32)

da_imgs_sub = da_imgs_flt - compute_adj_harmonic_mean_projection(da_imgs_flt)
da_imgs_sub -= da_imgs_sub.min()

# Store background removed data
dask_store[subgroup_sub] = da_imgs_sub

dask.distributed.progress(dask_store[subgroup_sub], notebook=False)
print("")


# View results
imgs_min, imgs_max = 0, 100

da_imgs = dask_store[subgroup_sub]

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    da_imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

### Normalize Data

In [None]:
for k in [subgroup_norm]:
    with suppress(KeyError):
        del dask_store[k]


# Load and prep data for computation.
da_imgs = dask_store[subgroup_sub]

da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
        da_imgs_flt.dtype.itemsize >= 4):
    da_imgs_flt = da_imgs_flt.astype(np.float32)

da_imgs_flt_mins = da_imgs_flt.min(
    axis=tuple(irange(1, da_imgs_flt.ndim)),
    keepdims=True
)

da_imgs_flt_shift = da_imgs_flt - da_imgs_flt_mins

da_result = renormalized_images(da_imgs_flt_shift)

# Store normalized data
dask_store[subgroup_norm] = da_result

dask.distributed.progress(dask_store[subgroup_norm], notebook=False)
print("")


# View results
imgs_min, imgs_max = 0, 100

da_imgs = dask_store[subgroup_norm]

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    da_imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

### Dictionary Learning

* `n_components` (`int`): number of basis images in the dictionary.
* `batchsize` (`int`): minibatch size to use.
* `iters` (`int`): number of iterations to run before getting dictionary.
* `lambda1` (`float`): weight for L<sup>1</sup> sparisty enforcement on sparse code.
* `lambda2` (`float`): weight for L<sup>2</sup> sparisty enforcement on sparse code.

<br>
* `block_frames` (`int`): number of frames to work with in each full frame block (run in parallel).

In [None]:
import dask_ml
import dask_ml.base
import sklearn
import sklearn.decomposition

class MiniBatchDictionaryLearning(dask_ml.base._BigPartialFitMixin,
                                  sklearn.decomposition.MiniBatchDictionaryLearning):
    pass


n_components = 5
batchsize = 256
iters = 100
lambda1 = 0.001
lambda2 = 0.0

block_frames = 51


for k in [subgroup_dict]:
    with suppress(KeyError):
        del dask_store[k]


da_imgs = dask_store[subgroup_norm]
da_imgs_mtx = da_imgs.reshape(
    da_imgs.shape[0],
    int(np.prod(da_imgs.shape[1:]))
)

dict_sel_idx = []
for i in range(n_components):
    redraw = True
    while redraw:
        e = np.random.randint(0, len(da_imgs_mtx))
        redraw = bool(e in dict_sel_idx)
    dict_sel_idx.append(e)
dict_sel = da.take(da_imgs_mtx, dict_sel_idx)

learner = MiniBatchDictionaryLearning(
    n_components=n_components, alpha=lambda1, n_iter=iters, fit_algorithm="lars",
    n_jobs=1, batch_size=batchsize, shuffle=True,
    dict_init=dict_sel, transform_algorithm="omp",
    transform_n_nonzero_coefs=None, transform_alpha=None,
    verbose=False, split_sign=False, random_state=None
)
dictionary = dask_store._create_dataset(
    subgroup_dict, shape=(n_components,) + da_imgs.shape[1:], dtype=da_imgs.dtype, chunks=True
)

learner.fit(da_imgs_mtx, get=client.get)
dictionary[:] = learner.components_.reshape((n_components,) + da_imgs.shape[1:])

del dictionary

In [None]:
algorithm = "lasso_lars"
alpha = 0.001
n_nonzero_coefs = None
max_iter = 1000


import sklearn
import sklearn.decomposition

for k in ["code"]:
    with suppress(KeyError):
        del dask_store[k]


imgs = dask_store[subgroup_norm]
dictionary = dask_store[subgroup_dict]

imgs = imgs.astype(np.float64)
dictionary = dictionary.astype(np.float64)

dictionary = dictionary.rechunk(dictionary.ndim * (-1,))

imgs = imgs.reshape((imgs.shape[0], np.prod(imgs.shape[1:])))
dictionary = dictionary.reshape((dictionary.shape[0], np.prod(dictionary.shape[1:])))

gram = da.dot(dictionary, dictionary.T)
cov = da.dot(dictionary, imgs.T)


def sparse_encode_wrapper(*args, **kwargs):
    args = tuple(e[0] if isinstance(e, list) else e for e in args)
    return sklearn.decomposition.sparse_encode(*args, **kwargs).T


code = da.atop(
    sparse_encode_wrapper,
    (0, 1),
    imgs,
    (1, 2),
    dictionary,
    (0, 2),
    gram,
    (0, 0),
    cov,
    (0, 1),
    dtype=np.float64,
    algorithm=algorithm,
    n_nonzero_coefs=n_nonzero_coefs,
    alpha=alpha,
    copy_cov=True,
    init=None,
    max_iter=max_iter,
    n_jobs=1,
    check_input=True,
    verbose=0
)

dask_store["code"] = code

dask.distributed.progress(dask_store["code"], notebook=False)
print("")

In [None]:
import ipywidgets

@ipywidgets.interact(i=ipywidgets.IntSlider(min=0,max=len(dictionary)-1,step=1,value=0))
def show_basis_code_plts(i):
    fig = plt.figure()
    fig.add_subplot(1,2,1)
    plt.imshow(dask_store[subgroup_dict][i])
    fig.add_subplot(1,2,2)
    plt.plot(dask_store["code"][i])
    plt.show()

In [None]:
da_imgs = dask_store[subgroup_norm]
da_imgs_recons = da.tensordot(dask_store["code"], dask_store[subgroup_dict], axes=(0, 0))

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    (da_imgs - da_imgs_recons).max(axis=0),
    vmin=0,
    vmax=0.2
)

In [None]:
import functools
import io
import logging
import sys

try:
    from contextlib import ExitStack, redirect_stdout, redirect_stderr
except ImportError:
    from contextlib2 import ExitStack, redirect_stdout, redirect_stderr


def func_log_stdoe(func):
    @functools.wraps(func)
    def wrapped(*args, **kwargs):
        with ExitStack() as stack:
            out, err = io.StringIO(), io.StringIO()

            stack.enter_context(redirect_stdout(out))
            stack.enter_context(redirect_stdout(err))

            try:
                return func(*args, **kwargs)
            finally:
                logging.getLogger("distributed.worker.stdout").info(out.getvalue())
                logging.getLogger("distributed.worker.stderr").info(err.getvalue())

    return wrapped


@func_logging
def print_hello():
    print("Hello my Friends!")

client.run(print_hello)

### Postprocessing

* `significance_threshold` (`float`): number of standard deviations below which to include in "noise" estimate
* `wavelet_scale` (`int`): scale of wavelet transform to apply (should be the same as the one used above)
* `noise_threshold` (`float`): number of units of "noise" above which something needs to be to be significant
* `accepted_region_shape_constraints` (`dict`): if ROIs don't match this, reduce the `wavelet_scale` once.
* `percentage_pixels_below_max` (`float`): upper bound on ratio of ROI pixels not at max intensity vs. all ROI pixels
* `min_local_max_distance` (`float`): minimum allowable euclidean distance between two ROIs maximum intensities
* `accepted_neuron_shape_constraints` (`dict`): shape constraints for ROI to be kept.

* `alignment_min_threshold` (`float`): similarity measure of the intensity of two ROIs images used for merging.
* `overlap_min_threshold` (`float`): similarity measure of the masks of two ROIs used for merging.

In [None]:
significance_threshold = 3.0
wavelet_scale = 3
noise_threshold = 3.0
percentage_pixels_below_max = 0.8
min_local_max_distance = 16.0

alignment_min_threshold = 0.6
overlap_min_threshold = 0.6


for k in zarr_store.get(subgroup_post, {}).keys():
    with suppress(KeyError):
        del dask_store[subgroup_post + "/" + k]
with suppress(KeyError):
    del zarr_store[subgroup_post]
zarr_store.require_group(subgroup_post)


imgs = dask_store._diskstore[subgroup_dict]
da_imgs = da.from_array(imgs, chunks=((1,) + imgs.shape[1:]))

result = block_postprocess_data_parallel(client)(da_imgs,
                              **{
                                    "wavelet_denoising" : {
                                        "estimate_noise" : {
                                            "significance_threshold" : significance_threshold
                                        },
                                        "wavelet.transform" : {
                                            "scale" : wavelet_scale
                                        },
                                        "significant_mask" : {
                                            "noise_threshold" : noise_threshold
                                        },
                                        "accepted_region_shape_constraints" : {
                                            "major_axis_length" : {
                                                "min" : 0.0,
                                                "max" : 25.0
                                            }
                                        },
                                        "remove_low_intensity_local_maxima" : {
                                            "percentage_pixels_below_max" : percentage_pixels_below_max
                                        },
                                        "remove_too_close_local_maxima" : {
                                            "min_local_max_distance" : min_local_max_distance
                                        },
                                        "accepted_neuron_shape_constraints" : {
                                            "area" : {
                                                "min" : 25,
                                                "max" : 600
                                            },
                                            "eccentricity" : {
                                                "min" : 0.0,
                                                "max" : 0.9
                                            }
                                        }
                                    },
                                    "merge_neuron_sets" : {
                                        "alignment_min_threshold" : alignment_min_threshold,
                                        "overlap_min_threshold" : overlap_min_threshold,
                                        "fuse_neurons" : {
                                            "fraction_mean_neuron_max_threshold" : 0.01
                                        }
                                    }
                              }
)

# Store projections
dask_store.update(dict(zip(
    ["%s/%s" % (subgroup_post, e) for e in result.dtype.names],
    [result[e] for e in result.dtype.names]
)))

dask.distributed.progress(
    dask.distributed.futures_of([
        dask_store["%s/%s" % (subgroup_post, e)]
        for e in result.dtype.names
    ]),
    notebook=False
)
print("")

### ROI and trace extraction

In [None]:
dask_io_remove(data_basename + postfix_rois + h5_ext, client)
for k in [subgroup_rois_masks, subgroup_rois_masks_j, subgroup_rois_labels, subgroup_rois_labels_j, subgroup_rois]:
    with suppress(KeyError):
        del dask_store[k]
with suppress(KeyError):
    del zarr_store[subgroup_rois]
zarr_store.require_group(subgroup_rois)


da_roi_masks = dask_store[subgroup_post_mask]

da_lbls = da.arange(
    1,
    len(da_roi_masks) + 1,
    chunks=da_roi_masks.chunks[0],
    dtype=np.uint64
)
da_lblimg = (
    da_lbls[(slice(None),) + (da_roi_masks.ndim - 1) * (None,)] * 
    da_roi_masks.astype(np.uint64)
).max(axis=0)

dask_store.update(dict(zip(
    [subgroup_rois_masks, subgroup_rois_masks_j, subgroup_rois_labels, subgroup_rois_labels_j],
    [da_roi_masks, da_roi_masks.astype(numpy.uint8), da_lblimg, da_lblimg.astype(numpy.uint8)]
)))

dask.distributed.progress(
    dask.distributed.futures_of([
        dask_store[e] for e in
    [subgroup_rois_masks, subgroup_rois_masks_j, subgroup_rois_labels, subgroup_rois_labels_j]
    ]),
    notebook=False
)
print("")


with h5py.File(data_basename + postfix_rois + h5_ext, "w") as f2:
    for k in [subgroup_rois_masks, subgroup_rois_masks_j, subgroup_rois_labels, subgroup_rois_labels_j]:
        zarr.copy(dask_store._diskstore[k], f2)


dask_io_remove(data_basename + postfix_traces + h5_ext, client)
for k in [subgroup_traces]:
    with suppress(KeyError):
        del dask_store[k]


# Load and prep data for computation.
da_images = dask_store[subgroup_sub]
da_masks = dask_store[subgroup_rois_masks]

da_result = compute_traces(da_images, da_masks)

# Store traces
dask_store[subgroup_traces] = da_result

dask.distributed.progress(dask_store[subgroup_traces], notebook=False)
print("")


with h5py.File(data_basename + postfix_traces + h5_ext, "w") as f2:
    zarr.copy(dask_store._diskstore[subgroup_traces], f2)


# View results
imgs_min, imgs_max = 0, 100

da_imgs = dask_store[subgroup_sub]

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    da_imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

lblimg = dask_store[subgroup_rois_labels].compute()
lblimg_msk = numpy.ma.masked_array(lblimg, mask=(lblimg==0))

mplsv.viewer.matshow(lblimg_msk, alpha=0.3, cmap=mpl.cm.jet)


mskimg = None
mskimg_j = None
lblimg = None
traces = None
traces_j = None

del mskimg
del mskimg_j
del lblimg
del traces
del traces_j

# End of workflow. Shutdown cluster.

In [None]:
import distributed
from nanshe_workflow.par import shutdown_distributed

try:
    del dask_store
except NameError:
    pass

client = distributed.client.default_client()

shutdown_distributed(client)

client

# Prepare interactive projection graph

In [None]:
import io
import textwrap
import zlib

import numpy
import numpy as np

import bokeh.plotting
import bokeh.plotting as bp

import bokeh.io
import bokeh.io as bio

import bokeh.embed
import bokeh.embed as be

from bokeh.models.mappers import LinearColorMapper

import webcolors

from bokeh.models import CustomJS, ColumnDataSource, HoverTool
from bokeh.models.layouts import Row

from builtins import (
    map as imap,
    range as irange
)

from past.builtins import basestring

import nanshe_workflow
from nanshe_workflow.data import io_remove, open_zarr
from nanshe_workflow.vis import get_rgb_array, get_rgba_array, get_all_greys, masks_to_contours_2d

In [None]:
mskimg = zarr_store[subgroup_rois_masks][...]

traces = zarr_store[subgroup_traces][...]

imgproj_mean = zarr_store[subgroup_proj_max][...]
imgproj_max = zarr_store[subgroup_proj_mean][...]
imgproj_std = zarr_store[subgroup_proj_std][...]

### Result visualization

* `proj_img` (`str` or `list` of `str`): which projection or projections to plot (e.g. "max", "mean", "std").
* `block_size` (`int`): size of each point on any dimension in the image in terms of pixels.
* `roi_alpha` (`float`): transparency of the ROIs in a range of [0.0, 1.0].
* `roi_border_width` (`int`): width of the line border on each ROI.

<br>
* `trace_plot_width` (`int`): width of the trace plot.

In [None]:
proj_img = "std"
block_size = 1
roi_alpha = 0.3
roi_border_width = 3
trace_plot_width = 500


bio.curdoc().clear()

grey_range = get_all_greys()
grey_cm = LinearColorMapper(grey_range)

colors_rgb = get_rgb_array(len(mskimg))
colors_rgb = colors_rgb.tolist()
colors_rgb = list(imap(webcolors.rgb_to_hex, colors_rgb))

mskctr_pts_y, mskctr_pts_x = masks_to_contours_2d(mskimg)

mskctr_pts_dtype = np.min_scalar_type(max(mskimg.shape[1:]) - 1)
mskctr_pts_y = [np.array(_, dtype=mskctr_pts_dtype) for _ in mskctr_pts_y]
mskctr_pts_x = [np.array(_, dtype=mskctr_pts_dtype) for _ in mskctr_pts_x]

mskctr_srcs = ColumnDataSource(data=dict(x=mskctr_pts_x, y=mskctr_pts_y, color=colors_rgb))


if isinstance(proj_img, basestring):
    proj_img = [proj_img]
else:
    proj_img = list(proj_img)


proj_plot_width = block_size*mskimg.shape[2]
proj_plot_height = block_size*mskimg.shape[1]
plot_projs = []

if "max" in proj_img:
    plot_max = bp.Figure(plot_width=proj_plot_width, plot_height=proj_plot_height,
                         x_range=[0, mskimg.shape[2]], y_range=[mskimg.shape[1], 0],
                         tools=["tap", "pan", "box_zoom", "wheel_zoom", "save", "reset"],
                         title="Max Projection with ROIs", border_fill_color="black")
    plot_max.image(image=[numpy.flipud(imgproj_max)], x=[0], y=[mskimg.shape[1]],
                   dw=[imgproj_max.shape[1]], dh=[imgproj_max.shape[0]], color_mapper=grey_cm)
    plot_max.patches('x', 'y', source=mskctr_srcs, alpha=roi_alpha, line_width=roi_border_width, color="color")

    plot_max.outline_line_color = "white"
    for i in irange(len(plot_max.axis)):
        plot_max.axis[i].axis_line_color = "white"

    plot_projs.append(plot_max)


if "mean" in proj_img:
    plot_mean = bp.Figure(plot_width=proj_plot_width, plot_height=proj_plot_height,
                         x_range=[0, mskimg.shape[2]], y_range=[mskimg.shape[1], 0],
                         tools=["tap", "pan", "box_zoom", "wheel_zoom", "save", "reset"],
                         title="Mean Projection with ROIs", border_fill_color="black")
    plot_mean.image(image=[numpy.flipud(imgproj_mean)], x=[0], y=[mskimg.shape[1]],
                   dw=[mskimg.shape[2]], dh=[mskimg.shape[1]], color_mapper=grey_cm)
    plot_mean.patches('x', 'y', source=mskctr_srcs, alpha=roi_alpha, line_width=roi_border_width, color="color")

    plot_mean.outline_line_color = "white"
    for i in irange(len(plot_mean.axis)):
        plot_mean.axis[i].axis_line_color = "white"

    plot_projs.append(plot_mean)


if "std" in proj_img:
    plot_std = bp.Figure(plot_width=proj_plot_width, plot_height=proj_plot_height,
                         x_range=[0, mskimg.shape[2]], y_range=[mskimg.shape[1], 0],
                         tools=["tap", "pan", "box_zoom", "wheel_zoom", "save", "reset"],
                         title="Std Dev Projection with ROIs", border_fill_color="black")
    plot_std.image(image=[numpy.flipud(imgproj_std)], x=[0], y=[mskimg.shape[1]],
                   dw=[mskimg.shape[2]], dh=[mskimg.shape[1]], color_mapper=grey_cm)
    plot_std.patches('x', 'y', source=mskctr_srcs, alpha=roi_alpha, line_width=roi_border_width, color="color")

    plot_std.outline_line_color = "white"
    for i in irange(len(plot_std.axis)):
        plot_std.axis[i].axis_line_color = "white"

    plot_projs.append(plot_std)


all_tr_dtype_srcs = ColumnDataSource(data=dict(traces_dtype=traces.dtype.type(0)[None]))
all_tr_shape_srcs = ColumnDataSource(data=dict(traces_shape=traces.shape))
all_tr_srcs = ColumnDataSource(data=dict(
    traces=numpy.frombuffer(
        zlib.compress(traces.tobytes()),
        dtype=np.uint8
    )
))
tr_srcs = ColumnDataSource(data=dict(times_sel=[], traces_sel=[], colors_sel=[]))
plot_tr = bp.Figure(plot_width=trace_plot_width, plot_height=proj_plot_height,
                    x_range=(0.0, float(traces.shape[1])), y_range=(float(traces.min()), float(traces.max())),
                    tools=["pan", "box_zoom", "wheel_zoom", "save", "reset"], title="ROI traces",
                    background_fill_color="black", border_fill_color="black")
plot_tr.multi_line("times_sel", "traces_sel", source=tr_srcs, color="colors_sel")

plot_tr.outline_line_color = "white"
for i in irange(len(plot_tr.axis)):
    plot_tr.axis[i].axis_line_color = "white"

plot_projs.append(plot_tr)


mskctr_srcs.callback = CustomJS(
    args=dict(
        all_tr_dtype_srcs=all_tr_dtype_srcs,
        all_tr_shape_srcs=all_tr_shape_srcs,
        all_tr_srcs=all_tr_srcs,
        tr_srcs=tr_srcs
    ), code="""
    var range = function(n){ return Array.from(Array(n).keys()); };

    var traces_not_decoded = (all_tr_dtype_srcs.data['traces_dtype'] == 0);
    var traces_dtype = all_tr_dtype_srcs.data['traces_dtype'].constructor;
    var traces_shape = all_tr_shape_srcs.data['traces_shape'];
    var trace_len = traces_shape[1];
    var traces = all_tr_srcs.data['traces'];
    if (traces_not_decoded) {
        traces = window.pako.inflate(traces);
        traces = new traces_dtype(traces.buffer);
        all_tr_srcs.data['traces'] = traces;
        all_tr_dtype_srcs.data['traces_dtype'] = 1;
    }

    var inds = cb_obj.selected['1d'].indices;
    var colors = cb_obj.data['color'];
    var selected = tr_srcs.data;

    var times = range(trace_len);

    selected['times_sel'] = [];
    selected['traces_sel'] = [];
    selected['colors_sel'] = [];

    for (i = 0; i < inds.length; i++) {
        var inds_i = inds[i];
        var trace_i = traces.slice(trace_len*inds_i, trace_len*(inds_i+1));
        var color_i = colors[inds_i];

        selected['times_sel'].push(times);
        selected['traces_sel'].push(trace_i);
        selected['colors_sel'].push(color_i);
    }

    tr_srcs.change.emit();
""")


plot_group = Row(*plot_projs)


# Clear out the old HTML file before writing a new one.
io_remove(data_basename + postfix_html + html_ext)


def indent(text, spaces):
    spaces = " " * int(spaces)
    return "\n".join(imap(lambda l: spaces + l, text.splitlines()))

def write_html(filename, title, div, script, cdn):
    html_tmplt = textwrap.dedent(u"""\
        <html lang="en">
            <head>
                <meta charset="utf-8">
                <title>{title}</title>
                {cdn}
                <style>
                  html {{
                    width: 100%;
                    height: 100%;
                  }}
                  body {{
                    width: 90%;
                    height: 100%;
                    margin: auto;
                    background-color: black;
                  }}
                </style>
            </head>
            <body>
                {div}
                {script}
            </body>
        </html>
    """)

    html_cont = html_tmplt.format(
        title=title,
        div=indent(div, 8),
        script=indent(script, 8),
        cdn=indent(cdn, 8),
    )

    with io.open(filename, "w") as fh:
        fh.write(html_cont)

script, div = be.components(plot_group)
cdn = bokeh.resources.CDN.render() + "\n"
cdn += """
<script type="text/javascript" src="https://cdn.jsdelivr.net/pako/1.0.4/pako_inflate.min.js"></script>
"""
cdn += "\n"

write_html(data_basename + postfix_html + html_ext, data_basename + postfix_html, div, script, cdn)


from IPython.display import display, IFrame
display(IFrame(data_basename + postfix_html + html_ext, "100%", 1.05*proj_plot_height))

In [None]:
# Test teardown. Ignore warnings during production runs.

%run ./teardown_tests.py