# exaSPIM example workflow

## Intro

This notebook provides a proof of principle for processing a exaSPIM dataset comprised of:
- 15 tiles (5x3 grid)
- 2 channels
- Full data size ~ 150 TB

The data is made available by the Allen Institute for Neural Dynamics and browseable via https://open.quiltdata.com/b/aind-open-data/tree/?prefix=exa.

## Setup

1. Setup Python in your preferred way (e.g. conda, venv, pipenv, poetry, etc.)
1. Make sure `multiview-stitcher >= 0.1.35` is installed: `pip install "multiview-stitcher>=0.1.35"`
1. Optionally, install `ray` for parallelising fusion on top of dask: `pip install "ray[default]"`

In [None]:
# imports

from pathlib import Path
import os
import dask.array as da
import numpy as np
import xarray as xr
import pandas as pd

import ngff_zarr
import pydantic_bigstitcher as pbs

from multiview_stitcher import spatial_image_utils as si_utils
from multiview_stitcher import (
    registration,
    fusion,
    param_utils,
    msi_utils,
    misc_utils,
    vis_utils,
    ngff_utils,
)

%matplotlib ipympl

## Define input paths

All data lives in the cloud in OME-Zarr format and we can directly access it via the web urls. Here we set the URLs.

In [None]:
bucket_url = "https://aind-open-data.s3.amazonaws.com"

base_path = "exaSPIM_674185_2023-10-02_14-06-36"
first_tile_path = "exaSPIM.zarr/tile_x_0000_y_0000_z_0000_ch_488.zarr"
metadata_path = "exaSPIM_acquisition.json"

n_tiles_x = 5
n_tiles_y = 3
# n_tiles_x = 2
# n_tiles_y = 1
channels = ["561"]

## Read data and metadata

Note that no pixel data is read at this stage, only metadata.

In [None]:
# load OME-Zarr multiscale images

df = []
for ix in range(n_tiles_x):
    for iy in range(n_tiles_y):
        for ch in channels:
            # TODO: this could be cleaner
            tile_path = first_tile_path.replace("tile_x_0000_y_0000", f"tile_x_{ix:04d}_y_{iy:04d}")
            tile_path = tile_path.replace("ch_488", f"ch_{ch}")
            file_url = os.path.join(bucket_url, base_path, tile_path)
            print(f"Loading tile x={ix}, y={iy}, ch={ch}... filepath={tile_path}")
            msim = ngff_utils.ngff_multiscales_to_msim(
                ngff_zarr.from_ngff_zarr(file_url),
                transform_key='ome-zarr'
                )
            df.append({
                'ix': ix,
                'iy': iy,
                'filename': os.path.basename(tile_path),
                'file_url': file_url,
                'ch': ch,
                'msim': msim
                })

# combine everything into a dataframe
df = pd.DataFrame(df)
df

## Correct origins for multiscale data

The OME-Zarr files contain multiscale data. However, the origins of the different scales are not correctly set. We correct them here.

In [None]:
def correct_origins(msim):
    # # correct origins at each scale
    sks = msi_utils.get_sorted_scale_keys(msim)
    spacing0 = si_utils.get_spacing_from_sim(msi_utils.get_sim_from_msim(msim, sks[0]))
    origin0 = si_utils.get_origin_from_sim(msi_utils.get_sim_from_msim(msim, sks[0]))
    sdims = msi_utils.get_spatial_dims(msim)

    sim0 = msi_utils.get_sim_from_msim(msim, sks[0])
    shape0 = {dim: len(sim0.coords[dim]) for dim in sdims}
    msim = msim.map_over_datasets(lambda ds: xr.Dataset(
        {'image': ds.image.assign_coords(
            {dim: ds.image.coords[dim] - ds.image.coords[dim].values[0] + origin0[dim]\
              + (np.round(shape0[dim] / len(ds.image.coords[dim])) - 1) / 2 * spacing0[dim]
              for dim in sdims}
                )} | \
        {t: ds.data_vars[t] for t in ds.data_vars if t != 'image'})
        if len(ds.data_vars) > 0 else ds)

    return msim

df['msim'] = df['msim'].apply(lambda msim: correct_origins(msim))

## Visualize tile configuration

