# Example 3D stitching workflow

This notebook demonstrates a workflow for stitching (potentially large) 3D data available as tif stacks.

#### 1) Download the example dataset

This notebook uses the 3D example dataset (6 tiles, 3 channels) provided by BigStitcher: https://imagej.net/plugins/bigstitcher/#example-datasets.

#### 2) Load and preposition the input tiles

The input tiles are loaded as numpy or dask arrays. The tiles are prepositioned on a grid.

#### 3) Tile registration

The tiles are registered using one of the input channels. The obtained transform parameters can be read out.

#### 4) Fuse the tiles into a single output image

The registered tiles are combined (fused) into a single output image. Weighted averaging is used to smoothly blend the tiles at the tile boundaries by using pixel-wise weights that decrease in proximity to the tile boundaries.

### Downloading the example dataset

In [None]:
import os
import zipfile
from pathlib import Path

url = "https://preibischlab.mdc-berlin.de/BigStitcher/Grid_3d.zip"

# directory to save the data
base_dir = os.path.join('./data', os.path.basename(url)[:-4])
os.makedirs(base_dir, exist_ok=True)

zip_filepath = os.path.join(base_dir, os.path.basename(url))

# download
if not os.path.exists(zip_filepath):
    from urllib.request import urlretrieve
    urlretrieve(url, zip_filepath)

# unzip
with zipfile.ZipFile(os.path.join(base_dir, os.path.basename(url)), 'r') as zip_ref:
    zip_ref.extractall(base_dir)

# get the list of files
sorted(Path(base_dir).glob('*/*.tif'))

## Defining the stitching input

### Specifying the input files

In [None]:
import numpy as np

# indicate the tiles and channels to process
tiles = np.arange(73, 79)
channels = np.arange(1, 4)

def get_filename_from_tile_and_channel(tile, channel):
    """
    This convenience function returns the filename given the tile and channel.
    """
    return f'./data/Grid_3d/Grid1/C{channel:01d}-{tile:02d}.tif'

print('Example filename:\n', get_filename_from_tile_and_channel(tiles[0], channels[0]))

### Reading the input files

In [None]:
import tifffile
import aicsimageio

def read_image(filename):

    # use tifffile to read into a numpy array
    ar = tifffile.imread(filename)

    # alternatively: use aicsimageio to read the file
    # aicsim = aicsimageio.AICSImage(filename)

    # as a numpy array
    # ar = aicsim.get_image_data().squeeze()

    # or as a dask array
    # ar = aicsim.get_dask_stack().squeeze()

    return ar

filename = get_filename_from_tile_and_channel(tiles[0], channels[0])
ar = read_image(filename)
ar

In [None]:
# read all tiles and concatenate channels into a single array

tile_data = []
for tile in tiles[:]:
    tile_channels = []
    for channel in channels:
        filename = get_filename_from_tile_and_channel(tile, channel)
        ar = read_image(filename)
        tile_channels.append(ar)
    tile_ar = np.stack(tile_channels)
    tile_data.append(tile_ar)

tile_data[0].shape, len(tile_data)

### Attaching metadata to arrays

In [None]:
from multiview_stitcher import spatial_image_utils as si_utils

# example with a single tile
sim = si_utils.get_sim_from_array(
    tile_data[0],
    dims=['c', 'z', 'y', 'x'],
    scale={'z': 5, 'y': 1, 'x': 1},
    translation={'z': 2, 'y': 0, 'x': 0},
    c_coords=['RFP', 'GFP', 'BFP'],
    transform_key='manual_prepositioning',
    # affine=np.eye(4),
)

sim

### Converting a SpatialImage into a MultiscaleSpatialImage

This will be the registration input.

In [None]:
from multiview_stitcher import msi_utils

msim = msi_utils.get_msim_from_sim(sim, scale_factors=[2, 4])
msim

### Prepositioning images on a regular grid

First we define the grid layout.

