In [None]:
# Test setup. Ignore warnings during production runs.

%run ./setup_tests.py

# Specify input data

* `data_dir` (`str`): Where the data is located. (change if data is not in the current directory, normally is)
* `data` (`str`): HDF5 file to use as input data.
* `data_basename` (`str`): Basename to use for intermediate and final result files.
* `dataset` (`str`): HDF5 dataset to use as input data.

</br>
* `num_workers` (`int`): Number of workers for iPython Cluster. (default all cores excepting one for client)

In [None]:
data_dir = ""
data = "data.tif"
data_basename = "data"
dataset = "images"

num_workers = None


import os

data_ext = os.path.splitext(data)[1].lower()
data_dir = os.path.abspath(data_dir)

postfix_trim = "_trim"
postfix_dn = "_dn"
postfix_reg = "_reg"
postfix_sub = "_sub"
postfix_f_f0 = "_f_f0"
postfix_wt = "_wt"
postfix_norm = "_norm"
postfix_dict = "_dict"
postfix_post = "_post"
postfix_rois = "_rois"
postfix_traces = "_traces"
postfix_proj = "_proj"
postfix_html = "_proj"

h5_ext = os.path.extsep + "h5"
tiff_ext = os.path.extsep + "tif"
zarr_ext = os.path.extsep + "zarr"
html_ext = os.path.extsep + "html"

# Configure and startup Cluster

In [None]:
from nanshe_workflow.par import set_num_workers, startup_distributed

num_workers = set_num_workers(num_workers)

client = startup_distributed(num_workers)

# Define functions for computation.

In [None]:
%matplotlib notebook

import matplotlib
import matplotlib.cm
import matplotlib.pyplot

import matplotlib as mpl
import matplotlib.pyplot as plt

from mplview.core import MatplotlibViewer as MPLViewer

In [None]:
import logging
import os
import sys

from builtins import range as irange

import numpy
import scipy
import scipy.ndimage
import h5py

import numpy as np
import scipy as sp
import scipy.ndimage as spim
import h5py as hp

import dask
import dask.array
import dask.array.fft
import dask.distributed

import dask.array as da

import dask_imread
import dask_ndfilters
import dask_ndfourier

import zarr

import nanshe
from nanshe.imp.segment import generate_dictionary

import nanshe_workflow
from nanshe_workflow.data import io_remove, dask_io_remove, dask_load_hdf5, dask_store_zarr, zip_zarr, open_zarr

zarr.blosc.set_nthreads(1)
zarr.blosc.use_threads = False
client.run(zarr.blosc.set_nthreads, 1)
client.run(setattr, zarr.blosc, "use_threads", False)

logging.getLogger("nanshe").setLevel(logging.INFO)

In [None]:
from nanshe_workflow.data import hdf5_to_zarr, zarr_to_hdf5
from nanshe_workflow.data import save_tiff

In [None]:
try:
    import pyfftw.interfaces.numpy_fft as numpy_fft
except ImportError:
    import numpy.fft as numpy_fft

rfftn = da.fft.fft_wrap(numpy_fft.rfftn)
irfftn = da.fft.fft_wrap(numpy_fft.irfftn)

In [None]:
from nanshe_workflow.imp2 import extract_f0, wavelet_transform, renormalized_images, normalize_data

from nanshe_workflow.par import halo_block_generate_dictionary_parallel
from nanshe_workflow.imp import block_postprocess_data_parallel

In [None]:
from nanshe_workflow.proj2 import compute_traces

from nanshe_workflow.proj2 import compute_adj_harmonic_mean_projection
from nanshe_workflow.proj2 import compute_min_projection
from nanshe_workflow.proj2 import compute_max_projection

from nanshe_workflow.proj2 import compute_moment_projections

from nanshe_workflow.proj2 import norm_layer

# Begin workflow. Set parameters and run each cell.

### Convert TIFF/HDF5 to Zarr

In [None]:
dask_io_remove(data_basename + zarr_ext, client)

if data_ext == tiff_ext:
    a = dask_imread.imread(data)
elif data_ext == h5_ext:
    a = dask_load_hdf5(data, dataset)

dask_store_zarr(data_basename + zarr_ext, [dataset], [a], client)

del a

In [None]:
import itertools
import hashlib

import dask
import zarr

from builtins import map as imap
from builtins import range as irange


def hash_chunk(dataset, chunk_id, hashname="sha1"):
    str_i = ".".join(imap(str, chunk_id))

    h = hashlib.new(hashname)

    h.update(dataset.store[dataset.path + '/' + zarr.storage.array_meta_key])
    h.update(dataset.chunk_store[dataset.path + '/' + str_i])

    checksum = h.digest()

    return checksum


def hash_reduce(*hashes, hashname="sha1"):
    h = hashlib.new(hashname)

    for each_hash in hashes:
        h.update(each_hash)

    checksum = h.digest()

    return checksum


def hash_dataset(dataset, hashname="sha1"):
    d = []
    for i in itertools.product(*[irange(s) for s in dataset.cdata_shape]):
        d.append(dask.delayed(hash_chunk)(dataset, i, hashname=hashname))

    while len(d) > 1:
        # Ensure there is an even number of hashes
        if len(d) % 2:
            d.append(d[-1])

        # Pair them off for comparison
        d2 = []
        for i in irange(0, len(d), 2):
            d2.append(dask.delayed(hash_reduce)(*d[i:i+2], hashname=hashname))
        d = d2
    d = d[0]

    return d


fn = data_basename + zarr_ext
f = zarr.open_group(fn, "r")

imgs = f[dataset]

