
# CircuitSeeker multimodal registration (clean)

This notebook is a cleaned, end-to-end multimodal registration workflow using **CircuitSeeker**.

## Workflow
1. **Inputs & obvious corrections**
   - Define a single global identifier: `run_id = "<exp_id>_fish<fish>"` for naming outputs
   - Load fixed/moving TIFF stacks and convert axis order to **(X, Y, Z)**
   - Pad **±20 µm** in Z (recommended)

2. **Foreground detection**
   - Coarse brain masks for fixed and moving

3. **Moments initialization**
   - Principal-axes alignment (`modes`)

4. **Global alignment**
   - rigid → affine → deform (B-spline)

5. **Wiggle refinement**
   - Nested piecewise alignment (runs locally via ClusterWrap `local_cluster`)

6. **Invert transforms + sanity check**
   - Save inverse transforms and warp fixed → moving space

## Notes
- **Axis order**: TIFF stacks are assumed **(Z, Y, X)** and are converted to **(X, Y, Z)** internally.
- **Spacing units**: microns (µm), order **(X, Y, Z)**.
- All outputs are written under `out_path` with the prefix: **`<exp_id>_fish<fish>_...`**


In [6]:

# Imports (clean)
from pathlib import Path
import json
import numpy as np
import nrrd

from tifffile import imread, imwrite
from scipy.ndimage import zoom, binary_closing, binary_dilation

from CircuitSeeker import level_set
from CircuitSeeker.axisalign import principal_axes, align_modes
from CircuitSeeker.align import alignment_pipeline, nested_distributed_piecewise_alignment_pipeline
from CircuitSeeker.transform import apply_transform, invert_displacement_vector_field


In [46]:

# =====================
# Experiment parameters
# =====================

# Global identifiers used for ALL output names
exp_id = "exp_001"     # e.g. "001" or "exp1_110425"
fish = 2           # integer

run_id = f"{exp_id}_fish{fish}"

# Spacing (µm), order: (X, Y, Z)
mov_spacing = np.array([1.0, 1.0, 2.0], dtype=float)
fix_spacing = np.array([1.1, 1.1, 1.0], dtype=float)

# Wiggle random search iterations
randomiter = 25

# Z padding (physical units)
pad_um = 20.0

# =====================
# Paths (edit these)
# =====================

mov_path = Path(
    "/Users/jonathanboulanger-weill/Harvard University Dropbox/"
    "Jonathan Boulanger-Weill/Projects/calcium-spatial-transcriptomics-align/"
    "data/exp1_110425/2p_stacks/"
    "2025-10-13_16-04-47_fish002_setup1_arena0_MW_preprocessed_data_repeat00_tile000_950nm_0_flippedxz.tif"
)

fix_path = Path(
    "/Users/jonathanboulanger-weill/Harvard University Dropbox/"
    "Jonathan Boulanger-Weill/Projects/calcium-spatial-transcriptomics-align/"
    "data/exp1_110425/oct_confocal_stacks/benchmark_data/fish2/prealigned/"
    "20x_4us_1um_DAPI_GFP488_RFP594_fish2_s1_montaged_MattesMI_GCaMP_ch1.tif"
)

out_path = Path(
    "/Users/jonathanboulanger-weill/Harvard University Dropbox/"
    "Jonathan Boulanger-Weill/Projects/calcium-spatial-transcriptomics-align/"
    "data/exp1_110425/CircuitSeeker_output"
)
out_path.mkdir(parents=True, exist_ok=True)

intermediates_path = out_path / "intermediates"
intermediates_path.mkdir(parents=True, exist_ok=True)

print("Fixed stack:\n", fix_path)
print("\nMoving stack:\n", mov_path)
print("\nOutput folder:\n", out_path)


Fixed stack:
 /Users/jonathanboulanger-weill/Harvard University Dropbox/Jonathan Boulanger-Weill/Projects/calcium-spatial-transcriptomics-align/data/exp1_110425/oct_confocal_stacks/benchmark_data/fish2/prealigned/20x_4us_1um_DAPI_GFP488_RFP594_fish2_s1_montaged_MattesMI_GCaMP_ch1.tif

