In [None]:
# Test setup. Ignore warnings during production runs.

%run ./setup_tests.py

# Specify input data

* `data_dir` (`str`): Where the data is located. (change if data is not in the current directory, normally is)
* `data` (`str`): HDF5 file to use as input data.
* `data_basename` (`str`): Basename to use for intermediate and final result files.
* `dataset` (`str`): HDF5 dataset to use as input data.

</br>
* `num_workers` (`int`): Number of workers for iPython Cluster. (default all cores excepting one for client)

In [None]:
data_dir = ""
data = "data.tif"
data_basename = "data"
dataset = "images"

num_workers = None


import os

data_ext = os.path.splitext(data)[1].lower()
data_dir = os.path.abspath(data_dir)

postfix_trim = "_trim"
postfix_dn = "_dn"
postfix_reg = "_reg"
postfix_sub = "_sub"
postfix_f_f0 = "_f_f0"
postfix_wt = "_wt"
postfix_norm = "_norm"
postfix_dict = "_dict"
postfix_cc = "_cc"
postfix_post = "_post"
postfix_thrd = "_thrd"
postfix_rois = "_rois"
postfix_traces = "_traces"
postfix_proj = "_proj"
postfix_html = "_proj"

h5_ext = os.path.extsep + "h5"
tiff_ext = os.path.extsep + "tif"
zarr_ext = os.path.extsep + "zarr"
html_ext = os.path.extsep + "html"

# Configure and startup Cluster

In [None]:
from nanshe_workflow.par import cleanup_cluster_files, get_client, set_num_workers

ipypar_prof = "sge"

num_workers = set_num_workers(num_workers)

cleanup_cluster_files(ipypar_prof)

from sys import executable as PYTHON
!$PYTHON -m ipyparallel.apps.ipclusterapp start --daemon --profile=$ipypar_prof
del PYTHON

client = get_client(ipypar_prof)

In [None]:
from builtins import range as irange

def getcwdi(i):
    from os import getcwd
    return getcwd()

while not all(map(lambda p: p == data_dir, [os.getcwd()] + client[:].map(getcwdi, irange(len(client))).get())):
    os.chdir(data_dir)
    client[:].map(os.chdir, len(client) * [data_dir]).get()

del getcwdi

# Define functions for computation.

In [None]:
%matplotlib notebook

import matplotlib
import matplotlib.cm
import matplotlib.pyplot

import matplotlib as mpl
import matplotlib.pyplot as plt

from mplview.core import MatplotlibViewer as MPLViewer

In [None]:
client[:].use_cloudpickle().get()

with client[:].sync_imports():
    import collections
    import contextlib
    import copy
    import functools
    import gc
    import inspect
    import itertools
    import logging
    import math
    import numbers
    import os
    import sys

    from contextlib import contextmanager

    from builtins import range as irange

    import numpy
    import scipy
    import scipy.ndimage
    import h5py

    import numpy as np
    import scipy as sp
    import scipy.ndimage as spim
    import h5py as hp

    import dask
    import dask.array
    import dask.array.fft
    import dask.utils
    import dask.distributed

    import dask.array as da

    import dask_imread
    import dask_ndfilters
    import dask_ndfourier
    import dask_ndmeasure

    from toolz import sliding_window

    import zarr

    import imgroi
    import imgroi.core
    from imgroi.core import label_mask_stack

    import nanshe
    from nanshe.imp.segment import generate_dictionary

    import nanshe_workflow
    from nanshe_workflow.data import io_remove, dask_load_hdf5, zip_zarr, open_zarr, DataBlocks, LazyZarrDataset
    from nanshe_workflow.par import get_executor

zarr.blosc.set_nthreads(1)
zarr.blosc.use_threads = False
client[:].apply(zarr.blosc.set_nthreads, 1).get();
client[:].apply(setattr, zarr.blosc, "use_threads", False).get();

logging.getLogger("nanshe").setLevel(logging.INFO)

In [None]:
from nanshe_workflow.data import hdf5_to_zarr, zarr_to_hdf5
from nanshe_workflow.data import save_tiff

In [None]:
try:
    import pyfftw.interfaces.numpy_fft as numpy_fft
except ImportError:
    import numpy.fft as numpy_fft

rfftn = da.fft.fft_wrap(numpy_fft.rfftn)
irfftn = da.fft.fft_wrap(numpy_fft.irfftn)

In [None]:
from nanshe_workflow.par import halo_block_parallel

from nanshe_workflow.imp2 import extract_f0, wavelet_transform, normalize_data

from nanshe_workflow.par import halo_block_generate_dictionary_parallel
from nanshe_workflow.imp import block_postprocess_data_parallel

par_generate_dictionary = halo_block_generate_dictionary_parallel(client, None)(generate_dictionary)
par_postprocess_data = block_postprocess_data_parallel(client)

In [None]:
from nanshe_workflow.par import frame_stack_calculate_parallel

from nanshe_workflow.proj2 import compute_traces

from nanshe_workflow.proj2 import compute_adj_harmonic_mean_projection
from nanshe_workflow.proj2 import compute_min_projection
from nanshe_workflow.proj2 import compute_max_projection

from nanshe_workflow.proj2 import compute_moment_projections

from nanshe_workflow.proj2 import norm_layer

from nanshe_workflow.proj import stack_norm_layer_parallel
from nanshe_workflow.proj import stack_compute_min_projection_parallel
from nanshe_workflow.proj import stack_compute_max_projection_parallel

par_norm_layer = frame_stack_calculate_parallel(client, stack_norm_layer_parallel)
par_compute_min_projection = frame_stack_calculate_parallel(client, stack_compute_min_projection_parallel)
par_compute_max_projection = frame_stack_calculate_parallel(client, stack_compute_max_projection_parallel)

# Begin workflow. Set parameters and run each cell.

### Convert TIFF/HDF5 to Zarr

In [None]:
io_remove(data_basename + zarr_ext)
with open_zarr(data_basename + zarr_ext, "w") as f1:
    with get_executor(client) as executor:
        if data_ext == tiff_ext:
            a = dask_imread.imread(data)
        elif data_ext == h5_ext:
            a = dask_load_hdf5(data, dataset)

        d = f1.create_dataset(
            dataset,
            shape=a.shape,
            dtype=a.dtype,
            chunks=True
        )
        a = a.rechunk(d.chunks)
        status = executor.compute(da.store(a, d, lock=False, compute=False))
        dask.distributed.progress(status, notebook=False)

        del a
        del d

zip_zarr(data_basename + zarr_ext)

### View Input Data

* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
norm_frames = 100

if __IPYTHON__:
    result_image_stack = LazyZarrDataset(data_basename + zarr_ext, dataset)

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),
        vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()
    )

