# CPU proof-of-concept connectomics pipeline on a small public *C. elegans* EM cutout.
Connectomes are maps of synaptic wiring that help explain how nervous systems produce behaviour. The *C. elegans* connectome showed that a complete wiring diagram is possible and biologically useful, and later reconstructions improved coverage, accuracy, and developmental, sex-specific insights. However, volume EM generates enormous image stacks, and turning them into neuron and synapse graphs requires robust image processing, machine learning, and long-term computing infrastructure. Modern pipelines use ML plus human proofreading platforms, but running the full workflow end to end is often out of reach on modest hardware. To bridge this gap, we built a CPU-friendly proof of concept using a small BossDB cutout, implementing each stage in a transparent, reproducible way, producing diagnostic figures and intermediate outputs that can be refined by humans and later upgraded to deep learning.
## Data source: BossDB (downloads only a small cutout).
## Default dataset URI: bossdb://mulcahy2022/1h_L1/em

## Pipeline steps:
1) Download cutout stack (Z,Y,X) as numpy
2) Preprocess: denoise + normalize + CLAHE
3) Align slices (rigid translation via phase correlation)
4) Segment neurites (2D watershed per slice + simple 3D stitching)
5) Detect synapse candidates (heuristic, blob-like dark features near inter-neurite boundaries)
6) Build wiring graph (NetworkX) + basic stats
7) Save publishable figures (PDF/PNG)

## New flags for small machines:
- --preset safe8gb, sets a conservative cutout
- --save_minimal, saves only aligned volume, labels, tables, figures
- --aligned_dtype float16, halves aligned volume disk usage

In [6]:
import sys, pkgutil
print(sys.executable)
print("intern" in [m.name for m in pkgutil.iter_modules()])

/opt/anaconda3/bin/python
True


In [12]:


import argparse
import json
import os
import sqlite3
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
from tqdm import tqdm

from skimage import exposure
from skimage.filters import gaussian, sobel
from skimage.morphology import remove_small_objects, binary_opening, disk
from skimage.segmentation import watershed, find_boundaries
from skimage.feature import peak_local_max, blob_log
from skimage.registration import phase_cross_correlation
from skimage.transform import warp
from skimage.transform._warps import SimilarityTransform
from scipy import ndimage as ndi

try:
    from intern import array as boss_array
except Exception:
    boss_array = None



# Config

@dataclass
class RunConfig:
    dataset_uri: str
    outdir: Path
    cutout_xyz: Tuple[int, int, int]  # (X, Y, Z)
    seed: int = 7
    simulate_misalignment: bool = True
    save_minimal: bool = False
    aligned_dtype: str = "float32"


# I/O helpers

def ensure_dirs(outdir: Path) -> Dict[str, Path]:
    outdir.mkdir(parents=True, exist_ok=True)
    paths = {
        "root": outdir,
        "data": outdir / "data",
        "figures": outdir / "figures",
        "tables": outdir / "tables",
        "cache": outdir / "cache",
        "runs": outdir / "runs",
    }
    for p in paths.values():
        p.mkdir(parents=True, exist_ok=True)
    return paths


def init_db(db_path: Path) -> None:
    con = sqlite3.connect(db_path)
    cur = con.cursor()
    cur.execute(
        """
        CREATE TABLE IF NOT EXISTS runs (
            run_id TEXT PRIMARY KEY,
            dataset_uri TEXT,
            cutout_x INTEGER,
            cutout_y INTEGER,
            cutout_z INTEGER,
            created_at TEXT,
            meta_json TEXT
        );
        """
    )
    con.commit()
    con.close()


def log_run(db_path: Path, run_id: str, cfg: RunConfig, meta: Dict) -> None:
    con = sqlite3.connect(db_path)
    cur = con.cursor()
    cur.execute(
        """
        INSERT OR REPLACE INTO runs(run_id, dataset_uri, cutout_x, cutout_y, cutout_z, created_at, meta_json)
        VALUES (?, ?, ?, ?, ?, datetime('now'), ?);
        """,
        (
            run_id,
            cfg.dataset_uri,
            cfg.cutout_xyz[0],
            cfg.cutout_xyz[1],
            cfg.cutout_xyz[2],
            json.dumps(meta, indent=2),
        ),
    )
    con.commit()
    con.close()


