# Generating Ground Truth Segmentation Labels with Ultrack

This notebook runs the [Ultrack](https://github.com/royerlab/ultrack) segmentation and tracking pipeline on a small subset of the **ZSNS001_tail** zebrafish embryo dataset to generate dense instance segmentation labels for U-Net training.

Ultrack's approach:
1. **Preprocess**: `detect_foreground` (binary cell mask) + `robust_invert` (boundary/contour map)
2. **Segment**: Generate multiple candidate segmentation hypotheses via hierarchical watershed
3. **Link**: Find candidate connections between segments in adjacent frames
4. **Solve**: Integer Linear Programming (ILP) to select the optimal segmentation + tracking

This gives us high-quality instance labels that are temporally consistent — much better than simple thresholding, and avoids the circularity of training a U-Net from another U-Net's predictions.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import zarr
import dask.array as da
from ome_zarr.io import parse_url
from ome_zarr.reader import Reader
from pathlib import Path
import os

from ultrack import MainConfig, Tracker
from ultrack.imgproc import robust_invert, detect_foreground

## 1. Load the ZSNS001_tail dataset

Same lazy loading as notebook 01. We'll select a small subset of timepoints to keep compute manageable.

In [None]:
URL = "https://public.czbiohub.org/royerlab/zebrahub/imaging/single-objective/ZSNS001_tail.ome.zarr"

reader = Reader(parse_url(URL))
nodes = list(reader())
dask_data = nodes[0].data

data = dask_data[0]  # full resolution
print(f"Full dataset shape (T, C, Z, Y, X): {data.shape}")
print(f"Dtype: {data.dtype}")

n_t, n_c, n_z, n_y, n_x = data.shape
voxel_size = (1.24, 0.439, 0.439)  # Z, Y, X in micrometers

In [None]:
# Select a small subset of timepoints (evenly spaced across development)
N_TIMEPOINTS = 5  # start small — increase once you've verified it works
selected_times = np.linspace(0, n_t - 1, N_TIMEPOINTS, dtype=int)
print(f"Selected timepoints ({N_TIMEPOINTS}): {selected_times}")
print(f"Volume per timepoint: {n_z} x {n_y} x {n_x} = {n_z * n_y * n_x / 1e6:.1f}M voxels")

## 2. Preprocessing

Ultrack requires two inputs:
- **Foreground map**: probability/binary mask of where cells are (from `detect_foreground`)
- **Contour/edge map**: boundary strength between cells (from `robust_invert`)

We preprocess each selected timepoint and stack them into arrays.

In [None]:
# Load and preprocess the selected timepoints
foreground_list = []
contour_list = []
raw_list = []

for i, t in enumerate(selected_times):
    print(f"Processing timepoint {t} ({i+1}/{len(selected_times)})...", end=" ")
    
    # Load the full 3D volume for this timepoint
    volume = data[t, 0].compute()  # shape: (Z, Y, X)
    raw_list.append(volume)
    
    # Detect foreground (binary cell mask)
    fg = detect_foreground(volume, voxel_size=voxel_size)
    foreground_list.append(fg.astype(np.float32))
    
    # Robust inversion (boundary/contour map)
    contours = robust_invert(volume, voxel_size=voxel_size)
    contour_list.append(contours)
    
    print(f"foreground coverage: {fg.mean()*100:.1f}%")

# Stack into (T, Z, Y, X) arrays
foreground_stack = np.stack(foreground_list)
contour_stack = np.stack(contour_list)
raw_stack = np.stack(raw_list)

print(f"\nForeground shape: {foreground_stack.shape}")
print(f"Contour shape: {contour_stack.shape}")
print(f"Raw shape: {raw_stack.shape}")

In [None]:
# Visualize preprocessing for the first timepoint
t_vis = 0
mid_z = n_z // 2

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

raw_slice = raw_stack[t_vis, mid_z]
vmin, vmax = np.percentile(raw_slice, [1, 99.5])

axes[0].imshow(raw_slice, cmap="gray", vmin=vmin, vmax=vmax)
axes[0].set_title(f"Raw image (t={selected_times[t_vis]}, z={mid_z})")
axes[0].axis("off")

axes[1].imshow(foreground_stack[t_vis, mid_z], cmap="gray")
axes[1].set_title("Detected foreground")
axes[1].axis("off")

axes[2].imshow(contour_stack[t_vis, mid_z], cmap="gray")
axes[2].set_title("Contour map (robust_invert)")
axes[2].axis("off")

plt.suptitle("Ultrack preprocessing", fontsize=14)
plt.tight_layout()
plt.show()

## 3. Run the Ultrack pipeline

Configure the tracker with parameters from the [Zebrahub example](https://github.com/royerlab/ultrack/tree/main/examples/zebrahub) and run segmentation + linking + ILP solving.

In [None]:
# Configure Ultrack (parameters from the zebrahub example)
config = MainConfig()

# Segmentation parameters
config.segmentation_config.threshold = 0.5
config.segmentation_config.min_area = 500
config.segmentation_config.max_area = 10_000
config.segmentation_config.n_workers = 1

# Linking parameters
config.linking_config.max_distance = 5.0
config.linking_config.max_neighbors = 5
config.linking_config.n_workers = 1

# Tracking/solving parameters
config.tracking_config.appear_weight = -0.002
config.tracking_config.disappear_weight = -0.01
config.tracking_config.division_weight = -0.001
config.tracking_config.window_size = max(N_TIMEPOINTS, 5)

print("Configuration:")
print(f"  Segmentation: min_area={config.segmentation_config.min_area}, "
      f"max_area={config.segmentation_config.max_area}, "
      f"threshold={config.segmentation_config.threshold}")
print(f"  Linking: max_distance={config.linking_config.max_distance}, "
      f"max_neighbors={config.linking_config.max_neighbors}")
print(f"  Tracking: window_size={config.tracking_config.window_size}")

In [None]:
# Run the full pipeline
tracker = Tracker(config)

print("Running Ultrack pipeline...")
print("  This may take several minutes per timepoint.")
print(f"  Processing {N_TIMEPOINTS} timepoints of shape ({n_z}, {n_y}, {n_x})\n")

tracker.track(
    foreground=foreground_stack,
    edges=contour_stack,
)

print("\nPipeline complete!")

## 4. Export dense instance labels

Extract the dense (T, Z, Y, X) label array where each voxel is assigned its instance/track ID.

In [None]:
# Export to dense label array
OUTPUT_DIR = Path("../data/ground_truth")
LABELS_PATH = OUTPUT_DIR / "labels.zarr"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

labels = tracker.to_zarr(
    time_points=N_TIMEPOINTS,
    store_or_path=str(LABELS_PATH),
    overwrite=True,
)

print(f"Labels shape: {labels.shape}")
print(f"Labels dtype: {labels.dtype}")
print(f"Saved to: {LABELS_PATH.resolve()}")

# Count instances per timepoint
for t in range(labels.shape[0]):
    n_instances = len(np.unique(labels[t])) - 1  # subtract background (0)
    print(f"  t={selected_times[t]}: {n_instances} cell instances")

In [None]:
# Visualize: raw image + instance labels for each timepoint
n_show = min(N_TIMEPOINTS, 5)
fig, axes = plt.subplots(2, n_show, figsize=(5 * n_show, 10))
if n_show == 1:
    axes = axes.reshape(2, 1)

# Random colormap for instance labels
max_label = max(int(labels[t, mid_z].max()) for t in range(n_show))
rand_cmap = ListedColormap(np.random.default_rng(42).random((max(max_label + 1, 2), 3)))
rand_cmap.colors[0] = [0, 0, 0]  # background = black

for i in range(n_show):
    raw_slice = raw_stack[i, mid_z]
    label_slice = labels[i, mid_z]
    vmin, vmax = np.percentile(raw_slice, [1, 99.5])
    
    axes[0, i].imshow(raw_slice, cmap="gray", vmin=vmin, vmax=vmax)
    axes[0, i].set_title(f"t={selected_times[i]}")
    axes[0, i].axis("off")
    
    axes[1, i].imshow(label_slice, cmap=rand_cmap, interpolation="nearest")
    n_inst = len(np.unique(label_slice)) - 1
    axes[1, i].set_title(f"{n_inst} instances")
    axes[1, i].axis("off")

axes[0, 0].set_ylabel("Raw", fontsize=12)
axes[1, 0].set_ylabel("Instance labels", fontsize=12)
plt.suptitle(f"Ultrack instance segmentation (z={mid_z})", fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Overlay: raw image with label boundaries
from skimage.segmentation import find_boundaries

t_vis = 0
raw_slice = raw_stack[t_vis, mid_z]
label_slice = labels[t_vis, mid_z]
vmin, vmax = np.percentile(raw_slice, [1, 99.5])

boundaries = find_boundaries(label_slice, mode="outer")

fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(raw_slice, cmap="gray", vmin=vmin, vmax=vmax)
ax.imshow(np.ma.masked_where(~boundaries, boundaries), cmap="spring", alpha=0.8)
ax.set_title(f"Raw + Ultrack boundaries (t={selected_times[t_vis]}, z={mid_z})")
ax.axis("off")
plt.tight_layout()
plt.show()

## 5. Extract paired training patches

Create (image, mask) pairs for U-Net training. Each pair is a 256x256 XY patch with:
- `images/`: float32 normalized to [0, 1]
- `masks/`: int32 instance label map (0 = background)

In [None]:
PATCH_SIZE = 256
PATCHES_PER_SLICE = 4
MIN_FOREGROUND = 0.1  # skip patches with <10% labeled pixels

IMG_DIR = OUTPUT_DIR / "images"
MASK_DIR = OUTPUT_DIR / "masks"
IMG_DIR.mkdir(parents=True, exist_ok=True)
MASK_DIR.mkdir(parents=True, exist_ok=True)

rng = np.random.default_rng(42)
patch_count = 0

# Sample z-slices from the middle 50% of the volume (edges are often empty)
z_start = n_z // 4
z_end = 3 * n_z // 4
sample_zs = np.linspace(z_start, z_end, 5, dtype=int)

for t_idx in range(N_TIMEPOINTS):
    for z in sample_zs:
        raw_slice = raw_stack[t_idx, z].astype(np.float32)
        
        # Quantile-normalize to [0, 1]
        lo, hi = np.percentile(raw_slice, [0.1, 99.9])
        img_norm = np.clip((raw_slice - lo) / max(hi - lo, 1e-8), 0, 1)
        
        label_slice = np.array(labels[t_idx, z])
        
        for p in range(PATCHES_PER_SLICE):
            y0 = rng.integers(0, n_y - PATCH_SIZE)
            x0 = rng.integers(0, n_x - PATCH_SIZE)
            
            img_patch = img_norm[y0:y0+PATCH_SIZE, x0:x0+PATCH_SIZE]
            mask_patch = label_slice[y0:y0+PATCH_SIZE, x0:x0+PATCH_SIZE]
            
            # Skip patches with too little signal
            fg_frac = (mask_patch > 0).mean()
            if fg_frac < MIN_FOREGROUND:
                continue
            
            t_real = selected_times[t_idx]
            fname = f"t{t_real:04d}_z{z:03d}_y{y0:04d}_x{x0:04d}"
            np.save(IMG_DIR / f"{fname}_img.npy", img_patch.astype(np.float32))
            np.save(MASK_DIR / f"{fname}_mask.npy", mask_patch.astype(np.int32))
            patch_count += 1

print(f"Saved {patch_count} training pairs")
print(f"  Images: {IMG_DIR.resolve()}")
print(f"  Masks:  {MASK_DIR.resolve()}")

In [None]:
# Visualize saved training pairs
saved_imgs = sorted(IMG_DIR.glob("*_img.npy"))
n_show = min(6, len(saved_imgs))

fig, axes = plt.subplots(2, n_show, figsize=(4 * n_show, 8))
if n_show == 1:
    axes = axes.reshape(2, 1)

for i in range(n_show):
    img = np.load(saved_imgs[i])
    mask = np.load(str(saved_imgs[i]).replace("_img.npy", "_mask.npy"))
    
    axes[0, i].imshow(img, cmap="gray", vmin=0, vmax=1)
    axes[0, i].set_title(saved_imgs[i].stem.replace("_img", ""), fontsize=7)
    axes[0, i].axis("off")
    
    n_cells = len(np.unique(mask)) - 1
    cmap = ListedColormap(np.random.default_rng(42).random((max(n_cells + 1, 2), 3)))
    cmap.colors[0] = [0, 0, 0]
    axes[1, i].imshow(mask, cmap=cmap, interpolation="nearest")
    axes[1, i].set_title(f"{n_cells} cells", fontsize=8)
    axes[1, i].axis("off")

axes[0, 0].set_ylabel("Raw (normalized)", fontsize=10)
axes[1, 0].set_ylabel("Instance mask", fontsize=10)
plt.suptitle("Training patch pairs", fontsize=14)
plt.tight_layout()
plt.show()

## 6. Label quality assessment

In [None]:
from skimage.measure import regionprops

all_areas = []
cells_per_patch = []

for mask_path in sorted(MASK_DIR.glob("*_mask.npy")):
    mask = np.load(mask_path)
    props = regionprops(mask)
    areas = [p.area for p in props]
    all_areas.extend(areas)
    cells_per_patch.append(len(props))

all_areas = np.array(all_areas)

print(f"Total patches: {len(cells_per_patch)}")
print(f"Total cell instances: {len(all_areas)}")
print(f"Cells per patch: {np.mean(cells_per_patch):.1f} +/- {np.std(cells_per_patch):.1f}")
print(f"Cell area (pixels): median={np.median(all_areas):.0f}, "
      f"mean={np.mean(all_areas):.0f}, std={np.std(all_areas):.0f}")
print(f"Cell area range: [{all_areas.min()}, {all_areas.max()}]")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].hist(all_areas, bins=50, edgecolor="black")
axes[0].set_xlabel("Cell area (pixels)")
axes[0].set_ylabel("Count")
axes[0].set_title("Cell size distribution")

axes[1].hist(cells_per_patch, bins=range(0, max(cells_per_patch) + 2), edgecolor="black")
axes[1].set_xlabel("Number of cells")
axes[1].set_ylabel("Number of patches")
axes[1].set_title("Cells per patch")

plt.tight_layout()
plt.show()

## Next steps

- **Scale up**: Increase `N_TIMEPOINTS` to generate more training data (the pipeline processes one frame at a time, so RAM isn't a bottleneck)
- **Manual inspection**: Review a sample of patches in napari or QuPath to verify label quality
- **Data augmentation**: Plan flips, rotations, intensity jitter for training
- **Train/val split**: Hold out entire timepoints for validation to test generalization across developmental stages
- **U-Net training**: Build a PyTorch DataLoader from these patches and train a segmentation model