## Example 2D stitching workflow

This notebook uses the 2D example dataset provided by BigStitcher: https://imagej.net/plugins/bigstitcher/#example-datasets

Notes:
- install napari-stitcher for visualization functionality

In [None]:
import os
import numpy as np
from pathlib import Path
from tqdm import tqdm
import dask.diagnostics, tempfile

from multiview_stitcher import msi_utils, spatial_image_utils

%matplotlib notebook

# the package napari-stitcher is required for visualization with napari
VISUALIZE_USING_NAPARI = False

if VISUALIZE_USING_NAPARI:
    import napari
    from napari_stitcher import viewer_utils

In [None]:
# Start a dask cluster

from distributed import Client, LocalCluster

lc = LocalCluster(n_workers=1, threads_per_worker=None)
client = Client(lc)
client

## Download and specify input data

In [None]:
url = "https://preibischlab.mdc-berlin.de/BigStitcher/Grid_2d.zip"

base_dir = './%s' %os.path.basename(url)[:-4]
os.makedirs(base_dir, exist_ok=True)

zip_filepath = os.path.join(base_dir, os.path.basename(url))

# download
if not os.path.exists(zip_filepath):
    from urllib.request import urlretrieve
    urlretrieve(url, zip_filepath)

# unzip
import zipfile
with zipfile.ZipFile(os.path.join(base_dir, os.path.basename(url)), 'r') as zip_ref:
    zip_ref.extractall(base_dir)

In [None]:
filenames = [(os.path.join(base_dir, f)) for f in sorted(os.listdir(base_dir))
             if f.startswith('MAX') and f.endswith('.tif')]

filenames = [Path(fn) for fn in sorted(filenames)]

print('Files:')
print('\n'.join([fn.name for fn in filenames]))


## Convert input data to OME-Zarr

In [None]:
from multiview_stitcher import io

overwrite = True

msims = []
for filename in tqdm(filenames):
    store_path = filename.with_suffix('.zarr')
    if not os.path.exists(store_path) or overwrite:
        sim = io.read_tiff_into_spatial_xarray(
                        filename,
                        scale={'y': 1, 'x': 1},
                        )
        msim = msi_utils.get_msim_from_sim(sim, scale_factors=None) # choose scale factors automatically
        msim.to_zarr(store_path)
    msim = msi_utils.multiscale_spatial_image_from_zarr(Path(store_path))
    msims.append(msim)

## Set estimate of initial transformations

In [None]:
from multiview_stitcher import param_utils

overlap = 0.1
# use prior knowledge and arrange images on a 2*3 grid
for tile_index, msim in enumerate(msims):

    x_index = tile_index % 2
    y_index = tile_index // 2

    tile_extent = spatial_image_utils.get_center_of_sim(msi_utils.get_sim_from_msim(msim)) * 2
    y_extent, x_extent = tile_extent
    
    affine = param_utils.affine_from_translation(
        [y_index * (1 - overlap) * y_extent, x_index * (1 - overlap) * x_extent])

    msi_utils.set_affine_transform(
        msim,
        affine[None], # one tp
        transform_key='affine_manual',
    )


### Visualize pre-registered views

In [None]:
if VISUALIZE_USING_NAPARI:

    ch_coord = 0 # None if all should be shown

    viewer = napari.Viewer(ndisplay=msi_utils.get_ndim(msims[0]))
    lds = viewer_utils.create_image_layer_tuples_from_msims(
        msims, ch_coord=ch_coord, transform_key='affine_manual', n_colors=2, contrast_limits=[0, 500])
    viewer_utils.add_image_layer_tuples_to_viewer(viewer, lds)
    viewer.reset_view()

### Register views

In [None]:
# import skimage.registration
from multiview_stitcher import registration

with dask.diagnostics.ProgressBar():

    params = registration.register(
        # [msi_utils.get_sim_from_msim(msim) for msim in msims],
        msims,
        # registration_binning={'z': 2, 'y': 8, 'x': 8},
        reg_channel_index=0,
        transform_key='affine_manual',
    )

for msim, param in zip(msims, params):
    msi_utils.set_affine_transform(msim, param, transform_key='affine_registered', base_transform_key='affine_manual')

### Visualize registration

In [None]:
if VISUALIZE_USING_NAPARI:
    
    viewer = napari.Viewer(ndisplay=msi_utils.get_ndim(msims[0]))

    ch_coord = 0 # None if all should be shown

    lds = viewer_utils.create_image_layer_tuples_from_msims(
        msims, ch_coord=ch_coord, transform_key='affine_registered', n_colors=2,
        name_prefix='registered view',
        contrast_limits=[0, 500],
        )
    rlayers = viewer_utils.add_image_layer_tuples_to_viewer(viewer, lds, do_link_layers=True)

### Fuse views (linear blending)

In [None]:
from multiview_stitcher import fusion
import dask.array as da

sims = [msi_utils.get_sim_from_msim(msim) for msim in msims]

tmpdir = tempfile.TemporaryDirectory()

import importlib
fusion = importlib.reload(fusion)

fused = fusion.fuse(
    sims[:],
    transform_key='affine_registered',
    output_chunksize=256,
    )

print('Fusing views...')
with dask.diagnostics.ProgressBar():

    fused.data = da.to_zarr(
        fused.data,
        os.path.join(tmpdir.name, 'fused_sim.zarr'),
        overwrite=True, return_stored=True, compute=True)

print('Creating multiscale output OME-Zarr...')
with dask.diagnostics.ProgressBar():

    mfused = msi_utils.get_msim_from_sim(fused, scale_factors=None)

    fused_path = os.path.join(tmpdir.name, 'fused.zarr')
    mfused.to_zarr(fused_path)

mfused = msi_utils.multiscale_spatial_image_from_zarr(fused_path)

### Visualize fusion in napari

In [None]:
if VISUALIZE_USING_NAPARI:

    viewer = napari.Viewer(ndisplay=msi_utils.get_ndim(msims[0]))

    ch_coord = None

    lds = viewer_utils.create_image_layer_tuples_from_msims(
        [mfused],
        ch_coord=ch_coord,
        transform_key='affine_registered',
        name_prefix='fused',
        contrast_limits=[0, 500])

    viewer_utils.add_image_layer_tuples_to_viewer(viewer, lds)

In [None]:
# stream presaved fused image to tif

from multiview_stitcher import io

with dask.diagnostics.ProgressBar():
    io.save_sim_as_tif(
        os.path.join(base_dir, os.path.basename(url)[:4]+'_fused.tif'),
        msi_utils.get_sim_from_msim(mfused))