### Trimming

* `front` (`int`): amount to trim off the front
* `back` (`int`): amount to trim off the back

<br>
* `block_frames` (`int`): number of frames to work with in each block (run in parallel).
* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
%%time


front = 0
back = 0

block_frames = 1
norm_frames = 100


# Somehow we can't overwrite the file in the container so this is needed.
io_remove(data_basename + postfix_trim + zarr_ext)
io_remove(data_basename + postfix_trim + h5_ext)


with open_zarr(data_basename + zarr_ext, "r") as f:
    with get_executor(client) as executor:
        # Load and prep data for computation.
        imgs = f["images"]
        da_imgs = da.from_array(imgs, chunks=(block_frames,) + imgs.shape[1:])

        # Trim frames from front and back
        da_imgs_trim = da_imgs[front:len(da_imgs)-back]

        # Store denoised data
        with open_zarr(data_basename + postfix_trim + zarr_ext, "w") as f2:
            result = f2.create_dataset(
                "images",
                shape=da_imgs_trim.shape,
                dtype=da_imgs_trim.dtype,
                chunks=True
            )
            da_imgs_trim = da_imgs_trim.rechunk(result.chunks)
            status = executor.compute(da.store(da_imgs_trim, result, lock=False, compute=False))
            dask.distributed.progress(status, notebook=False)


zip_zarr(data_basename + postfix_trim + zarr_ext)

with h5py.File(data_basename + postfix_trim + h5_ext, "w") as f2:
    with open_zarr(data_basename + postfix_trim + zarr_ext, "r") as f1:
        zarr_to_hdf5(f1, f2)


if __IPYTHON__:
    result_image_stack = LazyZarrDataset(data_basename + postfix_trim + zarr_ext, dataset)

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),
        vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()
    )

### Denoising

* `med_filt_size` (`int`): footprint size for median filter
* `norm_filt_sigma` (`int`/`float`): sigma for Gaussian filter

<br>
* `block_frames` (`int`): number of frames to work with in each block (run in parallel).
* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
%%time


med_filt_size = 3
norm_filt_sigma = 10

block_frames = 1
norm_frames = 100


# Somehow we can't overwrite the file in the container so this is needed.
io_remove(data_basename + postfix_dn + zarr_ext)
io_remove(data_basename + postfix_dn + h5_ext)


with open_zarr(data_basename + postfix_trim + zarr_ext, "r") as f:
    with get_executor(client) as executor:
        # Load and prep data for computation.
        imgs = f["images"]
        da_imgs = da.from_array(imgs, chunks=(block_frames,) + imgs.shape[1:])

        da_imgs_flt = da_imgs
        if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
                da_imgs_flt.dtype.itemsize >= 4):
            da_imgs_flt = da_imgs_flt.astype(np.float32)

        # Median filter frames
        da_imgs_medf = dask_ndfilters.median_filter(
            da_imgs_flt, (1,) + (da_imgs_flt.ndim - 1) * (med_filt_size,)
        )

        # Compute the Gaussian filter of frames
        da_imgs_smoothed = dask_ndfilters.gaussian_filter(
            da_imgs_medf, (0,) + (da_imgs_medf.ndim - 1) * (norm_filt_sigma,)
        )

        # Apply high pass filter to images
        da_imgs_filt = da_imgs_medf - da_imgs_smoothed

        # Reset minimum to original value.
        da_imgs_filt += da_imgs.min() - da_imgs_filt.min()

        # Store denoised data
        with open_zarr(data_basename + postfix_dn + zarr_ext, "w") as f2:
            result = f2.create_dataset(
                "images",
                shape=da_imgs_filt.shape,
                dtype=da_imgs_filt.dtype,
                chunks=True
            )
            da_imgs_filt = da_imgs_filt.rechunk(result.chunks)
            status = executor.compute(da.store(da_imgs_filt, result, lock=False, compute=False))
            dask.distributed.progress(status, notebook=False)


zip_zarr(data_basename + postfix_dn + zarr_ext)

with h5py.File(data_basename + postfix_dn + h5_ext, "w") as f2:
    with open_zarr(data_basename + postfix_dn + zarr_ext, "r") as f1:
        zarr_to_hdf5(f1, f2)


if __IPYTHON__:
    result_image_stack = LazyZarrDataset(data_basename + postfix_dn + zarr_ext, dataset)

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),
        vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()
    )

### Registration

In [None]:
def find_best_match(matches):
    i = numpy.argmin((matches ** 2).sum(axis=0))

    return matches[:, i]


def compute_offset(match_mask):
    frame_shape = np.array(match_mask.shape)
    half_frame_shape = frame_shape // 2

    matches = np.array(match_mask.nonzero())
    if matches.size == 0:
        matches = np.array([[0], [0]])

    above = (matches > half_frame_shape[:, None]).astype(matches.dtype)
    matches -= above * frame_shape[:, None]

    return find_best_match(matches)

In [None]:
%%time


num_reps = 0
tmpl_hist_wght = 0.25
thld_rel_dist = 0.0

block_frames = 1
norm_frames = 100


# Somehow we can't overwrite the file in the container so this is needed.
io_remove(data_basename + postfix_reg + zarr_ext)
io_remove(data_basename + postfix_reg + h5_ext)


