In [1]:
import functools as ft
import jax
import jax.numpy as jnp
import numpy as np

import zarr_io
import fine_registration

from sofima import stitch_elastic, flow_utils, mesh  # Other stuffs

In [2]:
data = np.load('coarse_results.npz')
cx = data['cx']
cy = data['cy']
coarse_mesh = data['mesh']

In [3]:
bucket = 'sofima-test-bucket'
path_0 = f'preprocessed_0.zarr'
path_1 = f'preprocessed_1.zarr'
tile_0 = zarr_io.open_zarr(bucket, path_0)
tile_1 = zarr_io.open_zarr(bucket, path_1)

# Existing data structures:
tile_layout = np.array([[0], 
                        [1]])
idx_to_coord = {0:(0, 0), 1:(1, 0)}

# Must load in 4 dimensions: 1zyx shape
# tile_volumes = [tile_0[0,:,:,:,:].resize(exclusive_max=(1, 3544, 576, 576)).result(), 
#                 tile_1[0,:,:,:,:]]

tile_volumes = [tile_0[0,:,:,:,:], tile_1[0,:,:,:,:]]

# Replacing 'tile_map' with SyncAdapter objects and adopting reverse basis
tile_map = {(0, 0): fine_registration.SyncAdapter(tile_volumes[0], 0), 
            (0, 1): fine_registration.SyncAdapter(tile_volumes[1], 1)}
# Different basis
idx_to_coord = {0:(0, 0), 1:(0, 1)}

I0000 00:00:1684534367.550418 2125426 gcs_resource.cc:102] Using default AdmissionQueue with limit 32
I0000 00:00:1684534367.552101 2125695 google_auth_provider.cc:179] Running on GCE, using service account 895865026362-compute@developer.gserviceaccount.com


In [4]:
# Fine Registration, compute patch flows
stride = 20, 20, 20
tile_size_xyz = (576, 576, 3544)  # Yet it expects the tiles as 1zyx...
flow_x, offsets_x = fine_registration.compute_flow_map3d(tile_map,
                                                        tile_size_xyz, cx, axis=0,
                                                        stride=stride,
                                                        patch_size=(80, 80, 80))

flow_y, offsets_y = fine_registration.compute_flow_map3d(tile_map,
                                                        tile_size_xyz, cy, axis=1,
                                                        stride=stride,
                                                        patch_size=(80, 80, 80))

np.savez_compressed('flow_results.npz', flow_x=flow_x, flow_y=flow_y, offsets_x=offsets_x, offsets_y=offsets_y)

(slice(None, None, None), slice(0, 3544, None), slice(280, 576, None), slice(0, 576, None))
(slice(None, None, None), slice(0, 3544, None), slice(0, 296, None), slice(0, 576, None))


2023-05-19 22:12:56.644616: W external/xla/xla/service/gpu/gpu_conv_algorithm_picker.cc:850] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms.
2023-05-19 22:12:56.644654: W external/xla/xla/service/gpu/gpu_conv_algorithm_picker.cc:853] Conv: (f32[16,5,159,159,159]{4,3,2,1,0}, u8[0]{0}) custom-call(f32[16,1,159,159,159]{4,3,2,1,0}, f32[5,1,5,1,1]{4,3,2,1,0}), window={size=5x1x1 pad=2_2x0_0x0_0}, dim_labels=bf012_oi012->bf012, custom_call_target="__cudnn$convForward", backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"


XlaRuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv.3 = (f32[16,5,159,159,159]{4,3,2,1,0}, u8[0]{0}) custom-call(f32[16,1,159,159,159]{4,3,2,1,0} %bitcast.250, f32[5,1,5,1,1]{4,3,2,1,0} %bitcast.258), window={size=5x1x1 pad=2_2x0_0x0_0}, dim_labels=bf012_oi012->bf012, custom_call_target="__cudnn$convForward", metadata={op_name="jit(batched_xcorr_peaks)/jit(main)/conv_general_dilated[window_strides=(1, 1, 1) padding=((2, 2), (0, 0), (0, 0)) lhs_dilation=(1, 1, 1) rhs_dilation=(1, 1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3, 4), rhs_spec=(0, 1, 2, 3, 4), out_spec=(0, 1, 2, 3, 4)) feature_group_count=1 batch_group_count=1 precision=None preferred_element_type=None]" source_file="/home/jonathan.wong/sofima-testing/fine_registration.py" source_line=149}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"