In [None]:
# visualize the tile configuration and check it's properly set
from multiview_stitcher import vis_utils, msi_utils, fusion
vis_utils.plot_positions(
    df["msim"].tolist(), transform_key='ome-zarr'
)

In [None]:
import importlib
importlib.reload(vis_utils)

vis_utils.view_neuroglancer(
    df['file_url'].tolist(),
    sims=[msi_utils.get_sim_from_msim(msim) for msim in df['msim']],
    transform_key='ome-zarr',
    contrast_limits=(0, 150),
    single_layer=True,
)

## Registration

Here, we'll use phase correlation based registration to register the tiles.

Note that this step (currently) requires that at least the overlap of two neighboring tiles needs to fit into memory.

How to achieve this?
- choose a suitable scale to register on
- advanced: Define a [custom registration function](https://multiview-stitcher.github.io/multiview-stitcher/main/extension_api_pairwise_registration/)

In [None]:
# print data sizes for the different scales
print("Data sizes for different scales (first tile):")
msim = df['msim'][0]
print('Dimensions:', msim['scale0/image'].dims)
for scale in msim:
    print(f"Scale {scale}: {msim[scale].image.shape}")

In [None]:
from multiview_stitcher import registration
import dask.diagnostics

with dask.diagnostics.ProgressBar():
    registration.register(
            df['msim'].tolist(),
            transform_key='ome-zarr',
            new_transform_key='phase_corr',
            reg_channel_index=0,
            # registration_binning={'z': 1, 'y': 1, 'x': 1},
            reg_res_level=5,
            n_parallel_pairwise_regs=2,  # trade-off speed vs memory requirements (estimate of required memory: 2 * n_parallel_pairwise_regs * overlap_data_size))
        )

In [None]:
# visualize obtained tile configuration after registration
# (this doesn't show image data so we can mostly use it to
# get an idea of the corrected tile layout)

vis_utils.plot_positions(
    df["msim"].tolist(), transform_key='phase_corr'
)

## Visualize segmentation result

In [None]:
vis_utils.view_neuroglancer(
    df['file_url'].tolist(),
    sims=[msi_utils.get_sim_from_msim(msim) for msim in df['msim']],
    transform_key='phase_corr',
    single_layer=True, # setting this to true can improve neuroglancer performance
)

## Fusion

We save the output to a (local) multiscale OME-Zarr file.

The progress bar gives an estimate of the processing time. Different scales can be used for fusion. Also different shapes and offsets can be specified for the output.

In [None]:
# Define output Zarr URL
# make sure to set an unexisting / unused output path
output_zarr_url = "fused_exa.zarr"

# define which transform key to use for fusion
# fusion_transform_key = 'ome-zarr' # fuse without alignment
fusion_transform_key = 'phase_corr' # fuse with alignment

msims = df['msim'].tolist()
sims = [msi_utils.get_sim_from_msim(
    msim,
    # scale='scale0', # set the scale to be used for fusion
    scale='scale5', # set the scale to be used for fusion
    )
    for msim in msims]

fused = fusion.fuse_to_multiscale_ome_zarr(
    fuse_kwargs={
        "sims": sims,
        "transform_key": fusion_transform_key,
        "output_chunksize": {dim: 256 for dim in ['z', 'y', 'x']},
        # "output_shape": {'z': 500, 'y': 500, 'x': 500}, # option to test smaller output
        "blending_widths": {"z": 1000, "y": 1000, "x": 1000},
    },
    output_zarr_url=output_zarr_url,
    overwrite=True, # whether to overwrite existing output Zarr (if it exists)
    # optionally, we can use ray for parallelization (`pip install "ray[default]"`)
    # batch_func=misc_utils.process_batch_using_ray,
    # n_batch=4, # number of chunk fusions to schedule / submit at a time
    # batch_func_kwargs={
    #     'num_cpus': 4 # number of processes for parallel processing to use with ray
    #     },
)

## Visualize the fused dataset

In [None]:
# interrupt the notebook cell to stop the viewer
vis_utils.view_neuroglancer(
    sims=[fused],
    ome_zarr_paths=[output_zarr_url],
    channel_coord=0,
    transform_key=fusion_transform_key,
)