Running Evaluation script, and figuring out shape/type which is not documented in Python. 

NOTE: aind-registration-evaluation must be pip-installable for production. 


In [1]:
from eval_reg import utils

import functools as ft
import jax
import jax.numpy as jnp
import numpy as np
import time
import tensorstore as ts

from ng_link import NgState, link_utils

import zarr_io
import coarse_registration

In [6]:
# Load in images and stuff here
READ_BUCKET = 'aind-open-data'
WRITE_BUCKET = 'sofima-test-bucket'

DATASET = 'diSPIM_645736_2022-12-13_02-21-00/diSPIM.zarr'
DOWNSAMPLE_EXP = 2

tile_layout = np.array([[0], 
                        [1]])
tile_paths = ['tile_X_0001_Y_0000_Z_0000_CH_0405_cam1.zarr',
              'tile_X_0002_Y_0000_Z_0000_CH_0405_cam1.zarr']
tile_volumes = []
for path in tile_paths: 
    tile_volumes.append(zarr_io.open_zarr_s3(READ_BUCKET, DATASET + f'/{path}/{DOWNSAMPLE_EXP}').T[:,:,:,0,0])

tile_1 = tile_volumes[0]
tile_2 = tile_volumes[1]

In [8]:
# Find the offset between these two tiles: 
cx, cy = coarse_registration.compute_coarse_offsets(tile_layout, tile_volumes)

Top Id: 0, Bottom Id: 1
Top: (0, 0), Bot: (1, 0) [ -1. 325.   0.]


In [5]:
tile_1.shape

(576, 576, 5966)

In [14]:
# Running through the script 

transform = np.array([[1, 0, 0, 0], 
                      [0, 1, 0, 325], 
                      [0, 0, 1, -1]])
bounds_1, bounds_2 = utils.calculate_bounds(
    tile_1.shape, tile_2.shape, transform
)

print(bounds_1)
print(bounds_2)


[[   0    0    0]
 [ 576  576 5966]]
[[  0 325]
 [576 901]]


In [None]:
# #Sample points in overlapping bounds
points = utils.sample_points_in_overlap(
    bounds_1=bounds_1,
    bounds_2=bounds_2,
    numpoints=self.args["sampling_info"]["numpoints"],
    sample_type=self.args["sampling_info"]["sampling_type"],
    image_shape=image_1_shape,
)

# print("Points: ", points)

# Points that fit in window based on a window size
pruned_points = utils.prune_points_to_fit_window(
    image_1_shape, points, self.args["window_size"]
)

discarded_points_window = points.shape[0] - pruned_points.shape[0]
LOGGER.info(
    f"""Number of discarded points when prunning
    points to window: {discarded_points_window}""",
)

# calculate metrics per images
metric_per_point = []

metric_calculator = ImageMetricsFactory().create(
    image_1_data,
    image_2_data,
    self.args["metric"],
    self.args["window_size"],
)

selected_pruned_points = []

for pruned_point in pruned_points:

    met = metric_calculator.calculate_metrics(
        point=pruned_point, transform=transform
    )

    if met:
        selected_pruned_points.append(pruned_point)
        metric_per_point.append(met)

# compute statistics
metric = self.args["metric"]
computed_points = len(metric_per_point)

dscrd_pts = points.shape[0] - discarded_points_window - computed_points
message = f"""Computed metric: {metric}
\nMean: {np.mean(metric_per_point)}
\nStd: {np.std(metric_per_point)}
"""