with open_zarr(data_basename + postfix_dn + zarr_ext, "r") as f:
    with get_executor(client) as executor:
        # Load and prep data for computation.
        imgs = f["images"]
        da_imgs = da.from_array(imgs, chunks=(block_frames,) + imgs.shape[1:])

        da_imgs_flt = da_imgs
        if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
                da_imgs_flt.dtype.itemsize >= 4):
            da_imgs_flt = da_imgs_flt.astype(np.float32)

        # Create frame array
        frame_shape = da.from_array(
            np.array(da_imgs_flt.shape[1:], dtype=int),
            chunks=(da_imgs_flt.ndim - 1,)
        )

        # Persist frame shape
        frame_shape = frame_shape.persist()

        # Compute the FFT of frames
        da_imgs_fft = rfftn(da_imgs_flt, axes=tuple(irange(1, imgs.ndim)))

        # Persist FFT of frames
        da_imgs_fft = da_imgs_fft.persist()

        # Initialize
        i = 0
        avg_rel_dist = 1.0
        tmpl_hist_wght = da_imgs_flt.dtype.type(tmpl_hist_wght)
        shifts = da.zeros(
            (len(da_imgs_flt), da_imgs_flt.ndim - 1),
            dtype=int,
            chunks=(1, da_imgs_flt.ndim - 1)
        )
        da_imgs_fft_tmplt = da_imgs_fft.mean(axis=0)

        # Persist FFT template image
        da_imgs_fft_tmplt = da_imgs_fft_tmplt.persist()

        while avg_rel_dist > thld_rel_dist and i < num_reps:
            # Compute the shifted frames
            shifted_frames = []
            for j in irange(len(da_imgs_fft)):
                shifted_frames.append(dask_ndfourier.fourier_shift(
                    da_imgs_fft[i], shifts[i]
                ))
            shifted_frames = da.stack(shifted_frames)

            # Compute the template FFT
            da_imgs_fft_tmplt = (
                tmpl_hist_wght * da_imgs_fft_tmplt +
                (1 - tmpl_hist_wght) * shifted_frames.mean(axis=0)
            )

            # Free connected persisted values
            del shifted_frames

            # Persist FFT template image
            da_imgs_fft_tmplt = da_imgs_fft_tmplt.persist()

            # Find the best overlap with the template.
            overlap = irfftn(
                da_imgs_fft * da_imgs_fft_tmplt[None],
                s=da_imgs_flt.shape[1:],
                axes=tuple(irange(1, imgs.ndim))
            )
            overlap_max = overlap.max(axis=tuple(irange(1, imgs.ndim)))
            overlap_max_match = (overlap == overlap_max[(Ellipsis,) + (None,) * (imgs.ndim - 1)])

            # Compute the shift for each frame.
            old_shifts = shifts
            shifts = []
            for j in irange(len(overlap_max_match)):
                shift_j = dask.delayed(compute_offset)(overlap_max_match[j])
                shift_j = da.from_delayed(shift_j, (2,), int)
                shifts.append(shift_j)
                del shift_j
            shifts = da.stack(shifts)

            # Free connected persisted values
            del overlap
            del overlap_max
            del overlap_max_match

            # Remove any collective frame drift.
            drift = shifts.mean(axis=0).round().astype(shifts.dtype)
            shifts = shifts - drift[None]

            # Free connected persisted values
            del drift

            # Persist shifts
            shifts = shifts.persist()

            # Find shift change.
            diff_shifts = shifts - old_shifts
            rel_diff_shifts = (
                diff_shifts.astype(da_imgs_flt.dtype) / 
                frame_shape.astype(da_imgs_flt.dtype) /
                (da_imgs_flt.dtype.type(len(frame_shape)) ** 0.5)
            )
            rel_dist_shifts = (rel_diff_shifts ** 2.0).sum(axis=1) ** 0.5
            avg_rel_dist = rel_dist_shifts.sum() / da_imgs_flt.dtype.type(len(shifts))

            # Free old shifts
            del old_shifts

            # Free connected persisted values
            del diff_shifts
            del rel_diff_shifts
            del rel_dist_shifts

            # Compute change
            status = executor.compute(avg_rel_dist)
            dask.distributed.progress(status, notebook=False)
            avg_rel_dist = status.result()
            i += 1

            # Show change
            print("")
            print((i, avg_rel_dist))

        # Drop unneeded items
        del frame_shape
        del da_imgs_flt
        del da_imgs_fft
        del da_imgs_fft_tmplt

        # Truncate shifted part of each frame
        da_imgs_trunc = []
        da_imgs_trunc_shape = da_imgs.shape[1:]
        for i in irange(len(da_imgs)):
            slice_i = [i]
            for j in irange(shifts.shape[1]):
                shifts_ij = numpy.array(shifts[i, j])[()]
                if shifts_ij < 0:
                    slice_i.append(slice(-shifts_ij, None))
                elif shifts_ij > 0:
                    slice_i.append(slice(None, -shifts_ij))
                else:
                    slice_i.append(slice(None))
            slice_i = tuple(slice_i)
            da_imgs_trunc.append(da_imgs[slice_i])
            da_imgs_trunc_shape = tuple(np.minimum(
                da_imgs_trunc_shape, da_imgs_trunc[-1].shape
            ))

        # Free raw data and shifts
        del da_imgs
        del shifts

        # Truncate all frames to smallest one
        da_imgs_trunc_cut = tuple(map(
            lambda s: slice(None, s), da_imgs_trunc_shape
        ))
        for i in irange(len(da_imgs_trunc)):
            da_imgs_trunc[i] = da_imgs_trunc[i][da_imgs_trunc_cut]
        da_imgs_trunc = da.stack(da_imgs_trunc)

        # Store registered data
        with open_zarr(data_basename + postfix_reg + zarr_ext, "w") as f2:
            result = f2.create_dataset(
                "images",
                shape=da_imgs_trunc.shape,
                dtype=da_imgs_trunc.dtype,
                chunks=True
            )
            da_imgs_trunc = da_imgs_trunc.rechunk(result.chunks)
            status = executor.compute(da.store(da_imgs_trunc, result, lock=False, compute=False))
            dask.distributed.progress(status, notebook=False)

        # Free truncated frames
        del da_imgs_trunc


zip_zarr(data_basename + postfix_reg + zarr_ext)

with h5py.File(data_basename + postfix_reg + h5_ext, "w") as f2:
    with open_zarr(data_basename + postfix_reg + zarr_ext, "r") as f1:
        zarr_to_hdf5(f1, f2)


if __IPYTHON__:
    result_image_stack = LazyZarrDataset(data_basename + postfix_reg + zarr_ext, dataset)

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),
        vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()
    )

### Projections

* `block_frames` (`int`): number of frames to work with in each block (run in parallel).

In [None]:
%%time


block_frames = 100


# Somehow we can't overwrite the file in the container so this is needed.
io_remove(data_basename + postfix_proj + zarr_ext)
io_remove(data_basename + postfix_proj + h5_ext)


