In [1]:
import functools as ft
import jax
import jax.numpy as jnp
import numpy as np
import time
import tensorstore as ts

from ng_link import NgState, link_utils

import zarr_io
import coarse_registration
import fine_registration
import fusion

from sofima import stitch_rigid, flow_utils, stitch_elastic, mesh, map_utils
from sofima.processor import warp   # tensorflow dependency, very weird

from connectomics.common import bounding_box
from connectomics.common import box_generator
from connectomics.volume import subvolume

In [2]:
# Application Inputs
# Changing to two tiles

READ_BUCKET = 'aind-open-data'
WRITE_BUCKET = 'sofima-test-bucket'

DATASET = 'diSPIM_647459_2022-12-07_00-00-00/diSPIM.zarr'
DOWNSAMPLE_EXP = 2

tile_layout = np.array([[1],
                        [0]])

In [3]:
# Coarse Registration Data Loading

# Tile volumes simply maps index to volume. 
vox_sizes_xyz = [0.298, 0.298, 0.176]  # In um
channels = [405, 488, 561, 638]
tile_volumes = []
tile_paths = []
for channel in channels:
    c_paths = []
    for i in range(0, tile_layout.shape[0]):  # + 1 Temporary hack
        if i < 10: 
            i = f"0{i}" 
        path = f"tile_X_00{i}_Y_0000_Z_0000_CH_0{channel}_cam1.zarr"
        c_paths.append(DATASET + '/' + path)
        
        if channel == 405:   # Just selecting one
            tile = zarr_io.open_zarr_s3(READ_BUCKET, 
                                        DATASET + f'/{path}/{DOWNSAMPLE_EXP}')
            print(path)
            print(tile.T[:,:,:,0,0].shape)
            tile_volumes.append(tile.T[:,:,:,0,0])
    tile_paths.append(c_paths)

tile_X_0000_Y_0000_Z_0000_CH_0405_cam1.zarr
(576, 576, 3543)
tile_X_0001_Y_0000_Z_0000_CH_0405_cam1.zarr
(576, 576, 3551)


In [4]:
# Coarse Registration
cx, cy = coarse_registration.compute_coarse_offsets(tile_layout, tile_volumes, False)
coarse_mesh = stitch_rigid.optimize_coarse_mesh(cx, cy, mesh_fn=stitch_rigid.elastic_tile_mesh_3d)
# np.savez_compressed('coarse_results_6_cam1.npz', mesh=coarse_mesh, cx=cx, cy=cy)

Top Id: 1, Bottom Id: 0
Top: (0, 0), Bot: (1, 0) [ -1. 285.  -4.]


In [5]:
# Fusing the downsampled images 
downsampled_stride = (20, 20, 20)
downsampled_tile_size_xyz = np.array(tile_volumes[0].shape)

# original_stride = tuple(np.array(downsampled_stride) * 2**DOWNSAMPLE_EXP)
# original_tile_size_xyz = downsampled_tile_size_xyz * 2**DOWNSAMPLE_EXP

class StitchAndRender3dTiles(fusion.StitchAndRender3dTiles):
  cache = {}

  def _open_tile_volume(self, tile_id: int):
    if tile_id in self.cache:
      return self.cache[tile_id]

    if tile_id < 10: 
      i = f"0{tile_id}" 
    path = f"tile_X_00{i}_Y_0000_Z_0000_CH_0405_cam1.zarr"
    tile = zarr_io.open_zarr_s3(READ_BUCKET, 
                                DATASET + f'/{path}/{2}')
    tile = tile[0,0,:,:,:] # convert to zyx axis layout (diff from coarse and fine)
    self.cache[tile_id] = fine_registration.SyncAdapter(tile)
    return self.cache[tile_id]