checksum_sha1 = hash_dataset(f[dataset], hashname="sha1").compute()
checksum_md5 = hashlib.md5(checksum_sha1).hexdigest()

In [None]:
bytes is not str

In [None]:
import binascii

binascii.hexlify(checksum_sha1)

In [None]:
imgs.attrs.get("__dask_name__")

In [None]:
%pdoc da.from_array

### View Input Data

* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
da_imgs.name

In [None]:
norm_frames = 100

fn = data_basename + zarr_ext

imgs_min, imgs_max = 0, 100

f = zarr.open_group(fn, "r")

imgs = f[dataset]
da_imgs = da.from_array(imgs, chunks=(norm_frames,) + imgs.shape[1:])

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max], lock=False, compute=False)
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

### Trimming

* `front` (`int`): amount to trim off the front
* `back` (`int`): amount to trim off the back

<br>
* `block_frames` (`int`): number of frames to work with in each block (run in parallel).
* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
front = 0
back = 0

block_frames = 1
norm_frames = 100


dask_io_remove(data_basename + postfix_trim + zarr_ext, client)


# Load and prep data for computation.
f = zarr.open_group(data_basename + zarr_ext, "r")
imgs = f[dataset]
da_imgs = da.from_array(imgs, chunks=(block_frames,) + imgs.shape[1:])

# Trim frames from front and back
da_imgs_trim = da_imgs[front:len(da_imgs)-back]

# Store denoised data
dask_store_zarr(data_basename + postfix_trim + zarr_ext, ["images"], [da_imgs_trim], client)


# View results
fn = data_basename + postfix_trim + zarr_ext

imgs_min, imgs_max = 0, 100

f = zarr.open_group(fn, "r")

imgs = f["images"]
da_imgs = da.from_array(imgs, chunks=(norm_frames,) + imgs.shape[1:])

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max], lock=False, compute=False)
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

### Denoising

* `med_filt_size` (`int`): footprint size for median filter
* `norm_filt_sigma` (`int`/`float`): sigma for Gaussian filter

<br>
* `block_frames` (`int`): number of frames to work with in each block (run in parallel).
* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
med_filt_size = 3
norm_filt_sigma = 10

block_frames = 1
norm_frames = 100


dask_io_remove(data_basename + postfix_dn + zarr_ext, client)


# Load and prep data for computation.
f = zarr.open_group(data_basename + postfix_trim + zarr_ext, "r")
imgs = f["images"]
da_imgs = da.from_array(imgs, chunks=(block_frames,) + imgs.shape[1:])

da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
        da_imgs_flt.dtype.itemsize >= 4):
    da_imgs_flt = da_imgs_flt.astype(np.float32)

# Median filter frames
da_imgs_medf = dask_ndfilters.median_filter(
    da_imgs_flt, (1,) + (da_imgs_flt.ndim - 1) * (med_filt_size,)
)

# Compute the Gaussian filter of frames
da_imgs_smoothed = dask_ndfilters.gaussian_filter(
    da_imgs_medf, (0,) + (da_imgs_medf.ndim - 1) * (norm_filt_sigma,)
)

# Apply high pass filter to images
da_imgs_filt = da_imgs_medf - da_imgs_smoothed

# Reset minimum to original value.
da_imgs_filt += da_imgs.min() - da_imgs_filt.min()

# Store denoised data
dask_store_zarr(data_basename + postfix_dn + zarr_ext, ["images"], [da_imgs_filt], client)


# View results
fn = data_basename + postfix_dn + zarr_ext

imgs_min, imgs_max = 0, 100

f = zarr.open_group(fn, "r")

imgs = f["images"]
da_imgs = da.from_array(imgs, chunks=(norm_frames,) + imgs.shape[1:])

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max], lock=False, compute=False)
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

### Registration

In [None]:
def fourier_shift_wrap(array, shift):
    result = numpy.empty_like(array)
    for i in irange(len(array)):
        result[i] = spim.fourier_shift(array[i], shift[0][i])
    return result


def find_best_match(matches):
    best_match = numpy.zeros(
        matches.shape[:1],
        dtype=matches.dtype
    )
    if matches.size:
        i = numpy.argmin((matches ** 2).sum(axis=0))
        best_match = matches[:, i]

    return best_match


def compute_offset(match_mask):
    match_mask = match_mask[0][0]

    result = numpy.empty((len(match_mask), match_mask.ndim - 1), dtype=int)
    for i in irange(len(match_mask)):
        match_mask_i = match_mask[i]

        frame_shape = np.array(match_mask_i.shape)
        half_frame_shape = frame_shape // 2

        matches = np.array(match_mask_i.nonzero())
        above = (matches > half_frame_shape[:, None]).astype(matches.dtype)
        matches -= above * frame_shape[:, None]

        result[i] = find_best_match(matches)

    return result

In [None]:
num_reps = 5
tmpl_hist_wght = 0.25
thld_rel_dist = 0.0

block_frames = 1
norm_frames = 100


dask_io_remove(data_basename + postfix_reg + zarr_ext, client)


# Load and prep data for computation.
f = zarr.open_group(data_basename + postfix_dn + zarr_ext, "r")
imgs = f["images"]
da_imgs = da.from_array(imgs, chunks=(block_frames,) + imgs.shape[1:])

da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
        da_imgs_flt.dtype.itemsize >= 4):
    da_imgs_flt = da_imgs_flt.astype(np.float32)

# Create frame shape arrays
frame_shape = np.array(da_imgs_flt.shape[1:], dtype=int)
half_frame_shape = frame_shape // 2
frame_shape = da.asarray(frame_shape)
half_frame_shape = da.asarray(half_frame_shape)

