## Example workflow to register multi-view light sheet data

Notes:
- install napari-stitcher for visualization functionality
- registration: so far only translation registration is performed
- fusion: only vanilla linear blending currently supported
- generally
  - this is a first hacky workflow that will change in API and become simplified
  - documentation will follow

In [None]:
# imports

import os
import numpy as np
from pathlib import Path
from tqdm import tqdm
import dask.diagnostics, tempfile

from multiview_stitcher import msi_utils, spatial_image_utils

%matplotlib notebook

# the package napari-stitcher is required for visualization with napari
VISUALIZE_USING_NAPARI = False

if VISUALIZE_USING_NAPARI:
    import napari
    from napari_stitcher import viewer_utils

In [None]:
# Start a dask cluster

from distributed import Client, LocalCluster

lc = LocalCluster(n_workers=1, threads_per_worker=None)
client = Client(lc)
client


## Specify input data

In [None]:
base_dir = '../image-datasets/multi-view/old_mDSLM_classical_4_angles_10x_0.3NA_detection'
filenames = [(os.path.join(base_dir, f)) for f in os.listdir(base_dir) if f.endswith('.tif')]

# sort angles
filenames = [Path(fn) for fn in sorted(filenames)]
print('Files:')
print('\n'.join([fn.name for fn in filenames]))


## Convert input data to OME-Zarr

In [None]:
from multiview_stitcher import io

overwrite = False

msims = []
for filename in tqdm(filenames):
    store_path = filename.with_suffix('.zarr')
    if not os.path.exists(store_path) or overwrite:
        sim = io.read_tiff_into_spatial_xarray(
                        filename,
                        scale={'z': 2.58, 'y': 0.645, 'x': 0.645})
        msim = msi_utils.get_msim_from_sim(sim, scale_factors=None) # choose scale factors automatically
        msim.to_zarr(store_path)
    msim = msi_utils.multiscale_spatial_image_from_zarr(Path(store_path))
    msims.append(msim)

# alternative
# msims = []
# for filename in tqdm(filenames):
#     msim = msi_utils.get_store_decorator(
#         filename.with_suffix('.zarr'),
#         store_overwrite=False)(
#             msi_utils.get_msim_from_sim)(
#                 io.read_tiff_into_spatial_xarray(
#                     filename,
#                     scale={'z': 2.58, 'y': 0.645, 'x': 0.645}
#                 ))
#     msims.append(msim)


## Set estimate of initial transformations

In [None]:
from multiview_stitcher import param_utils

for imsim, msim in enumerate(msims):

    # 90 degree rotation around x axis
    affine = param_utils.affine_from_rotation(
        -np.pi/2 * imsim,
        point=spatial_image_utils.get_center_of_sim(msims[imsim]['scale0/image'], transform_key=None),
        direction=[0,0,1],
        )

    msi_utils.set_affine_transform(
        msim,
        affine[None], # one tp
        transform_key='affine_manual',
    )


### Visualize pre-registered views

In [None]:
if VISUALIZE_USING_NAPARI:

    viewer = napari.Viewer(ndisplay=3)
    lds = viewer_utils.create_image_layer_tuples_from_msims(
        msims, transform_key='affine_manual', n_colors=4, contrast_limits=[0, 1000])
    viewer_utils.add_image_layer_tuples_to_viewer(viewer, lds)
    viewer.reset_view()
    

### Register views

In [None]:
# import skimage.registration
from multiview_stitcher import registration

with dask.diagnostics.ProgressBar():

    params = registration.register(
        # [msi_utils.get_sim_from_msim(msim) for msim in msims],
        msims,
        registration_binning={'z': 2, 'y': 8, 'x': 8},
        reg_channel_index=0,
        transform_key='affine_manual',
    )
    
for msim, param in zip(msims, params):
    msi_utils.set_affine_transform(msim, param, transform_key='affine_registered', base_transform_key='affine_manual')

### Visualize registration

In [None]:
if VISUALIZE_USING_NAPARI:
    
    viewer = napari.Viewer(ndisplay=3)

    lds = viewer_utils.create_image_layer_tuples_from_msims(
        msims, transform_key='affine_manual', n_colors=4,
        name_prefix='pre-registered view',
        contrast_limits=[0, 1000],
        )
    mlayers = viewer_utils.add_image_layer_tuples_to_viewer(viewer, lds, do_link_layers=True)

    lds = viewer_utils.create_image_layer_tuples_from_msims(
        msims, transform_key='affine_registered', n_colors=4,
        name_prefix='registered view',
        contrast_limits=[0, 1000],
        )
    rlayers = viewer_utils.add_image_layer_tuples_to_viewer(viewer, lds, do_link_layers=True)

### Fuse views (linear blending)

In [None]:
from multiview_stitcher import fusion
import dask.array as da

sims = [msi_utils.get_sim_from_msim(msim) for msim in msims]

tmpdir = tempfile.TemporaryDirectory()

import importlib
fusion = importlib.reload(fusion)

fused = fusion.fuse(
    sims[:],
    transform_key='affine_registered',
    output_spacing={dim: 10 for dim in ['z', 'y', 'x']},
    output_chunksize=128,
    )

print('Fusing views...')
with dask.diagnostics.ProgressBar():

    fused.data = da.to_zarr(
        fused.data,
        os.path.join(tmpdir.name, 'fused_sim.zarr'),
        overwrite=True, return_stored=True, compute=True)

print('Creating multiscale output OME-Zarr...')
with dask.diagnostics.ProgressBar():

    mfused = msi_utils.get_msim_from_sim(fused, scale_factors=None)

    fused_path = os.path.join(tmpdir.name, 'fused.zarr')
    mfused.to_zarr(fused_path)
    
mfused = msi_utils.multiscale_spatial_image_from_zarr(fused_path)

### Visualize fusion in napari

In [None]:
if VISUALIZE_USING_NAPARI:

    viewer = napari.Viewer(ndisplay=3)

    lds = viewer_utils.create_image_layer_tuples_from_msims(
        msims, transform_key='affine_registered', n_colors=4,
        name_prefix='registered view',
        contrast_limits=[0, 1000]
        )

    rlayers = viewer_utils.add_image_layer_tuples_to_viewer(
        viewer, lds, do_link_layers=False)

    lds = viewer_utils.create_image_layer_tuples_from_msim(
        mfused,
        transform_key='affine_registered',
        name_prefix='fused',
        contrast_limits=[0, 1000])

    viewer_utils.add_image_layer_tuples_to_viewer(viewer, lds)

In [None]:
# stream presaved fused image to tif

from multiview_stitcher import io

with dask.diagnostics.ProgressBar():
    io.save_sim_as_tif('fused.tif', msi_utils.get_sim_from_msim(mfused))