dim = len(downsampled_stride)
mesh_shape = (np.array(downsampled_tile_size_xyz[::-1]) // downsampled_stride).tolist()
downsampled_mesh = np.zeros([dim, len(tile_volumes)] + mesh_shape, dtype=np.float32)

_, _, my, mx = coarse_mesh.shape
mesh_index_to_key = {}

index = 0
for y, row in enumerate(tile_layout):
  for x, tile_id in enumerate(row):
    mesh_index_to_key[index] = (y, x)
    index += 1
key_to_mesh_index = {v:k for k, v in mesh_index_to_key.items()}

print(f'{mesh_index_to_key=}')
print(f'{key_to_mesh_index=}')

# Screw it, just create the key-map the same way as Michal did. 
key_to_idx = {(tx, ty): i for i, (tx, ty) in enumerate(tile_coords)}

# Okay it comes down to the confusing convention again
# and dependency on ad-hoc initialized data structure. 

# To standardize, need to standardize this tile coords object. 
# Also, I like the idea of factoring out the mesh initalization 
# in stitch elastic increasingly more. 


for ind, (tx, ty) in mesh_index_to_key.items():
  downsampled_mesh[:, ind, ...] = coarse_mesh[:, 0, ty, tx].reshape(
  (dim,) + (1,) * dim)
# Something fishy with this mesh initalization


cx[np.isnan(cx)] = 0    
cy[np.isnan(cy)] = 0
x_overlap = cx[2,0,0,0] / downsampled_tile_size_xyz[1]
y_overlap = cy[1,0,0,0] / downsampled_tile_size_xyz[0]
y_shape, x_shape = cx.shape[2], cx.shape[3]

fused_x = downsampled_tile_size_xyz[0] * (1 + ((x_shape - 1) * (1 - x_overlap)))
fused_y = downsampled_tile_size_xyz[1] * (1 + ((y_shape - 1) * (1 - y_overlap)))
fused_z = downsampled_tile_size_xyz[2]
fused_shape = [1, 1, int(fused_z), int(fused_y), int(fused_x)]
print(f'{fused_shape=}')

# Fusion input: Output path
FUSED_PATH = 'downsample_res_2_tiles_refactor_tmp.zarr'

# Fusion input: Crop offset 
start = np.array([np.inf, np.inf, np.inf])
map_box = bounding_box.BoundingBox(
  start=(0, 0, 0),
  size=downsampled_mesh.shape[2:][::-1],
)
for i in range(0, len(tile_volumes)): 
  tx, ty = mesh_index_to_key[i]
  mesh = downsampled_mesh[:, i, ...]
  tg_box = map_utils.outer_box(mesh, map_box, downsampled_stride)

  out_box = bounding_box.BoundingBox(
    start=(
      tg_box.start[0] * downsampled_stride[2] + tx * downsampled_tile_size_xyz[0],
      tg_box.start[1] * downsampled_stride[1] + ty * downsampled_tile_size_xyz[1],
      tg_box.start[2] * downsampled_stride[0],
    ),
    size=(
      tg_box.size[0] * downsampled_stride[2],
      tg_box.size[1] * downsampled_stride[1],
      tg_box.size[2] * downsampled_stride[0],
    )
  )
  start = np.minimum(start, out_box.start)
  print(f'{tg_box=}')
  print(f'{out_box=}')

offset = -start
print(offset)
# offset = (0, -160, 0)
offset = (0, 0, 0)

# Fusion time:
ds_out = zarr_io.write_zarr(WRITE_BUCKET, fused_shape, FUSED_PATH)
renderer = StitchAndRender3dTiles(
    tile_map=tile_layout,
    tile_mesh=downsampled_mesh,
    key_to_mesh_index=key_to_mesh_index,
    stride=downsampled_stride,
    offset=offset,
    parallelism=8
)

box = bounding_box.BoundingBox(start=(0,0,0), size=ds_out.shape[4:1:-1])  # Needs xyz 
gen = box_generator.BoxGenerator(box, (512, 512, 512), (0, 0, 0), True) # These are xyz
renderer.set_effective_subvol_and_overlap((512, 512, 512), (0, 0, 0))
for i, sub_box in enumerate(gen.boxes):
    t_start = time.time()

    # Feed in an empty subvol, with dimensions of sub_box. 
    inp_subvol = subvolume.Subvolume(np.zeros(sub_box.size[::-1], dtype=np.uint16)[None, ...], sub_box)
    ret_subvol = renderer.process(inp_subvol)  # czyx

    t_render = time.time()

    # ret_subvol is a 4D CZYX volume
    slice = ret_subvol.bbox.to_slice3d()
    slice = (0, 0, slice[0], slice[1], slice[2])
    ds_out[slice].write(ret_subvol.data[0, ...]).result()
    
    t_write = time.time()
    
    print('box {i}: {t1:0.2f} render  {t2:0.2f} write'.format(i=i, t1=t_render - t_start, t2=t_write - t_render))

mesh_index_to_key={0: (0, 0), 1: (0, 1)}
key_to_mesh_index={(0, 0): 0, (0, 1): 1}
fused_shape=[1, 1, 3543, 867, 576]
tg_box=BoundingBox(start=(0, -8, 0), size=(28, 29, 178), is_border_start=(False, False, False), is_border_end=(False, False, False))
out_box=BoundingBox(start=(0, -160, 0), size=(560, 580, 3560), is_border_start=(False, False, False), is_border_end=(False, False, False))
tg_box=BoundingBox(start=(0, 7, -1), size=(28, 29, 178), is_border_start=(False, False, False), is_border_end=(False, False, False))
out_box=BoundingBox(start=(0, 716, -20), size=(560, 580, 3560), is_border_start=(False, False, False), is_border_end=(False, False, False))
MutableArray([ -0., 160.,  20.])


I0000 00:00:1686850007.597435    7046 gcs_resource.cc:102] Using default AdmissionQueue with limit 32
I0000 00:00:1686850007.605294    7743 google_auth_provider.cc:179] Running on GCE, using service account 895865026362-compute@developer.gserviceaccount.com