with open_zarr(data_basename + postfix_reg + zarr_ext, "r") as f:
    with get_executor(client) as executor:
        # Load and prep data for computation.
        imgs = f["images"]
        da_imgs = da.from_array(imgs, chunks=(block_frames,) + imgs.shape[1:])

        da_imgs_flt = da_imgs
        if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
                da_imgs_flt.dtype.itemsize >= 4):
            da_imgs_flt = da_imgs_flt.astype(np.float32)

        da_imgs_proj_hmean = compute_adj_harmonic_mean_projection(da_imgs_flt)

        da_imgs_proj_max = compute_max_projection(da_imgs_flt)

        da_imgs_proj_mean, da_imgs_proj_std = compute_moment_projections(da_imgs_flt, 3)[1:]
        da_imgs_proj_std -= da_imgs_proj_mean**2
        da_imgs_proj_std = da.sqrt(da_imgs_proj_std)

        # Store denoised data
        with open_zarr(data_basename + postfix_proj + zarr_ext, "w") as f2:
            statuses = []

            zarr_proj_hmean = f2.create_dataset(
                "hmean",
                shape=da_imgs_proj_hmean.shape,
                dtype=da_imgs_proj_hmean.dtype,
                chunks=True
            )
            da_imgs_proj_hmean = da_imgs_proj_hmean.rechunk(zarr_proj_hmean.chunks)
            statuses.append(executor.compute(
                da.store(da_imgs_proj_hmean, zarr_proj_hmean, lock=False, compute=False
            )))

            zarr_proj_max = f2.create_dataset(
                "max",
                shape=da_imgs_proj_max.shape,
                dtype=da_imgs_proj_max.dtype,
                chunks=True
            )
            da_imgs_proj_max = da_imgs_proj_max.rechunk(zarr_proj_max.chunks)
            statuses.append(executor.compute(
                da.store(da_imgs_proj_max, zarr_proj_max, lock=False, compute=False
            )))

            zarr_proj_mean = f2.create_dataset(
                "mean",
                shape=da_imgs_proj_mean.shape,
                dtype=da_imgs_proj_mean.dtype,
                chunks=True
            )
            da_imgs_proj_mean = da_imgs_proj_mean.rechunk(zarr_proj_mean.chunks)
            statuses.append(executor.compute(
                da.store(da_imgs_proj_mean, zarr_proj_mean, lock=False, compute=False
            )))

            zarr_proj_std = f2.create_dataset(
                "std",
                shape=da_imgs_proj_std.shape,
                dtype=da_imgs_proj_std.dtype,
                chunks=True
            )
            da_imgs_proj_std = da_imgs_proj_std.rechunk(zarr_proj_std.chunks)
            statuses.append(executor.compute(
                da.store(da_imgs_proj_std, zarr_proj_std, lock=False, compute=False
            )))

            dask.distributed.progress(statuses, notebook=False)


zip_zarr(data_basename + postfix_proj + zarr_ext)

with h5py.File(data_basename + postfix_proj + h5_ext, "w") as f2:
    with open_zarr(data_basename + postfix_proj + zarr_ext, "r") as f1:
        zarr_to_hdf5(f1, f2)

### Subtract Projection

* `block_frames` (`int`): number of frames to work with in each block (run in parallel).
* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
%%time


block_frames = 100
norm_frames = 100


# Somehow we can't overwrite the file in the container so this is needed.
io_remove(data_basename + postfix_sub + zarr_ext)
io_remove(data_basename + postfix_sub + h5_ext)


with open_zarr(data_basename + postfix_reg + zarr_ext, "r") as f:
    with get_executor(client) as executor:
        # Load and prep data for computation.
        imgs = f["images"]
        da_imgs = da.from_array(imgs, chunks=(block_frames,) + imgs.shape[1:])

        da_imgs_flt = da_imgs
        if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
                da_imgs_flt.dtype.itemsize >= 4):
            da_imgs_flt = da_imgs_flt.astype(np.float32)

        da_imgs_adj = da_imgs_flt.min() - 1
        da_imgs_flt_shifted = da_imgs_flt - da_imgs_adj
        da_imgs_hmean = da_imgs_adj + 1 / (1 / da_imgs_flt_shifted).mean(axis=0)

        da_imgs_sub = da_imgs_flt - da_imgs_hmean
        da_imgs_sub -= da_imgs_sub.min()

        # Store denoised data
        with open_zarr(data_basename + postfix_sub + zarr_ext, "w") as f2:
            result = f2.create_dataset(
                "images",
                shape=da_imgs_sub.shape,
                dtype=da_imgs_sub.dtype,
                chunks=True
            )
            da_imgs_sub = da_imgs_sub.rechunk(result.chunks)
            status = executor.compute(da.store(da_imgs_sub, result, lock=False, compute=False))
            dask.distributed.progress(status, notebook=False)


zip_zarr(data_basename + postfix_sub + zarr_ext)

with h5py.File(data_basename + postfix_sub + h5_ext, "w") as f2:
    with open_zarr(data_basename + postfix_sub + zarr_ext, "r") as f1:
        zarr_to_hdf5(f1, f2)


if __IPYTHON__:
    result_image_stack = LazyZarrDataset(data_basename + postfix_sub + zarr_ext, "images")

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),
        vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()
    )

### Background Subtraction

* `half_window_size` (`int`): the rank filter window size is `2*half_window_size+1`.
* `which_quantile` (`float`): which quantile to return from the rank filter.
* `temporal_smoothing_gaussian_filter_stdev` (`float`): stdev for gaussian filter to convolve over time.
* `temporal_smoothing_gaussian_filter_window_size` (`float`): window for gaussian filter to convolve over time. (Measured in standard deviations)
* `spatial_smoothing_gaussian_filter_stdev` (`float`): stdev for gaussian filter to convolve over space.
* `spatial_smoothing_gaussian_filter_window_size` (`float`): window for gaussian filter to convolve over space. (Measured in standard deviations)

<br>
* `block_frames` (`int`): number of frames to work with in each block (run in parallel).
* `block_space` (`int`): extent of each spatial dimension for each block (run in parallel).
* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
%%time


half_window_size = 100
which_quantile = 0.5
temporal_smoothing_gaussian_filter_stdev = 0.0
temporal_smoothing_gaussian_filter_window_size = 0
spatial_smoothing_gaussian_filter_stdev = 0.0
spatial_smoothing_gaussian_filter_window_size = 0

block_frames = 1000
block_space = 100
norm_frames = 100


# Somehow we can't overwrite the file in the container so this is needed.
io_remove(data_basename + postfix_f_f0 + zarr_ext)
io_remove(data_basename + postfix_f_f0 + h5_ext)


with open_zarr(data_basename + postfix_sub + zarr_ext, "r") as f:
    with get_executor(client) as executor:
        # Load and prep data for computation.
        imgs = f["images"]
        da_imgs = da.from_array(
            imgs, chunks=(block_frames,) + (imgs.ndim - 1) * (block_space,)
        )

        da_imgs_flt = da_imgs
        if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
                da_imgs_flt.dtype.itemsize >= 4):
            da_imgs_flt = da_imgs_flt.astype(np.float32)

        bias = 1 - da_imgs_flt.min()

        da_result = extract_f0(
            da_imgs_flt,
            half_window_size=half_window_size,
            which_quantile=which_quantile,
            temporal_smoothing_gaussian_filter_stdev=temporal_smoothing_gaussian_filter_stdev,
            temporal_smoothing_gaussian_filter_window_size=temporal_smoothing_gaussian_filter_window_size,
            spatial_smoothing_gaussian_filter_stdev=spatial_smoothing_gaussian_filter_stdev,
            spatial_smoothing_gaussian_filter_window_size=spatial_smoothing_gaussian_filter_window_size,
            bias=bias
        )

        # Store denoised data
        with open_zarr(data_basename + postfix_f_f0 + zarr_ext, "w") as f2:
            result = f2.create_dataset(
                "images",
                shape=da_result.shape,
                dtype=da_result.dtype,
                chunks=True
            )
            da_result = da_result.rechunk(result.chunks)
            status = executor.compute(da.store(da_result, result, lock=False, compute=False))
            dask.distributed.progress(status, notebook=False)


