In [1]:
import functools as ft
import jax
import jax.numpy as jnp
import numpy as np

import zarr_io
import fine_registration

from sofima import stitch_elastic, flow_utils, mesh  # Other stuffs

In [2]:
data = np.load('coarse_results_6.npz')
cx = data['cx']
cy = data['cy']
coarse_mesh = data['mesh']

In [3]:
# Load raw tiles: 
bucket = 'sofima-test-bucket'
downsampling_exp = 2
path_0 = f'tile_X_0001_Y_0000_Z_0000_CH_0405_cam0.zarr/{downsampling_exp}'  
path_1 = f'tile_X_0002_Y_0000_Z_0000_CH_0405_cam0.zarr/{downsampling_exp}'
path_2 = f'tile_X_0003_Y_0000_Z_0000_CH_0405_cam0.zarr/{downsampling_exp}'
path_3 = f'tile_X_0004_Y_0000_Z_0000_CH_0405_cam0.zarr/{downsampling_exp}'
path_4 = f'tile_X_0005_Y_0000_Z_0000_CH_0405_cam0.zarr/{downsampling_exp}'
path_5 = f'tile_X_0006_Y_0000_Z_0000_CH_0405_cam0.zarr/{downsampling_exp}'

tile_0 = zarr_io.open_zarr(bucket, path_0)
tile_1 = zarr_io.open_zarr(bucket, path_1)
tile_2 = zarr_io.open_zarr(bucket, path_2)
tile_3 = zarr_io.open_zarr(bucket, path_3)
tile_4 = zarr_io.open_zarr(bucket, path_4)
tile_5 = zarr_io.open_zarr(bucket, path_5)

# Existing data structures:
tile_layout = np.array([[0], 
                        [1],
                        [2],
                        [3],
                        [4],
                        [5]])
tile_volumes = [tile_0[0,:,:,:,:], 
                tile_1[0,:,:,:,:],
                tile_2[0,:,:,:,:],
                tile_3[0,:,:,:,:],
                tile_4[0,:,:,:,:],
                tile_5[0,:,:,:,:]]

# Replacing 'tile_map' with SyncAdapter objects and adopting reverse basis
tile_map = {(0, 0): fine_registration.SyncAdapter(tile_volumes[0]), 
            (0, 1): fine_registration.SyncAdapter(tile_volumes[1]),
            (0, 2): fine_registration.SyncAdapter(tile_volumes[2]),
            (0, 3): fine_registration.SyncAdapter(tile_volumes[3]),
            (0, 4): fine_registration.SyncAdapter(tile_volumes[4]),
            (0, 5): fine_registration.SyncAdapter(tile_volumes[5])}


I0000 00:00:1684538502.052938   29644 gcs_resource.cc:102] Using default AdmissionQueue with limit 32
I0000 00:00:1684538502.055383   29690 google_auth_provider.cc:179] Running on GCE, using service account 895865026362-compute@developer.gserviceaccount.com


In [4]:
# Fine Registration, compute patch flows
stride = 20, 20, 20
tile_size_xyz = (576, 576, 3544)  # Yet it expects the tiles as 1zyx...
flow_x, offsets_x = fine_registration.compute_flow_map3d(tile_map,
                                                        tile_size_xyz, cx, axis=0,
                                                        stride=stride,
                                                        patch_size=(80, 80, 80))

flow_y, offsets_y = fine_registration.compute_flow_map3d(tile_map,
                                                        tile_size_xyz, cy, axis=1,
                                                        stride=stride,
                                                        patch_size=(80, 80, 80))

np.savez_compressed('flow_results_6.npz', flow_x=flow_x, flow_y=flow_y, offsets_x=offsets_x, offsets_y=offsets_y)

(slice(None, None, None), slice(0, 3544, None), slice(280, 576, None), slice(0, 576, None))
(slice(None, None, None), slice(0, 3544, None), slice(0, 296, None), slice(0, 576, None))
(slice(None, None, None), slice(0, 3544, None), slice(280, 576, None), slice(0, 576, None))
(slice(None, None, None), slice(0, 3544, None), slice(0, 296, None), slice(0, 576, None))
(slice(None, None, None), slice(0, 3544, None), slice(280, 576, None), slice(0, 576, None))
(slice(None, None, None), slice(0, 3544, None), slice(0, 296, None), slice(0, 576, None))
(slice(None, None, None), slice(0, 3544, None), slice(280, 576, None), slice(0, 576, None))
(slice(None, None, None), slice(0, 3544, None), slice(0, 296, None), slice(0, 576, None))
(slice(None, None, None), slice(0, 3544, None), slice(280, 576, None), slice(0, 576, None))
(slice(None, None, None), slice(0, 3544, None), slice(0, 296, None), slice(0, 576, None))


In [5]:
# Fine Registration, filter patch flows
kwargs = {"min_peak_ratio": 1.4, "min_peak_sharpness": 1.4, "max_deviation": 5, "max_magnitude": 0, "dim": 3}
fine_x = {k: flow_utils.clean_flow(v, **kwargs) for k, v in flow_x.items()}
fine_y = {k: flow_utils.clean_flow(v, **kwargs) for k, v in flow_y.items()}

kwargs = {"min_patch_size": 10, "max_gradient": -1, "max_deviation": -1}
fine_x = {k: flow_utils.reconcile_flows([v], **kwargs) for k, v in fine_x.items()}
fine_y = {k: flow_utils.reconcile_flows([v], **kwargs) for k, v in fine_y.items()}

In [6]:
# Fine Registration, update mesh (convert coarse tile mesh into fine patch mesh)
data_x = (cx[:, 0, ...], fine_x, offsets_x)
data_y = (cy[:, 0, ...], fine_y, offsets_y)

fx, fy, init_x, nbors, key_to_idx = stitch_elastic.aggregate_arrays(
    data_x, data_y, list(tile_map.keys()),
    coarse_mesh[:, 0, ...], stride=stride, tile_shape=tile_size_xyz[::-1])

@jax.jit
def prev_fn(x):
  target_fn = ft.partial(stitch_elastic.compute_target_mesh, x=x, fx=fx, fy=fy, stride=stride)
  x = jax.vmap(target_fn)(nbors)
  return jnp.transpose(x, [1, 0, 2, 3, 4])

config = mesh.IntegrationConfig(dt=0.001, gamma=0., k0=0.01, k=0.1, stride=stride,
                                num_iters=1000, max_iters=20000, stop_v_max=0.001,
                                dt_max=100, prefer_orig_order=False,
                                start_cap=0.1, final_cap=10., remove_drift=True)

x, ekin, t = mesh.relax_mesh(init_x, None, config, prev_fn=prev_fn, mesh_force=mesh.elastic_mesh_3d)

np.savez_compressed('solved_mesh_st20_6.npz', x=x, key_to_idx=key_to_idx)  # This 'x' is the solved patch mesh(es).