Original error: INTERNAL: All algorithms tried for %cudnn-conv.3 = (f32[16,5,159,159,159]{4,3,2,1,0}, u8[0]{0}) custom-call(f32[16,1,159,159,159]{4,3,2,1,0} %bitcast.250, f32[5,1,5,1,1]{4,3,2,1,0} %bitcast.258), window={size=5x1x1 pad=2_2x0_0x0_0}, dim_labels=bf012_oi012->bf012, custom_call_target="__cudnn$convForward", metadata={op_name="jit(batched_xcorr_peaks)/jit(main)/conv_general_dilated[window_strides=(1, 1, 1) padding=((2, 2), (0, 0), (0, 0)) lhs_dilation=(1, 1, 1) rhs_dilation=(1, 1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3, 4), rhs_spec=(0, 1, 2, 3, 4), out_spec=(0, 1, 2, 3, 4)) feature_group_count=1 batch_group_count=1 precision=None preferred_element_type=None]" source_file="/home/jonathan.wong/sofima-testing/fine_registration.py" source_line=149}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm.  Per-algorithm errors:
  Profiling failure on cuDNN engine eng0{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng28{k2=4,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng28{k2=3,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng48{k2=15,k6=0,k13=1,k14=0,k22=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng48{k2=2,k6=0,k13=1,k14=0,k22=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng28{k2=0,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng36{k2=4,k13=0,k14=2,k18=0,k23=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng36{k2=0,k13=2,k14=3,k18=0,k23=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng36{k2=1,k13=0,k14=4,k18=0,k23=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng38{k2=8,k13=1,k14=4,k18=0,k22=0,k23=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng36{k2=5,k13=1,k14=3,k18=1,k23=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng36{k2=5,k13=1,k14=3,k18=0,k23=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng28{k2=1,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng36{k2=7,k13=0,k14=4,k18=0,k23=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng38{k2=0,k13=2,k14=3,k18=1,k22=0,k23=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng28{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng0{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'

To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.

In [None]:
# Fine Registration, filter patch flows
kwargs = {"min_peak_ratio": 1.4, "min_peak_sharpness": 1.4, "max_deviation": 5, "max_magnitude": 0, "dim": 3}
fine_x = {k: flow_utils.clean_flow(v, **kwargs) for k, v in flow_x.items()}
fine_y = {k: flow_utils.clean_flow(v, **kwargs) for k, v in flow_y.items()}

kwargs = {"min_patch_size": 10, "max_gradient": -1, "max_deviation": -1}
fine_x = {k: flow_utils.reconcile_flows([v], **kwargs) for k, v in fine_x.items()}
fine_y = {k: flow_utils.reconcile_flows([v], **kwargs) for k, v in fine_y.items()}


In [None]:
# Fine Registration, update mesh (convert coarse tile mesh into fine patch mesh)
data_x = (cx[:, 0, ...], fine_x, offsets_x)
data_y = (cy[:, 0, ...], fine_y, offsets_y)

fx, fy, init_x, nbors, key_to_idx = stitch_elastic.aggregate_arrays(
    data_x, data_y, list(tile_map.keys()),
    coarse_mesh[:, 0, ...], stride=stride, tile_shape=tile_size_xyz[::-1])

@jax.jit
def prev_fn(x):
  target_fn = ft.partial(stitch_elastic.compute_target_mesh, x=x, fx=fx, fy=fy, stride=stride)
  x = jax.vmap(target_fn)(nbors)
  return jnp.transpose(x, [1, 0, 2, 3, 4])

config = mesh.IntegrationConfig(dt=0.001, gamma=0., k0=0.01, k=0.1, stride=stride,
                                num_iters=1000, max_iters=20000, stop_v_max=0.001,
                                dt_max=100, prefer_orig_order=False,
                                start_cap=0.1, final_cap=10., remove_drift=True)

x, ekin, t = mesh.relax_mesh(init_x, None, config, prev_fn=prev_fn, mesh_force=mesh.elastic_mesh_3d)

np.savez_compressed('solved_mesh_st20.npz', x=x, key_to_idx=key_to_idx)  # This 'x' is the solved patch mesh(es).