In [11]:
renderer._tile_boxes

{0: (BoundingBox(start=(0, -160, 0), size=(560, 580, 3540), is_border_start=(False, False, False), is_border_end=(False, False, False)),
  BoundingBox(start=(0, -8, 0), size=(28, 29, 177), is_border_start=(False, False, False), is_border_end=(False, False, False))),
 1: (BoundingBox(start=(0, 716, 0), size=(560, 580, 3540), is_border_start=(False, False, False), is_border_end=(False, False, False)),
  BoundingBox(start=(0, 7, 0), size=(28, 29, 177), is_border_start=(False, False, False), is_border_end=(False, False, False)))}

In [15]:
renderer.cache[1].shape

(3551, 576, 576)

In [None]:
# Tg box is flipped, 
# Everything else is correct. 



In [None]:
# Appears like the calculated offset is very wrong-
# [0, 716, 0] != [0, 150, 0]    (magnitudes)
# which effects defintion of _tile_boxes

# Hint-- not accessing the SyncAdapter at all. 
# Getting stuck on inverting the map, and I know that 
# the tile boxes are different. 
# Will investigate shortly. At least I found the problem. 

# Difference between tile boxes is 876, should be 276. 

# Other states are correct: (maps, fused shape)
# Downsampled mesh:
# Should be: [3, 2, 177, 28, 28]
# Instead: [3, 2, 177, 28, 28]


In [18]:
start = np.array([np.inf, np.inf, np.inf])
map_box = bounding_box.BoundingBox(
  start=(0, 0, 0),
  size=downsampled_mesh.shape[2:][::-1],
)

for i in range(0, len(tile_volumes)):
  tx, ty = mesh_index_to_key[i]
  mesh = downsampled_mesh[:, i, ...]
  tg_box = map_utils.outer_box(mesh, map_box, downsampled_stride)

  out_box = bounding_box.BoundingBox(
    start=(
      tg_box.start[0] * downsampled_stride[2] + tx * downsampled_tile_size_xyz[0],
      tg_box.start[1] * downsampled_stride[1] + ty * downsampled_tile_size_xyz[1],
      tg_box.start[2] * downsampled_stride[0],
    ),
    size=(
      tg_box.size[0] * downsampled_stride[2],
      tg_box.size[1] * downsampled_stride[1],
      tg_box.size[2] * downsampled_stride[0],
    )
  )
  start = np.minimum(start, out_box.start)  # Must indent :P
  print(f'{tg_box=}')
  print(f'{out_box=}')

offset = -start
print(offset)


# Odd stuff. 
# Defining these boxes is defintely the wrong part. 

# Well this will obviously choose the last out_box
# also, I believe the error is related to tx, ty. 

# 716 is wrong!!!!
# how the heck is the other thing getting a different box



tg_box=BoundingBox(start=(0, -8, 0), size=(28, 29, 177), is_border_start=(False, False, False), is_border_end=(False, False, False))
out_box=BoundingBox(start=(0, -160, 0), size=(560, 580, 3540), is_border_start=(False, False, False), is_border_end=(False, False, False))
tg_box=BoundingBox(start=(0, 7, 0), size=(28, 29, 177), is_border_start=(False, False, False), is_border_end=(False, False, False))
out_box=BoundingBox(start=(0, 716, 0), size=(560, 580, 3540), is_border_start=(False, False, False), is_border_end=(False, False, False))
MutableArray([ -0., 160.,  -0.])