zip_zarr(data_basename + postfix_f_f0 + zarr_ext)

with h5py.File(data_basename + postfix_f_f0 + h5_ext, "w") as f2:
    with open_zarr(data_basename + postfix_f_f0 + zarr_ext, "r") as f1:
        zarr_to_hdf5(f1, f2)


if __IPYTHON__:
    result_image_stack = LazyZarrDataset(data_basename + postfix_f_f0 + zarr_ext, "images")

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),
        vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()
    )

### Wavelet Transform

* `scale` (`int`): the scale of wavelet transform to apply.

<br>
* `block_frames` (`int`): number of frames to work with in each block (run in parallel).
* `block_space` (`int`): extent of each spatial dimension for each block (run in parallel).
* `norm_frames` (`int`): number of frames for use during normalization of each full frame block (run in parallel).

In [None]:
%%time


scale = 3

block_frames = 200
block_space = 300
norm_frames = 100


# Somehow we can't overwrite the file in the container so this is needed.
io_remove(data_basename + postfix_wt + zarr_ext)
io_remove(data_basename + postfix_wt + h5_ext)


with open_zarr(data_basename + postfix_f_f0 + zarr_ext, "r") as f:
    with get_executor(client) as executor:
        # Load and prep data for computation.
        imgs = f["images"]
        da_imgs = da.from_array(
            imgs, chunks=(block_frames,) + (imgs.ndim - 1) * (block_space,)
        )

        da_imgs_flt = da_imgs
        if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
                da_imgs_flt.dtype.itemsize >= 4):
            da_imgs_flt = da_imgs_flt.astype(np.float32)

        da_result = wavelet_transform(
            da_imgs,
            scale=scale
        )

        # Store denoised data
        with open_zarr(data_basename + postfix_wt + zarr_ext, "w") as f2:
            result = f2.create_dataset(
                "images",
                shape=da_result.shape,
                dtype=da_result.dtype,
                chunks=True
            )
            da_result = da_result.rechunk(result.chunks)
            status = executor.compute(da.store(da_result, result, lock=False, compute=False))
            dask.distributed.progress(status, notebook=False)


zip_zarr(data_basename + postfix_wt + zarr_ext)

with h5py.File(data_basename + postfix_wt + h5_ext, "w") as f2:
    with open_zarr(data_basename + postfix_wt + zarr_ext, "r") as f1:
        zarr_to_hdf5(f1, f2)


if __IPYTHON__:
    result_image_stack = LazyZarrDataset(data_basename + postfix_wt + zarr_ext, "images")

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=par_compute_min_projection(num_frames=norm_frames)(result_image_stack).min(),
        vmax=par_compute_max_projection(num_frames=norm_frames)(result_image_stack).max()
    )

### Project

* `proj_type` (`str`): type of projection to take.

<br>

* `block_frames` (`int`): number of frames to work with in each full frame block (run in parallel).
* `block_space` (`int`): extent of each spatial dimension for each block (run in parallel).

In [None]:
%%time


proj_type = "max"

block_frames = 40
block_space = 300


# Somehow we can't overwrite the file in the container so this is needed.
io_remove(data_basename + postfix_dict + zarr_ext)
io_remove(data_basename + postfix_dict + h5_ext)


with open_zarr(data_basename + postfix_wt + zarr_ext, "r") as f:
    with get_executor(client) as executor:
        # Load and prep data for computation.
        imgs = f["images"]
        da_imgs = da.from_array(
            imgs, chunks=(block_frames,) + (imgs.ndim - 1) * (block_space,)
        )

        da_imgs_flt = da_imgs
        if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
                da_imgs_flt.dtype.itemsize >= 4):
            da_imgs_flt = da_imgs_flt.astype(np.float32)

        da_result = da_imgs
        if proj_type == "max":
            da_result = da_result.max(axis=0, keepdims=True)
        elif proj_type == "std":
            da_result = da_result.std(axis=0, keepdims=True)

        # Store denoised data
        with open_zarr(data_basename + postfix_dict + zarr_ext, "w") as f2:
            result = f2.create_dataset(
                "images",
                shape=da_result.shape,
                dtype=da_result.dtype,
                chunks=True
            )
            da_result = da_result.rechunk(result.chunks)
            status = executor.compute(da.store(da_result, result, lock=False, compute=False))
            dask.distributed.progress(status, notebook=False)


zip_zarr(data_basename + postfix_dict + zarr_ext)

with h5py.File(data_basename + postfix_dict + h5_ext, "w") as f2:
    with open_zarr(data_basename + postfix_dict + zarr_ext, "r") as f1:
        zarr_to_hdf5(f1, f2)


if __IPYTHON__:
    result_image_stack = LazyZarrDataset(data_basename + postfix_dict + zarr_ext, "images")[...][...]

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=result_image_stack.min(),
        vmax=result_image_stack.max()
    )

### Connected components

* `significance_threshold` (`float`): number of standard deviations below which to include in "noise" estimate
* `noise_threshold` (`float`): number of units of "noise" above which something needs to be to be significant

In [None]:
%%time


significance_threshold = 3.0
noise_threshold = 1.0


# Somehow we can't overwrite the file in the container so this is needed.
io_remove(data_basename + postfix_cc + zarr_ext)
io_remove(data_basename + postfix_cc + h5_ext)


with open_zarr(data_basename + postfix_dict + zarr_ext, "r") as f:
    with get_executor(client) as executor:
        imgs = f["images"]
        da_imgs = da.from_array(
            imgs, chunks=imgs.shape
        )
        da_imgs = da_imgs[0]

        da_imgs_thrd = (da_imgs - noise_threshold * (da_imgs - significance_threshold * da_imgs.std()).std()) > 0

        da_lbl_img, da_num_lbls = dask_ndmeasure.label(da_imgs_thrd)
        da_lbl_img, da_num_lbls = executor.persist([da_lbl_img, da_num_lbls])

        da_result = []
        for i in irange(1, 1 + int(da_num_lbls)):
            da_result.append((da_lbl_img == i)[None])
        da_result = da.concatenate(da_result)

        with open_zarr(data_basename + postfix_cc + zarr_ext, "w") as f2:
            result = f2.create_group("rois").create_dataset(
                "mask",
                shape=da_result.shape,
                dtype=da_result.dtype,
                chunks=True
            )
            da_result = da_result.rechunk(result.chunks)
            status = executor.compute(da.store(da_result, result, lock=False, compute=False))
            dask.distributed.progress(status, notebook=False)