In [None]:
def get_tile_grid_position_from_tile_index(
    tile_index,
    tile_grid_shape=(2, 3)
    ):
    """
    This function returns the grid index for a given (linear) tile index.
    E.g. for a 2x3 grid, the tile indices are:
    0 1
    2 3
    4 5
    """
    return {
        'z': 0,
        'y': tile_index // tile_grid_shape[0],
        'x': tile_index % tile_grid_shape[0],
    }

get_tile_grid_position_from_tile_index(0, tile_grid_shape=(2, 3))

Secondly we use the grid layout to preposition the images in physical space. Here we set the overlap between neighboring tiles.

In [None]:
dims = ['z', 'y', 'x']
scale = {'z': 5, 'y': 1, 'x': 1}
overlap = 0.1
grid_indices = [get_tile_grid_position_from_tile_index(itile, (2, 3)) for itile in range(len(tile_data))]
shape = {dim: tile_data[0].shape[-idim] for idim, dim in enumerate(scale.keys())}
translations = [
    {dim: grid_indices[itile][dim] * (1 - overlap) * shape[dim] * scale[dim]
    for dim in dims} for itile in range(len(tile_data))]

translations

Using the previously calculated positions we can now preposition all images.

In [9]:
sims = [
    si_utils.get_sim_from_array(
        tile_data[itile],
        dims=['c', 'z', 'y', 'x'],
        scale=scale,
        translation=translations[itile],
        transform_key='manual_prepositioning',
        c_coords=['RFP', 'GFP', 'BFP'],
    ) for itile in range(len(tile_data))
]

msims = [msi_utils.get_msim_from_sim(sim, scale_factors=[]) for sim in sims]

### Visualizing the prepositioned tiles

In [None]:
# visualize the positions of the tiles

from multiview_stitcher import vis_utils

# uncomment the following line for 3D interactivity with the plot (requires ipympl to be installed)
%matplotlib widget

fig, ax = vis_utils.plot_positions(
    msims,
    use_positional_colors=True, # set to False for faster execution in case of more than 20 tiles/views
    transform_key='manual_prepositioning'
    )

In [None]:
from napari_stitcher import viewer_utils
import napari

viewer = napari.Viewer(ndisplay=3)

lds = viewer_utils.create_image_layer_tuples_from_msims(
    msims,
    ch_coord='GFP',
    transform_key='manual_prepositioning')

viewer_utils.add_image_layer_tuples_to_viewer(viewer, lds)

## Registering the tiles

During registration, the tile positions are refined.


In [None]:
from multiview_stitcher import registration
from dask.diagnostics import ProgressBar

with ProgressBar():

    params = registration.register(
        msims,
        registration_binning={'z': 1, 'y': 2, 'x': 2},
        reg_channel='GFP',
        transform_key='manual_prepositioning',
        new_transform_key='registered',
        plot_summary=True,
    )

### Printing the obtained parameters

How to interpret these?
   - parameters are expressed as homogeneous matrices of dimensions (n+1, n+1)
   - the first n rows and columns correspond to the linear part of the transformation
   - the first three elements of the last column correspond to the translation in (z, y, x)
   - the parameters map the coordinates of the input files (considering their scale and translation) into a world coordinate system in which the tiles are registered


In [None]:
affine = msi_utils.get_transform_from_msim(msims[0], transform_key='registered')[0]
affine

In [None]:
from multiview_stitcher import param_utils

t = param_utils.translation_from_affine(affine)
t

### Visualize registration

In [None]:
from napari_stitcher import viewer_utils
import napari

viewer = napari.Viewer(ndisplay=3)

lds = viewer_utils.create_image_layer_tuples_from_msims(
    msims,
    ch_coord='GFP',
    transform_key='registered')

viewer_utils.add_image_layer_tuples_to_viewer(viewer, lds)

## Refining the registration

