## Example workflow to register multi-view light sheet data

This notebook uses the [multiview example dataset](https://drive.google.com/file/d/1VFT2APVPItBCyzrQ7dPWBNILyNh6yDKB/view?usp=sharing) provided in the [BigStitcher-Spark github repo](https://github.com/JaneliaSciComp/BigStitcher-Spark).

Notes:
- registration: this notebook uses translation and affine registration
- fusion: currently linear blending and content-based fusion are supported
- visualization:
  - neuroglancer (no installation required)
  - (optional) install [napari-stitcher](https://multiview-stitcher.github.io/napari-stitcher/main/)

In [None]:
import os
import numpy as np
from pathlib import Path
from tqdm import tqdm
import dask.diagnostics, tempfile

from multiview_stitcher import msi_utils, spatial_image_utils

%matplotlib notebook

# the package napari-stitcher is required for visualization with napari
VISUALIZE_USING_NAPARI = False

if VISUALIZE_USING_NAPARI:
    import napari
    from napari_stitcher import viewer_utils

## Specify input data

In [None]:
# Download the example dataset
# https://drive.google.com/file/d/1VFT2APVPItBCyzrQ7dPWBNILyNh6yDKB/view?usp=sharing
# and indicate the extracted paths below

base_dir = Path('../image-datasets/bigstitcher/IP_TIFF')
filenames = list(base_dir.glob('*TL18*.tif'))

# extract angle from e.g. spim_TL18_Angle135.tif
angles = [int(f.stem.split('Angle')[1].split('.')[0]) for f in filenames]

# sort filenames by angle
filenames = [filenames[i] for i in np.argsort(angles)]
angles = np.sort(angles)

if not len(filenames):
    raise ValueError('No files found. Please download the example dataset and set the correct path.')

print('Files:', [f.name for f in filenames])
print('Extracted angles:', angles)


## Convert input data to OME-Zarr

First the convert the input views to OME-Zarr format. We indicate the scale stored in the imagej metadata (TODO: confirm).

In [None]:
from multiview_stitcher import io, ngff_utils
from multiview_stitcher import spatial_image_utils as si_utils

# read input tiff files into spatial images
sims = [
    io.read_tiff_into_spatial_xarray(
        filename,
        dims=('z', 'y', 'x'),
        scale={'z': 2., 'y': 0.7310781, 'x': 0.7310781})
    for filename in filenames
]

# write tiles to ome zarr format
for sim, filename in zip(sims, filenames):
    sim = ngff_utils.write_sim_to_ome_zarr(
        sim,
        filename.with_suffix('.ome.zarr'),
        overwrite=True
    )

## Load the input data from OME-Zarr

`ngff_utils.read_sim_from_ome_zarr` will only reference the data on disk, so the data will not be loaded into memory.

### Registration

For registration, we load the spatial images at resolution level 1 (one lower than 0).

In [None]:
# load spatial images for registration
sims_reg = [
    ngff_utils.read_sim_from_ome_zarr(
        filename.with_suffix('.ome.zarr'),
        resolution_level=1,
    ) for filename in filenames]

# the next two lines can be executed if registration should only be done on a
# subset of the views in which only the beads are visible
# otherwise if commented out, the registration will be done on the full field of view
for iview, sim in enumerate(sims_reg):
    sims_reg[iview] = si_utils.sim_sel_coords(sims_reg[iview], {'x': slice(800, 1100)})

# convert spatial images to multiscale spatial images which the registration step expects
msims_reg = [msi_utils.get_msim_from_sim(
    sim, scale_factors=None)
    for sim in sims_reg]

### Fusion

For fusion, we load the full resolution data at resolution level 0. This is the data that will be used for the final fusion step.

In [None]:
sims_fus = [
    ngff_utils.read_sim_from_ome_zarr(
        filename.with_suffix('.ome.zarr'),
        resolution_level=0,
    ) for filename in filenames]

msims_fus = [msi_utils.get_msim_from_sim(
    sim, scale_factors=None)
    for sim in sims_fus]

## Set estimate of initial transformations

We set the initial transformations estimated from the angles indicated in the filenames. We're assuming they represent a rotation around the center of the views.

These transformations will be refined during the registration process.

In [None]:
from multiview_stitcher import param_utils

for iview in range(len(angles)):

    # define rotation transformation around x axis
    affine = param_utils.affine_from_rotation(
        angles[iview] / 180 * np.pi,
        point=spatial_image_utils.get_center_of_sim(msims_reg[iview]['scale0/image'], transform_key=None),
        direction=[0,0,1],
        )
    
    # convert matrix to xarray.DataArray with labeled axes
    xaffine = param_utils.affine_to_xaffine(affine)

    # set the transformations on the image objects used for registration and fusion
    msi_utils.set_affine_transform(
        msims_reg[iview],
        affine[None], # one tp
        transform_key='affine_manual',
    )


### Visualize pre-registered views

In [None]:
if VISUALIZE_USING_NAPARI:

    viewer = napari.Viewer(ndisplay=3)
    lds = viewer_utils.create_image_layer_tuples_from_msims(
        msims_reg, transform_key='affine_manual', n_colors=4, contrast_limits=[0, 10])
    viewer_utils.add_image_layer_tuples_to_viewer(viewer, lds)
    viewer.reset_view()

### Register views

First we apply a translation registration on the (pre-rotated) views to perform a first coarse alignment.

In [None]:
from multiview_stitcher import registration

with dask.diagnostics.ProgressBar():

    # phase shift registration
    params = registration.register(
        msims_reg,
        # registration_binning={'z': 2, 'y': 8, 'x': 8},
        registration_binning={'z': 1, 'y': 4, 'x': 4},
        reg_channel_index=0,
        transform_key='affine_manual',
        new_transform_key='translation_registered',
        pre_registration_pruning_method=None,
        groupwise_resolution_kwargs={
            'transform': 'translation',
        },
        n_parallel_pairwise_regs=None, # limit this to e.g. 2 to reduce memory usage
    )

    # # alternatively registration using ANTsPy
    # params = registration.register(
    #     msims_reg,
    #     registration_binning={'z': 2, 'y': 8, 'x': 8},
    #     reg_channel_index=0,
    #     transform_key='affine_manual',
    #     new_transform_key='translation_registered',
    #     pre_registration_pruning_method=None,
    #     pairwise_reg_func=registration.registration_ANTsPy,
    #     pairwise_reg_func_kwargs={
    #         'transform_types': ['Translation'],
    #     },
    #     groupwise_resolution_kwargs={
    #         'transform': 'translation',
    #     }
    # )
    

### Visualize registration (translation)

In [None]:
if VISUALIZE_USING_NAPARI:
    
    viewer = napari.Viewer(ndisplay=3)

    lds = viewer_utils.create_image_layer_tuples_from_msims(
        msims_reg,
        transform_key='affine_registered',
        n_colors=100,
        name_prefix='registered view',
        contrast_limits=[0, 10],
        positional_cmaps=True,
        )
    
    rlayers = viewer_utils.add_image_layer_tuples_to_viewer(viewer, lds, do_link_layers=False)

## Affine registration

In this step we refine the transformations using an affine registration.

In [None]:
input_transform_key = 'translation_registered'
output_transform_key = 'affine_registered'

# Number of iterations for the registration
N_iter = 1
for iteration in range(N_iter):

    if iteration == N_iter - 1:
        iter_output_transform_key = output_transform_key
    else:
        iter_output_transform_key = f'{output_transform_key}_iter{iteration}'

    # rigid registration
    params = registration.register(
        msims_reg,
        # registration_binning={'z': 2, 'y': 8, 'x': 8},
        registration_binning={'z': 1, 'y': 4, 'x': 4},
        reg_channel_index=0,
        transform_key=input_transform_key,
        new_transform_key=iter_output_transform_key,
        pre_registration_pruning_method=None,
        pairwise_reg_func=registration.registration_ANTsPy,
        pairwise_reg_func_kwargs={
            # 'transform_types': ['Rigid'],
            'transform_types': ['Rigid', 'Affine'],
        },
        groupwise_resolution_kwargs={
            # 'transform': 'rigid',
            'transform': 'affine',
        },
        n_parallel_pairwise_regs=None, # limit this to e.g. 2 to reduce memory usage
    )

    input_transform_key = iter_output_transform_key

### Visualize registration

In [None]:
if VISUALIZE_USING_NAPARI:
    
    viewer = napari.Viewer(ndisplay=3)

    lds = viewer_utils.create_image_layer_tuples_from_msims(
        msims_reg,
        transform_key='affine_registered',
        n_colors=100,
        name_prefix='registered view',
        contrast_limits=[0, 10],
        positional_cmaps=True,
        )
    
    rlayers = viewer_utils.add_image_layer_tuples_to_viewer(viewer, lds, do_link_layers=False)

## Set the obtained transformations on the images to be fused

In [None]:
fusion_transform_key = 'affine_registered'

for iview, msim in enumerate(msims_reg):
    p = msi_utils.get_transform_from_msim(msim, fusion_transform_key).sel(t=0)
    si_utils.set_sim_affine(sims_fus[iview], p, fusion_transform_key)

## Fuse views (linear blending)

In [None]:
from multiview_stitcher import fusion, weights, misc_utils

import importlib
importlib.reload(fusion)

output_ome_zarr_path = "multiview_fused.ome.zarr"

# define the fused output
# 'fused' is a dask array
fused = fusion.fuse(
    sims=sims_fus[:],
    transform_key=fusion_transform_key,
    output_spacing={dim: 2 for dim in ['z', 'y', 'x']},
    blending_widths={dim: 50 for dim in ['z', 'y', 'x']},
    output_chunksize=256,
    # uncomment the next two lines for content-based fusion
    # weights_func=weights.content_based, # uncomment for content-based fusion
    # weights_func_kwargs={"sigma_1": 5, "sigma_2": 11}, # uncomment for content-based fusion
    output_zarr_url=output_ome_zarr_path,
    zarr_options={
        "ome_zarr": True,
        "overwrite": True,
    },
    batch_options={
        "batch_func": misc_utils.process_batch_using_ray,
    },
)

print('Summary of the fused dask array:')
fused

### Visualize fusion in napari

If data is too large to visualize in napari, you can use neuroglancer in the cell below instead.

In [None]:
if VISUALIZE_USING_NAPARI:

    viewer = napari.Viewer(ndisplay=3)

    lds = viewer_utils.create_image_layer_tuples_from_msims(
        msims_reg, transform_key=fusion_transform_key, n_colors=4,
        name_prefix='registered view',
        contrast_limits=[0, 200]
        )

    rlayers = viewer_utils.add_image_layer_tuples_to_viewer(
        viewer, lds, do_link_layers=False)

    lds = viewer_utils.create_image_layer_tuples_from_msim(
        msi_utils.get_msim_from_sim(fused),
        transform_key=fusion_transform_key,
        name_prefix='fused',
        contrast_limits=[0, 200])

    viewer_utils.add_image_layer_tuples_to_viewer(viewer, lds)

## Visualize using neuroglancer

In [None]:
# interrupt the notebook cell to stop the viewer

from multiview_stitcher import vis_utils

# just the fusion result
vis_utils.view_neuroglancer(
    # sims=sims_fus + [fused],
    sims=[fused],
    # ome_zarr_paths=[str(filename.with_suffix('.ome.zarr')) for filename in filenames] + [output_ome_zarr_path],
    ome_zarr_paths=[output_ome_zarr_path],
    channel_coord=0,
    transform_key=fusion_transform_key,
)

# # or visualize the registered views and the fusion result
# vis_utils.view_neuroglancer(
#     sims=sims_fus + [fused],
#     ome_zarr_paths=[str(filename.with_suffix('.ome.zarr')) for filename in filenames] + [output_ome_zarr_path],
#     channel_coord=0,
#     transform_key="affine_registered",
# )