# Compute the FFT of frames and template
da_imgs_fft = rfftn(da_imgs_flt, axes=tuple(irange(1, imgs.ndim)))
da_imgs_fft_tmplt = da_imgs_fft.mean(axis=0, keepdims=True)

# Initialize
i = 0
avg_rel_dist = 1.0
tmpl_hist_wght = da_imgs_flt.dtype.type(tmpl_hist_wght)
shifts = da.zeros(
    (len(da_imgs_flt), da_imgs_flt.ndim - 1),
    dtype=int,
    chunks=(1, da_imgs_flt.ndim - 1)
)

# Persist FFT of frames and template
da_imgs_fft, da_imgs_fft_tmplt = client.persist([da_imgs_fft, da_imgs_fft_tmplt])

while avg_rel_dist > thld_rel_dist and i < num_reps:
    # Compute the shifted frames
    shifted_frames = da.atop(
        fourier_shift_wrap,
        (0,) + tuple(irange(1, da_imgs_fft.ndim)),
        da_imgs_fft,
        (0,) + tuple(irange(1, da_imgs_fft.ndim)),
        shifts,
        (0, da_imgs_fft.ndim),
        dtype=da_imgs_fft.dtype
    )

    # Compute the template FFT
    da_imgs_fft_tmplt = (
        tmpl_hist_wght * da_imgs_fft_tmplt +
        (1 - tmpl_hist_wght) * shifted_frames.mean(axis=0, keepdims=True)
    )

    # Free connected persisted values
    del shifted_frames

    # Find the best overlap with the template.
    overlap = irfftn(
        da_imgs_fft * da_imgs_fft_tmplt,
        s=da_imgs_flt.shape[1:],
        axes=tuple(irange(1, imgs.ndim))
    )
    overlap_max = overlap.max(axis=tuple(irange(1, imgs.ndim)), keepdims=True)
    overlap_max_match = (overlap == overlap_max)

    # Compute the shift for each frame.
    old_shifts = shifts
    shifts = da.atop(
        compute_offset,
        (0, overlap_max_match.ndim),
        overlap_max_match.rechunk(dict(enumerate(overlap_max_match.shape[1:], 1))),
        tuple(irange(0, overlap_max_match.ndim)),
        dtype=int,
        new_axes={overlap_max_match.ndim: overlap_max_match.ndim - 1}
    )

    # Free connected persisted values
    del overlap
    del overlap_max
    del overlap_max_match

    # Remove any collective frame drift.
    drift = shifts.mean(axis=0, keepdims=True).round().astype(shifts.dtype)
    shifts = shifts - drift

    # Free connected persisted values
    del drift

    # Find shift change.
    diff_shifts = shifts - old_shifts
    rel_diff_shifts = (
        diff_shifts.astype(da_imgs_flt.dtype) / 
        frame_shape.astype(da_imgs_flt.dtype) /
        (da_imgs_flt.dtype.type(len(frame_shape)) ** 0.5)
    )
    rel_dist_shifts = (rel_diff_shifts ** 2.0).sum(axis=1) ** 0.5
    avg_rel_dist = rel_dist_shifts.sum() / da_imgs_flt.dtype.type(len(shifts))

    # Free old shifts
    del old_shifts

    # Free connected persisted values
    del diff_shifts
    del rel_diff_shifts
    del rel_dist_shifts

    # Persist values needed for the next iteration (and end of this one).
    da_imgs_fft_tmplt, shifts, avg_rel_dist = client.persist([da_imgs_fft_tmplt, shifts, avg_rel_dist])

    # Compute change
    status = client.compute(avg_rel_dist)
    dask.distributed.progress(status, notebook=False)
    print("")
    avg_rel_dist = status.result()
    i += 1

    # Show change
    print((i, avg_rel_dist))

# Drop unneeded items
del frame_shape
del half_frame_shape
del da_imgs_flt
del da_imgs_fft
del da_imgs_fft_tmplt

# Truncate shifted part of each frame
da_imgs_trunc = []
da_imgs_trunc_shape = da_imgs.shape[1:]
for i in irange(len(da_imgs)):
    slice_i = [i]
    shift_i = numpy.array(shifts[i])[()]
    for j in irange(len(shift_i)):
        shifts_ij = shift_i[j]
        if shifts_ij < 0:
            slice_i.append(slice(-shifts_ij, None))
        elif shifts_ij > 0:
            slice_i.append(slice(None, -shifts_ij))
        else:
            slice_i.append(slice(None))
    slice_i = tuple(slice_i)
    da_imgs_trunc.append(da_imgs[slice_i])
    da_imgs_trunc_shape = tuple(np.minimum(
        da_imgs_trunc_shape, da_imgs_trunc[-1].shape
    ))

# Free raw data
del da_imgs

# Truncate all frames to smallest one
da_imgs_trunc_cut = tuple(map(
    lambda s: slice(None, s), da_imgs_trunc_shape
))
for i in irange(len(da_imgs_trunc)):
    da_imgs_trunc[i] = da_imgs_trunc[i][da_imgs_trunc_cut]
da_imgs_trunc = da.stack(da_imgs_trunc)

# Store registered data
dask_store_zarr(
    data_basename + postfix_reg + zarr_ext,
    ["images", "shifts"],
    [da_imgs_trunc, shifts],
    client
)

# Free truncated frames and shifts
del da_imgs_trunc
del shifts


# View results
fn = data_basename + postfix_reg + zarr_ext

imgs_min, imgs_max = 0, 100