zip_zarr(data_basename + postfix_cc + zarr_ext)

with h5py.File(data_basename + postfix_cc + h5_ext, "w") as f2:
    with open_zarr(data_basename + postfix_cc + zarr_ext, "r") as f1:
        zarr_to_hdf5(f1, f2)


if __IPYTHON__:
    result_image_stack = LazyZarrDataset(data_basename + postfix_cc + zarr_ext, "rois/mask")[...][...].astype(np.uint8)

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=result_image_stack.min(),
        vmax=result_image_stack.max()
    )

### ROI Refinement

* `area_min_threshold` (`float`): minimum area required for all ROIs

In [None]:
%%time


area_min_threshold = 20.0


# Somehow we can't overwrite the file in the container so this is needed.
io_remove(data_basename + postfix_post + zarr_ext)
io_remove(data_basename + postfix_post + h5_ext)


with open_zarr(data_basename + postfix_cc + zarr_ext, "r") as f:
    with get_executor(client) as executor:
        imgs = f["rois/mask"]
        da_imgs = da.from_array(
            imgs, chunks=imgs.shape
        )

        da_num_lbls = len(da_imgs)
        da_lbl_img = (
            np.arange(1, 1 + da_num_lbls)[(slice(None),) + da_lbl_img.ndim * (None,)] * da_imgs
        ).sum(axis=0)

        da_area_lbls = dask_ndmeasure.sum(
            da.ones(da_lbl_img.shape, dtype=int, chunks=da_lbl_img.chunks),
            da_lbl_img,
            list(irange(1, 1 + int(da_num_lbls)))
        )

        da_lbl_img, da_num_lbls = dask_ndmeasure.label(
            (
                ((da_area_lbls >= area_min_threshold)[(slice(None),) + da_lbl_img.ndim * (None,)] * da_lbl_img) > 0
            ).sum(axis=0)
        )

        da_lbl_img, da_num_lbls = executor.persist([da_lbl_img, da_num_lbls])

        da_result = []
        for i in irange(1, 1 + int(da_num_lbls)):
            da_result.append((da_lbl_img == i)[None])
        da_result = da.concatenate(da_result)

        with open_zarr(data_basename + postfix_post + zarr_ext, "w") as f2:
            result = f2.create_group("rois").create_dataset(
                "mask",
                shape=da_result.shape,
                dtype=da_result.dtype,
                chunks=True
            )
            da_result = da_result.rechunk(result.chunks)
            status = executor.compute(da.store(da_result, result, lock=False, compute=False))
            dask.distributed.progress(status, notebook=False)


zip_zarr(data_basename + postfix_post + zarr_ext)

with h5py.File(data_basename + postfix_post + h5_ext, "w") as f2:
    with open_zarr(data_basename + postfix_post + zarr_ext, "r") as f1:
        zarr_to_hdf5(f1, f2)


if __IPYTHON__:
    result_image_stack = LazyZarrDataset(data_basename + postfix_post + zarr_ext, "rois/mask")[...][...].astype(np.uint8)

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=result_image_stack.min(),
        vmax=result_image_stack.max()
    )

### Threshold data

* `block_frames` (`int`): number of frames to work with in each full frame block (run in parallel).
* `block_space` (`int`): extent of each spatial dimension for each block (run in parallel).

In [None]:
%%time


block_frames = 40
block_space = 300


# Somehow we can't overwrite the file in the container so this is needed.
io_remove(data_basename + postfix_thrd + zarr_ext)
io_remove(data_basename + postfix_thrd + h5_ext)

with open_zarr(data_basename + postfix_wt + zarr_ext, "r") as f:
    with open_zarr(data_basename + postfix_post + zarr_ext, "r") as f2:
        with get_executor(client) as executor:
            # Load and prep data for computation.
            imgs = f["images"]
            da_imgs = da.from_array(
                imgs, chunks=(block_frames,) + (imgs.ndim - 1) * (block_space,)
            )
            msks = f2["rois/mask"]
            da_msks = da.from_array(
                msks, chunks=(block_frames,) + (imgs.ndim - 1) * (block_space,)
            )

            da_result = da_imgs * da_msks.max(axis=0, keepdims=True).astype(da_imgs.dtype)

            # Store data
            with open_zarr(data_basename + postfix_thrd + zarr_ext, "w") as f2:
                result = f2.create_dataset(
                    "images",
                    shape=da_result.shape,
                    dtype=da_result.dtype,
                    chunks=True
                )
                da_result = da_result.rechunk(result.chunks)
                status = executor.compute(da.store(da_result, result, lock=False, compute=False))
                dask.distributed.progress(status, notebook=False)


zip_zarr(data_basename + postfix_thrd + zarr_ext)

with h5py.File(data_basename + postfix_thrd + h5_ext, "w") as f2:
    with open_zarr(data_basename + postfix_thrd + zarr_ext, "r") as f1:
        zarr_to_hdf5(f1, f2)


if __IPYTHON__:
    result_image_stack = LazyZarrDataset(data_basename + postfix_thrd + zarr_ext, "images")[...][...]

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=result_image_stack.min(),
        vmax=result_image_stack.max()
    )

### ROI and trace extraction

* `block_frames` (`int`): number of frames to work with in each block (run in parallel).

In [None]:
%%time


block_frames = 100


# Somehow we can't overwrite the file in the container so this is needed.
io_remove(data_basename + postfix_rois + zarr_ext)
io_remove(data_basename + postfix_rois + h5_ext)

with open_zarr(data_basename + postfix_rois + zarr_ext, "w") as f2:
    with open_zarr(data_basename + postfix_post + zarr_ext, "r") as f1:
        f2["masks"] = f1["rois/mask"]

    mskimg = f2["masks"]
    mskimg_j = f2.create_dataset("masks_j", shape=mskimg.shape, dtype=numpy.uint8, chunks=True)
    par_norm_layer(num_frames=block_frames)(mskimg, out=mskimg_j)

    lblimg = label_mask_stack(mskimg, np.uint64)
    f2["labels"] = lblimg
    f2["labels_j"] = lblimg.astype(np.uint16)
    lblimg = f2["labels"]

zip_zarr(data_basename + postfix_rois + zarr_ext)

with h5py.File(data_basename + postfix_rois + h5_ext, "w") as f2:
    with open_zarr(data_basename + postfix_rois + zarr_ext, "r") as f1:
        zarr_to_hdf5(f1, f2)