Moving stack:
 /Users/jonathanboulanger-weill/Harvard University Dropbox/Jonathan Boulanger-Weill/Projects/calcium-spatial-transcriptomics-align/data/exp1_110425/2p_stacks/2025-10-13_16-04-47_fish002_setup1_arena0_MW_preprocessed_data_repeat00_tile000_950nm_0_flippedxz.tif

Output folder:
 /Users/jonathanboulanger-weill/Harvard University Dropbox/Jonathan Boulanger-Weill/Projects/calcium-spatial-transcriptomics-align/data/exp1_110425/CircuitSeeker_output


In [47]:

# =====================
# Helpers
# =====================

def save_path(tag: str, ext: str) -> Path:
    """Standard output path: <out_path>/<run_id>_<tag>.<ext>"""
    return out_path / f"{run_id}_{tag}.{ext}"

def write_meta(**kwargs):
    meta = dict(run_id=run_id, exp_id=exp_id, fish=int(fish), **kwargs)
    p = save_path("metadata", "json")
    with open(p, "w") as f:
        json.dump(meta, f, indent=2, sort_keys=True)
    print("[save]", p)

def to_xyz(vol_zyx: np.ndarray) -> np.ndarray:
    """Convert TIFF order (Z,Y,X) -> (X,Y,Z)."""
    if vol_zyx.ndim != 3:
        raise ValueError(f"Expected 3D TIFF stack (Z,Y,X). Got {vol_zyx.shape}")
    return vol_zyx.transpose(2, 1, 0)

def to_zyx(vol_xyz: np.ndarray) -> np.ndarray:
    """Convert (X,Y,Z) -> TIFF order (Z,Y,X)."""
    if vol_xyz.ndim != 3:
        raise ValueError(f"Expected 3D volume (X,Y,Z). Got {vol_xyz.shape}")
    return vol_xyz.transpose(2, 1, 0)

def pad_z_um(vol_xyz: np.ndarray, spacing_z_um: float, pad_um: float, fill=0):
    """Pad in Z by pad_um (µm) on each side. Volume is (X,Y,Z)."""
    n = int(np.round(float(pad_um) / float(spacing_z_um)))
    if n <= 0:
        return vol_xyz, 0
    vol_p = np.pad(vol_xyz, ((0, 0), (0, 0), (n, n)), mode="constant", constant_values=fill)
    return vol_p, n

def brain_mask(vol_xyz: np.ndarray, spacing_xyz: np.ndarray, lambda2: float):
    """Coarse brain mask via level_set on a downsampled volume, then upsample + smooth."""
    vol_skip = vol_xyz[::2, ::2, ::2]
    skip_spacing = spacing_xyz * np.array([2, 2, 2], dtype=float)

    mask_small = level_set.brain_detection(
        vol_skip, skip_spacing,
        mask_smoothing=2,
        iterations=[80, 40, 10],
        smooth_sigmas=[12, 6, 3],
        lambda2=lambda2,
    )

    mask = zoom(mask_small, np.array(vol_xyz.shape) / np.array(vol_skip.shape), order=0)
    mask = binary_closing(mask, np.ones((5, 5, 5))).astype(np.uint8)
    mask = binary_dilation(mask, np.ones((5, 5, 5))).astype(np.uint8)
    return mask



## 1) Inputs and obvious corrections

Load fixed and moving stacks, convert to **(X, Y, Z)**, and pad **±20 µm** in Z.
Also saves padded TIFFs for QC.


In [48]:

# Load TIFFs (assumed Z,Y,X) -> convert to X,Y,Z
fix = to_xyz(imread(fix_path))
mov = to_xyz(imread(mov_path))

print("Original shapes (X,Y,Z):")
print("  fix:", fix.shape)
print("  mov:", mov.shape)