f = zarr.open_group(fn, "r")

imgs = f["images"]
shifts = f["shifts"]

da_imgs = da.from_array(imgs, chunks=(norm_frames,) + imgs.shape[1:])

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max], lock=False, compute=False)
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

fig, axs = plt.subplots(nrows=shifts.shape[1], sharex=True)
fig.subplots_adjust(hspace=0.0)
for i in range(shifts.shape[1]):
    axs[i].plot(shifts[:, i][...])
    axs[i].set_ylabel("%s (px)" % chr(ord("X") + shifts.shape[1] - i - 1))
    axs[i].yaxis.set_tick_params(width=1.5)
    [v.set_linewidth(2) for v in axs[i].spines.values()]
axs[-1].set_xlabel("Frame (#)")
axs[-1].set_xlim((0, shifts.shape[0] - 1))
axs[-1].xaxis.set_tick_params(width=1.5)

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

### Projections

* `block_frames` (`int`): number of frames to work with in each block (run in parallel).

In [None]:
block_frames = 100


dask_io_remove(data_basename + postfix_proj + zarr_ext, client)


# Load and prep data for computation.
f = zarr.open_group(data_basename + postfix_reg + zarr_ext, "r")
imgs = f["images"]
da_imgs = da.from_array(imgs, chunks=(block_frames,) + imgs.shape[1:])

da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
        da_imgs_flt.dtype.itemsize >= 4):
    da_imgs_flt = da_imgs_flt.astype(np.float32)

da_imgs_proj_hmean = compute_adj_harmonic_mean_projection(da_imgs_flt)

da_imgs_proj_max = compute_max_projection(da_imgs_flt)

da_imgs_proj_mean, da_imgs_proj_std = compute_moment_projections(da_imgs_flt, 3)[1:]
da_imgs_proj_std -= da_imgs_proj_mean**2
da_imgs_proj_std = da.sqrt(da_imgs_proj_std)

# Store denoised data
dask_store_zarr(
    data_basename + postfix_proj + zarr_ext,
    ["hmean", "max", "mean", "std"],
    [da_imgs_proj_hmean, da_imgs_proj_max, da_imgs_proj_mean, da_imgs_proj_std],
    client
)

### Subtract Projection

* `block_frames` (`int`): number of frames to work with in each block (run in parallel).
* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
block_frames = 100
norm_frames = 100



dask_io_remove(data_basename + postfix_sub + zarr_ext, client)


# Load and prep data for computation.
f = zarr.open_group(data_basename + postfix_reg + zarr_ext, "r")
imgs = f["images"]
da_imgs = da.from_array(imgs, chunks=(block_frames,) + imgs.shape[1:])

da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
        da_imgs_flt.dtype.itemsize >= 4):
    da_imgs_flt = da_imgs_flt.astype(np.float32)

da_imgs_sub = da_imgs_flt - compute_adj_harmonic_mean_projection(da_imgs_flt)
da_imgs_sub -= da_imgs_sub.min()

# Store denoised data
dask_store_zarr(data_basename + postfix_sub + zarr_ext, ["images"], [da_imgs_sub], client)


# View results
fn = data_basename + postfix_sub + zarr_ext

imgs_min, imgs_max = 0, 100

f = zarr.open_group(fn, "r")

imgs = f["images"]
da_imgs = da.from_array(imgs, chunks=(norm_frames,) + imgs.shape[1:])

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max], lock=False, compute=False)
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

### Background Subtraction

* `half_window_size` (`int`): the rank filter window size is `2*half_window_size+1`.
* `which_quantile` (`float`): which quantile to return from the rank filter.
* `temporal_smoothing_gaussian_filter_stdev` (`float`): stdev for gaussian filter to convolve over time.
* `temporal_smoothing_gaussian_filter_window_size` (`float`): window for gaussian filter to convolve over time. (Measured in standard deviations)
* `spatial_smoothing_gaussian_filter_stdev` (`float`): stdev for gaussian filter to convolve over space.
* `spatial_smoothing_gaussian_filter_window_size` (`float`): window for gaussian filter to convolve over space. (Measured in standard deviations)

<br>
* `block_frames` (`int`): number of frames to work with in each block (run in parallel).
* `block_space` (`int`): extent of each spatial dimension for each block (run in parallel).
* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
half_window_size = 100
which_quantile = 0.5
temporal_smoothing_gaussian_filter_stdev = 0.0
temporal_smoothing_gaussian_filter_window_size = 0
spatial_smoothing_gaussian_filter_stdev = 0.0
spatial_smoothing_gaussian_filter_window_size = 0

block_frames = 1000
block_space = 100
norm_frames = 100



dask_io_remove(data_basename + postfix_f_f0 + zarr_ext, client)


# Load and prep data for computation.
f = zarr.open_group(data_basename + postfix_sub + zarr_ext, "r")
imgs = f["images"]
da_imgs = da.from_array(
    imgs, chunks=(block_frames,) + (imgs.ndim - 1) * (block_space,)
)

da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
        da_imgs_flt.dtype.itemsize >= 4):
    da_imgs_flt = da_imgs_flt.astype(np.float32)

bias = 1 - da_imgs_flt.min()

da_result = extract_f0(
    da_imgs_flt,
    half_window_size=half_window_size,
    which_quantile=which_quantile,
    temporal_smoothing_gaussian_filter_stdev=temporal_smoothing_gaussian_filter_stdev,
    temporal_smoothing_gaussian_filter_window_size=temporal_smoothing_gaussian_filter_window_size,
    spatial_smoothing_gaussian_filter_stdev=spatial_smoothing_gaussian_filter_stdev,
    spatial_smoothing_gaussian_filter_window_size=spatial_smoothing_gaussian_filter_window_size,
    bias=bias
)