# Somehow we can't overwrite the file in the container so this is needed.
io_remove(data_basename + postfix_traces + zarr_ext)
io_remove(data_basename + postfix_traces + h5_ext)

with open_zarr(data_basename + postfix_f_f0 + zarr_ext, "r") as fh_f_f0:
    with open_zarr(data_basename + postfix_rois + zarr_ext, "r") as fh_rois:
        with get_executor(client) as executor:
            # Load and prep data for computation.
            images = fh_f_f0["images"]
            da_images = da.from_array(
                images, chunks=(block_frames,) + images.shape[1:]
            )
            masks = fh_rois["masks"]
            da_masks = da.from_array(
                masks, chunks=(block_frames,) + masks.shape[1:]
            )

            da_result = compute_traces(da_images, da_masks)

            # Store taces
            with open_zarr(data_basename + postfix_traces + zarr_ext, "w") as fh_traces:
                result = fh_traces.create_dataset(
                    "traces",
                    shape=da_result.shape,
                    dtype=da_result.dtype,
                    chunks=True
                )
                da_result = da_result.rechunk(result.chunks)
                status = executor.compute(da.store(da_result, result, lock=False, compute=False))
                dask.distributed.progress(status, notebook=False)

zip_zarr(data_basename + postfix_traces + zarr_ext)

with h5py.File(data_basename + postfix_traces + h5_ext, "w") as f2:
    with open_zarr(data_basename + postfix_traces + zarr_ext, "r") as f1:
        zarr_to_hdf5(f1, f2)


if __IPYTHON__:
    result_image_stack = LazyZarrDataset(data_basename + postfix_f_f0 + zarr_ext, "images")
    lblimg = LazyZarrDataset(data_basename + postfix_rois + zarr_ext, "labels")

    mplsv = plt.figure(FigureClass=MPLViewer)
    mplsv.set_images(
        result_image_stack,
        vmin=par_compute_min_projection(num_frames=block_frames)(result_image_stack).min(),
        vmax=par_compute_max_projection(num_frames=block_frames)(result_image_stack).max()
    )

    lblimg = lblimg[...][...]
    lblimg_msk = numpy.ma.masked_array(lblimg, mask=(lblimg==0))

    mplsv.viewer.matshow(lblimg_msk, alpha=0.3, cmap=mpl.cm.jet)


mskimg = None
mskimg_j = None
lblimg = None
traces = None
traces_j = None

del mskimg
del mskimg_j
del lblimg
del traces
del traces_j

# End of workflow. Shutdown cluster.

In [None]:
from nanshe_workflow.par import cleanup_cluster_files

ipypar_prof = "sge"

from sys import executable as PYTHON
!$PYTHON -m ipyparallel.apps.ipclusterapp stop --profile=$ipypar_prof
del PYTHON

cleanup_cluster_files(ipypar_prof)

# Prepare interactive projection graph

In [None]:
import io
import os
import textwrap
import zlib

import numpy
import numpy as np

import scipy
import scipy as sp

import scipy.ndimage
import scipy.ndimage as spim

import h5py
import h5py as hp

import bokeh.plotting
import bokeh.plotting as bp

import bokeh.io
import bokeh.io as bio

import bokeh.embed
import bokeh.embed as be

from bokeh.models.mappers import LinearColorMapper

import matplotlib
import matplotlib.cm

from matplotlib.colors import ColorConverter
from matplotlib.cm import gist_rainbow

import webcolors

from bokeh.models import CustomJS, ColumnDataSource, HoverTool
from bokeh.models.layouts import Row

from builtins import (
    map as imap,
    range as irange
)

from past.builtins import basestring

import nanshe

import xnumpy
import xnumpy.core
from xnumpy.core import expand

import nanshe_workflow
from nanshe_workflow.data import io_remove, open_zarr
from nanshe_workflow.vis import get_rgb_array, get_rgba_array, get_all_greys, masks_to_contours_2d

In [None]:
with open_zarr(data_basename + postfix_rois + zarr_ext, "r") as f:
    mskimg = f["masks"][...]

with open_zarr(data_basename + postfix_traces + zarr_ext, "r") as f:
    traces = f["traces"][...]

with open_zarr(data_basename + postfix_proj + zarr_ext, "r") as f:
    imgproj_mean = f["mean"][...]
    imgproj_max = f["max"][...]
    imgproj_std = f["std"][...]

### Result visualization

* `proj_img` (`str` or `list` of `str`): which projection or projections to plot (e.g. "max", "mean", "std").
* `block_size` (`int`): size of each point on any dimension in the image in terms of pixels.
* `roi_alpha` (`float`): transparency of the ROIs in a range of [0.0, 1.0].
* `roi_border_width` (`int`): width of the line border on each ROI.

<br>
* `trace_plot_width` (`int`): width of the trace plot.

In [None]:
proj_img = "std"
block_size = 1
roi_alpha = 0.3
roi_border_width = 3
trace_plot_width = 500


bio.curdoc().clear()

grey_range = get_all_greys()
grey_cm = LinearColorMapper(grey_range)

colors_rgb = get_rgb_array(len(mskimg))
colors_rgb = colors_rgb.tolist()
colors_rgb = list(imap(webcolors.rgb_to_hex, colors_rgb))

mskctr_pts_y, mskctr_pts_x = masks_to_contours_2d(mskimg)

mskctr_pts_dtype = np.min_scalar_type(max(mskimg.shape[1:]) - 1)
mskctr_pts_y = [np.array(_, dtype=mskctr_pts_dtype) for _ in mskctr_pts_y]
mskctr_pts_x = [np.array(_, dtype=mskctr_pts_dtype) for _ in mskctr_pts_x]

mskctr_srcs = ColumnDataSource(data=dict(x=mskctr_pts_x, y=mskctr_pts_y, color=colors_rgb))


if isinstance(proj_img, basestring):
    proj_img = [proj_img]
else:
    proj_img = list(proj_img)


proj_plot_width = block_size*mskimg.shape[2]
proj_plot_height = block_size*mskimg.shape[1]
plot_projs = []

if "max" in proj_img:
    plot_max = bp.Figure(plot_width=proj_plot_width, plot_height=proj_plot_height,
                         x_range=[0, mskimg.shape[2]], y_range=[mskimg.shape[1], 0],
                         tools=["tap", "pan", "box_zoom", "resize", "wheel_zoom", "save", "reset"],
                         title="Max Projection with ROIs", border_fill_color="black")
    plot_max.image(image=[numpy.flipud(imgproj_max)], x=[0], y=[mskimg.shape[1]],
                   dw=[imgproj_max.shape[1]], dh=[imgproj_max.shape[0]], color_mapper=grey_cm)
    plot_max.patches('x', 'y', source=mskctr_srcs, alpha=roi_alpha, line_width=roi_border_width, color="color")

    plot_max.outline_line_color = "white"
    for i in irange(len(plot_max.axis)):
        plot_max.axis[i].axis_line_color = "white"

    plot_projs.append(plot_max)


