In [None]:
# Licensed under the Apache License, Version 2.0. See LICENSE file in the project root for license information.

import numpy as np
import warnings, glob
warnings.filterwarnings("ignore")
import cupy as cp 
import torch 
import stitching   as st 
import utility     as ut
import distortion  as ds 

# clean memory 
torch.cuda.empty_cache()
cp._default_memory_pool.free_all_blocks()

# load data
data          = np.load('data.npz')
tiles         = data['tiles']              # tiles in     [t, c, y, x] float32 format
positions_ini = data['positions_ini']      # positions in [t, 2] format yx global position of each tile 

# sharpen images 
tiles = st.sharpen_tiles(tiles)  # sharpening

# chromatic correction
chromo_shifts = np.array( st.compute_chromatic(tiles) )
chromo_shifts_mean, _, outlier_tiles = st.compute_ransac_average_shifts(chromo_shifts)
tiles_chromo_correct = st.chrom_correct(tiles, chromo_shifts)
 
# pre_stitching, without distortion correction 
positions_pre = st.pr_stitching(tiles_chromo_correct, positions_ini)

# distortion correction and stitching 
positions, k1color_s = st.compute_k1_recursive_colors(tiles_chromo_correct, positions_pre, outlier_tiles)
tiles_correct = tiles_chromo_correct.copy()
nchannels = tiles.shape[1]
for ch in range(nchannels):
    tiles_cupy = ds.undistort_tiles_batched(tiles[:, ch], k1color_s[:, ch])
    tiles_correct[:, ch] = cp.asnumpy(tiles_cupy)
    del tiles_cupy
    cp.get_default_memory_pool().free_all_blocks()

# assemble canvas
canvas = ut.assemble_canvas(tiles_correct, positions)

# save results 
f = 'canvas.tif'
ut.save_fiji(canvas, f, dimension_order='CYX', normalization=False)