# Store denoised data
dask_store_zarr(data_basename + postfix_f_f0 + zarr_ext, ["images"], [da_result], client)


# View results
fn = data_basename + postfix_f_f0 + zarr_ext

imgs_min, imgs_max = 0, 100

f = zarr.open_group(fn, "r")

imgs = f["images"]
da_imgs = da.from_array(imgs, chunks=(norm_frames,) + imgs.shape[1:])

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max], lock=False, compute=False)
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

### Wavelet Transform

* `scale` (`int`): the scale of wavelet transform to apply.

<br>
* `block_frames` (`int`): number of frames to work with in each block (run in parallel).
* `block_space` (`int`): extent of each spatial dimension for each block (run in parallel).
* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
scale = 3

block_frames = 200
block_space = 300
norm_frames = 100



dask_io_remove(data_basename + postfix_wt + zarr_ext, client)


# Load and prep data for computation.
f = zarr.open_group(data_basename + postfix_f_f0 + zarr_ext, "r")
imgs = f["images"]
da_imgs = da.from_array(
    imgs, chunks=(block_frames,) + (imgs.ndim - 1) * (block_space,)
)

da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
        da_imgs_flt.dtype.itemsize >= 4):
    da_imgs_flt = da_imgs_flt.astype(np.float32)

da_result = wavelet_transform(
    da_imgs,
    scale=scale
)

# Store denoised data
dask_store_zarr(data_basename + postfix_wt + zarr_ext, ["images"], [da_result], client)


# View results
fn = data_basename + postfix_wt + zarr_ext

imgs_min, imgs_max = 0, 100

f = zarr.open_group(fn, "r")

imgs = f["images"]
da_imgs = da.from_array(imgs, chunks=(norm_frames,) + imgs.shape[1:])

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max], lock=False, compute=False)
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

### Normalize Data

* `block_frames` (`int`): number of frames to work with in each full frame block (run in parallel).
* `block_space` (`int`): extent of each spatial dimension for each block (run in parallel).
* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
block_frames = 40
block_space = 300
norm_frames = 100



dask_io_remove(data_basename + postfix_norm + zarr_ext, client)


# Load and prep data for computation.
f = zarr.open_group(data_basename + postfix_wt + zarr_ext, "r")
imgs = f["images"]
da_imgs = da.from_array(
    imgs, chunks=(block_frames,) + (imgs.ndim - 1) * (block_space,)
)

da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
        da_imgs_flt.dtype.itemsize >= 4):
    da_imgs_flt = da_imgs_flt.astype(np.float32)

da_imgs_flt_mins = da_imgs_flt.min(
    axis=tuple(irange(1, da_imgs_flt.ndim)),
    keepdims=True
)

da_imgs_flt_shift = da_imgs_flt - da_imgs_flt_mins

da_result = renormalized_images(da_imgs_flt_shift)

# Store denoised data
dask_store_zarr(data_basename + postfix_norm + zarr_ext, ["images"], [da_result], client)


# View results
fn = data_basename + postfix_norm + zarr_ext

imgs_min, imgs_max = 0, 100

f = zarr.open_group(fn, "r")

imgs = f["images"]
da_imgs = da.from_array(imgs, chunks=(norm_frames,) + imgs.shape[1:])

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max], lock=False, compute=False)
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

### Dictionary Learning

* `n_components` (`int`): number of basis images in the dictionary.
* `batchsize` (`int`): minibatch size to use.
* `iters` (`int`): number of iterations to run before getting dictionary.
* `lambda1` (`float`): weight for L<sup>1</sup> sparisty enforcement on sparse code.
* `lambda2` (`float`): weight for L<sup>2</sup> sparisty enforcement on sparse code.

<br>
* `block_frames` (`int`): number of frames to work with in each full frame block (run in parallel).
* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
n_components = 50
batchsize = 256
iters = 100
lambda1 = 0.2
lambda2 = 0.0

block_frames = 51
norm_frames = 100



dask_io_remove(data_basename + postfix_dict + zarr_ext, client)


f1 = zarr.open_group(data_basename + postfix_norm + zarr_ext, "r")
imgs = f1["images"]
block_shape = (block_frames,) + imgs.shape[1:]
da_imgs = da.from_array(imgs, chunks=block_shape)
with open_zarr(data_basename + postfix_dict + zarr_ext, "w") as f2:
    new_result = f2.create_dataset("images", shape=(n_components,) + da_imgs.shape[1:], dtype=da_imgs.dtype, chunks=True)

    result = halo_block_generate_dictionary_parallel(client, None)(generate_dictionary)(block_shape)(
        da_imgs,
        n_components=n_components,
        out=new_result,
        **{"sklearn.decomposition.dict_learning_online" : {
                "n_jobs" : 1,
                "n_iter" : iters,
                "batch_size" : batchsize,
                "alpha" : lambda1
            }
        }
    )


# View results
fn = data_basename + postfix_dict + zarr_ext

imgs_min, imgs_max = 0, 100

f = zarr.open_group(fn, "r")

imgs = f["images"]
da_imgs = da.from_array(imgs, chunks=(norm_frames,) + imgs.shape[1:])

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max], lock=False, compute=False)
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

### Postprocessing

