In [1]:
import functools as ft
import jax
import jax.numpy as jnp
import numpy as np
import time
import tensorstore as ts

from ng_link import NgState, link_utils

import zarr_io
import coarse_registration
import fine_registration
import fusion

from sofima import stitch_rigid, flow_utils, stitch_elastic, mesh, map_utils
from sofima.processor import warp   # tensorflow dependency, very weird

from connectomics.common import bounding_box
from connectomics.common import box_generator
from connectomics.volume import subvolume

Automation Domain: 

Data shape inconsistencies. 
Try running on full dataset, and you will find out.
Need to run .resize(shape).result as a preprocessing step, which requires write permissions.

- Coarse Registration: 
    - Registration based on peak sharpness

- Elastic Registration: 
    - Channel Parameter

(Would be worthwhile understanding and modifying this bit)
- Fusion:
    - Image reading for coarse registration + fusion references hardcoded paths. 
    These need to point to the database instead for code ocean.
    - Calculate fused image boundaries based on image size + coarse offsets.
    - Calculate the fusion offset. 
    - Debug the OOB error?
        |-- consider passing in tensorstore views of min size among size of all tensorstore sizes. 
    - tile mesh path, tile pattern path are awkward to work with

Write can only be made to GCP bucket, not S3.
Effects the fusion writer.