# Pad in Z (±pad_um)
fix, fix_pad_slices = pad_z_um(fix, fix_spacing[2], pad_um, fill=0)
mov, mov_pad_slices = pad_z_um(mov, mov_spacing[2], pad_um, fill=0)

print(f"Padding applied: fix +/- {fix_pad_slices} slices, mov +/- {mov_pad_slices} slices")
print("Padded shapes (X,Y,Z):")
print("  fix:", fix.shape)
print("  mov:", mov.shape)

# Save padded stacks (TIFF order Z,Y,X)
fix_padded_path = save_path("fixed_pad20um", "tif")
mov_padded_path = save_path("moving_pad20um", "tif")

imwrite(fix_padded_path, to_zyx(fix), bigtiff=True)
imwrite(mov_padded_path, to_zyx(mov), bigtiff=True)

print("[save] padded stacks:")
print("  fix:", fix_padded_path)
print("  mov:", mov_padded_path)

write_meta(
    fixed_path=str(fix_path),
    moving_path=str(mov_path),
    fix_spacing_um=list(map(float, fix_spacing)),
    mov_spacing_um=list(map(float, mov_spacing)),
    pad_um=float(pad_um),
    fix_pad_slices=int(fix_pad_slices),
    mov_pad_slices=int(mov_pad_slices),
)


Original shapes (X,Y,Z):
  fix: (1024, 1024, 37)
  mov: (799, 799, 56)
Padding applied: fix +/- 20 slices, mov +/- 10 slices
Padded shapes (X,Y,Z):
  fix: (1024, 1024, 77)
  mov: (799, 799, 76)
[save] padded stacks:
  fix: /Users/jonathanboulanger-weill/Harvard University Dropbox/Jonathan Boulanger-Weill/Projects/calcium-spatial-transcriptomics-align/data/exp1_110425/CircuitSeeker_output/exp_001_fish2_fixed_pad20um.tif
  mov: /Users/jonathanboulanger-weill/Harvard University Dropbox/Jonathan Boulanger-Weill/Projects/calcium-spatial-transcriptomics-align/data/exp1_110425/CircuitSeeker_output/exp_001_fish2_moving_pad20um.tif
[save] /Users/jonathanboulanger-weill/Harvard University Dropbox/Jonathan Boulanger-Weill/Projects/calcium-spatial-transcriptomics-align/data/exp1_110425/CircuitSeeker_output/exp_001_fish2_metadata.json



## 2) Foreground masks (fixed + moving)


In [49]:

fix_mask = brain_mask(fix, fix_spacing, lambda2=32.0)
mov_mask = brain_mask(mov, mov_spacing, lambda2=64.0)

fix_mask_path = save_path("fix_mask", "nrrd")
mov_mask_path = save_path("mov_mask", "nrrd")

nrrd.write(str(fix_mask_path), fix_mask)
nrrd.write(str(mov_mask_path), mov_mask)

print("[save] masks:")
print("  fix:", fix_mask_path)
print("  mov:", mov_mask_path)


[save] masks:
  fix: /Users/jonathanboulanger-weill/Harvard University Dropbox/Jonathan Boulanger-Weill/Projects/calcium-spatial-transcriptomics-align/data/exp1_110425/CircuitSeeker_output/exp_001_fish2_fix_mask.nrrd
  mov: /Users/jonathanboulanger-weill/Harvard University Dropbox/Jonathan Boulanger-Weill/Projects/calcium-spatial-transcriptomics-align/data/exp1_110425/CircuitSeeker_output/exp_001_fish2_mov_mask.nrrd



## 3) Moments initialization (modes)


In [50]:

fix_mean, fix_evals, fix_evecs = principal_axes(fix_mask, fix_spacing)
mov_mean, mov_evals, mov_evecs = principal_axes(mov_mask, mov_spacing)

modes = align_modes(fix_mean, fix_evecs, mov_mean, mov_evecs)

modes_aligned = apply_transform(
    fix, mov,
    fix_spacing, mov_spacing,
    transform_list=[modes],
)

np.savetxt(save_path("modes", "mat"), modes)
nrrd.write(str(save_path("modes_aligned", "nrrd")), modes_aligned, compression_level=2)