* `significance_threshold` (`float`): number of standard deviations below which to include in "noise" estimate
* `wavelet_scale` (`int`): scale of wavelet transform to apply (should be the same as the one used above)
* `noise_threshold` (`float`): number of units of "noise" above which something needs to be to be significant
* `accepted_region_shape_constraints` (`dict`): if ROIs don't match this, reduce the `wavelet_scale` once.
* `percentage_pixels_below_max` (`float`): upper bound on ratio of ROI pixels not at max intensity vs. all ROI pixels
* `min_local_max_distance` (`float`): minimum allowable euclidean distance between two ROIs maximum intensities
* `accepted_neuron_shape_constraints` (`dict`): shape constraints for ROI to be kept.

* `alignment_min_threshold` (`float`): similarity measure of the intensity of two ROIs images used for merging.
* `overlap_min_threshold` (`float`): similarity measure of the masks of two ROIs used for merging.

In [None]:
significance_threshold = 3.0
wavelet_scale = 3
noise_threshold = 3.0
percentage_pixels_below_max = 0.8
min_local_max_distance = 16.0

alignment_min_threshold = 0.6
overlap_min_threshold = 0.6



dask_io_remove(data_basename + postfix_post + zarr_ext, client)


f1 = zarr.open_group(data_basename + postfix_dict + zarr_ext, "r")
imgs = f1["images"]
da_imgs = da.from_array(imgs, chunks=((1,) + imgs.shape[1:]))

with open_zarr(data_basename + postfix_post + zarr_ext, "w") as f2:
    result = block_postprocess_data_parallel(client)(da_imgs,
                                  **{
                                        "wavelet_denoising" : {
                                            "estimate_noise" : {
                                                "significance_threshold" : significance_threshold
                                            },
                                            "wavelet.transform" : {
                                                "scale" : wavelet_scale
                                            },
                                            "significant_mask" : {
                                                "noise_threshold" : noise_threshold
                                            },
                                            "accepted_region_shape_constraints" : {
                                                "major_axis_length" : {
                                                    "min" : 0.0,
                                                    "max" : 25.0
                                                }
                                            },
                                            "remove_low_intensity_local_maxima" : {
                                                "percentage_pixels_below_max" : percentage_pixels_below_max
                                            },
                                            "remove_too_close_local_maxima" : {
                                                "min_local_max_distance" : min_local_max_distance
                                            },
                                            "accepted_neuron_shape_constraints" : {
                                                "area" : {
                                                    "min" : 25,
                                                    "max" : 600
                                                },
                                                "eccentricity" : {
                                                    "min" : 0.0,
                                                    "max" : 0.9
                                                }
                                            }
                                        },
                                        "merge_neuron_sets" : {
                                            "alignment_min_threshold" : alignment_min_threshold,
                                            "overlap_min_threshold" : overlap_min_threshold,
                                            "fuse_neurons" : {
                                                "fraction_mean_neuron_max_threshold" : 0.01
                                            }
                                        }
                                  }
    )

    f2.create_group("rois")
    for each_name in result.dtype.names:
        f2.require_group("rois").create_dataset(
            each_name,
            data=result[each_name],
            chunks=True
        )

### ROI and trace extraction

* `block_frames` (`int`): number of frames to work with in each block (run in parallel).

In [None]:
block_frames = 100



dask_io_remove(data_basename + postfix_rois + zarr_ext, client)
dask_io_remove(data_basename + postfix_rois + h5_ext, client)


with open_zarr(data_basename + postfix_post + zarr_ext, "r") as f1:
    roi_masks = f1["rois/mask"]

    da_roi_masks = da.from_array(
        roi_masks, chunks=(block_frames,) + roi_masks.shape[1:]
    )

    da_lbls = da.arange(
        1,
        len(da_roi_masks) + 1,
        chunks=da_roi_masks.chunks[0],
        dtype=np.uint64
    )
    da_lblimg = (
        da_lbls[(slice(None),) + (da_roi_masks.ndim - 1) * (None,)] * 
        da_roi_masks.astype(np.uint64)
    ).max(axis=0)

    dask_store_zarr(
        data_basename + postfix_rois + zarr_ext,
        ["masks", "masks_j", "labels", "labels_j"],
        [da_roi_masks, da_roi_masks.astype(numpy.uint8), da_lblimg, da_lblimg.astype(numpy.uint8)],
        client
    )


with h5py.File(data_basename + postfix_rois + h5_ext, "w") as f2:
    with open_zarr(data_basename + postfix_rois + zarr_ext, "r") as f1:
        zarr_to_hdf5(f1, f2)



dask_io_remove(data_basename + postfix_traces + zarr_ext, client)
dask_io_remove(data_basename + postfix_traces + h5_ext, client)


with open_zarr(data_basename + postfix_f_f0 + zarr_ext, "r") as fh_f_f0:
    with open_zarr(data_basename + postfix_rois + zarr_ext, "r") as fh_rois:
        # Load and prep data for computation.
        images = fh_f_f0["images"]
        da_images = da.from_array(
            images, chunks=(block_frames,) + images.shape[1:]
        )
        masks = fh_rois["masks"]
        da_masks = da.from_array(
            masks, chunks=(block_frames,) + masks.shape[1:]
        )

        da_result = compute_traces(da_images, da_masks)

        # Store traces
        dask_store_zarr(data_basename + postfix_traces + zarr_ext, ["traces"], [da_result], client)


with h5py.File(data_basename + postfix_traces + h5_ext, "w") as f2:
    with open_zarr(data_basename + postfix_traces + zarr_ext, "r") as f1:
        zarr_to_hdf5(f1, f2)


# View results
fn = data_basename + postfix_f_f0 + zarr_ext

imgs_min, imgs_max = 0, 100

f1 = zarr.open_group(data_basename + postfix_f_f0 + zarr_ext, "r")