if "mean" in proj_img:
    plot_mean = bp.Figure(plot_width=proj_plot_width, plot_height=proj_plot_height,
                         x_range=[0, mskimg.shape[2]], y_range=[mskimg.shape[1], 0],
                         tools=["tap", "pan", "box_zoom", "resize", "wheel_zoom", "save", "reset"],
                         title="Mean Projection with ROIs", border_fill_color="black")
    plot_mean.image(image=[numpy.flipud(imgproj_mean)], x=[0], y=[mskimg.shape[1]],
                   dw=[mskimg.shape[2]], dh=[mskimg.shape[1]], color_mapper=grey_cm)
    plot_mean.patches('x', 'y', source=mskctr_srcs, alpha=roi_alpha, line_width=roi_border_width, color="color")

    plot_mean.outline_line_color = "white"
    for i in irange(len(plot_mean.axis)):
        plot_mean.axis[i].axis_line_color = "white"

    plot_projs.append(plot_mean)


if "std" in proj_img:
    plot_std = bp.Figure(plot_width=proj_plot_width, plot_height=proj_plot_height,
                         x_range=[0, mskimg.shape[2]], y_range=[mskimg.shape[1], 0],
                         tools=["tap", "pan", "box_zoom", "resize", "wheel_zoom", "save", "reset"],
                         title="Std Dev Projection with ROIs", border_fill_color="black")
    plot_std.image(image=[numpy.flipud(imgproj_std)], x=[0], y=[mskimg.shape[1]],
                   dw=[mskimg.shape[2]], dh=[mskimg.shape[1]], color_mapper=grey_cm)
    plot_std.patches('x', 'y', source=mskctr_srcs, alpha=roi_alpha, line_width=roi_border_width, color="color")

    plot_std.outline_line_color = "white"
    for i in irange(len(plot_std.axis)):
        plot_std.axis[i].axis_line_color = "white"

    plot_projs.append(plot_std)


all_tr_dtype_srcs = ColumnDataSource(data=dict(traces_dtype=traces.dtype.type(0)[None]))
all_tr_shape_srcs = ColumnDataSource(data=dict(traces_shape=traces.shape))
all_tr_srcs = ColumnDataSource(data=dict(
    traces=numpy.frombuffer(
        zlib.compress(traces.tobytes()),
        dtype=np.uint8
    )
))
tr_srcs = ColumnDataSource(data=dict(times_sel=[], traces_sel=[], colors_sel=[]))
plot_tr = bp.Figure(plot_width=trace_plot_width, plot_height=proj_plot_height,
                    x_range=(0.0, float(traces.shape[1])), y_range=(float(traces.min()), float(traces.max())),
                    tools=["pan", "box_zoom", "resize", "wheel_zoom", "save", "reset"], title="ROI traces",
                    background_fill_color="black", border_fill_color="black")
plot_tr.multi_line("times_sel", "traces_sel", source=tr_srcs, color="colors_sel")

plot_tr.outline_line_color = "white"
for i in irange(len(plot_tr.axis)):
    plot_tr.axis[i].axis_line_color = "white"

plot_projs.append(plot_tr)


mskctr_srcs.callback = CustomJS(
    args=dict(
        all_tr_dtype_srcs=all_tr_dtype_srcs,
        all_tr_shape_srcs=all_tr_shape_srcs,
        all_tr_srcs=all_tr_srcs,
        tr_srcs=tr_srcs
    ), code="""
    var range = function(n){ return Array.from(Array(n).keys()); };

    var traces_not_decoded = (all_tr_dtype_srcs.get('data')['traces_dtype'] == 0);
    var traces_dtype = all_tr_dtype_srcs.get('data')['traces_dtype'].constructor;
    var traces_shape = all_tr_shape_srcs.get('data')['traces_shape'];
    var trace_len = traces_shape[1];
    var traces = all_tr_srcs.get('data')['traces'];
    if (traces_not_decoded) {
        traces = window.pako.inflate(traces);
        traces = new traces_dtype(traces.buffer);
        all_tr_srcs.get('data')['traces'] = traces;
        all_tr_dtype_srcs.get('data')['traces_dtype'] = 1;
    }

    var inds = cb_obj.get('selected')['1d'].indices;
    var colors = cb_obj.get('data')['color'];
    var selected = tr_srcs.get('data');

    var times = range(trace_len);

    selected['times_sel'] = [];
    selected['traces_sel'] = [];
    selected['colors_sel'] = [];

    for (i = 0; i < inds.length; i++) {
        var inds_i = inds[i];
        var trace_i = traces.slice(trace_len*inds_i, trace_len*(inds_i+1));
        var color_i = colors[inds_i];

        selected['times_sel'].push(times);
        selected['traces_sel'].push(trace_i);
        selected['colors_sel'].push(color_i);
    }

    tr_srcs.trigger('change');
""")


plot_group = Row(*plot_projs)


# Clear out the old HTML file before writing a new one.
io_remove(data_basename + postfix_html + html_ext)


def indent(text, spaces):
    spaces = " " * int(spaces)
    return "\n".join(imap(lambda l: spaces + l, text.splitlines()))

def write_html(filename, title, div, script, cdn):
    html_tmplt = textwrap.dedent(u"""\
        <html lang="en">
            <head>
                <meta charset="utf-8">
                <title>{title}</title>
                {cdn}
                <style>
                  html {{
                    width: 100%;
                    height: 100%;
                  }}
                  body {{
                    width: 90%;
                    height: 100%;
                    margin: auto;
                    background-color: black;
                  }}
                </style>
            </head>
            <body>
                {div}
                {script}
            </body>
        </html>
    """)

    html_cont = html_tmplt.format(
        title=title,
        div=indent(div, 8),
        script=indent(script, 8),
        cdn=indent(cdn, 8),
    )

    with io.open(filename, "w") as fh:
        fh.write(html_cont)

script, div = be.components(plot_group)
cdn = bokeh.resources.CDN.render() + "\n"
cdn += """
<script type="text/javascript" src="https://cdn.jsdelivr.net/pako/1.0.4/pako_inflate.min.js"></script>
"""
cdn += "\n"

write_html(data_basename + postfix_html + html_ext, data_basename + postfix_html, div, script, cdn)


if __IPYTHON__:
    from IPython.display import display, IFrame
    display(IFrame(data_basename + postfix_html + html_ext, "100%", 1.05*proj_plot_height))

In [None]:
# Test teardown. Ignore warnings during production runs.

%run ./teardown_tests.py