In [None]:
with ProgressBar():

    params = registration.register(
        msims,
        registration_binning={'z': 1, 'y': 2, 'x': 2},
        reg_channel='GFP',
        transform_key='registered',
        new_transform_key='affine_registered',
        plot_summary=True,
        pairwise_reg_func=registration.registration_ANTsPy,
        pairwise_reg_func_kwargs={
            'transform_types': ['Rigid', 'Affine'],
            },
        groupwise_resolution_kwargs={
            'transform': 'affine',
            }
        )

### Visualizing the refined results

In [None]:
from napari_stitcher import viewer_utils
import napari

viewer = napari.Viewer(ndisplay=3)

lds = viewer_utils.create_image_layer_tuples_from_msims(
    msims,
    ch_coord='GFP',
    transform_key='affine_registered',
    positional_cmaps=False
    )

viewer_utils.add_image_layer_tuples_to_viewer(viewer, lds)

# automatic coloring not working yet for affine transformations
for il, l in enumerate(viewer.layers):
    l.colormap = ['Green', 'Red'][sum(get_tile_grid_position_from_tile_index(il, (2, 3)).values()) % 2]

## Fusion

During fusion, the tiles are combined into a single image.

In [None]:
from multiview_stitcher import fusion

fused = fusion.fuse(
    [msi_utils.get_sim_from_msim(msim) for msim in msims],
    transform_key='affine_registered',
    output_chunksize=256,
    output_spacing={'z': 10, 'y': 2, 'x': 2},
    )

# this is a SpatialImage object
fused

# this is a dask array
fused.data

In [None]:
# fuse in memory

with ProgressBar():
    fused = fused.compute()

from matplotlib import pyplot as plt
plt.figure()
plt.imshow(fused.data.squeeze()[0].max(-3))

### Stream fusion into zarr

In [None]:
from dask import array as da

with ProgressBar():
    fused.data = da.to_zarr(
        fused.data,
        "./data/Grid_3d/fused.zarr",
        overwrite=True, return_stored=True, compute=True)

### Visualize fusion in napari

In [None]:
from napari_stitcher import viewer_utils
import napari

viewer = napari.Viewer(ndisplay=3)

lds = viewer_utils.create_image_layer_tuples_from_msim(
    msi_utils.get_msim_from_sim(fused),
    transform_key='affine_registered',
    )

viewer_utils.add_image_layer_tuples_to_viewer(viewer, lds)

### Using a custom fusion function

In [None]:
# taken from multiview_stitcher.fusion
from multiview_stitcher import weights
def weighted_average_fusion(
    transformed_views,
    blending_weights,
    fusion_weights=None,
):
    """
    Simple weighted average fusion.

    Parameters
    ----------
    transformed_views : list of ndarrays
        transformed input views
    blending_weights : list of ndarrays
        blending weights for each view
    fusion_weights : list of ndarrays, optional
        additional view weights for fusion, e.g. contrast weighted scores.
        By default None.

    Returns
    -------
    ndarray
        Fusion of input views
    """

    if fusion_weights is None:
        additive_weights = blending_weights
    else:
        additive_weights = blending_weights * fusion_weights

    additive_weights = weights.normalize_weights(additive_weights)

    product = transformed_views * additive_weights

    return np.nansum(product, axis=0).astype(transformed_views[0].dtype)


# custom fusion function
def max_fusion(
    transformed_views,
    blending_weights,
    fusion_weights=None,
):
    """
    Maximum intensity fusion.
    """

    return np.nanmax(transformed_views, axis=0)

fused_max = fusion.fuse(
    [msi_utils.get_sim_from_msim(msim) for msim in msims],
    transform_key='affine_registered',
    output_chunksize=256,
    output_spacing={'z': 10, 'y': 2, 'x': 2},
    fusion_func=max_fusion,
    )

In [None]:
# fuse in memory

with ProgressBar():
    fused_max = fused_max.compute()

from matplotlib import pyplot as plt
plt.figure()
plt.imshow(fused_max.data.squeeze()[0].max(-3))

### Visualize the difference between the default and custom fusion

In [None]:
plt.figure()
plt.imshow(np.abs(fused.data.astype(float) - fused_max.data).squeeze()[0].max(-3))