def save_json(path: Path, obj: Dict) -> None:
    path.write_text(json.dumps(obj, indent=2))


def save_figure(fig, path_png: Path, path_pdf: Path, dpi: int = 300) -> None:
    fig.tight_layout()
    fig.savefig(path_png, dpi=dpi, bbox_inches="tight")
    fig.savefig(path_pdf, bbox_inches="tight")
    plt.close(fig)


def sizeof_gb(n_bytes: int) -> float:
    return float(n_bytes) / (1024.0 ** 3)


def estimate_volume_bytes(xyz: Tuple[int, int, int], dtype: str) -> int:
    x, y, z = xyz
    n = x * y * z
    b = np.dtype(dtype).itemsize
    # volume arrays are stored as (Z,Y,X), size is the same
    return n * b


# 1) Download data (small cutout)

def download_bossdb_cutout(dataset_uri: str, cutout_xyz: Tuple[int, int, int], seed: int) -> np.ndarray:
    """
    Download a small cutout from a BossDB dataset using intern.

    Returns:
        vol_zyx: uint8 volume shaped (Z, Y, X)
    """
    if boss_array is None:
        raise RuntimeError("Could not import 'intern'. Install it with: pip install intern")

    em = boss_array(dataset_uri)

    # intern uses Z,Y,X indexing, em.shape is (Z,Y,X)
    zyx_shape = em.shape
    Z, Y, X = int(zyx_shape[0]), int(zyx_shape[1]), int(zyx_shape[2])

    cx, cy, cz = X // 2, Y // 2, Z // 2
    sx, sy, sz = cutout_xyz
    rng = np.random.default_rng(seed)

    # pick a random center near the middle (avoid empty edges)
    jitter_x = int(rng.integers(-X * 0.05, X * 0.05))
    jitter_y = int(rng.integers(-Y * 0.05, Y * 0.05))
    jitter_z = int(rng.integers(-Z * 0.05, Z * 0.05))

    x0 = max(0, min(X - sx, cx - sx // 2 + jitter_x))
    y0 = max(0, min(Y - sy, cy - sy // 2 + jitter_y))
    z0 = max(0, min(Z - sz, cz - sz // 2 + jitter_z))

    x1, y1, z1 = x0 + sx, y0 + sy, z0 + sz

    vol = em[z0:z1, y0:y1, x0:x1]

    if vol.dtype != np.uint8:
        vol = exposure.rescale_intensity(vol, out_range=(0, 255)).astype(np.uint8)

    return vol


# 2) Preprocess (denoise + normalize)

def preprocess_volume(vol_zyx_u8: np.ndarray) -> np.ndarray:
    """
    Per-slice preprocessing for EM:
    - light gaussian denoise
    - contrast normalization via CLAHE
    Returns float32 volume in [0,1].
    """
    out = np.empty_like(vol_zyx_u8, dtype=np.float32)
    for z in range(vol_zyx_u8.shape[0]):
        img = vol_zyx_u8[z].astype(np.float32) / 255.0
        img = gaussian(img, sigma=0.7, preserve_range=True)
        img = exposure.equalize_adapthist(img, clip_limit=0.01)
        out[z] = img.astype(np.float32)
    return out


# 3) Align slices (rigid translation)

def simulate_misalignment(vol: np.ndarray, seed: int) -> np.ndarray:
    """
    Apply small random translations to each slice to simulate misalignment.
    """
    rng = np.random.default_rng(seed)
    Z = vol.shape[0]
    out = np.empty_like(vol)
    for z in range(Z):
        dy = float(rng.normal(0, 1.0))
        dx = float(rng.normal(0, 1.0))
        tform = SimilarityTransform(translation=(dx, dy))
        out[z] = warp(vol[z], tform.inverse, preserve_range=True, mode="edge").astype(vol.dtype)
    return out


def align_slices_translation(vol: np.ndarray) -> Tuple[np.ndarray, List[Tuple[float, float]]]:
    """
    Align each slice to the previous using phase cross-correlation (translation only).
    Returns aligned volume and list of (dy, dx) shifts.
    """
    Z = vol.shape[0]
    aligned = np.empty_like(vol)
    aligned[0] = vol[0]
    shifts = [(0.0, 0.0)]

    ref = vol[0]
    for z in range(1, Z):
        shift, _, _ = phase_cross_correlation(ref, vol[z], upsample_factor=10)
        dy, dx = float(shift[0]), float(shift[1])
        tform = SimilarityTransform(translation=(-dx, -dy))
        aligned[z] = warp(vol[z], tform.inverse, preserve_range=True, mode="edge").astype(vol.dtype)
        shifts.append((dy, dx))
        ref = aligned[z]
    return aligned, shifts


# 4) Segmentation (watershed baseline + 3D stitching)

def relabel_compact(lbl: np.ndarray) -> np.ndarray:
    ids = np.unique(lbl)
    ids = ids[ids != 0]
    mapping = {old: new for new, old in enumerate(ids, start=1)}
    out = lbl.copy()
    mask = out > 0
    if mask.any():
        out[mask] = np.vectorize(mapping.get)(out[mask])
    return out.astype(np.int32)


def stitch_labels_by_overlap(labels_zyx: np.ndarray, min_overlap: int = 200) -> np.ndarray:
    """
    Stitch 2D labels into consistent 3D IDs by overlapping area across adjacent slices.

    Strategy:
    - make all slice labels unique by offsetting
    - merge IDs across adjacent slices by large overlaps
    """
    Z = labels_zyx.shape[0]
    out = labels_zyx.copy()

    current_max = 0
    for z in range(Z):
        sl = out[z]
        sl_nonzero = sl > 0
        if sl_nonzero.any():
            sl = sl.copy()
            sl[sl_nonzero] += current_max
            out[z] = sl
            current_max = int(out[z].max())

        parent = {}

    def find(a: int) -> int:
        # Find root
        root = a
        while parent.get(root, root) != root:
            root = parent.get(root, root)

        # Path compression (safe)
        while parent.get(a, a) != a:
            p = parent.get(a, a)
            parent[a] = root
            a = p

        return root

    def union(a: int, b: int) -> None:
        ra, rb = find(a), find(b)
        if ra != rb:
            parent[rb] = ra

    for z in tqdm(range(Z - 1), desc="Stitching 3D IDs"):
        a = out[z]
        b = out[z + 1]
        mask = a > 0
        if mask.sum() == 0:
            continue

        pairs = np.stack([a[mask], b[mask]], axis=1)
        uniq, counts = np.unique(pairs, axis=0, return_counts=True)
        for (aid, bid), cnt in zip(uniq, counts):
            aid = int(aid)
            bid = int(bid)
            if bid == 0:
                continue
            if cnt >= min_overlap:
                union(aid, bid)

    flat = out.reshape(-1)
    nonzero = flat > 0
    ids = flat[nonzero]
    if ids.size > 0:
        flat[nonzero] = np.array([find(int(x)) for x in ids], dtype=np.int32)

    out = flat.reshape(out.shape)
    out = relabel_compact(out)
    return out


def segment_neurites_watershed(vol: np.ndarray) -> np.ndarray:
    """
    Baseline neurite segmentation (proof-of-concept):
    - foreground proxy by quantile thresholding
    - watershed from distance seeds
    - slice-wise labels stitched to 3D IDs by overlap
    """
    Z, Y, X = vol.shape
    labels = np.zeros((Z, Y, X), dtype=np.int32)

    for z in tqdm(range(Z), desc="Segmenting slices"):
        img = vol[z]

        thr = np.quantile(img, 0.55)
        fg = img > thr
        fg = binary_opening(fg, disk(1))
        fg = remove_small_objects(fg, min_size=128)

        if fg.sum() < 256:
            continue

        dist = ndi.distance_transform_edt(fg)
        coords = peak_local_max(dist, min_distance=8, labels=fg, exclude_border=False)

        markers = np.zeros_like(img, dtype=np.int32)
        for i, (r, c) in enumerate(coords, start=1):
            markers[r, c] = i
        markers = ndi.label(markers > 0)[0]
        if markers.max() == 0:
            continue

        elevation = sobel(img)
        seg = watershed(elevation, markers=markers, mask=fg)
        labels[z] = seg.astype(np.int32)

    labels = stitch_labels_by_overlap(labels, min_overlap=200)
    return labels


# 5) Synapse candidate detection (heuristic)

def detect_synapse_candidates(vol: np.ndarray, labels: np.ndarray) -> pd.DataFrame:
    """
    Proof-of-concept synapse candidate detection:
    - boundaries between segments
    - dark blob-like features near boundaries (blob_log on inverted intensity)
    - assign candidate to (pre_id, post_id) based on local neighborhood labels
    Direction is heuristic, used only to construct a directed graph.
    """
    rows = []
    Z = vol.shape[0]

    for z in tqdm(range(Z), desc="Detecting synapse candidates"):
        img = vol[z]
        lbl = labels[z]

        if lbl.max() < 2:
            continue

        bnd = find_boundaries(lbl, mode="outer")
        if bnd.sum() < 64:
            continue

        img_n = (img - img.min()) / (img.max() - img.min() + 1e-6)
        inv = 1.0 - img_n

        blobs = blob_log(inv, min_sigma=1, max_sigma=4, num_sigma=6, threshold=0.08)

        for (y, x, s) in blobs:
            y = int(round(y))
            x = int(round(x))
            if y < 2 or x < 2 or y >= img.shape[0] - 2 or x >= img.shape[1] - 2:
                continue

            y0, y1 = max(0, y - 3), min(img.shape[0], y + 4)
            x0, x1 = max(0, x - 3), min(img.shape[1], x + 4)
            if bnd[y0:y1, x0:x1].sum() == 0:
                continue

            neigh = lbl[y0:y1, x0:x1]
            ids = neigh[neigh > 0]
            if ids.size < 10:
                continue
            uniq, cnt = np.unique(ids, return_counts=True)
            if uniq.size < 2:
                continue

            order = np.argsort(-cnt)
            a = int(uniq[order[0]])
            b = int(uniq[order[1]])

            win = inv[y0:y1, x0:x1]
            mean_a = float(win[neigh == a].mean()) if (neigh == a).any() else 0.0
            mean_b = float(win[neigh == b].mean()) if (neigh == b).any() else 0.0

            if mean_a >= mean_b:
                pre_id, post_id, score = a, b, mean_a
            else:
                pre_id, post_id, score = b, a, mean_b

            rows.append({"z": z, "y": y, "x": x, "pre_id": pre_id, "post_id": post_id, "score": score})

    if not rows:
        return pd.DataFrame(columns=["z", "y", "x", "pre_id", "post_id", "score"])

    df = pd.DataFrame(rows).sort_values("score", ascending=False).reset_index(drop=True)

    # de-duplicate close detections
    
    df["key"] = (
        df["z"].astype(int).astype(str) + "_" +
        (df["y"].astype(int) // 4).astype(str) + "_" +
        (df["x"].astype(int) // 4).astype(str) + "_" +
        df["pre_id"].astype(int).astype(str) + "_" +
        df["post_id"].astype(int).astype(str)
    )
    df = df.drop_duplicates("key", keep="first").drop(columns=["key"]).reset_index(drop=True)
    return df


# 6) Graph

def build_graph(syn_df: pd.DataFrame, min_score: float = 0.10) -> nx.DiGraph:
    G = nx.DiGraph()
    if syn_df.empty:
        return G

    use = syn_df[syn_df["score"] >= min_score].copy()
    for _, r in use.iterrows():
        a = int(r["pre_id"])
        b = int(r["post_id"])
        if a == b:
            continue
        if G.has_edge(a, b):
            G[a][b]["weight"] += 1.0
        else:
            G.add_edge(a, b, weight=1.0)
    return G


def graph_stats(G: nx.DiGraph) -> pd.DataFrame:
    if G.number_of_nodes() == 0:
        return pd.DataFrame(columns=["node", "in_degree", "out_degree", "pagerank"])

    indeg = dict(G.in_degree(weight="weight"))
    outdeg = dict(G.out_degree(weight="weight"))
    pr = nx.pagerank(G, weight="weight") if G.number_of_edges() > 0 else {n: 0.0 for n in G.nodes()}

    df = pd.DataFrame(
        {
            "node": list(G.nodes()),
            "in_degree": [indeg.get(n, 0.0) for n in G.nodes()],
            "out_degree": [outdeg.get(n, 0.0) for n in G.nodes()],
            "pagerank": [pr.get(n, 0.0) for n in G.nodes()],
        }
    ).sort_values("pagerank", ascending=False)

    return df.reset_index(drop=True)


# 7) Figures

def plot_overview(vol_raw_u8: np.ndarray,
                  vol_proc_f: np.ndarray,
                  vol_aligned: np.ndarray,
                  labels: np.ndarray,
                  syn_df: pd.DataFrame,
                  out_fig_dir: Path) -> None:
    Z = vol_raw_u8.shape[0]
    mid = Z // 2

    raw = vol_raw_u8[mid]
    proc = vol_proc_f[mid]
    ali = vol_aligned[mid]
    lbl = labels[mid]

    fig, ax = plt.subplots(2, 2, figsize=(10, 10))
    ax[0, 0].imshow(raw, cmap="gray")
    ax[0, 0].set_title("Raw EM slice (mid-Z)")
    ax[0, 0].axis("off")

    ax[0, 1].imshow(proc, cmap="gray")
    ax[0, 1].set_title("Preprocessed slice (CLAHE + denoise)")
    ax[0, 1].axis("off")

    ax[1, 0].imshow(ali, cmap="gray")
    ax[1, 0].set_title("Aligned slice (translation)")
    ax[1, 0].axis("off")

    ax[1, 1].imshow(lbl, cmap="nipy_spectral")
    ax[1, 1].set_title(f"Neurite segments (IDs: {int(lbl.max())})")
    ax[1, 1].axis("off")

    save_figure(
        fig,
        out_fig_dir / "Figure1_EM_preprocess_align_segment.png",
        out_fig_dir / "Figure1_EM_preprocess_align_segment.pdf",
    )

    fig2, ax2 = plt.subplots(1, 1, figsize=(6, 6))
    ax2.imshow(ali, cmap="gray")
    if not syn_df.empty:
        show = syn_df[syn_df["z"] == mid]
        ax2.scatter(show["x"], show["y"], s=10)
        ax2.set_title(f"Synapse candidates (mid-Z), n={len(show)}")
    else:
        ax2.set_title("Synapse candidates (none found in mid-Z)")
    ax2.axis("off")

    save_figure(
        fig2,
        out_fig_dir / "Figure2_synapse_candidates_overlay.png",
        out_fig_dir / "Figure2_synapse_candidates_overlay.pdf",
    )


def plot_graph(G: nx.DiGraph, stats_df: pd.DataFrame, out_fig_dir: Path) -> None:
    fig, ax = plt.subplots(1, 1, figsize=(8, 7))
    ax.axis("off")

    if G.number_of_nodes() == 0:
        ax.set_title("Connectivity graph (empty)")
        save_figure(fig, out_fig_dir / "Figure3_graph.png", out_fig_dir / "Figure3_graph.pdf")
        return

    pos = nx.spring_layout(G, seed=7, k=0.8 / np.sqrt(max(G.number_of_nodes(), 1)))
    weights = np.array([G[u][v]["weight"] for u, v in G.edges()])
    widths = 0.5 + 2.0 * (weights / (weights.max() + 1e-6))

    nx.draw_networkx_edges(G, pos, ax=ax, width=widths, alpha=0.6, arrows=True, arrowsize=10)
    nx.draw_networkx_nodes(G, pos, ax=ax, node_size=200, alpha=0.9)

    top = stats_df.head(10)["node"].tolist() if not stats_df.empty else []
    labels = {n: str(n) for n in top}
    nx.draw_networkx_labels(G, pos, labels=labels, font_size=8, ax=ax)

    ax.set_title(f"Connectivity graph, nodes={G.number_of_nodes()}, edges={G.number_of_edges()}")
    save_figure(fig, out_fig_dir / "Figure3_graph.png", out_fig_dir / "Figure3_graph.pdf")


def plot_alignment_shifts(shifts: List[Tuple[float, float]], out_fig_dir: Path) -> None:
    if not shifts:
        return

    dy = [s[0] for s in shifts]
    dx = [s[1] for s in shifts]
    z = np.arange(len(shifts))

    fig, ax = plt.subplots(1, 1, figsize=(8, 4))
    ax.plot(z, dy, label="dy")
    ax.plot(z, dx, label="dx")
    ax.set_xlabel("Slice (Z)")
    ax.set_ylabel("Estimated shift (pixels)")
    ax.set_title("Estimated per-slice alignment shifts")
    ax.legend()

    save_figure(fig, out_fig_dir / "Figure4_alignment_shifts.png", out_fig_dir / "Figure4_alignment_shifts.pdf")


# CLI

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--outdir", type=str, default="outputs", help="Output directory")
    p.add_argument("--dataset", type=str, default="bossdb://mulcahy2022/1h_L1/em", help="BossDB dataset URI")
    p.add_argument("--preset", type=str, default="none",
                   choices=["none", "safe8gb", "medium"],
                   help="Convenience preset for cutout size")
    p.add_argument("--cutout", type=int, nargs=3, default=[256, 256, 64], metavar=("X", "Y", "Z"),
                   help="Cutout size in voxels (X Y Z), ignored if preset != none")
    p.add_argument("--save_minimal", action="store_true",
                   help="Save only aligned volume, labels, tables, figures, keeps disk usage low")
    p.add_argument("--aligned_dtype", type=str, default="float32", choices=["float16", "float32"],
                   help="Storage dtype for aligned volume, float16 is smaller")
    p.add_argument("--no-sim-misalignment", action="store_true",
                   help="Disable misalignment simulation step")
    p.add_argument("--seed", type=int, default=7, help="Random seed")

    # Jupyter passes extra args like "-f <kernel.json>", ignore them
    args, _ = p.parse_known_args()
    return args


def preset_cutout(preset: str) -> Tuple[int, int, int]:
    if preset == "safe8gb":
        return (192, 192, 48)
    if preset == "medium":
        return (256, 256, 48)
    return None


def main():
    args = parse_args()

    cutout_xyz = tuple(int(x) for x in args.cutout)
    if args.preset != "none":
        cutout_xyz = preset_cutout(args.preset)

    cfg = RunConfig(
        dataset_uri=args.dataset,
        outdir=Path(args.outdir),
        cutout_xyz=cutout_xyz,
        seed=int(args.seed),
        simulate_misalignment=(not args.no_sim_misalignment),
        save_minimal=bool(args.save_minimal),
        aligned_dtype=str(args.aligned_dtype),
    )

    paths = ensure_dirs(cfg.outdir)
    db_path = paths["runs"] / "pipeline_runs.sqlite"
    init_db(db_path)

    run_id = f"run_{cfg.seed}_{cfg.cutout_xyz[0]}x{cfg.cutout_xyz[1]}x{cfg.cutout_xyz[2]}_{cfg.aligned_dtype}"
    meta = {
        "run_id": run_id,
        "dataset_uri": cfg.dataset_uri,
        "cutout_xyz": cfg.cutout_xyz,
        "simulate_misalignment": cfg.simulate_misalignment,
        "save_minimal": cfg.save_minimal,
        "aligned_dtype": cfg.aligned_dtype,
        "env_OMP_NUM_THREADS": os.environ.get("OMP_NUM_THREADS", ""),
        "env_MKL_NUM_THREADS": os.environ.get("MKL_NUM_THREADS", ""),
    }
    save_json(paths["runs"] / f"{run_id}_config.json", meta)

    # sanity estimates
    raw_bytes = estimate_volume_bytes(cfg.cutout_xyz, "uint8")
    pre_bytes = estimate_volume_bytes(cfg.cutout_xyz, "float32")
    ali_bytes = estimate_volume_bytes(cfg.cutout_xyz, cfg.aligned_dtype)
    print("\nPlanned run settings")
    print(f"Dataset URI, {cfg.dataset_uri}")
    print(f"Cutout (X,Y,Z), {cfg.cutout_xyz}")
    print(f"Aligned dtype, {cfg.aligned_dtype}")
    print(f"save_minimal, {cfg.save_minimal}")
    print("\nApproximate single-volume sizes")
    print(f"Raw uint8 volume, {sizeof_gb(raw_bytes):.3f} GB")
    print(f"Preprocessed float32 volume, {sizeof_gb(pre_bytes):.3f} GB")
    print(f"Aligned {cfg.aligned_dtype} volume, {sizeof_gb(ali_bytes):.3f} GB\n")

    # 1) Download
    print(f"[1/8] Downloading cutout from: {cfg.dataset_uri}")
    vol_u8 = download_bossdb_cutout(cfg.dataset_uri, cfg.cutout_xyz, cfg.seed)
    if not cfg.save_minimal:
        np.save(paths["data"] / "em_cutout_u8_zyx.npy", vol_u8)

    # 2) Preprocess
    print("[2/8] Preprocessing (denoise + CLAHE)")
    vol_f = preprocess_volume(vol_u8)
    if not cfg.save_minimal:
        np.save(paths["data"] / "em_preprocessed_f32_zyx.npy", vol_f.astype(np.float32))

    # 3) Align
    print("[3/8] Alignment")
    vol_for_align = vol_f.copy()
    if cfg.simulate_misalignment:
        print("Simulating misalignment, then correcting it")
        vol_for_align = simulate_misalignment(vol_for_align, cfg.seed)

    aligned, shifts = align_slices_translation(vol_for_align)

    # store aligned in chosen dtype
    if cfg.aligned_dtype == "float16":
        aligned_store = aligned.astype(np.float16)
    else:
        aligned_store = aligned.astype(np.float32)

    np.save(paths["data"] / f"em_aligned_{cfg.aligned_dtype}_zyx.npy", aligned_store)
    save_json(paths["tables"] / "alignment_shifts.json", {"shifts": shifts})

    # 4) Segmentation
    print("[4/8] Segmentation (watershed baseline + 3D stitching)")
    # use float32 for computation to avoid numerical issues, even if stored float16
    aligned_f32 = aligned.astype(np.float32, copy=False)
    labels = segment_neurites_watershed(aligned_f32)
    np.save(paths["data"] / "neurite_labels_i32_zyx.npy", labels)

    # 5) Synapse candidates
    print("[5/8] Synapse candidate detection (heuristic)")
    syn_df = detect_synapse_candidates(aligned_f32, labels)
    syn_df.to_csv(paths["tables"] / "synapse_candidates.csv", index=False)

    # 6) Graph
    print("[6/8] Build wiring graph")
    G = build_graph(syn_df, min_score=0.10)
    stats_df = graph_stats(G)
    stats_df.to_csv(paths["tables"] / "graph_node_stats.csv", index=False)

    # 7) Figures
    print("[7/8] Saving figures")
    # For figures, use the float32 aligned slice for display quality
    plot_overview(vol_u8, vol_f, aligned_f32, labels, syn_df, paths["figures"])
    plot_graph(G, stats_df, paths["figures"])
    plot_alignment_shifts(shifts, paths["figures"])

    # 8) Log run
    print("[8/8] Logging run")
    log_run(db_path, run_id, cfg, meta)

    # If save_minimal, free memory aggressively
    if cfg.save_minimal:
        del vol_u8, vol_f, vol_for_align, aligned, aligned_store, aligned_f32

    print("\nDone.")
    print(f"Outputs written to, {cfg.outdir.resolve()}")
    print(f"Figures, {paths['figures'].resolve()}")
    print(f"Tables, {paths['tables'].resolve()}")
    print(f"Aligned volume saved as, {paths['data'] / f'em_aligned_{cfg.aligned_dtype}_zyx.npy'}")


if __name__ == "__main__":
    main()


Planned run settings
Dataset URI, bossdb://mulcahy2022/1h_L1/em
Cutout (X,Y,Z), (256, 256, 64)
Aligned dtype, float32
save_minimal, False

Approximate single-volume sizes
Raw uint8 volume, 0.004 GB
Preprocessed float32 volume, 0.016 GB
Aligned float32 volume, 0.016 GB

[1/8] Downloading cutout from: bossdb://mulcahy2022/1h_L1/em




[2/8] Preprocessing (denoise + CLAHE)
[3/8] Alignment
Simulating misalignment, then correcting it
[4/8] Segmentation (watershed baseline + 3D stitching)


Segmenting slices: 100%|██████████| 64/64 [00:00<00:00, 70.71it/s]
Stitching 3D IDs: 100%|██████████| 63/63 [00:00<00:00, 92.60it/s]


[5/8] Synapse candidate detection (heuristic)


Detecting synapse candidates: 100%|██████████| 64/64 [00:05<00:00, 11.76it/s]


[6/8] Build wiring graph
[7/8] Saving figures
[8/8] Logging run

Done.
Outputs written to, /Users/petalc01/Connectomics C.elegans EM cutout/outputs
Figures, /Users/petalc01/Connectomics C.elegans EM cutout/outputs/figures
Tables, /Users/petalc01/Connectomics C.elegans EM cutout/outputs/tables
Aligned volume saved as, outputs/data/em_aligned_float32_zyx.npy
