# SOFIMA alignment of FAST-EM data with arbitrary dataset sizes
Author: Arent Kievits  
Date: 29-11-2024

- This notebook serves as an example to use [SOFIMA](https://github.com/google-research/sofima?tab=readme-ov-file) to elastically align a FAST-EM data in WebKnossos. 
- It aligns the data in chunks, so that the maximum dataset size is not limited
- To speed up data loading and facilitate writing of chunks of data, this notebook uses **local import only**.
The software assumes the use of a GPU with CUDA, but can also run on CPU

In [2]:
import webknossos as wk
import numpy as np
from fastem_sofima.importo import import_wk_dataset_local
from fastem_sofima.utils import define_bbox_chunks
import os

# set the environment variable 'CUDA_VISIBLE_DEVICES'
# this sets which GPU can be seen by the program
# Either set it to "0", "1", "2", "3"
# or if you want to use multiple GPUs add commas between the numbers
# e.g. "0,1"
# the numbers correspond with those in the command 'nvidia-smi'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [5]:
# WebKnossos settings
path_2_mount = "/home/ajkievits/sonic"
dataset_name = "20231107_MCF7_UAC_test" # Dataset name as in WebKnossos
organization_id = "hoogenboom-group" # "hoogenboom-group"
dir_path = f"{path_2_mount}/webknossos/binaryData/{organization_id}/{dataset_name}" # Path to mount
layer="postcorrection_rigid_scaled"
himag = wk.Mag("1-1-1") 

## 1. Define bounding box from data
The alignment is only performed on the overlapping dataset, that means parts of sections that are not continuous throughout the volume are cropped. We apply a series of image processing operations to define the data bounding box for SOFIMA on the continuous part only

In [7]:
# Read lowest mag
MAG_low = wk.Mag("128-128-1") 
dataset, voxel_size = import_wk_dataset_local(
    dir_path=dir_path,
    MAG=MAG_low
    )
EM = dataset.get_layer(layer) # EM data Layer
mag_view = EM.get_mag(MAG_low) # MagView
# Get data at required resolution
# Read data from remote
data_bbox = mag_view.read()

In [None]:
data_bbox.shape

In [None]:
import matplotlib.pyplot as plt
from skimage.filters import threshold_otsu, try_all_threshold
from skimage.measure import label, regionprops
import matplotlib.patches as mpatches
from scipy.ndimage import gaussian_filter, binary_dilation

In [None]:
# Plot first z level and calculate min projection to find continuous data in z
# Threshold, dilate, label image to find bbox
data_min_proj = np.min(data_bbox, axis=3)
blurred = gaussian_filter(data_min_proj[0, :, :], sigma=3)
thres = threshold_otsu(blurred)
object = blurred > thres
object = binary_dilation(object, np.ones((7, 7)))
labeled_image = label(object)
props = regionprops(labeled_image)
minr, minc, maxr, maxc = props[0].bbox

fig, ax = plt.subplots(2, 3)
ax[0][0].imshow(data_bbox[0, :, :, 0], cmap=plt.cm.Greys_r)
ax[0][1].imshow(data_min_proj[0, :, :], cmap=plt.cm.Greys_r)
ax[0][2].imshow(blurred, cmap=plt.cm.Greys_r)
ax[1][0].imshow(object, cmap=plt.cm.Greys_r)
ax[1][1].imshow(labeled_image)
ax[1][2].imshow(data_min_proj[0, :, :], cmap=plt.cm.Greys_r)

ax[0][0].set_title("z=0")
ax[0][1].set_title("minimum projection")
ax[0][2].set_title("blurred")
ax[1][0].set_title("thresholded")
ax[1][1].set_title("labeled")
ax[1][2].set_title("bounding box")
rect = mpatches.Rectangle((minc, minr), maxc - minc, maxr - minr,
                                  fill=False, edgecolor='red', linewidth=0.5)
ax[1][2].add_patch(rect)

In [None]:
props[0].bbox

In [None]:
# Define WebKnossos bounding box at full resolution
# Mag(128)
bbox_cropped = wk.BoundingBox((minc, minr, 0),
                              (maxc - minc, maxr - minr, data_bbox.shape[-1])) # ((x0, y0, z0), (x_size, y_size, z_size))
bbox_cropped = bbox_cropped.from_mag_to_mag1(from_mag=MAG_low)
bbox_cropped

## 2. Read the data for the flow field computations
The data at reduced zoom levels is downsampled and much smaller, loading can be done in one chunk with decent RAM sizes (except for very large datasets!)

In [None]:
## Mags to use for SOFIMA 
# We have found that a combination of 4-4-1 (16 nm/px), 8-8-1 (32 nm/px) and 16-16-1 (64 nm/px) works best for FAST-EM data
# Increasing the data resolution to 2-2-1 or 1-1-1 does not lead to better performance.
# Decreasing the data resolution to 32-32-1 or larger does not add any information
MAG = [wk.Mag("4-4-1"), wk.Mag("8-8-1"), wk.Mag("16-16-1")] 

data_zooms = {}
EM = dataset.get_layer(layer) # EM data Layer
# Iterate through zoom levels (a.k.a. MAGs)
for mag in MAG:
    mag_view = EM.get_mag(mag) # MagView
    bbox_at_mag = bbox_cropped.align_with_mag(mag)
    # Get data at required resolution
    view = mag_view.get_view(absolute_offset=bbox_at_mag.topleft, size=bbox_at_mag.size) # "absolute_offset" and "size" are in Mag(1)!
    # Read data from remote
    data = view.read()
    data_zooms[f"unaligned_{mag.x}x"] = np.transpose(data, (1, 2, 3, 0))

In [None]:
data_zooms["unaligned_4x"].shape

# 3. Flow field estimation

First, we calculate the flow fields between the current section and the directly preceding section. Flow fields can also be computed between pairs of sections that are not directly adjacent. This is useful if sections are incomplete or missing, but is not something we have to worry about in this demo.

In a distributed environment, this step would be done with the `EstimateFlow` processor.


On a V100, the expected time for the flow calculation over a single 5000x5000 section with the settings below, is ~0.6s. The patch (160) and step (40) sizes are set to conservative values which work for most synaptic-resolution EM volumes (i.e. at an in-plane resolution of ~10 nm/px).


In [None]:
from concurrent import futures
import time

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from connectomics.common import bounding_box
from sofima import flow_field
from sofima import flow_utils
from sofima import map_utils
from sofima import mesh
from sofima import warp
from tqdm.notebook import tqdm

In [None]:
# Ensure that we're running this code on a GPU machine. If this fails and you're using
# Google Colab, use "Edit >> Notebook settings" and set s"Hardware accelerator" to "GPU".
assert jax.devices()[0].platform == 'gpu'

In [None]:
# Both of the settings below are expressed in pixels.
patch_size = 160 # XY spatial context used for flow field estimation
stride = 40  # XY distance between centers of adjacent patches.

def _compute_flow(volume):
  mfc = flow_field.JAXMaskedXCorrWithStatsCalculator()
  flows = []
  prev = volume[..., 0, 0].T

  fs = []
  with futures.ThreadPoolExecutor() as tpe:
    # Prefetch the next sections to memory so that we don't have to wait for them
    # to load when the GPU becomes available.
    for z in range(1, volume.shape[2]):
      fs.append(tpe.submit(lambda z=z: volume[..., z, 0].T))

    for z in tqdm(range(0, volume.shape[2]-1)):
      curr = fs[z].result()

      # The batch size is a parameter which impacts the efficiency of the computation (but
      # not its result). It has to be large enough for the computation to fully utilize the
      # available GPU capacity, but small enough so that the batch fits in GPU RAM.
      flows.append(mfc.flow_field(prev, curr, (patch_size, patch_size),
                                  (stride, stride), batch_size=256))
      prev = curr

  return flows

Compute flows at 4x reduced in-plane resolution and downsampled version. The lower resolution flow has reduced precision, but is helpful for providing approximate flow vectors in places where the full-resolution flow might be impossible to estimate, e.g. in the interior of cell bodies or blood vessels.


In [None]:
# Iterate through data zoom levels at compute flows
flows = {}
for zoom, mag in zip(data_zooms, MAG):
    flow_at_zoom = np.array(
        _compute_flow(data_zooms[zoom])
        )
    # Convert to [channels, z, y, x].
    flow_at_zoom = np.transpose(flow_at_zoom, [1, 0, 2, 3])
    # Pad to account for the edges of the images where there is insufficient context to estimate flow.
    pad = patch_size // 2 // stride
    flow_at_zoom = np.pad(flow_at_zoom, [[0, 0], [0, 0], [pad, pad], [pad, pad]], constant_values=np.nan)
    flows[f"flows{mag.x}x"] = flow_at_zoom

The flow fields generated in the previous step are 4-channel arrays, where the first two channels store the XY components of the flow vector, and the two remaining channels are measures of estimation quality (see `sofima.flow_field._batched_peaks` for more info).

We now remove uncertain flow estimates by replacing them with NaNs, and merge the two flow arrays into a single flow field at full resolution. In a distributed environment, this step would be done with the `ReconcileAndFilterFlows` processor.

In [None]:
flows_clean = {}
# Iterate through flows and clean them
for flow in flows:
    flow_clean = flow_utils.clean_flow(flows[flow], 
                                       min_peak_ratio=1.2, 
                                       min_peak_sharpness=1.2, 
                                       max_magnitude=40, 
                                       max_deviation=10)
    flows_clean[f"{flow}_clean"] = flow_clean

Plot the horizontal component of the flow vector, before (left) and after (right) filtering. While blobs indicate areas where uncertain flow estimates were removed.

In [None]:
z = 20 # section index

nrows = len(flows_clean.keys())
f, axes = plt.subplots(nrows, 2, figsize=(6.5, 3*nrows))
vmin, vmax = -10, 10
for ax, flow, flow_clean in zip(axes, flows, flows_clean):
    cax1 = ax[0].imshow(flows[flow][0, z, ...], cmap=plt.cm.RdBu, vmin=vmin, vmax=vmax)
    ax[0].set_title(flow)
    ax[1].imshow(flows_clean[flow_clean][0, z, ...], cmap=plt.cm.RdBu, vmin=vmin, vmax=vmax)
    ax[1].set_title(flow_clean)
    cbar1 = f.colorbar(cax1)

Interpolate lower resolution flows

In [None]:
from scipy import interpolate

In [None]:
scale = 0.5
clean_flows = [key for key in flows_clean.keys()]
highest_flow = clean_flows[0]
combine_flows = (flows_clean[highest_flow],)

for f in clean_flows[1:]:
  f_upsampled = np.zeros_like(flows_clean[highest_flow])
  oy, ox = np.ogrid[:flows_clean[f].shape[-2], :flows_clean[f].shape[-1]]
  oy = oy.ravel() / scale
  ox = ox.ravel() / scale
  
  box_high = bounding_box.BoundingBox(start=(0, 0, 0), size=(flows_clean[highest_flow].shape[-1], flows_clean[highest_flow].shape[-2], 1))
  box_f = bounding_box.BoundingBox(start=(0, 0, 0), size=(flows_clean[f].shape[-1], flows_clean[f].shape[-2], 1))

  for z in tqdm(range(flows_clean[f].shape[1])): # loop through z
    # Upsample and scale spatial components.
    resampled = map_utils.resample_map(
        flows_clean[f][:, z:z + 1, ...],  #
        box_f, box_high, 1 / scale, 1)
    f_upsampled[:, z:z + 1, ...] = resampled / scale
  combine_flows += (f_upsampled,)
    

In [None]:
final_flow = flow_utils.reconcile_flows(combine_flows, max_gradient=0, max_deviation=20, min_patch_size=400)

In [None]:
z = 20 # z index

# Plot (left to right): high res. flow, upsampled low res. flow, combined flow to use for alignment.
f, ax = plt.subplots(1, len(combine_flows)+2, figsize=(10, 2.5))
data_full = list(data_zooms.items())[0][1]
ax[0].imshow(data_full[:, :, z, 0], cmap=plt.cm.Greys_r)
ax[0].set_title("input data")
ax[-1].imshow(final_flow[0, z, ...].T, cmap=plt.cm.RdBu, vmin=-10, vmax=10)
ax[-1].set_title("final flow")
for i, flow in enumerate(flows_clean):
    ax[i+1].imshow(flows_clean[flow][0, z, ...].T, cmap=plt.cm.RdBu, vmin=-10, vmax=10)
    ax[i+1].set_title(flow)

# 4. Mesh optimization

We use an elastic mesh optimizer to find a configuration of the imagery that is compatible with the estimated flow field and preserves the original geometry as much as possible.

The optimization proceeds sequentially, section by section. In a distributed environment, this computation can be parallelized across the plane (by independently solving overlapping XY tiles), as well as split into blocks along the Z axis. This makes it possible to scale this process to arbitrarily large volumes. For simplicity, here we solve the complete stack in one go.

In [None]:
config = mesh.IntegrationConfig(dt=0.001, gamma=0.0, k0=0.01, k=0.1, stride=(stride, stride), num_iters=1000,
                                max_iters=100000, stop_v_max=0.005, dt_max=1000, start_cap=0.01,
                                final_cap=10, prefer_orig_order=True)

In [None]:
solved = [np.zeros_like(final_flow[:, 0:1, ...])]
origin = jnp.array([0., 0.])

for z in tqdm(range(0, final_flow.shape[1])):
  prev = map_utils.compose_maps_fast(final_flow[:, z:z+1, ...], origin, stride,
                                     solved[-1], origin, stride)
  x = np.zeros_like(solved[0])
  x, e_kin, num_steps = mesh.relax_mesh(x, prev, config)
  x = np.array(x)
  solved.append(x)

In [None]:
solved = np.concatenate(solved, axis=1)

# 5. Image warping

Image warping requires an inverse coordinate map, so compute that first. In a distributed environment, this can be done with the `InvertMap` processor.

In [None]:
inv_map = map_utils.invert_map(solved, box_high, box_high, stride)

In [None]:
inv_map.shape


Upsample `inv_map` to full resolution

In [None]:
scale_2_full = MAG[0].x / himag.x
x_full = int(flows_clean[highest_flow].shape[-1] * scale_2_full)
y_full = int(flows_clean[highest_flow].shape[-2] * scale_2_full)
box_highest = bounding_box.BoundingBox(start = (0, 0, 0), 
                                       size = (x_full, y_full, 1)
                                       )
inv_map_full = np.zeros((inv_map.shape[0], inv_map.shape[1], y_full, x_full))
for z in tqdm(range(inv_map.shape[1])): # loop through z
    # Upsample and scale spatial components.
    resampled = map_utils.resample_map(
        inv_map[:, z:z + 1, ...],  #
        box_high, box_highest, int(scale_2_full), 1)
    inv_map_full[:, z:z + 1, ...] = resampled * scale_2_full

In [None]:
inv_map_full.shape

### Plot elastic mesh and inverted map

In [None]:
# Plot (left to right): high res. flow, upsampled low res. flow, combined flow to use for alignment.
f, ax = plt.subplots(1, 4, figsize=(12, 3))

cax1 = ax[0].imshow(data_full[:, :, 14, 0].T, cmap=plt.cm.Greys_r)
ax[0].set_title("input data")

cax2 = ax[1].imshow(final_flow[0, 14, ...], cmap=plt.cm.RdBu)
ax[1].set_title("final flow")

cax3 = ax[2].imshow(solved[0, 14, ...], cmap=plt.cm.RdBu)
ax[2].set_title("solved mesh")

cax4 = ax[3].imshow(inv_map[0, 14, ...], cmap=plt.cm.RdBu)
ax[3].set_title("inverted coordinate map")

# cbar1 = f.colorbar(cax1, fraction=0.046, pad=0.04)
cbar2 = f.colorbar(cax2, fraction=0.046, pad=0.04)
cbar3 = f.colorbar(cax3, fraction=0.046, pad=0.04)
cbar4 = f.colorbar(cax4, fraction=0.046, pad=0.04)
f.tight_layout()

### 5.1 Warp subvolumes
We are now ready to render the aligned volume. To reduce RAM usage, the data is read and warped in chunks.

First, make a new layer for the aligned volume

In [None]:
# Make new EM layer
new_em_layer = dataset.add_layer(
        "postcorrection_realigned_SOFIMA_chunked", 
        wk.COLOR_CATEGORY,
        dtype_per_layer='uint8',
        compress=True)
mag = new_em_layer.add_mag(himag, compress=True)
new_em_layer.bounding_box = bbox_cropped

Render the aligned volume with tiles. This circumvents the OpenCV limit

In [None]:
# # OpenCV limit
# bbox_size = 32766 - 2*pad_size 
pad_size = 3000
bbox_size = 26000

bbox_at_mag = bbox_cropped.align_with_mag(himag)

# Get EM layer
mag_view = EM.get_mag(himag) # MagView
view = mag_view.get_view(absolute_offset=bbox_at_mag.topleft, size=bbox_at_mag.size) # "absolute_offset" and "size" are in Mag(1)!

x0 = view.bounding_box.in_mag(himag).topleft.x
y0 = view.bounding_box.in_mag(himag).topleft.y

bboxes = define_bbox_chunks(view, himag, bbox_size=bbox_size)d
bboxes

In [None]:
offset = bboxes[0].in_mag(himag).topleft # Zero index
for bbox_small in tqdm(bboxes,
                       total=len(bboxes)):
    # Take data from a padded region to ensure no seams are introduced by the warping
    # But only pad data when there is data, i.e. adjust padding to only pad till edge
    # Basically an incredibly overcomplicated transformation
    pad = [pad_size if bbox_small.in_mag(himag).topleft[0] - pad_size >= 0 else 0, # x0 webknossos
           pad_size if bbox_small.in_mag(himag).topleft[1] - pad_size >= 0 else 0, # y0 webknossos
           pad_size if bbox_small.in_mag(himag).bottomright[0] + pad_size <= bbox_at_mag.in_mag(himag).bottomright[0] else 0, # x1 webknossos
           pad_size if bbox_small.in_mag(himag).bottomright[1] + pad_size <= bbox_at_mag.in_mag(himag).bottomright[1] else 0 # y1 webknossos
           ]
    # Data bounding box (padded)
    # note that this is a different object than the wk.BoundingBox! Coordinates are transposed
    # inv coordinate map in himag (SOFIMA coordinates), therefore define bbox in himag
    data_box = bounding_box.BoundingBox(start=(bbox_small.in_mag(himag).topleft[0] - offset[0] - pad[0],
                                               bbox_small.in_mag(himag).topleft[1] - offset[1] - pad[1],
                                               0), 
                                        size=(bbox_small.in_mag(himag).size[0] + pad[0] + pad[2],
                                              bbox_small.in_mag(himag).size[1] + pad[1] + pad[3],
                                              1)
                                        )
    # Target box 
    out_box = bounding_box.BoundingBox(start=(bbox_small.in_mag(himag).topleft[0] - offset[0],
                                              bbox_small.in_mag(himag).topleft[1] - offset[1], 
                                              0), 
                                       size=(bbox_small.in_mag(himag).size[0],
                                             bbox_small.in_mag(himag).size[1],
                                             1)
                                       )
    # Load padded data view (Webknossos coordinates)
    bbox_2_warp = wk.BoundingBox((bbox_small.in_mag(himag).topleft[0] - pad[0], 
                                  bbox_small.in_mag(himag).topleft[1] - pad[1], 
                                  0),
                                 (bbox_small.in_mag(himag).size[0] + pad[0] + pad[2], 
                                  bbox_small.in_mag(himag).size[1] + pad[1] + pad[3], 
                                  bbox_small.in_mag(himag).size[-1])) # ((x0, y0, z0), (x_size, y_size, z_size)))
    data_2_warp = view.read(absolute_bounding_box=bbox_2_warp.from_mag_to_mag1(himag)) # Bounding box is in Mag(1)
    data_2_warp = np.transpose(data_2_warp, (1, 2, 3, 0)) # (x, y, z, 1)        
    # Warp section for section
    warped = [np.transpose(data_2_warp[pad[0]:pad[0]+bbox_small.in_mag(himag).size[0],
                                       pad[1]:pad[1]+bbox_small.in_mag(himag).size[1], 
                                       0:1, 0], [2, 1, 0])]
    for z in tqdm(range(1, bbox_at_mag.size[2])):
        data = np.transpose(data_2_warp[:,
                                        :,
                                        z:z+1, 0:1], [3, 2, 1, 0])
        warped.append(
            warp.warp_subvolume(data, data_box, inv_map_full[:, z:z+1, ...], box_high, stride, out_box, 'lanczos', parallelism=1)[0, ...])
        
    warped_xyz = np.transpose(np.concatenate(warped, axis=0), [2, 1, 0])
    # Write to new layer
    mag.write(warped_xyz,
              absolute_bounding_box=bbox_small.align_with_mag(himag))

In [None]:
# Downsample to all zoom levels ("mags")
MAG_low = list(EM.mags.keys())[-1]
new_em_layer.downsample(from_mag=himag,
                        coarsest_mag=MAG_low,
                        allow_overwrite=True)