imgs = f1["images"]
da_imgs = da.from_array(imgs, chunks=(block_frames,) + imgs.shape[1:])

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max], lock=False, compute=False)
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

f2 = zarr.open_group(data_basename + postfix_rois + zarr_ext, "r")
lblimg = f2["labels"][...]
lblimg_msk = numpy.ma.masked_array(lblimg, mask=(lblimg==0))

mplsv.viewer.matshow(lblimg_msk, alpha=0.3, cmap=mpl.cm.jet)


mskimg = None
mskimg_j = None
lblimg = None
traces = None
traces_j = None

del mskimg
del mskimg_j
del lblimg
del traces
del traces_j

# End of workflow. Shutdown cluster.

In [None]:
import distributed
from nanshe_workflow.par import shutdown_distributed

shutdown_distributed(distributed.client.default_client())

# Prepare interactive projection graph

In [None]:
import io
import textwrap
import zlib

import numpy
import numpy as np

import bokeh.plotting
import bokeh.plotting as bp

import bokeh.io
import bokeh.io as bio

import bokeh.embed
import bokeh.embed as be

from bokeh.models.mappers import LinearColorMapper

import webcolors

from bokeh.models import CustomJS, ColumnDataSource, HoverTool
from bokeh.models.layouts import Row

from builtins import (
    map as imap,
    range as irange
)

from past.builtins import basestring

import nanshe_workflow
from nanshe_workflow.data import io_remove, open_zarr
from nanshe_workflow.vis import get_rgb_array, get_rgba_array, get_all_greys, masks_to_contours_2d

In [None]:
with open_zarr(data_basename + postfix_rois + zarr_ext, "r") as f:
    mskimg = f["masks"][...]

with open_zarr(data_basename + postfix_traces + zarr_ext, "r") as f:
    traces = f["traces"][...]

with open_zarr(data_basename + postfix_proj + zarr_ext, "r") as f:
    imgproj_mean = f["mean"][...]
    imgproj_max = f["max"][...]
    imgproj_std = f["std"][...]

### Result visualization

* `proj_img` (`str` or `list` of `str`): which projection or projections to plot (e.g. "max", "mean", "std").
* `block_size` (`int`): size of each point on any dimension in the image in terms of pixels.
* `roi_alpha` (`float`): transparency of the ROIs in a range of [0.0, 1.0].
* `roi_border_width` (`int`): width of the line border on each ROI.

<br>
* `trace_plot_width` (`int`): width of the trace plot.

In [None]:
proj_img = "std"
block_size = 1
roi_alpha = 0.3
roi_border_width = 3
trace_plot_width = 500


bio.curdoc().clear()

grey_range = get_all_greys()
grey_cm = LinearColorMapper(grey_range)

colors_rgb = get_rgb_array(len(mskimg))
colors_rgb = colors_rgb.tolist()
colors_rgb = list(imap(webcolors.rgb_to_hex, colors_rgb))

mskctr_pts_y, mskctr_pts_x = masks_to_contours_2d(mskimg)

mskctr_pts_dtype = np.min_scalar_type(max(mskimg.shape[1:]) - 1)
mskctr_pts_y = [np.array(_, dtype=mskctr_pts_dtype) for _ in mskctr_pts_y]
mskctr_pts_x = [np.array(_, dtype=mskctr_pts_dtype) for _ in mskctr_pts_x]

mskctr_srcs = ColumnDataSource(data=dict(x=mskctr_pts_x, y=mskctr_pts_y, color=colors_rgb))


if isinstance(proj_img, basestring):
    proj_img = [proj_img]
else:
    proj_img = list(proj_img)


proj_plot_width = block_size*mskimg.shape[2]
proj_plot_height = block_size*mskimg.shape[1]
plot_projs = []

if "max" in proj_img:
    plot_max = bp.Figure(plot_width=proj_plot_width, plot_height=proj_plot_height,
                         x_range=[0, mskimg.shape[2]], y_range=[mskimg.shape[1], 0],
                         tools=["tap", "pan", "box_zoom", "wheel_zoom", "save", "reset"],
                         title="Max Projection with ROIs", border_fill_color="black")
    plot_max.image(image=[numpy.flipud(imgproj_max)], x=[0], y=[mskimg.shape[1]],
                   dw=[imgproj_max.shape[1]], dh=[imgproj_max.shape[0]], color_mapper=grey_cm)
    plot_max.patches('x', 'y', source=mskctr_srcs, alpha=roi_alpha, line_width=roi_border_width, color="color")

    plot_max.outline_line_color = "white"
    for i in irange(len(plot_max.axis)):
        plot_max.axis[i].axis_line_color = "white"

    plot_projs.append(plot_max)


if "mean" in proj_img:
    plot_mean = bp.Figure(plot_width=proj_plot_width, plot_height=proj_plot_height,
                         x_range=[0, mskimg.shape[2]], y_range=[mskimg.shape[1], 0],
                         tools=["tap", "pan", "box_zoom", "wheel_zoom", "save", "reset"],
                         title="Mean Projection with ROIs", border_fill_color="black")
    plot_mean.image(image=[numpy.flipud(imgproj_mean)], x=[0], y=[mskimg.shape[1]],
                   dw=[mskimg.shape[2]], dh=[mskimg.shape[1]], color_mapper=grey_cm)
    plot_mean.patches('x', 'y', source=mskctr_srcs, alpha=roi_alpha, line_width=roi_border_width, color="color")

    plot_mean.outline_line_color = "white"
    for i in irange(len(plot_mean.axis)):
        plot_mean.axis[i].axis_line_color = "white"

    plot_projs.append(plot_mean)