(Some annoying bugs, that's okay)

In [13]:
# Ng Link Stuff
def _zyx_vector_to_3x4(zyx_vector: np.ndarray):
    output = np.zeros((3, 4))
    
    # Set identity
    output[0, 0] = 1
    output[1, 1] = 1
    output[2, 2] = 1

    # Set translation vector
    output[0, 3] = zyx_vector[2]
    output[1, 3] = zyx_vector[1]
    output[2, 3] = zyx_vector[0]
    return output

def convert_matrix_3x4_to_5x6(matrix_3x4: np.ndarray) -> np.ndarray:
    # Initalize
    matrix_5x6 = np.zeros((5, 6), np.float16)
    np.fill_diagonal(matrix_5x6, 1)

    # Swap Rows 0 and 2; Swap Colums 0 and 2
    patch = np.copy(matrix_3x4)
    patch[[0, 2], :] = patch[[2, 0], :]
    patch[:, [0, 2]] = patch[:, [2, 0]]

    # Place patch in bottom-right corner
    matrix_5x6[2:6, 2:7] = patch

    return matrix_5x6

def apply_deskewing(matrix_3x4: np.ndarray, theta: float = -45) -> np.ndarray:
    # Deskewing
    # X vector => XZ direction
    deskew_factor = np.tan(np.deg2rad(theta))
    deskew = np.array([[1, 0, 0], [0, 1, 0], [deskew_factor, 0, 1]])
    matrix_3x4 = deskew @ matrix_3x4

    return matrix_3x4

# Simply applies the same registration across channels, b/c coreg is extra effort for little gain
# Notice the same tile layout, cx, and cy will be applied to each of the nested lists of paths. 
def create_ng_link(tile_paths: list[list[str]], 
                   tile_layout: np.ndarray, 
                   coarse_mesh: np.ndarray, 
                   vox_sizes_xyz: np.ndarray, 
                   channels: list[int],
                   max_dr: int = 200,
                   opacity: float = 1.0, 
                   deskew_angle: int = -45,
                   blend: str = "default",
                   output_json_path: str = ".") -> None:
    
    # Generate input config
    layers = []  # Nueroglancer Tabs
    input_config = {
        "dimensions": {
            "x": {"voxel_size": vox_sizes_xyz[0], "unit": "microns"},
            "y": {"voxel_size": vox_sizes_xyz[1], "unit": "microns"},
            "z": {"voxel_size": vox_sizes_xyz[2], "unit": "microns"},
            "c'": {"voxel_size": 1, "unit": ""},
            "t": {"voxel_size": 0.001, "unit": "seconds"},
        },
        "layers": layers,
        "showScaleBar": False,
        "showAxisLines": False,
    }

    for channel_tile_paths, channel in zip(tile_paths, channels):
        hex_val: int = link_utils.wavelength_to_hex(channel)
        hex_str = f"#{str(hex(hex_val))[2:]}"

        sources = []  # Tiles within tabs
        layers.append(
            {
                "type": "image",  # Optional
                "source": sources,
                "channel": 0,  # Optional
                "shaderControls": {
                    "normalized": {"range": [0, max_dr]}
                },  # Optional  # Exaspim has low HDR
                "shader": {
                    "color": hex_str,
                    "emitter": "RGB",
                    "vec": "vec3",
                },
                "visible": True,  # Optional
                "opacity": opacity,
                "name": f"CH_{channel}",
                "blend": blend,
            }
        )

        for xi in range(tile_layout.shape[0]):
            for yi in range(tile_layout.shape[1]):
                tile_id = tile_layout[xi, yi]
                tr_zyx = coarse_mesh[:, 0, xi, yi]
                
                url = f"s3://aind-open-data/{channel_tile_paths[tile_id]}"
                sources.append(
                    {"url": url, "transform_matrix": convert_matrix_3x4_to_5x6(
                                                     apply_deskewing(_zyx_vector_to_3x4(tr_zyx), deskew_angle)
                                                     ).tolist()}
                )
                # sources.append(
                #     {"url": url, "transform_matrix": convert_matrix_3x4_to_5x6(
                #                                      _zyx_vector_to_3x4(tr_zyx)
                #                                      ).tolist()}
                # )
                
    # Generate the link
    neuroglancer_link = NgState(
        input_config=input_config,
        mount_service="s3",
        bucket_path="aind-open-data",
        output_json=output_json_path,
    )
    neuroglancer_link.save_state_as_json()
    # print(neuroglancer_link.get_url_link())

    return input_config

In [5]:
# Application Inputs
# Changing to two tiles

READ_BUCKET = 'aind-open-data'
WRITE_BUCKET = 'sofima-test-bucket'

DATASET = 'diSPIM_624852_2023-06-03_10-11-33/diSPIM.zarr'
DOWNSAMPLE_EXP = 2

"""(x, y)
(0, 0), (1, 0), (2, 0)
(0, 1), (1, 1), (2, 1)
"""
# tile_layout = np.array([[0, 1, 2], 
#                         [3, 4, 5]])
# tile_layout = np.array([[0, 1, 2]])
tile_layout = np.array([[0, 1], 
                        [2, 3]])
# tile_layout = np.array([[0], 
#                         [1]])

vox_sizes_xyz = [0.298, 0.298, 0.176]  # In um
channels = [405, 488, 561, 594, 638]
tile_volumes = []
tile_paths = []
for channel in channels: 
    c_paths = []
    for x in range(tile_layout.shape[1]):  # neuroglancer basis
        for y in range(tile_layout.shape[0]):
            # y = 0
            path = f'624852.R1.W1_try_2_X_000{x}_Y_000{y}_Z_0000_ch_{channel}.zarr'
            
            c_paths.append(DATASET + '/' + path)

            if channel == 488:   # Selecting one reg channel
                tile = zarr_io.open_zarr_s3(READ_BUCKET, 
                                            DATASET + f'/{path}/{DOWNSAMPLE_EXP}')
                print(path)
                print(tile.T[:,:,:,0,0].shape)
                tile_volumes.append(tile.T[:,:,:,0,0])
    tile_paths.append(c_paths)


624852.R1.W1_try_2_X_0000_Y_0000_Z_0000_ch_488.zarr
(576, 576, 781)
624852.R1.W1_try_2_X_0000_Y_0001_Z_0000_ch_488.zarr
(576, 576, 781)
624852.R1.W1_try_2_X_0001_Y_0000_Z_0000_ch_488.zarr
(576, 576, 781)
624852.R1.W1_try_2_X_0001_Y_0001_Z_0000_ch_488.zarr
(576, 576, 781)


In [None]:
# Coarse Registration Data Loading
# One scanning axis

# Tile volumes simply maps index to volume. 
vox_sizes_xyz = [0.298, 0.298, 0.176]  # In um
channels = [405, 488, 561, 638]
tile_volumes = []
tile_paths = []
for channel in channels:
    c_paths = []
    for i in range(0, tile_layout.shape[0]):
        
        if i < 10: 
            i = f"0{i}" 
        # path = f"tile_X_00{i}_Y_0000_Z_0000_CH_0{channel}_cam1.zarr"
        path = f'657584_trex_round2_X_00{i}_Y_0000_Z_0000_ch_{channel}.zarr'

        # path = f'657584_round1__X_00{i}_Y_0000_Z_0000_ch_{channel}.zarr'
        # # Unfortunately, Kevin's dataset deviates from convention:
        # if str(i) == '13':
        #    continue
        # # Unfortunately, Kevin's dataset deviates from convention:
        # elif str(i) >= '14':
        #     path = f'657584_round1_take_3_X_00{i}_Y_0000_Z_0000_ch_{channel}.zarr'

        c_paths.append(DATASET + '/' + path)
        
        if channel == 488:   # Just selecting one
            tile = zarr_io.open_zarr_s3(READ_BUCKET, 
                                        DATASET + f'/{path}/{DOWNSAMPLE_EXP}')
            print(path)
            print(tile.T[:,:,:,0,0].shape)
            tile_volumes.append(tile.T[:,:,:,0,0])
    tile_paths.append(c_paths)

In [6]:
# Coarse Registration
cx, cy = coarse_registration.compute_coarse_offsets(tile_layout, tile_volumes)
coarse_mesh = stitch_rigid.optimize_coarse_mesh(cx, cy, mesh_fn=stitch_rigid.elastic_tile_mesh_3d)
# np.savez_compressed('coarse_results_6_cam1.npz', mesh=coarse_mesh, cx=cx, cy=cy)

Left Id: 0, Right Id: 1
Left: (0, 0), Right: (0, 1) [ -1.  -1. 365.]
Left Id: 2, Right Id: 3
Left: (1, 0), Right: (1, 1) [ -1.  -1. 365.]
Top Id: 0, Bottom Id: 2
Top: (0, 0), Bot: (1, 0) [ -1. 275.   0.]
Top Id: 1, Bottom Id: 3
Top: (0, 1), Bot: (1, 1) [ -1. 275.   0.]


In [14]:
# Ng Link
create_ng_link(tile_paths, 
               tile_layout, 
               coarse_mesh, 
               vox_sizes_xyz, 
               channels, 
               max_dr=400, 
               output_json_path='/home/jonathan.wong/sofima-testing')



{'dimensions': {'x': {'voxel_size': 0.298, 'unit': 'microns'},
  'y': {'voxel_size': 0.298, 'unit': 'microns'},
  'z': {'voxel_size': 0.176, 'unit': 'microns'},
  "c'": {'voxel_size': 1, 'unit': ''},
  't': {'voxel_size': 0.001, 'unit': 'seconds'}},
 'layers': [{'type': 'image',
   'source': [{'url': 'zarr://s3://aind-open-data/diSPIM_624852_2023-06-03_10-11-33/diSPIM.zarr/624852.R1.W1_try_2_X_0000_Y_0000_Z_0000_ch_405.zarr',
     'transform': {'matrix': [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
       [0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
       [0.0, 0.0, 1.0, 0.0, -1.0, 183.5],
       [0.0, 0.0, 0.0, 1.0, 0.0, -137.0],
       [0.0, 0.0, 0.0, 0.0, 1.0, -182.5]],
      'outputDimensions': {'t': [0.001, 's'],
       "c'": [1, ''],
       'z': [1.76e-07, 'm'],
       'y': [2.98e-07, 'm'],
       'x': [2.98e-07, 'm']}}},
    {'url': 'zarr://s3://aind-open-data/diSPIM_624852_2023-06-03_10-11-33/diSPIM.zarr/624852.R1.W1_try_2_X_0000_Y_0001_Z_0000_ch_405.zarr',
     'transform': {'matrix': [[1.0, 0.0, 0.0

In [None]:
# NEW ADDITION
# Turns out I need to fully understand/standardize all inputs to 
# actually get things working downstream.

# 'tile_map' required for elastic registration and 
# skipping elastic registration may be derived from 
# tile layout and tile_volumes directly. 

# MAKE SURE TO FOLLOW THE CARTESIAN COORDINATE SYSTEM w/ flipped y. 
# That is the mindset in which this entire codebase is built. 
# Makes for a terrible developer experience, but at least it is consistent. 


# Might as well throw in the other standardizations now too.
# (Resizing)
# (Dataloading)
# How to approach this?

# Dataloading idea: 
# I: 
# |-- Dataset path, tile name, downsample_exp -> fill tile volumes
# tile volumes is indexed on order of tile paths fed in --
# so implcitly, user is defining the index. 
# I:
# |-- Tile layout is tightly coupled with these tile layout indices, 
# and is required for data loading tasks.  

# Methods: 
# - Coarse Reg needs tile layout and formatted tile volumes.
# - Elastic Reg needs formatted SyncAdapter tile volumes, and tile_map. 
# - Fusion needs formatted SynAdapter tile volumes. 

# Final:
# No Dataloader -- everything inside of ZarrStitcher constructor
# ZarrStitcher constructor has inputs above + standardizes the image volume sizes. 
# inside homogenous tensorstores. 

# Zarr Stitcher methods: 
# - run_{coarse_registration, elastic_registration, fusion}
# Each method has it's unique preparation (resizing)
# as well as standardized input/output signature. 
# Also within Zarr File: overloading of warp.StitchandRenderTiles...
# Will have a utility to pass tile volumes to use inside open_tile_volumes. 

# (Actually will need this object to run fusion in an indepenent process)
# That whole tmux thing

# Great, this should streamline everything

# Fine registration tile map preparation:
# (Or can place in constructor as well for fusion reference)
# (REFER TO FUSION CODE CONSTRUCTOR, has everything you need)

In [5]:
# NOTE: SKIPPING

# Fine Registration Data Loading
tile_volumes = []
for i in range(0, tile_layout.shape[0]):
    if i < 10: 
        i = f"0{i}" 
    path = f"tile_X_00{i}_Y_0000_Z_0000_CH_0405_cam1.zarr"
    tile = zarr_io.open_zarr_s3(READ_BUCKET, 
                                DATASET + f'/{path}/{DOWNSAMPLE_EXP}')
    print(path)
    print(tile[0,:,:,:,:].shape)
    tile_volumes.append(tile[0,:,:,:,:])

tile_map = {}
for i, vol in enumerate(tile_volumes):
    tile_map[(0, i)] = fine_registration.SyncAdapter(tile_volumes[i])

# Also resize to standardized shape
tile_map[(0, 1)] = fine_registration.SyncAdapter(tile_volumes[i][:, :3543, :, :])


for key, tile in tile_map.items(): 
    print(tile.shape)

tile_X_0000_Y_0000_Z_0000_CH_0405_cam1.zarr
(1, 3543, 576, 576)
tile_X_0001_Y_0000_Z_0000_CH_0405_cam1.zarr
(1, 3551, 576, 576)
(1, 3543, 576, 576)
(1, 3543, 576, 576)


In [6]:
# NOTE: SKIPPING

# Inconsistent size is resolved, great. 

# Fine Registration
stride = 20, 20, 20
tile_size_xyz = (576, 576, 3543)  # Yet it expects the tiles as 1zyx...
flow_x, offsets_x = fine_registration.compute_flow_map3d(tile_map,
                                                        tile_size_xyz, cx, axis=0,
                                                        stride=stride,
                                                        patch_size=(80, 80, 80))

flow_y, offsets_y = fine_registration.compute_flow_map3d(tile_map,
                                                        tile_size_xyz, cy, axis=1,
                                                        stride=stride,
                                                        patch_size=(80, 80, 80))

# np.savez_compressed('flow_results_6_cam1.npz', flow_x=flow_x, flow_y=flow_y, offsets_x=offsets_x, offsets_y=offsets_y)

# Fine Registration, filter patch flows
kwargs = {"min_peak_ratio": 1.4, "min_peak_sharpness": 1.4, "max_deviation": 5, "max_magnitude": 0, "dim": 3}
fine_x = {k: flow_utils.clean_flow(v, **kwargs) for k, v in flow_x.items()}
fine_y = {k: flow_utils.clean_flow(v, **kwargs) for k, v in flow_y.items()}

kwargs = {"min_patch_size": 10, "max_gradient": -1, "max_deviation": -1}
fine_x = {k: flow_utils.reconcile_flows([v], **kwargs) for k, v in fine_x.items()}
fine_y = {k: flow_utils.reconcile_flows([v], **kwargs) for k, v in fine_y.items()}


# Fine Registration, update mesh (convert coarse tile mesh into fine patch mesh)
data_x = (cx[:, 0, ...], fine_x, offsets_x)
data_y = (cy[:, 0, ...], fine_y, offsets_y)

fx, fy, init_x, nbors, key_to_idx = stitch_elastic.aggregate_arrays(
    data_x, data_y, list(tile_map.keys()),
    coarse_mesh[:, 0, ...], stride=stride, tile_shape=tile_size_xyz[::-1])

@jax.jit
def prev_fn(x):
  target_fn = ft.partial(stitch_elastic.compute_target_mesh, x=x, fx=fx, fy=fy, stride=stride)
  x = jax.vmap(target_fn)(nbors)
  return jnp.transpose(x, [1, 0, 2, 3, 4])

config = mesh.IntegrationConfig(dt=0.001, gamma=0., k0=0.01, k=0.1, stride=stride,
                                num_iters=1000, max_iters=20000, stop_v_max=0.001,
                                dt_max=100, prefer_orig_order=False,
                                start_cap=0.1, final_cap=10., remove_drift=True)

x, ekin, t = mesh.relax_mesh(init_x, None, config, prev_fn=prev_fn, mesh_force=mesh.elastic_mesh_3d)

tile_mesh_path = 'solved_mesh.npz'
np.savez_compressed(tile_mesh_path, x=x, key_to_idx=key_to_idx)  # This 'x' is the solved patch mesh(es).

(slice(None, None, None), slice(0, 3543, None), slice(280, 576, None), slice(0, 576, None))
(slice(None, None, None), slice(0, 3543, None), slice(0, 296, None), slice(0, 576, None))


### Fine registration bugs out

Might as well resolve this first. 


In [7]:
# Reproducing the original fusion for comparison: 
# Running elastic registration beforehand

DATASET = 'diSPIM_647459_2022-12-07_00-00-00/diSPIM.zarr'
READ_BUCKET = 'aind-open-data'
WRITE_BUCKET = 'sofima-test-bucket'
tile_layout = np.array([[1],
                        [0]])
tile_mesh_path = 'solved_mesh.npz'
stride = 20, 20, 20

class StitchAndRender3dTiles(warp.StitchAndRender3dTiles):
  cache = {}

  def _open_tile_volume(self, tile_id: int):
    if tile_id in self.cache:
      return self.cache[tile_id]

    if tile_id < 10: 
      i = f"0{tile_id}" 
    path = f"tile_X_00{i}_Y_0000_Z_0000_CH_0405_cam1.zarr"
    tile = zarr_io.open_zarr_s3(READ_BUCKET, 
                                DATASET + f'/{path}/{2}')
    tile = tile[0,0,:,:,:] # convert to zyx axis layout (diff from coarse and fine)
    self.cache[tile_id] = fine_registration.SyncAdapter(tile)
    return self.cache[tile_id]

FUSED_PATH = 'tmp.zarr'

fused_shape = [1, 1, 3543, 850, 576] # Hardcoded
offset = (0, 0, 0)
ds_out = zarr_io.write_zarr(WRITE_BUCKET, fused_shape, FUSED_PATH)
renderer = StitchAndRender3dTiles(
    tile_map=tile_layout,
    tile_mesh_path=tile_mesh_path,
    tile_pattern_path="",
    stride=stride,
    offset=offset,
    parallelism=8
)

box = bounding_box.BoundingBox(start=(0,0,0), size=ds_out.shape[4:1:-1])  # Needs xyz 
gen = box_generator.BoxGenerator(box, (512, 512, 512), (0, 0, 0), True) # These are xyz
renderer.set_effective_subvol_and_overlap((512, 512, 512), (0, 0, 0))
for i, sub_box in enumerate(gen.boxes):
    t_start = time.time()

    # Feed in an empty subvol, with dimensions of sub_box. 
    inp_subvol = subvolume.Subvolume(np.zeros(sub_box.size[::-1], dtype=np.uint16)[None, ...], sub_box)
    ret_subvol = renderer.process(inp_subvol)  # czyx

    t_render = time.time()

    # ret_subvol is a 4D CZYX volume
    slice = ret_subvol.bbox.to_slice3d()
    slice = (0, 0, slice[0], slice[1], slice[2])
    ds_out[slice].write(ret_subvol.data[0, ...]).result()
    
    t_write = time.time()
    
    print('box {i}: {t1:0.2f} render  {t2:0.2f} write'.format(i=i, t1=t_render - t_start, t2=t_write - t_render))

# Baseline takes: X minutes
# Intended automated offset: 

I0000 00:00:1687293129.124447   25384 gcs_resource.cc:102] Using default AdmissionQueue with limit 32
I0000 00:00:1687293129.155708   28719 google_auth_provider.cc:179] Running on GCE, using service account 895865026362-compute@developer.gserviceaccount.com


(slice(0, 542, None), slice(0, 392, None), slice(0, 543, None))
(slice(0, 543, None), slice(0, 109, None), slice(0, 543, None))
box 0: 41.01 render  0.37 write
(slice(0, 542, None), slice(0, 392, None), slice(19, 576, None))
(slice(0, 543, None), slice(0, 109, None), slice(21, 576, None))
box 1: 20.31 render  0.60 write
(slice(0, 542, None), slice(129, 572, None), slice(0, 543, None))
(slice(0, 543, None), slice(0, 450, None), slice(0, 543, None))
box 2: 37.06 render  0.71 write
(slice(0, 542, None), slice(129, 572, None), slice(19, 576, None))
(slice(0, 543, None), slice(0, 450, None), slice(20, 576, None))
box 3: 37.21 render  0.91 write
(slice(459, 1042, None), slice(0, 392, None), slice(0, 542, None))
(slice(459, 1043, None), slice(0, 117, None), slice(0, 542, None))
box 4: 21.32 render  0.49 write
(slice(459, 1042, None), slice(0, 392, None), slice(17, 576, None))
(slice(459, 1043, None), slice(0, 117, None), slice(20, 576, None))
box 5: 21.00 render  0.68 write
(slice(459, 1042, 

ValueError: OUT_OF_RANGE: Propagated bounds [0, 3543), with size=3543, for dimension 0 are incompatible with existing bounds [2978, 3551), with size=573. [transform='Rank 3 -> 5 index space transform:   Input domain:     0: [2978, 3551)     1: [0, 110)     2: [0, 547)   Output index maps:     out[0] = 0     out[1] = 0     out[2] = 0 + 1 * in[0]     out[3] = 0 + 1 * in[1]     out[4] = 0 + 1 * in[2] '] [domain='{origin={0, 0, 0, 0, 0}, shape={1, 1, 3543, 576, 576}}']

(Fix this sync adapter prep for fusion too) 
Some data standarization step woul dbe good


Want to check: 
- fused offset = 150 at 4x downsampled resolution
- fused shape = [1, 1, 3551, 850, 576]
- key_to_mesh_index <-> key_to_idx
- renderer.cache

In [8]:
renderer._key_to_idx

{(0, 0): 1, (0, 1): 0}

In [None]:
[(tile_id, tstore.shape) for tile_id, tstore in renderer.cache.items()]

In [None]:
print(x.shape) 

In [None]:
print(key_to_idx)   # Different than renderer's map, as expected.
mesh_index_to_key={0: (0, 0), 1: (0, 1)} 

In [9]:
renderer._tile_boxes  # out_box, tg_box --con't understand that, but moving on for now

{0: (BoundingBox(start=(-20, 140, -20), size=(600, 580, 3580), is_border_start=(False, False, False), is_border_end=(False, False, False)),
  BoundingBox(start=(-1, 7, -1), size=(30, 29, 179), is_border_start=(False, False, False), is_border_end=(False, False, False))),
 1: (BoundingBox(start=(-20, 416, -20), size=(600, 580, 3580), is_border_start=(False, False, False), is_border_end=(False, False, False)),
  BoundingBox(start=(-1, -8, -1), size=(30, 29, 179), is_border_start=(False, False, False), is_border_end=(False, False, False)))}

In [10]:
renderer._tile_idx_to_xy  # So the order matters?

{0: (0, 0), 1: (0, 1)}

In [None]:
renderer._stride

In [None]:
renderer._offset

In [None]:
np.array((30, 29, 179)) * 20

(np.array((-1, 7, -1)) * 20)
(np.array((-1, -8, -1)) * 20) + 576 + (0, 0, 0)  # How did this only add component wise?

# renderer.cache[0].shape[-2]

# renderer._tile_idx_to_xy[0]
# Okay I see, defining componentwise

# NEXT UP 
Step through below and see mismatch in state

In [None]:
i = 2
path = f"tile_X_000{i}_Y_0000_Z_0000_CH_0405_cam1.zarr"
original_res_tile = zarr_io.open_zarr_s3(READ_BUCKET, 
                            DATASET + f'/{path}/{0}')
original_res_tile.shape


In [None]:
# Fusing the downsampled images 
# tile_layout = np.array([[2],
#                         [1]])  # NOTE: Id starts at 1 is still a hack

downsampled_stride = (20, 20, 20)
downsampled_tile_size_xyz = np.array(tile_volumes[0].shape)

# original_stride = tuple(np.array(downsampled_stride) * 2**DOWNSAMPLE_EXP)
# original_tile_size_xyz = downsampled_tile_size_xyz * 2**DOWNSAMPLE_EXP

class StitchAndRender3dTiles(fusion.StitchAndRender3dTiles):
  cache = {}

  def _open_tile_volume(self, tile_id: int):
    if tile_id in self.cache:
      return self.cache[tile_id]

    if tile_id < 10: 
      i = f"0{tile_id}" 
    path = f"tile_X_00{i}_Y_0000_Z_0000_CH_0405_cam1.zarr"
    tile = zarr_io.open_zarr_s3(READ_BUCKET, 
                                DATASET + f'/{path}/{2}')
    tile = tile[0,0,:,:,:] # convert to zyx axis layout (diff from coarse and fine)
    self.cache[tile_id] = fine_registration.SyncAdapter(tile)
    return self.cache[tile_id]

dim = len(downsampled_stride)
mesh_shape = (np.array(downsampled_tile_size_xyz[::-1]) // downsampled_stride).tolist()
downsampled_mesh = np.zeros([dim, len(tile_volumes)] + mesh_shape, dtype=np.float32)

_, _, my, mx = coarse_mesh.shape
mesh_index_to_key = {}

index = 0
for y, row in enumerate(tile_layout):
  for x, tile_id in enumerate(row):
    mesh_index_to_key[index] = (x, y)
    index += 1
key_to_mesh_index = {v:k for k, v in mesh_index_to_key.items()}

print(mesh_index_to_key)

for ind, (tx, ty) in mesh_index_to_key.items():
  downsampled_mesh[:, ind, ...] = coarse_mesh[:, 0, ty, tx].reshape(
  (dim,) + (1,) * dim)

cx[np.isnan(cx)] = 0    
cy[np.isnan(cy)] = 0
x_overlap = cx[2,0,0,0] / downsampled_tile_size_xyz[1]
y_overlap = cy[1,0,0,0] / downsampled_tile_size_xyz[0]
y_shape, x_shape = cx.shape[2], cx.shape[3]

fused_x = downsampled_tile_size_xyz[0] * (1 + ((x_shape - 1) * (1 - x_overlap)))
fused_y = downsampled_tile_size_xyz[1] * (1 + ((y_shape - 1) * (1 - y_overlap)))
fused_z = downsampled_tile_size_xyz[2]
fused_shape = [1, 1, fused_z, fused_y, fused_x]

# Fusion input: Output path
FUSED_PATH = 'downsample_res_2_tiles_refactor.zarr'

# Fusion input: Crop offset 
start = np.array([np.inf, np.inf, np.inf])
end = np.array([-np.inf, -np.inf, -np.inf])

map_box = bounding_box.BoundingBox(
  start=(0, 0, 0),
  size=downsampled_mesh.shape[2:][::-1],
) # NOTE: Using stride length of full resolution mesh

for i in range(0, len(tile_volumes)):
  tx, ty = mesh_index_to_key[i]
  mesh = downsampled_mesh[:, i, ...]
  tg_box = map_utils.outer_box(mesh, map_box, downsampled_stride)

  out_box = bounding_box.BoundingBox(
    start=(
      tg_box.start[0] * downsampled_stride[2] + tx * downsampled_tile_size_xyz[0],
      tg_box.start[1] * downsampled_stride[1] + ty * downsampled_tile_size_xyz[1],
      tg_box.start[2] * downsampled_stride[0],
    ),
    size=(
      tg_box.size[0] * downsampled_stride[2],
      tg_box.size[1] * downsampled_stride[1],
      tg_box.size[2] * downsampled_stride[0],
    )
  )
start = np.minimum(start, out_box.start)
offset = -start

# Fusion time:
ds_out = zarr_io.write_zarr(WRITE_BUCKET, fused_shape, FUSED_PATH)
renderer = StitchAndRender3dTiles(
    tile_map=tile_layout,
    tile_mesh=downsampled_mesh,
    key_to_mesh_index=key_to_mesh_index,
    stride=downsampled_stride,
    offset=-offset,
    parallelism=8
)

box = bounding_box.BoundingBox(start=(0,0,0), size=ds_out.shape[4:1:-1])  # Needs xyz 
gen = box_generator.BoxGenerator(box, (512, 512, 512), (0, 0, 0), True) # These are xyz
renderer.set_effective_subvol_and_overlap((512, 512, 512), (0, 0, 0))
for i, sub_box in enumerate(gen.boxes):
    t_start = time.time()

    # Feed in an empty subvol, with dimensions of sub_box. 
    inp_subvol = subvolume.Subvolume(np.zeros(sub_box.size[::-1], dtype=np.uint16)[None, ...], sub_box)
    ret_subvol = renderer.process(inp_subvol)  # czyx

    t_render = time.time()

    # ret_subvol is a 4D CZYX volume
    slice = ret_subvol.bbox.to_slice3d()
    slice = (0, 0, slice[0], slice[1], slice[2])
    ds_out[slice].write(ret_subvol.data[0, ...]).result()
    
    t_write = time.time()
    
    print('box {i}: {t1:0.2f} render  {t2:0.2f} write'.format(i=i, t1=t_render - t_start, t2=t_write - t_render))

# Okay, doing something wrong. 

In [None]:
# Let's see if I refactored this correctly...
# Okay, may not be a performance issue-- 
# struggling to fuse downsampled images now. 



In [None]:
# Fusion
# NOTE: MODIFIED TO FUSE AT ORIGINAL RESOLUTION
# Tile layout business does not make sense
# that you need a separate data structure
# but will investigate later. 

tile_layout = np.array([[2],
                        [1]])  # NOTE: Id starts at 1 is still a hack
downsampled_stride = np.array((20, 20, 20))
downsampled_tile_size_xyz = np.array(tile_volumes[0].shape)
original_tile_size_xyz = downsampled_tile_size_xyz * 2**DOWNSAMPLE_EXP
original_stride = tuple(np.array(downsampled_stride) * 2**DOWNSAMPLE_EXP)

class StitchAndRender3dTiles(fusion.StitchAndRender3dTiles):
  cache = {}

  def _open_tile_volume(self, tile_id: int):
    if tile_id in self.cache:
      return self.cache[tile_id]

    tile = zarr_io.open_zarr_s3(READ_BUCKET, 
                                DATASET + f'/{path}/{0}')  # NOTE: Original resolution
    tile = tile[0,0,:,:,:] # convert to zyx axis layout (diff from coarse and fine)
    self.cache[tile_id] = fine_registration.SyncAdapter(tile)
    return self.cache[tile_id]

# Fusion input: Grid of offsets, 'tile_mesh'
# and dense linear index 'tile_id_to_xy'
dim = len(downsampled_stride)
mesh_shape = (np.array(original_tile_size_xyz[::-1]) // downsampled_stride).tolist()
downsampled_mesh = np.zeros([dim, len(tile_volumes)] + mesh_shape, dtype=np.float32)

# NOTE: order of coarse offsets in init_x is arbitrary-- 
# see stitch_elastic.aggregate_arrays for evidence. 
# Because order is arbitrary, should generalize to exaspim. 
# coarse_mesh shape: (3, 1, y, x)
_, _, my, mx = coarse_mesh.shape
mesh_index_to_key = {}

index = 0
for y, row in enumerate(tile_layout):
  for x, tile_id in enumerate(row):
    mesh_index_to_key[index] = (x, y)
    index += 1
key_to_mesh_index = {v:k for k, v in mesh_index_to_key.items()}

print(mesh_index_to_key)

for ind, (tx, ty) in mesh_index_to_key.items():
  downsampled_mesh[:, ind, ...] = coarse_mesh[:, 0, ty, tx].reshape(
  (dim,) + (1,) * dim)
tile_mesh = downsampled_mesh * 2**DOWNSAMPLE_EXP

# Fusion input: 'fused_shape'
cx[np.isnan(cx)] = 0    
cy[np.isnan(cy)] = 0
x_overlap = cx[2,0,0,0] / downsampled_tile_size_xyz[1]
y_overlap = cy[1,0,0,0] / downsampled_tile_size_xyz[0]
y_shape, x_shape = cx.shape[2], cx.shape[3]

fused_x = original_tile_size_xyz[0] * (1 + ((x_shape - 1) * (1 - x_overlap)))
fused_y = original_tile_size_xyz[1] * (1 + ((y_shape - 1) * (1 - y_overlap)))
fused_z = original_tile_size_xyz[2]
fused_shape = [1, 1, fused_z, fused_y, fused_x]

# Fusion input: Output path
FUSED_PATH = 'full_res_2_tiles.zarr'

# Fusion input: Crop offset 
start = np.array([np.inf, np.inf, np.inf])
end = np.array([-np.inf, -np.inf, -np.inf])

map_box = bounding_box.BoundingBox(
  start=(0, 0, 0),
  size=tile_mesh.shape[2:][::-1],
) # NOTE: Using stride length of full resolution mesh

for i in range(0, len(tile_volumes)):
  tx, ty = mesh_index_to_key[i]
  mesh = tile_mesh[:, i, ...]
  tg_box = map_utils.outer_box(mesh, map_box, original_stride)

  out_box = bounding_box.BoundingBox(
    start=(
      tg_box.start[0] * original_stride[2] + tx * original_tile_size_xyz[0],
      tg_box.start[1] * original_stride[1] + ty * original_tile_size_xyz[1],
      tg_box.start[2] * original_stride[0],
    ),
    size=(
      tg_box.size[0] * original_stride[2],
      tg_box.size[1] * original_stride[1],
      tg_box.size[2] * original_stride[0],
    )
  )
start = np.minimum(start, out_box.start)
offset = -start

# Fusion time:
ds_out = zarr_io.write_zarr(WRITE_BUCKET, fused_shape, FUSED_PATH)
renderer = StitchAndRender3dTiles(
    tile_map=tile_layout,
    tile_mesh=tile_mesh,
    key_to_mesh_index=key_to_mesh_index,
    stride=original_stride,
    offset=-offset,
    parallelism=8
)

box = bounding_box.BoundingBox(start=(0,0,0), size=ds_out.shape[4:1:-1])  # Needs xyz 
gen = box_generator.BoxGenerator(box, (512, 512, 512), (0, 0, 0), True) # These are xyz
renderer.set_effective_subvol_and_overlap((512, 512, 512), (0, 0, 0))
for i, sub_box in enumerate(gen.boxes):
    t_start = time.time()

    # Feed in an empty subvol, with dimensions of sub_box. 
    inp_subvol = subvolume.Subvolume(np.zeros(sub_box.size[::-1], dtype=np.uint16)[None, ...], sub_box)
    ret_subvol = renderer.process(inp_subvol)  # czyx

    t_render = time.time()

    # ret_subvol is a 4D CZYX volume
    slice = ret_subvol.bbox.to_slice3d()
    slice = (0, 0, slice[0], slice[1], slice[2])
    ds_out[slice].write(ret_subvol.data[0, ...]).result()
    
    t_write = time.time()
    
    print('box {i}: {t1:0.2f} render  {t2:0.2f} write'.format(i=i, t1=t_render - t_start, t2=t_write - t_render))

In [None]:
downsampled_mesh.shape
coarse_mesh.shape

# Shapes!
# It's doing something, we'll see what happens
# Getting stuck on box 20?
# Or maybe its a larger resolution, which is 4^3 = 64 times larger. 

# yikes, this is running serially. 
# I can try converting to a process pool...
# Or convert to kornia... 

# Let's try a smaller scale to see if I adapted everything correctly...




In [None]:
bucket = 'sofima-test-bucket'
in_path = 'fused_6_y_offset_cam1_2.zarr'
out_path = 'fused_cam1_y_offset_rechunked.zarr'
fused = zarr_io.open_zarr(bucket, in_path)

rechunked = ts.open({
        'driver': 'zarr', 
        'dtype': 'uint16',
        'kvstore' : {
            'driver': 'gcs', 
            'bucket': bucket,
        }, 
        'create': True,
        'delete_existing': True, 
        'path': out_path, 
        'metadata': {
        'chunks': [1, 1, 128, 128, 128],
        'compressor': {
          'blocksize': 0,
          'clevel': 1,
          'cname': 'zstd',
          'id': 'blosc',
          'shuffle': 1,
        },
        'dimension_separator': '/',
        'dtype': '<u2',
        'fill_value': 0,
        'filters': None,
        'order': 'C',
        'shape': fused.shape,
        'zarr_format': 2
        }
    }).result()

rechunked[:,:,:,:,:].write(fused[:,:,:,:,:]).result()

In [None]:
start = np.minimum([tile_box.start for tile_box in StitchAndRender3dTiles._tile_boxes])
