# mesoSPIM example workflow

## Intro

This notebook stitches an example mesoSPIM dataset comprised of:
- four tiles (2x2 grid)
- two channels
- two arms (illuminations from opposite sides)

The data is available via Globus [here](https://app.globus.org/file-manager?origin_id=b91c61e8-5611-4970-bf4c-9b0d392e5c3c&origin_path=%2FReussMouseBrain-2x2Tiles-2Ch-2Arms%2F). The dataset was shared by Nikita Vladimirov (see image.sc thread [here](https://forum.image.sc/t/are-there-publicly-available-mesospim-datasets-unprocessed/116020/14)).

## Setup

1. Download the dataset to your local machine.
1. Setup Python in your preferred way (e.g. conda, venv, pipenv, poetry, etc.)
1. Make sure `multiview-stitcher >= 0.1.35` is installed: `pip install "multiview-stitcher>=0.1.35"`
1. The dataset is in BigStitcher OME-Zarr format. For reading the associated BigStitcher XML file, we use the `pydantic-bigstitcher` package. Install it via pip ```pip install pydantic-bigstitcher```
1. Optionally, install `ray` for parallelising fusion on top of dask: `pip install "ray[default]"`

## Indicate path to BigStitcher XML file

Note: Here we're reading the positional / transform metadata from a predefined metadata file. Check out the example notebooks [here] to see how this can be set manually.

In [None]:
xml_path = "/Users/albertm/software/multiview-stitcher/image-datasets/mesospim/OME-ZARR/2x2-tiles_ome-zarr.xml"

In [None]:
from pathlib import Path
import os
import dask.array as da
import numpy as np

import pydantic_bigstitcher as pbs

from multiview_stitcher import spatial_image_utils as si_utils
from multiview_stitcher import (
    registration,
    fusion,
    param_utils,
    msi_utils,
    misc_utils,
    vis_utils,
    ngff_utils,
)

# interactive visualization in jupyter notebooks
%matplotlib ipympl

# mesospim example data contains "discrete" parameter in axes metadata
# which is not part of the NGFF spec and causes ngff-zarr to throw an error
# thus we need to monkey patch ngff-zarr to ignore this parameter

# monkeypatch: monkey patch __init__ of `ngff_zarr.Axis` to avoid throwing error with unexpected "discrete" parameter
import ngff_zarr
_original_init = ngff_zarr.Axis.__init__
ngff_zarr.Axis.__init__ = lambda self, *args, **kwargs: _original_init(
    self, *args, **{k: v for k, v in kwargs.items() if k != "discrete"}
)  # Ignore unexpected kwargs like "discrete"

In [None]:
# The following functions help
# - reading the transform metadata from BigStitcher XML files and
# - converting them to the xparams format used in multiview-stitcher

def transform_to_xparams(transform):
    matrix = transform.affine
    offset = transform.translation

    dims = ['z', 'y', 'x']
    affine = param_utils.affine_from_linear_affine(
        np.array([matrix[dim1][dim2] for dim1 in dims for dim2 in dims] + [offset[dim] for dim in dims])
    )
    xaffine = param_utils.affine_to_xaffine(affine)

    return xaffine

def view_transforms_to_xparams(view_transforms, n_transforms_to_consider=None):
    xparamss = [transform_to_xparams(view_transform.to_transform().transform)
                # for view_transform in view_transforms[:-1]]
                for view_transform in view_transforms[slice(n_transforms_to_consider)]]
    
    xparams = xparamss[0]
    for xp in xparamss[1:]:
        xparams = param_utils.matmul_xparams(xparams, xp)
    
    return xparams

## Read the data and define the initial tile configuration

Note that no actual pixel data is loaded yet. This only reads the metadata and sets up the initial tile configuration.

Mini glossary:
- `msim`: Multi-scale image: xr.DataTree containing multiple xr.DataArray at different resolution levels
- `xparams`: affine parameters represented as xarray.DataArray, which allows to label the axes

In [None]:
import xarray as xr

xml_path = Path(xml_path) # defined above

# use pydantic-bigstitcher to parse the XML file
sd = pbs.SpimData2.from_xml(open(xml_path).read())
zgroups = sd.sequence_description.image_loader.zgroups.elements

df = []
# for each zarr group (i.e. each view), read the metadata and create a spatial image
for izg, zgroup in enumerate(zgroups[:]):
    
    # read basic metadata from XML file
    filepath = xml_path.parent / sd.sequence_description.image_loader.zarr.path / zgroup.path
    view_setup = sd.sequence_description.view_setups.elements[izg]
    ch = view_setup.attributes.channel
    ill = view_setup.attributes.illumination
    tile = view_setup.attributes.tile
    angle = view_setup.attributes.angle
    
    spacing = {dim: float(v)
               for dim, v in zip(["x", "y", "z"], view_setup.voxel_size.size.split(" "))}

    xparams_ome_zarr = view_transforms_to_xparams(
        sd.view_registrations.elements[izg].view_transforms, n_transforms_to_consider=1)

    # In this dataset the second transform in the XML file defines the anisotropic pixel spacing,
    # which is not contained in the OME-Zarr metadata.
    # We need this only for correct visualization in neuroglancer (which reads the spacing
    # from OME-Zarr metadata)
    xparams_xml = view_transforms_to_xparams(
        sd.view_registrations.elements[izg].view_transforms, n_transforms_to_consider=None)

    # somehow params are in pixel uits    
    for idim, dim in enumerate(['z', 'y', 'x']):
        xparams_ome_zarr.loc[dim, "1"] *= spacing[dim]
        xparams_xml.loc[dim, "1"] *= spacing[dim]

    msim = ngff_utils.ngff_multiscales_to_msim(
        ngff_zarr.from_ngff_zarr(filepath),
        transform_key='ome-zarr',
        )
    
    msim = msim.map_over_datasets(lambda ds: xr.Dataset(
        {'image': ds.image.assign_coords(
                    {dim: ds.image.coords[dim] * spacing[dim]
                            for dim in ['z', 'y', 'x']} | \
                    {'c': [ch]}
                    )} | \
        {t: ds.data_vars[t] for t in ds.data_vars if t != 'image'})#.assign_coords({'v': izg})
        if len(ds.data_vars) > 0 else ds)
    
    msi_utils.set_affine_transform(msim, xparams_ome_zarr, transform_key='ome-zarr')
    msi_utils.set_affine_transform(msim, xparams_xml, transform_key='xml')

    df.append(
        {
            'msim': msim,
            'ch': int(ch),
            'ill': int(ill),
            'tile': int(tile),
            'angle': int(angle),
            'filepath': filepath,
        }
    )


## Create a pandas DataFrame describing the referenced image stacks

Having the data in a pandas DataFrame makes it easy to filter and select subsets (e.g. channels, tiles, illuminations) of the data.

**Note**: Typically in multiview-stitcher workflows we combine the channels of a given tile within the same "spatial-image" or "multiscale-spatial-image". However, here we're going to keep the channels separate because this is more convenient when using neuroglancer for visualization (since neuroglancer reads the data directly from the files). Therefore an alternative would have been to resave the data the combined data as OME-Zarr with combined channels.

In [None]:
import pandas as pd
df = pd.DataFrame(df)
df

## Visualize the tile configuration

In [None]:
# Visualize view configuration

# make a subselection so that the plot is not too crowded
df_vis = df[(df.ch == 0) & (df.ill == 0)]

vis_utils.plot_positions(
    df_vis["msim"].tolist(), transform_key='xml'
)

## Examine the multi-scale image data

Let's look at the data sizes for the different scales. This will help us to decide which scale to use for illumination selection and registration.

In [None]:
# print data sizes for the different scales
print("Data sizes for different scales (first tile):")
msim = df['msim'][0]
print('Dimensions:', msim['scale0/image'].dims)
for scale in msim:
    print(f"Scale {scale}: {msim[scale].image.shape}")

## Visualize the input data using neuroglancer

A browser window should open. When finished viewing the Jupyter Notebook kernel needs to be interrupted. This is because the cell below starts a web server that serves the data to neuroglancer.

In [None]:
# Choose which illumination and channel to visualize
# It's possible to view all of these, but the neuroglancer view will become crowded.
# This neuroglancer interaction is to be improved in the future
# (and is already better when working with OME-Zarr files that contain all channels)
df_vis = df[(df.ch == 1) & (df.ill == 0)]

import importlib
importlib.reload(vis_utils)

vis_utils.view_neuroglancer(
    ome_zarr_paths=[str(fp) for fp in df_vis['filepath'].tolist()],
    sims=[msi_utils.get_sim_from_msim(msim) for msim in df_vis['msim']],
    transform_key='xml',
    # transform_key='ome-zarr',
    contrast_limits=(0, 500),
    single_layer=True,
)

## Illumination selection

Here we decide on which illumination to keep by calculating gradient magnitude to find the most informative views (inspired by BigStitcher).

In [None]:
# define a function that can be applied to each group of tile, angle, ch
def select_illumination(rows, scale):

    # calculate sum of gradient magnitudes across each channel
    grad_mags = []
    for msim in rows['msim']:
        daims = np.linalg.norm(
            np.gradient(msim[scale].image.data.squeeze())
            , axis=0
        )
        grad_mag = np.mean(daims)
        grad_mags.append(grad_mag)

    # determine the illumination with the highest gradient magnitude
    grad_mags = da.compute(grad_mags)[0]
    msim_index_to_select = np.argmax(grad_mags)

    print(f"Gradient magnitudes: {[float(gm) for gm in grad_mags]}")
    print(f"Selecting illumination {msim_index_to_select}")

    return rows.iloc[msim_index_to_select]

# apply the function to each group of tile, angle, ch
dfi = df.groupby(['tile', 'angle', 'ch'])[['msim', 'ill']].apply(
    select_illumination,
    scale='scale3' # resolution level on which to calculate gradient magnitudes
).reset_index()

# merge with original dataframe to recover all columns
dfi = dfi.merge(df, on=['tile', 'angle', 'ch', 'ill'], suffixes=('_', ''))

dfi

## Registration

Here, we'll use phase correlation based registration to register the tiles.

Note that this step (currently) requires that at least the overlap of two neighboring tiles needs to fit into memory.

How to achieve this?
- choose a suitable scale to register on
- advanced: Define a [custom registration function](https://multiview-stitcher.github.io/multiview-stitcher/main/extension_api_pairwise_registration/)

In [None]:
import dask.diagnostics

# select a channel for registration
reg_channel_index = 1
df_reg = dfi[(dfi.ch == reg_channel_index)]

# select a resolution level for registration
reg_res_level = 2

with dask.diagnostics.ProgressBar():
    registration.register(
            df_reg['msim'].tolist(),
            transform_key='ome-zarr',
            new_transform_key='phase_corr_registered',
            reg_channel_index=0,
            # registration_binning={'z': 2, 'y': 2, 'x': 2},
            reg_res_level=reg_res_level,
            n_parallel_pairwise_regs=4, # trade-off speed vs memory requirements (estimate of required memory: 2 * n_parallel_pairwise_regs * overlap_data_size))
        )


## Visualize registration result

In [None]:
df_vis = df_reg

vis_utils.view_neuroglancer(
    ome_zarr_paths=[str(fp) for fp in df_vis['filepath'].tolist()],
    sims=[msi_utils.get_sim_from_msim(msim) for msim in df_vis['msim']],
    transform_key='phase_corr_registered',
    contrast_limits=(0, 500),
    single_layer=False,
)

## Combine channels

Combine stacks from different channels into a single stack.

In doing so, the registration transforms obtained in the previous step are copied from the registered channel to the other channels.

In [None]:
# define a function to combine msims along a given dimension
def combine_msims_along_dim(msims, concat_kwargs={}, dim='c'):

    with xr.set_options(keep_attrs=True):
        return xr.DataTree.from_dict(
            {sk: xr.concat([msim[sk].dataset for msim in msims],
                    dim=dim, data_vars='different', **concat_kwargs) for sk in list(msims)[0].keys()}
        )

# merge channels
dfic = dfi.groupby(['tile', 'angle'])['msim'].apply(
    combine_msims_along_dim, dim='c'
).reset_index()

dfic

## Fuse

In [None]:
import importlib
importlib.reload(fusion)

# Define output Zarr URL
# make sure to set an unexisting / unused output path
output_zarr_url = "fused_mesospim.zarr"

# define which transform key to use for fusion
# fusion_transform_key = 'ome-zarr' # fuse without alignment
fusion_transform_key = 'phase_corr_registered' # fuse with alignment

# msims = dfic['msim'].tolist()

msims = df_reg['msim'].tolist()
sims = [msi_utils.get_sim_from_msim(
    msim,
    scale='scale2' # resolution level to fuse
    ) for msim in msims]

fused = fusion.fuse(
    sims=sims,
    transform_key=fusion_transform_key,
    output_chunksize={dim: 256 for dim in ['z', 'y', 'x']},
    blending_widths={"z": 1000, "y": 1000, "x": 1000}, # in microns
    # overlap_in_pixels can be left to default; blending widths handle boundary smoothing
    output_zarr_url=output_zarr_url,
    zarr_options={
        "ome_zarr": True,
        # "ngff_version": "0.4",  # optional
    },
    # optionally, we can use ray for parallelization (`pip install "ray[default]"`)
    # batch_options={
    #     "batch_func": misc_utils.process_batch_using_ray,
    #     "n_batch": 4, # number of chunk fusions to schedule / submit at a time
    #     "batch_func_kwargs": {
    #         'num_cpus': 4 # number of processes for parallel processing to use with ray
    #     },
    # },
)

## Visualize fused result using neuroglancer

In [None]:
# interrupt the notebook cell to stop the viewer
vis_utils.view_neuroglancer(
    sims=[fused],
    ome_zarr_paths=[output_zarr_url],
    channel_coord=1,
    transform_key=fusion_transform_key,
)