if "std" in proj_img:
    plot_std = bp.Figure(plot_width=proj_plot_width, plot_height=proj_plot_height,
                         x_range=[0, mskimg.shape[2]], y_range=[mskimg.shape[1], 0],
                         tools=["tap", "pan", "box_zoom", "wheel_zoom", "save", "reset"],
                         title="Std Dev Projection with ROIs", border_fill_color="black")
    plot_std.image(image=[numpy.flipud(imgproj_std)], x=[0], y=[mskimg.shape[1]],
                   dw=[mskimg.shape[2]], dh=[mskimg.shape[1]], color_mapper=grey_cm)
    plot_std.patches('x', 'y', source=mskctr_srcs, alpha=roi_alpha, line_width=roi_border_width, color="color")

    plot_std.outline_line_color = "white"
    for i in irange(len(plot_std.axis)):
        plot_std.axis[i].axis_line_color = "white"

    plot_projs.append(plot_std)


all_tr_dtype_srcs = ColumnDataSource(data=dict(traces_dtype=traces.dtype.type(0)[None]))
all_tr_shape_srcs = ColumnDataSource(data=dict(traces_shape=traces.shape))
all_tr_srcs = ColumnDataSource(data=dict(
    traces=numpy.frombuffer(
        zlib.compress(traces.tobytes()),
        dtype=np.uint8
    )
))
tr_srcs = ColumnDataSource(data=dict(times_sel=[], traces_sel=[], colors_sel=[]))
plot_tr = bp.Figure(plot_width=trace_plot_width, plot_height=proj_plot_height,
                    x_range=(0.0, float(traces.shape[1])), y_range=(float(traces.min()), float(traces.max())),
                    tools=["pan", "box_zoom", "wheel_zoom", "save", "reset"], title="ROI traces",
                    background_fill_color="black", border_fill_color="black")
plot_tr.multi_line("times_sel", "traces_sel", source=tr_srcs, color="colors_sel")

plot_tr.outline_line_color = "white"
for i in irange(len(plot_tr.axis)):
    plot_tr.axis[i].axis_line_color = "white"

plot_projs.append(plot_tr)


mskctr_srcs.callback = CustomJS(
    args=dict(
        all_tr_dtype_srcs=all_tr_dtype_srcs,
        all_tr_shape_srcs=all_tr_shape_srcs,
        all_tr_srcs=all_tr_srcs,
        tr_srcs=tr_srcs
    ), code="""
    var range = function(n){ return Array.from(Array(n).keys()); };

    var traces_not_decoded = (all_tr_dtype_srcs.get('data')['traces_dtype'] == 0);
    var traces_dtype = all_tr_dtype_srcs.get('data')['traces_dtype'].constructor;
    var traces_shape = all_tr_shape_srcs.get('data')['traces_shape'];
    var trace_len = traces_shape[1];
    var traces = all_tr_srcs.get('data')['traces'];
    if (traces_not_decoded) {
        traces = window.pako.inflate(traces);
        traces = new traces_dtype(traces.buffer);
        all_tr_srcs.get('data')['traces'] = traces;
        all_tr_dtype_srcs.get('data')['traces_dtype'] = 1;
    }

    var inds = cb_obj.get('selected')['1d'].indices;
    var colors = cb_obj.get('data')['color'];
    var selected = tr_srcs.get('data');

    var times = range(trace_len);

    selected['times_sel'] = [];
    selected['traces_sel'] = [];
    selected['colors_sel'] = [];

    for (i = 0; i < inds.length; i++) {
        var inds_i = inds[i];
        var trace_i = traces.slice(trace_len*inds_i, trace_len*(inds_i+1));
        var color_i = colors[inds_i];

        selected['times_sel'].push(times);
        selected['traces_sel'].push(trace_i);
        selected['colors_sel'].push(color_i);
    }

    tr_srcs.trigger('change');
""")


plot_group = Row(*plot_projs)


# Clear out the old HTML file before writing a new one.
io_remove(data_basename + postfix_html + html_ext)


def indent(text, spaces):
    spaces = " " * int(spaces)
    return "\n".join(imap(lambda l: spaces + l, text.splitlines()))

def write_html(filename, title, div, script, cdn):
    html_tmplt = textwrap.dedent(u"""\
        <html lang="en">
            <head>
                <meta charset="utf-8">
                <title>{title}</title>
                {cdn}
                <style>
                  html {{
                    width: 100%;
                    height: 100%;
                  }}
                  body {{
                    width: 90%;
                    height: 100%;
                    margin: auto;
                    background-color: black;
                  }}
                </style>
            </head>
            <body>
                {div}
                {script}
            </body>
        </html>
    """)

    html_cont = html_tmplt.format(
        title=title,
        div=indent(div, 8),
        script=indent(script, 8),
        cdn=indent(cdn, 8),
    )

    with io.open(filename, "w") as fh:
        fh.write(html_cont)

script, div = be.components(plot_group)
cdn = bokeh.resources.CDN.render() + "\n"
cdn += """
<script type="text/javascript" src="https://cdn.jsdelivr.net/pako/1.0.4/pako_inflate.min.js"></script>
"""
cdn += "\n"

write_html(data_basename + postfix_html + html_ext, data_basename + postfix_html, div, script, cdn)


from IPython.display import display, IFrame
display(IFrame(data_basename + postfix_html + html_ext, "100%", 1.05*proj_plot_height))

In [None]:
# Test teardown. Ignore warnings during production runs.

%run ./teardown_tests.py