print("[save] modes:", save_path("modes", "mat"))


[save] modes: /Users/jonathanboulanger-weill/Harvard University Dropbox/Jonathan Boulanger-Weill/Projects/calcium-spatial-transcriptomics-align/data/exp1_110425/CircuitSeeker_output/exp_001_fish2_modes.mat



## 4) Global alignment (rigid → affine → deform)


In [51]:

affine, deform = alignment_pipeline(
    fix, mov, fix_spacing, mov_spacing,
    steps=["rigid", "affine", "deform"],
    initial_transform=modes,
    alignment_spacing=2.0,
    shrink_factors=[2],
    smooth_sigmas=[2.0],
    iterations=400,
    deform_kwargs={
        "control_point_spacing": 10.0,
        "control_point_levels": [1, 2, 4, 8, 16, 32, 64],
    },
)

# alignment_pipeline returns (affine, (params, field))
deform_field = deform[1]

affine_aligned = apply_transform(fix, mov, fix_spacing, mov_spacing, transform_list=[affine])
deform_aligned = apply_transform(fix, mov, fix_spacing, mov_spacing, transform_list=[affine, deform_field])

np.savetxt(save_path("affine", "mat"), affine)
nrrd.write(str(save_path("deform", "nrrd")), deform_field, compression_level=2)
nrrd.write(str(save_path("affine_aligned", "nrrd")), affine_aligned, compression_level=2)
nrrd.write(str(save_path("deform_aligned", "nrrd")), deform_aligned, compression_level=2)

print("[save] global alignment outputs under:", out_path)


LEVEL:  0  ITERATION:  0  METRIC:  -0.17726983213794648
LEVEL:  0  ITERATION:  1  METRIC:  -0.17787422987673498
LEVEL:  0  ITERATION:  2  METRIC:  -0.17831171513719746
LEVEL:  0  ITERATION:  3  METRIC:  -0.17867152629561708
LEVEL:  0  ITERATION:  4  METRIC:  -0.1789376555783284
LEVEL:  0  ITERATION:  5  METRIC:  -0.17907158765975156
LEVEL:  0  ITERATION:  6  METRIC:  -0.1793485053423143
LEVEL:  0  ITERATION:  7  METRIC:  -0.17949994764073604
LEVEL:  0  ITERATION:  8  METRIC:  -0.17955921088355345
LEVEL:  0  ITERATION:  9  METRIC:  -0.17958875364846028
LEVEL:  0  ITERATION:  10  METRIC:  -0.17960392724896887
LEVEL:  0  ITERATION:  11  METRIC:  -0.1796799970299679
LEVEL:  0  ITERATION:  12  METRIC:  -0.17975728752058293
LEVEL:  0  ITERATION:  13  METRIC:  -0.17983628866716758
LEVEL:  0  ITERATION:  14  METRIC:  -0.179868870646079
LEVEL:  0  ITERATION:  15  METRIC:  -0.17988743866785314
LEVEL:  0  ITERATION:  16  METRIC:  -0.17997135725587618
LEVEL:  0  ITERATION:  17  METRIC:  -0.1799677


## 5) Wiggle refinement (nested piecewise alignment)


In [None]:

block_schedule = [[tuple(np.maximum(1, np.round(np.array(fix.shape) / 32).astype(int)))]]

parameter_schedule = [{
    "random_kwargs": {
        "max_translation": 10.0,
        "max_rotation": 10.0 * np.pi / 180.0,
        "max_scale": 1.10,
        "max_shear": 0.10,
        "random_iterations": int(randomiter),
        "affine_align_best": 10,
        "iterations": 24,
    },
    "affine_kwargs": {},
    'deform_kwargs':{},
}]

wiggle = nested_distributed_piecewise_alignment_pipeline(
    fix, mov, fix_spacing, mov_spacing,
    block_schedule,
    parameter_schedule=parameter_schedule,
    initial_transform_list=[affine, deform_field],
    fix_mask=fix_mask,
    mov_mask=mov_mask,
    steps=["random", "affine", "deform"],
    bins=256,
    shrink_factors=[1],
    smooth_sigmas=[6.0],
    iterations=400,
    learning_rate=0.1,
    max_step=0.1,
    estimate_learning_rate="never",
    callback=lambda irm: None,
    intermediates_path=str(intermediates_path),
    cluster_kwargs={"cluster_type": "local_cluster", "n_workers": 4},
)

wiggled = apply_transform(
    fix, mov, fix_spacing, mov_spacing,
    transform_list=[affine, deform_field, wiggle],
)

nrrd.write(str(save_path("wiggle", "nrrd")), wiggle, compression_level=2)
nrrd.write(str(save_path("wiggled_aligned", "nrrd")), wiggled, compression_level=2)

print("[save] wiggle outputs:")
print("  wiggle:", save_path("wiggle", "nrrd"))
print("  wiggled:", save_path("wiggled_aligned", "nrrd"))






Random search failed due to ITK exception:
 Exception thrown in SimpleITK ImageRegistrationMethod_MetricEvaluate: /Users/runner/work/SimpleITK/SimpleITK/bld/ITK-prefix/include/ITK-5.4/itkRecursiveSeparableImageFilter.hxx:226:
ITK ERROR: RecursiveGaussianImageFilter(0x7fdb3a205cd0): The number of pixels along direction 0 is less than 4. This filter requires a minimum of four pixels along the dimension to be processed.
Returning default
Registration failed due to ITK exception:
 Exception thrown in SimpleITK ImageRegistrationMethod_MetricEvaluate: /Users/runner/work/SimpleITK/SimpleITK/bld/ITK-prefix/include/ITK-5.4/itkRecursiveSeparableImageFilter.hxx:226:
ITK ERROR: RecursiveGaussianImageFilter(0x7fdb605b88f0): The number of pixels along direction 0 is less than 4. This filter requires a minimum of four pixels along the dimension to be processed.

Returning default
Random search failed due to ITK exception:
 Exception thrown in SimpleITK ImageRegistrationMethod_MetricEvaluate: /Use

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


Optimization failed to improve metric
initial value: -0.14547586555223802
final value: -0.14464107533904086
Returning default
Optimization failed to improve metric
initial value: -0.13445044394062908
final value: -0.13429971836477728
Returning default
Optimization failed to improve metric
initial value: -0.11999284095962698
final value: -0.11915526706025525
Returning default
Optimization failed to improve metric
initial value: -0.11809635379324994
final value: -0.11661290070884052
Returning default
Optimization failed to improve metric
initial value: -0.11358096430560893
final value: -0.11216713235891511
Returning default
Optimization failed to improve metric
initial value: -0.13040410052807355
final value: -0.1303272123044762
Returning default
Optimization failed to improve metric
initial value: -0.11223508296290426
final value: -0.11153720909027341
Returning default
Optimization failed to improve metric
initial value: -0.13015773044365375
final value: -0.13014801404702694
Returning d


## 6) Invert transforms + sanity check


In [None]:

affine_inv = np.linalg.inv(affine)
np.savetxt(save_path("affine_inv", "mat"), affine_inv)

deform_inv = invert_displacement_vector_field(deform_field, fix_spacing)
nrrd.write(str(save_path("deform_inv", "nrrd")), deform_inv, compression_level=2)

wiggle_inv = invert_displacement_vector_field(wiggle, fix_spacing)
nrrd.write(str(save_path("wiggle_inv", "nrrd")), wiggle_inv, compression_level=2)

# sanity check: warp fixed -> moving space
fix_to_mov = apply_transform(
    mov, fix, mov_spacing, fix_spacing,
    transform_list=[wiggle_inv, deform_inv, affine_inv],
    transform_spacing=fix_spacing,
)
nrrd.write(str(save_path("fix_warped_to_moving", "nrrd")), fix_to_mov, compression_level=2)

print("[save] inverse transforms + sanity warp written under:", out_path)
