
# Main Text Figures/Tables (One-Click)

This notebook generates the **6 main-text items** for the manuscript:

1. **Figure 1** workflow (Input -> Inference -> Bootstrap -> Report)
2. **Figure 2** 4-model curves (500Mb): True vs Rust vs C
3. **Table 1** RMSE(log10 Ne) on 4 models
4. **Figure 3** single-thread Runtime / Peak RSS (10 repeats, mean ± SD)
5. **Table 2** Figure 3 numeric summary
6. **Figure 4** bootstrap dual panel (zigzag + bottleneck, Rust vs C)

Outputs are written to:
- `experiment/runs/main_text/figures` (`.png`, `.svg`, `.pdf`)
- `experiment/runs/main_text/tables` (`.csv`, `.tsv`, `.md`)
- `experiment/runs/main_text/logs` (command logs)

> Fairness note: Rust commands in this notebook explicitly use `--smooth-lambda 0` (except S1 smooth ablation B).


In [None]:

from __future__ import annotations

import json
import math
import os
import platform
import shlex
import subprocess
import sys
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.patheffects as pe
from matplotlib.ticker import FuncFormatter, MaxNLocator, ScalarFormatter
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import numpy as np
import pandas as pd

try:
    from IPython.display import display
except Exception:
    def display(x):
        print(x)

try:
    import psutil
    HAS_PSUTIL = True
except Exception:
    HAS_PSUTIL = False


COLORS = {
    "true": "#3B82F6",
    "rust": "#E15759",
    "rust_alt": "#F28E2B",
    "c": "#59A14F",
    "mhs": "#4E79A7",
    "vcf": "#B6992D",
    "grid": "#DCE3EC",
    "axis": "#425466",
    "text": "#1F2A37",
}


def setup_publication_style():
    plt.rcParams.update(
        {
            "figure.facecolor": "white",
            "axes.facecolor": "#F9FBFD",
            "savefig.facecolor": "white",
            "axes.edgecolor": "#B8C4D6",
            "axes.linewidth": 1.0,
            "axes.labelcolor": COLORS["axis"],
            "axes.titlecolor": COLORS["text"],
            "axes.titleweight": "semibold",
            "grid.color": COLORS["grid"],
            "grid.alpha": 0.9,
            "grid.linestyle": "-",
            "grid.linewidth": 0.8,
            "legend.frameon": False,
            "legend.fontsize": 10,
            "legend.title_fontsize": 10,
            "font.size": 10.5,
            "font.family": "DejaVu Sans",
            "xtick.color": COLORS["axis"],
            "ytick.color": COLORS["axis"],
            "xtick.direction": "out",
            "ytick.direction": "out",
            "xtick.major.width": 1.0,
            "ytick.major.width": 1.0,
            "xtick.minor.width": 0.8,
            "ytick.minor.width": 0.8,
            "pdf.fonttype": 42,
            "ps.fonttype": 42,
            "svg.fonttype": "none",
        }
    )


def panel_labels(axes, labels="ABCDEFGHIJKLMNOPQRSTUVWXYZ"):
    for i, ax in enumerate(axes):
        ax.text(
            0.01,
            0.99,
            labels[i],
            transform=ax.transAxes,
            ha="left",
            va="top",
            fontsize=12,
            fontweight="bold",
            color=COLORS["text"],
            bbox={"boxstyle": "round,pad=0.18", "fc": "white", "ec": "#D0D8E5", "lw": 0.8},
        )


def stylize_axis(ax, *, xlog: bool = False):
    if xlog:
        ax.set_xscale("log")
    ax.grid(True, which="major")
    ax.grid(True, which="minor", alpha=0.35, linewidth=0.55)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["left"].set_color("#9AA7B6")
    ax.spines["bottom"].set_color("#9AA7B6")
    ax.tick_params(axis="both", which="major", labelsize=10)
    ax.yaxis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:,.0f}"))


def find_repo_root(start: Path) -> Path:
    for p in [start, *start.parents]:
        if (p / "Cargo.toml").exists() and (p / "src").exists():
            return p
    raise RuntimeError(f"Cannot locate psmc-rs root from {start}")


ROOT = find_repo_root(Path.cwd().resolve())
RUN_DIR = ROOT / "experiment" / "runs" / "main_text"
INPUT_DIR = RUN_DIR / "inputs"
OUTPUT_DIR = RUN_DIR / "outputs"
PERF_DIR = RUN_DIR / "perf"
BOOT_DIR = RUN_DIR / "bootstrap"
FIG_DIR = RUN_DIR / "figures"
TABLE_DIR = RUN_DIR / "tables"
LOG_DIR = RUN_DIR / "logs"

for d in (RUN_DIR, INPUT_DIR, OUTPUT_DIR, PERF_DIR, BOOT_DIR, FIG_DIR, TABLE_DIR, LOG_DIR):
    d.mkdir(parents=True, exist_ok=True)

PSMC_RS_BIN = Path(os.environ.get("PSMC_RS_BIN", str(ROOT / "target" / "release" / "psmc-rs")))
C_PSMC_BIN = Path(os.environ.get("C_PSMC_BIN", str(ROOT.parent / "psmc-master" / "psmc")))
C_UTILS_DIR = Path(os.environ.get("C_UTILS_DIR", str(ROOT.parent / "psmc-master" / "utils")))
SPLITFA_BIN = Path(os.environ.get("SPLITFA_BIN", str(C_UTILS_DIR / "splitfa")))
SIM_SCRIPT = Path(os.environ.get("SIM_SCRIPT", str(ROOT / "experiment" / "scripts" / "simulate_msprime_to_psmcfa.py")))

MU = float(os.environ.get("MU", "2.5e-8"))
GEN_YEARS = float(os.environ.get("GEN_YEARS", "25"))
BIN_SIZE = int(os.environ.get("BIN_SIZE", "100"))
N_ITER = int(os.environ.get("N_ITER", "20"))
T_MAX = float(os.environ.get("T_MAX", "15"))
N_STEPS = int(os.environ.get("N_STEPS", "64"))
PATTERN = os.environ.get("PATTERN", "4+25*2+4+6")
RHO_T_RATIO = int(os.environ.get("RHO_T_RATIO", "5"))
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "300000"))
RUST_THREADS = int(os.environ.get("RUST_THREADS", "1"))
ALPHA_CACHE_MB = int(os.environ.get("PSMC_ALPHA_CACHE_MB", "2048"))

SIM_LENGTH_BP = int(os.environ.get("SIM_LENGTH_BP", "500000000"))
SIM_WINDOW_BP = int(os.environ.get("SIM_WINDOW_BP", "100"))
SIM_MUTATION = float(os.environ.get("SIM_MUTATION", str(MU)))
SIM_RECOMB = os.environ.get("SIM_RECOMB", "").strip() or None

PERF_REPEATS = int(os.environ.get("PERF_REPEATS", "10"))
BOOTSTRAP_REPS = int(os.environ.get("BOOTSTRAP_REPS", "100"))
BOOTSTRAP_ITERS = int(os.environ.get("BOOTSTRAP_ITERS", str(N_ITER)))
BOOTSTRAP_BLOCK_SIZE = int(os.environ.get("BOOTSTRAP_BLOCK_SIZE", "50000"))
BOOTSTRAP_SEED = int(os.environ.get("BOOTSTRAP_SEED", "42"))
BOOT_MODELS = ["zigzag", "bottleneck"]

FORCE_SIM = os.environ.get("FORCE_SIM", "0") == "1"
FORCE_RUN = os.environ.get("FORCE_RUN", "0") == "1"
FORCE_PERF = os.environ.get("FORCE_PERF", "0") == "1"
FORCE_BOOTSTRAP = os.environ.get("FORCE_BOOTSTRAP", "0") == "1"

MODELS: Dict[str, Dict] = {
    "constant": {
        "title": "Constant",
        "sim_model": "constant",
        "sim_ne": 10_000.0,
        "sim_seed": 42,
        "true_kind": "constant",
        "true_params": {"ne": 10_000.0},
    },
    "bottleneck": {
        "title": "Bottleneck",
        "sim_model": "bottleneck",
        "sim_ne": 20_000.0,
        "sim_seed": 43,
        "true_kind": "ms_piecewise",
        "true_params": {
            "ne0": 20_000.0,
            "events": [(0.01, 0.05), (0.015, 0.5), (0.05, 0.25), (0.5, 0.5)],
        },
    },
    "expansion": {
        "title": "Expansion",
        "sim_model": "expansion",
        "sim_ne": 10_000.0,
        "sim_seed": 44,
        "true_kind": "ms_piecewise",
        "true_params": {
            "ne0": 10_000.0,
            "events": [(0.01, 0.1), (0.06, 1.0), (0.2, 0.5), (1.0, 1.0), (2.0, 2.0)],
        },
    },
    "zigzag": {
        "title": "Zigzag",
        "sim_model": "sim2_zigzag",
        "sim_ne": 1_000.0,
        "sim_seed": 45,
        "true_kind": "ms_piecewise",
        "true_params": {
            "ne0": 1_000.0,
            "events": [(0.1, 5.0), (0.6, 20.0), (2.0, 5.0), (10.0, 10.0), (20.0, 5.0)],
        },
    },
}
MODEL_ORDER = ["constant", "bottleneck", "expansion", "zigzag"]

setup_publication_style()

print(f"ROOT={ROOT}")
print(f"RUN_DIR={RUN_DIR}")
print(f"PSMC_RS_BIN={PSMC_RS_BIN}")
print(f"C_PSMC_BIN={C_PSMC_BIN}")
print(f"SPLITFA_BIN={SPLITFA_BIN}")
print(f"SIM_SCRIPT={SIM_SCRIPT}")
print(f"SIM_LENGTH_BP={SIM_LENGTH_BP:,}, SIM_WINDOW_BP={SIM_WINDOW_BP}")
print(f"N_ITER={N_ITER}, PERF_REPEATS={PERF_REPEATS}, BOOTSTRAP_REPS={BOOTSTRAP_REPS}")
print(f"FORCE_SIM={FORCE_SIM}, FORCE_RUN={FORCE_RUN}, FORCE_PERF={FORCE_PERF}, FORCE_BOOTSTRAP={FORCE_BOOTSTRAP}")


In [None]:

def save_figure_multi(fig, stem: str):
    out_paths = []
    for ext in ("png", "svg", "pdf"):
        p = FIG_DIR / f"{stem}.{ext}"
        fig.savefig(
            p,
            dpi=320 if ext == "png" else None,
            bbox_inches="tight",
            metadata={"Creator": "psmc-rs experiment/notebooks/main_text_6items.ipynb"},
        )
        out_paths.append(p)
    print("saved figure:", ", ".join(str(p) for p in out_paths))


def save_table_multi(df: pd.DataFrame, stem: str):
    csv_p = TABLE_DIR / f"{stem}.csv"
    tsv_p = TABLE_DIR / f"{stem}.tsv"
    md_p = TABLE_DIR / f"{stem}.md"
    df.to_csv(csv_p, index=False)
    df.to_csv(tsv_p, index=False, sep="	")
    md_p.write_text(df.to_markdown(index=False) + "\n", encoding="utf-8")
    print(f"saved table: {csv_p}, {tsv_p}, {md_p}")


def run_cmd(cmd: List[str], cwd: Optional[Path] = None, check: bool = True, env: Optional[dict] = None, stdout_path: Optional[Path] = None):
    t0 = time.perf_counter()
    stdout_fh = None
    if stdout_path is not None:
        stdout_path.parent.mkdir(parents=True, exist_ok=True)
        stdout_fh = open(stdout_path, "w", encoding="utf-8")

    proc = subprocess.Popen(
        cmd,
        cwd=str(cwd) if cwd is not None else None,
        stdout=stdout_fh if stdout_fh is not None else subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True,
        env=env,
    )

    peak_rss_mb = float("nan")
    ps_proc = None
    peak_bytes = 0
    if HAS_PSUTIL:
        try:
            ps_proc = psutil.Process(proc.pid)
        except Exception:
            ps_proc = None

    while proc.poll() is None:
        if ps_proc is not None:
            try:
                rss = ps_proc.memory_info().rss
                for ch in ps_proc.children(recursive=True):
                    try:
                        rss += ch.memory_info().rss
                    except Exception:
                        pass
                if rss > peak_bytes:
                    peak_bytes = rss
            except Exception:
                pass
        time.sleep(0.02)

    out, err = proc.communicate() if stdout_fh is None else ("", proc.stderr.read())
    if stdout_fh is not None:
        stdout_fh.close()

    dt = time.perf_counter() - t0
    if HAS_PSUTIL and ps_proc is not None:
        peak_rss_mb = peak_bytes / (1024**2)

    rec = {
        "cmd": " ".join(shlex.quote(x) for x in cmd),
        "returncode": proc.returncode,
        "stdout": out,
        "stderr": err,
        "wall_sec": dt,
        "peak_rss_mb": peak_rss_mb,
    }

    log_path = LOG_DIR / "commands.jsonl"
    with log_path.open("a", encoding="utf-8") as f:
        f.write(json.dumps(rec, ensure_ascii=False) + "\n")

    if check and proc.returncode != 0:
        print(rec["cmd"])
        print("--- stdout ---")
        print(out)
        print("--- stderr ---")
        print(err)
        raise RuntimeError(f"command failed: rc={proc.returncode}")
    return rec


def ensure_tools() -> bool:
    if not SIM_SCRIPT.exists():
        raise FileNotFoundError(f"simulation script not found: {SIM_SCRIPT}")

    if not PSMC_RS_BIN.exists():
        print("[build] cargo build --release")
        run_cmd(["cargo", "build", "--release"], cwd=ROOT)
    if not PSMC_RS_BIN.exists():
        raise FileNotFoundError(f"psmc-rs binary not found: {PSMC_RS_BIN}")

    has_c = C_PSMC_BIN.exists()
    if not has_c:
        print(f"[warn] C binary missing: {C_PSMC_BIN}")
    return has_c


def ensure_splitfa() -> Path:
    if SPLITFA_BIN.exists():
        return SPLITFA_BIN
    splitfa_c = C_UTILS_DIR / "splitfa.c"
    if not splitfa_c.exists():
        raise FileNotFoundError(f"splitfa.c not found: {splitfa_c}")
    print(f"[build] cc -O3 -I.. -o {SPLITFA_BIN} {splitfa_c} -lm -lz")
    run_cmd(["cc", "-O3", "-I..", "-o", str(SPLITFA_BIN), str(splitfa_c), "-lm", "-lz"], cwd=C_UTILS_DIR)
    if not SPLITFA_BIN.exists():
        raise RuntimeError("failed to build splitfa")
    return SPLITFA_BIN


def model_input_path(model_key: str) -> Path:
    return INPUT_DIR / f"{model_key}.psmcfa"


def rust_main_json_path(model_key: str) -> Path:
    return OUTPUT_DIR / f"{model_key}.rust.main.json"


def c_main_psmc_path(model_key: str) -> Path:
    return OUTPUT_DIR / f"{model_key}.c.main.psmc"


def simulate_inputs(force: bool = False):
    for key in MODEL_ORDER:
        spec = MODELS[key]
        out_path = model_input_path(key)
        if out_path.exists() and not force:
            continue
        cmd = [
            sys.executable,
            str(SIM_SCRIPT),
            "--model", spec["sim_model"],
            "--out", str(out_path),
            "--length", str(SIM_LENGTH_BP),
            "--window", str(SIM_WINDOW_BP),
            "--mutation", str(SIM_MUTATION),
            "--seed", str(spec["sim_seed"]),
        ]
        if SIM_RECOMB:
            cmd += ["--recomb", str(SIM_RECOMB)]
        if spec.get("sim_ne"):
            cmd += ["--ne", str(spec["sim_ne"])]
        print("[simulate]", key)
        run_cmd(cmd, cwd=ROOT)


def run_rust(input_psmcfa: Path, output_json: Path, n_iter: int = N_ITER, extra: Optional[List[str]] = None):
    cmd = [
        str(PSMC_RS_BIN),
        str(input_psmcfa),
        str(output_json),
        str(n_iter),
        "--t-max", str(T_MAX),
        "--n-steps", str(N_STEPS),
        "--pattern", PATTERN,
        "--mu", str(MU),
        "--smooth-lambda", "0",
        "--batch-size", str(BATCH_SIZE),
        "--threads", str(RUST_THREADS),
        "--no-progress",
    ]
    if extra:
        cmd += extra
    env = os.environ.copy()
    env["PSMC_ALPHA_CACHE_MB"] = str(ALPHA_CACHE_MB)
    return run_cmd(cmd, cwd=ROOT, env=env)


def run_c(input_psmcfa: Path, output_psmc: Path, n_iter: int = N_ITER, bootstrap: bool = False):
    cmd = [
        str(C_PSMC_BIN),
        f"-N{n_iter}",
        f"-t{T_MAX}",
        f"-r{RHO_T_RATIO}",
        "-p", PATTERN,
    ]
    if bootstrap:
        cmd.append("-b")
    cmd += ["-o", str(output_psmc), str(input_psmcfa)]
    return run_cmd(cmd, cwd=ROOT)


In [None]:

def parse_pattern_spec(pattern):
    if pattern is None:
        return None
    out = []
    for part in str(pattern).split("+"):
        part = part.strip()
        if not part:
            continue
        if "*" in part:
            a, b = part.split("*", 1)
            nr = int(a.strip())
            gl = int(b.strip())
        else:
            nr = 1
            gl = int(part)
        out.append((nr, gl))
    return out if out else None


def parse_pattern_spec_legacy(pattern):
    if pattern is None:
        return None
    out = []
    for part in str(pattern).split("+"):
        part = part.strip()
        if not part:
            continue
        if "*" in part:
            a, b = part.split("*", 1)
            ts = int(a.strip())
            gs = int(b.strip())
        else:
            ts = int(part)
            gs = 1
        out.append((ts, gs))
    return out if out else None


def expand_lam(lam_grouped, n_steps, pattern_spec, pattern_raw=None):
    lam_grouped = list(map(float, lam_grouped))
    if pattern_spec is None:
        if len(lam_grouped) != n_steps + 1:
            raise ValueError(f"lam length {len(lam_grouped)} != n_steps+1 ({n_steps+1})")
        return lam_grouped

    expected_c = sum(nr for nr, _ in pattern_spec)
    if len(lam_grouped) == expected_c:
        lam = []
        idx = 0
        for nr, gl in pattern_spec:
            for _ in range(nr):
                lam.extend([lam_grouped[idx]] * gl)
                idx += 1
        if len(lam) != n_steps + 1:
            raise ValueError(f"expanded lam length {len(lam)} != n_steps+1 ({n_steps+1})")
        return lam

    legacy = parse_pattern_spec_legacy(pattern_raw)
    expected_legacy = sum(ts for ts, _ in legacy) + 1 if legacy is not None else None
    if expected_legacy is not None and len(lam_grouped) == expected_legacy:
        lam = []
        idx = 0
        for ts, gs in legacy:
            for _ in range(ts):
                lam.extend([lam_grouped[idx]] * gs)
                idx += 1
        lam.append(lam_grouped[-1])
        if len(lam) != n_steps + 1:
            raise ValueError(f"expanded legacy lam length {len(lam)} != n_steps+1 ({n_steps+1})")
        return lam

    raise ValueError("grouped lam length mismatch with pattern")


def compute_t_grid(n_steps: int, t_max: float, alpha: float = 0.1):
    beta = math.log(1 + t_max / alpha) / n_steps
    t = [alpha * (math.exp(beta * k) - 1.0) for k in range(n_steps)]
    t.append(float(t_max))
    return np.asarray(t, dtype=float)


def curve_from_json(path: Path):
    params = json.loads(path.read_text())
    theta = float(params["theta"])
    mu = float(params.get("mu", MU))
    n_steps = int(params["n_steps"])
    t_max = float(params["t_max"])
    pattern_raw = params.get("pattern")
    pattern_spec = parse_pattern_spec(pattern_raw)
    lam = np.asarray(expand_lam(params["lam"], n_steps, pattern_spec, pattern_raw), dtype=float)

    t = compute_t_grid(n_steps, t_max)
    n0 = theta / (4.0 * mu * float(BIN_SIZE))
    x = t * 2.0 * float(GEN_YEARS) * n0
    y = lam * n0
    x = np.append(x, 1e8)
    y = np.append(y, y[-1])
    return np.asarray(x, dtype=float), np.asarray(y, dtype=float)


def load_c_curve(psmc_path: Path):
    if not psmc_path.exists():
        return None

    lines = psmc_path.read_text().splitlines()
    blocks = []
    cur = None
    for ln in lines:
        if ln.startswith("RD	"):
            if cur is not None:
                blocks.append(cur)
            cur = {"tr": None, "pa": None, "rs": []}
        elif cur is not None and ln.startswith("TR	"):
            _, th, rh = ln.split("	")[:3]
            cur["tr"] = (float(th), float(rh))
        elif cur is not None and ln.startswith("PA	"):
            cur["pa"] = ln
        elif cur is not None and ln.startswith("RS	"):
            t = ln.split("	")
            cur["rs"].append((int(t[1]), float(t[2]), float(t[3])))
    if cur is not None:
        blocks.append(cur)

    best = None
    for b in blocks[::-1]:
        if b["pa"] and b["tr"] is not None and b["rs"]:
            best = b
            break
    if best is None:
        return None

    theta = best["tr"][0]
    n0 = theta / (4.0 * float(MU) * float(BIN_SIZE))
    xs = []
    ys = []
    for _, tk, lk in best["rs"]:
        xs.append(2.0 * n0 * tk * float(GEN_YEARS))
        ys.append(n0 * lk)
    xs.append(1e8)
    ys.append(ys[-1])
    return np.asarray(xs, dtype=float), np.asarray(ys, dtype=float)


def true_curve_constant(ne: float):
    return np.asarray([1e3, 1e8], dtype=float), np.asarray([ne, ne], dtype=float)


def true_curve_ms_piecewise(ne0: float, events: List[Tuple[float, float]]):
    xs = [1e3]
    ys = [ne0]
    for t_4n0, ratio in sorted(events, key=lambda x: x[0]):
        t_gen = t_4n0 * 4.0 * ne0
        xs.append(max(1e3, t_gen * GEN_YEARS))
        ys.append(ratio * ne0)
    xs.append(1e8)
    ys.append(ys[-1])
    return np.asarray(xs, dtype=float), np.asarray(ys, dtype=float)


def true_curve_for_model(model_key: str):
    spec = MODELS[model_key]
    if spec["true_kind"] == "constant":
        return true_curve_constant(**spec["true_params"])
    if spec["true_kind"] == "ms_piecewise":
        return true_curve_ms_piecewise(**spec["true_params"])
    raise ValueError("unknown true_kind")


def step_value(xs: np.ndarray, ys: np.ndarray, xq: float) -> float:
    idx = int(np.searchsorted(xs, xq, side="right") - 1)
    idx = max(0, min(idx, len(ys) - 1))
    return float(ys[idx])


def rmse_log10(true_curve, est_curve, x_min=1e3, x_max=1e8, n=400):
    tx, ty = true_curve
    ex, ey = est_curve
    grid = np.geomspace(x_min, x_max, n)
    tv = np.asarray([max(step_value(tx, ty, x), 1e-12) for x in grid], dtype=float)
    ev = np.asarray([max(step_value(ex, ey, x), 1e-12) for x in grid], dtype=float)
    return float(np.sqrt(np.mean((np.log10(tv) - np.log10(ev)) ** 2)))


In [None]:

def make_figure1_workflow():
    fig, ax = plt.subplots(figsize=(12.6, 4.0), dpi=220)
    ax.axis("off")

    boxes = [
        (0.045, 0.20, 0.21, 0.58, "Input", "psmcfa / mhs / vcf"),
        (0.295, 0.20, 0.21, 0.58, "Inference", "E-step + M-step (EM)"),
        (0.545, 0.20, 0.21, 0.58, "Bootstrap", "block resampling + CI"),
        (0.795, 0.20, 0.16, 0.58, "Report", "JSON + HTML + TMRCA"),
    ]

    for x, y, w, h, title, subtitle in boxes:
        rect = mpatches.FancyBboxPatch(
            (x, y), w, h,
            boxstyle="round,pad=0.02,rounding_size=0.025",
            linewidth=1.4,
            edgecolor="#87A0BC",
            facecolor="#F7FAFE",
            transform=ax.transAxes,
        )
        ax.add_patch(rect)
        ax.text(
            x + w / 2,
            y + h * 0.64,
            title,
            ha="center",
            va="center",
            fontsize=12.5,
            weight="bold",
            color=COLORS["text"],
            transform=ax.transAxes,
        )
        ax.text(
            x + w / 2,
            y + h * 0.36,
            subtitle,
            ha="center",
            va="center",
            fontsize=10.3,
            color=COLORS["axis"],
            transform=ax.transAxes,
        )

    for x1, x2 in [(0.255, 0.295), (0.505, 0.545), (0.755, 0.795)]:
        ax.annotate(
            "",
            xy=(x2 - 0.008, 0.49),
            xytext=(x1 + 0.008, 0.49),
            xycoords=ax.transAxes,
            arrowprops=dict(arrowstyle="-|>", lw=2.0, color="#5C7EA7"),
        )

    ax.set_title("Figure 1. PSMC-RS End-to-End Workflow", fontsize=15.5, fontweight="bold", pad=10)
    save_figure_multi(fig, "figure_1_workflow")
    plt.show()


def _step_with_outline(ax, x, y, *, color, lw=2.2, ls="-", label=None, zorder=3, alpha=1.0):
    line = ax.step(x, y, where="post", color=color, lw=lw, ls=ls, label=label, zorder=zorder, alpha=alpha)[0]
    line.set_path_effects([
        pe.Stroke(linewidth=lw + 1.6, foreground="white", alpha=0.92),
        pe.Normal(),
    ])
    return line


def _plot_step_change_points(ax, x, y, color, zorder=4):
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    if len(x) < 3 or len(y) < 3:
        return
    delta = np.abs(np.diff(y))
    idx = np.where(delta > 1e-9)[0]
    if len(idx) == 0:
        return
    xp = x[idx + 1]
    yp = y[idx + 1]
    ax.scatter(
        xp,
        yp,
        s=15,
        facecolor="white",
        edgecolor=color,
        linewidth=0.9,
        alpha=0.86,
        zorder=zorder,
    )


def _compute_log10_curve(curve, x_grid):
    x, y = curve
    vals = np.asarray([max(step_value(x, y, xv), 1e-12) for xv in x_grid], dtype=float)
    return np.log10(vals)


def run_figure2_and_table1(has_c: bool, force: bool = False):
    rows = []
    curves = {}

    for key in MODEL_ORDER:
        inp = model_input_path(key)
        rust_out = rust_main_json_path(key)
        c_out = c_main_psmc_path(key)

        if force or not rust_out.exists():
            print(f"[run rust] {key}")
            run_rust(inp, rust_out, n_iter=N_ITER)

        if has_c and (force or not c_out.exists()):
            print(f"[run c] {key}")
            run_c(inp, c_out, n_iter=N_ITER)

        true_curve = true_curve_for_model(key)
        rust_curve = curve_from_json(rust_out)
        c_curve = load_c_curve(c_out) if has_c else None

        curves[key] = {
            "true": true_curve,
            "rust": rust_curve,
            "c": c_curve,
        }

        row = {
            "model": key,
            "rmse_log10_ne_rust": rmse_log10(true_curve, rust_curve),
            "rmse_log10_ne_c": rmse_log10(true_curve, c_curve) if c_curve is not None else np.nan,
        }
        rows.append(row)

    table1 = pd.DataFrame(rows).sort_values("model")
    save_table_multi(table1, "table_1_rmse")

    fig = plt.figure(figsize=(14.8, 12.4), dpi=220, constrained_layout=True)
    outer = fig.add_gridspec(2, 2, wspace=0.10, hspace=0.18)

    main_axes = []
    diff_axes = []

    for i, key in enumerate(MODEL_ORDER):
        rr, cc = divmod(i, 2)
        inner = outer[rr, cc].subgridspec(2, 1, height_ratios=[4.0, 1.35], hspace=0.05)
        ax = fig.add_subplot(inner[0])
        axd = fig.add_subplot(inner[1], sharex=ax)
        main_axes.append(ax)
        diff_axes.append(axd)

        cset = curves[key]
        tx, ty = cset["true"]
        rx, ry = cset["rust"]

        _step_with_outline(ax, tx, ty, color=COLORS["true"], lw=2.35, ls=(0, (5, 3)), label="True", zorder=1, alpha=0.98)

        y_max = max(np.max(ty), np.max(ry))
        if has_c and cset["c"] is not None:
            cx, cy = cset["c"]
            _step_with_outline(ax, cx, cy, color=COLORS["c"], lw=2.15, label="C", zorder=2, alpha=0.95)
            _plot_step_change_points(ax, cx, cy, COLORS["c"], zorder=3)
            y_max = max(y_max, np.max(cy))

        _step_with_outline(ax, rx, ry, color=COLORS["rust"], lw=2.2, label="Rust", zorder=4, alpha=0.98)
        _plot_step_change_points(ax, rx, ry, COLORS["rust"], zorder=5)

        ax.set_title(MODELS[key]["title"], fontsize=12.6, pad=8)
        stylize_axis(ax, xlog=True)
        ax.set_xlim(1e3, 1e8)
        ax.set_ylim(0, y_max * 1.18)
        ax.set_ylabel("Effective population size (Ne)")
        ax.tick_params(labelbottom=False)

        # Lower strip: log10 differences so overlapped curves remain readable.
        xg = np.geomspace(1e3, 1e8, 900)
        l_true = _compute_log10_curve((tx, ty), xg)
        l_rust = _compute_log10_curve((rx, ry), xg)
        d_rust = l_rust - l_true

        axd.axhline(0.0, color="#6B7280", lw=1.0, ls=(0, (4, 2)), zorder=1)
        axd.fill_between(xg, 0.0, d_rust, where=d_rust >= 0, color=COLORS["rust"], alpha=0.14, linewidth=0)
        axd.fill_between(xg, 0.0, d_rust, where=d_rust < 0, color=COLORS["rust"], alpha=0.10, linewidth=0)
        axd.plot(xg, d_rust, lw=1.9, color=COLORS["rust"], label="Rust - True", zorder=3)

        max_abs = float(np.max(np.abs(d_rust)))
        focus = np.abs(d_rust)

        if has_c and cset["c"] is not None:
            cx, cy = cset["c"]
            l_c = _compute_log10_curve((cx, cy), xg)
            d_c = l_c - l_true
            d_rc = l_rust - l_c
            axd.plot(xg, d_c, lw=1.8, color=COLORS["c"], label="C - True", zorder=3)
            axd.plot(xg, d_rc, lw=1.2, ls=(0, (3, 2)), color="#374151", label="Rust - C", zorder=2)
            max_abs = max(max_abs, float(np.max(np.abs(d_c))), float(np.max(np.abs(d_rc))))
            focus = np.abs(d_rc)

        stylize_axis(axd, xlog=True)
        axd.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f"{y:+.2f}"))
        lim = max(0.03, min(0.65, max_abs * 1.30))
        axd.set_ylim(-lim, lim)
        axd.set_xlabel(f"Years (g={GEN_YEARS}, mu={MU:.1e})")
        axd.set_ylabel("Delta log10(Ne)", fontsize=9)
        axd.tick_params(axis="both", labelsize=8.4)

        # Inset around the highest Rust-vs-C disagreement.
        if np.any(np.isfinite(focus)):
            peak_idx = int(np.nanargmax(focus))
            x_center = float(xg[peak_idx])
            x_left = max(1e3, x_center / 2.4)
            x_right = min(1e8, x_center * 2.4)
            if x_right > x_left * 1.05:
                axins = inset_axes(ax, width="42%", height="39%", loc="lower right", borderpad=0.9)
                _step_with_outline(axins, tx, ty, color=COLORS["true"], lw=1.35, ls=(0, (5, 3)), zorder=1)
                if has_c and cset["c"] is not None:
                    _step_with_outline(axins, cx, cy, color=COLORS["c"], lw=1.25, zorder=2)
                _step_with_outline(axins, rx, ry, color=COLORS["rust"], lw=1.3, zorder=3)
                axins.set_xscale("log")
                axins.set_xlim(x_left, x_right)
                iy = [
                    step_value(tx, ty, x_left), step_value(tx, ty, x_right),
                    step_value(rx, ry, x_left), step_value(rx, ry, x_right),
                ]
                if has_c and cset["c"] is not None:
                    iy += [step_value(cx, cy, x_left), step_value(cx, cy, x_right)]
                ymin = max(0.0, min(iy) * 0.88)
                ymax = max(iy) * 1.12
                if ymax > ymin:
                    axins.set_ylim(ymin, ymax)
                axins.grid(True, which="major", alpha=0.25)
                axins.tick_params(labelleft=False, labelbottom=False, length=0)
                for spine in axins.spines.values():
                    spine.set_edgecolor("#B8C4D6")
                    spine.set_linewidth(0.9)

        rr_row = table1[table1["model"] == key].iloc[0]
        rmse_text = f"RMSE Rust={rr_row['rmse_log10_ne_rust']:.3f}"
        if has_c and not np.isnan(rr_row["rmse_log10_ne_c"]):
            rmse_text += f"\nRMSE C={rr_row['rmse_log10_ne_c']:.3f}"
        ax.text(
            0.985,
            0.975,
            rmse_text,
            transform=ax.transAxes,
            ha="right",
            va="top",
            fontsize=8.1,
            color=COLORS["axis"],
            bbox={"boxstyle": "round,pad=0.22", "fc": "white", "ec": "#D2DAE6", "lw": 0.8, "alpha": 0.93},
        )

    panel_labels(main_axes)
    h_main, l_main = main_axes[0].get_legend_handles_labels()
    fig.legend(h_main, l_main, loc="upper center", ncol=3 if has_c else 2, bbox_to_anchor=(0.5, 1.035))
    if diff_axes:
        h_diff, l_diff = diff_axes[0].get_legend_handles_labels()
        diff_axes[0].legend(h_diff, l_diff, loc="lower left", fontsize=8.0)

    fig.suptitle("Figure 2. 500Mb Demographic Recovery: True vs Rust vs C", fontsize=15.5, fontweight="bold", y=1.055)
    save_figure_multi(fig, "figure_2_true_vs_rust_vs_c")
    plt.show()

    return table1, curves


In [None]:

def run_figure3_and_table2(has_c: bool, force: bool = False):
    rows = []

    for rep in range(1, PERF_REPEATS + 1):
        for key in MODEL_ORDER:
            inp = model_input_path(key)

            rust_out = PERF_DIR / f"{key}.rust.rep{rep:02d}.json"
            if force or not rust_out.exists():
                r = run_rust(inp, rust_out, n_iter=N_ITER)
            else:
                r = {"wall_sec": np.nan, "peak_rss_mb": np.nan, "cmd": "cached", "stderr": "", "stdout": "", "returncode": 0}
            rows.append({
                "model": key,
                "tool": "rust",
                "rep": rep,
                "wall_sec": r["wall_sec"],
                "peak_rss_mb": r["peak_rss_mb"],
            })

            if has_c:
                c_out = PERF_DIR / f"{key}.c.rep{rep:02d}.psmc"
                if force or not c_out.exists():
                    c = run_c(inp, c_out, n_iter=N_ITER)
                else:
                    c = {"wall_sec": np.nan, "peak_rss_mb": np.nan, "cmd": "cached", "stderr": "", "stdout": "", "returncode": 0}
                rows.append({
                    "model": key,
                    "tool": "c",
                    "rep": rep,
                    "wall_sec": c["wall_sec"],
                    "peak_rss_mb": c["peak_rss_mb"],
                })

    perf_raw = pd.DataFrame(rows)
    perf_raw.to_csv(TABLE_DIR / "perf_raw_repeats.csv", index=False)

    perf_summary = (
        perf_raw
        .groupby(["model", "tool"], as_index=False)
        .agg(
            wall_sec_mean=("wall_sec", "mean"),
            wall_sec_std=("wall_sec", "std"),
            peak_rss_mb_mean=("peak_rss_mb", "mean"),
            peak_rss_mb_std=("peak_rss_mb", "std"),
        )
        .sort_values(["model", "tool"])
    )

    save_table_multi(perf_summary, "table_2_runtime_memory")

    fig, axes = plt.subplots(1, 2, figsize=(13.8, 5.2), dpi=220, constrained_layout=True)
    x = np.arange(len(MODEL_ORDER))
    width = 0.36

    for ax, y_col, e_col, title in [
        (axes[0], "wall_sec_mean", "wall_sec_std", "Runtime (s)"),
        (axes[1], "peak_rss_mb_mean", "peak_rss_mb_std", "Peak RSS (MB)"),
    ]:
        rust_y = []
        rust_e = []
        c_y = []
        c_e = []
        for m in MODEL_ORDER:
            rr = perf_summary[(perf_summary.model == m) & (perf_summary.tool == "rust")]
            rust_y.append(float(rr[y_col].iloc[0]) if len(rr) else np.nan)
            rust_e.append(float(rr[e_col].iloc[0]) if len(rr) else np.nan)
            if has_c:
                cc = perf_summary[(perf_summary.model == m) & (perf_summary.tool == "c")]
                c_y.append(float(cc[y_col].iloc[0]) if len(cc) else np.nan)
                c_e.append(float(cc[e_col].iloc[0]) if len(cc) else np.nan)

        ax.bar(
            x - width / 2,
            rust_y,
            width=width,
            yerr=rust_e,
            capsize=3,
            color=COLORS["rust"],
            alpha=0.9,
            edgecolor="white",
            linewidth=0.9,
            label="Rust",
        )
        if has_c:
            ax.bar(
                x + width / 2,
                c_y,
                width=width,
                yerr=c_e,
                capsize=3,
                color=COLORS["c"],
                alpha=0.9,
                edgecolor="white",
                linewidth=0.9,
                label="C",
            )

        ax.set_xticks(x)
        ax.set_xticklabels(MODEL_ORDER, rotation=15)
        ax.set_title(title, fontsize=12.5, pad=8)
        stylize_axis(ax, xlog=False)
        ax.grid(axis="x", alpha=0.0)

    axes[0].legend(frameon=False)
    panel_labels(axes)
    fig.suptitle(f"Figure 3. Single-thread Performance (n={PERF_REPEATS} repeats, mean±SD)", fontsize=15.5, fontweight="bold", y=1.05)
    save_figure_multi(fig, "figure_3_runtime_peakrss_singlethread")
    plt.show()

    return perf_summary


In [None]:

def compute_ci_from_curves(curves: List[Tuple[np.ndarray, np.ndarray]], x_grid: np.ndarray):
    vals = np.asarray([[step_value(x, y, qx) for qx in x_grid] for x, y in curves], dtype=float)
    q025 = np.quantile(vals, 0.025, axis=0)
    q500 = np.quantile(vals, 0.5, axis=0)
    q975 = np.quantile(vals, 0.975, axis=0)
    return q025, q500, q975


def run_c_bootstrap_replicates(model_key: str, reps: int, iters: int, force: bool = False):
    splitfa = ensure_splitfa()
    model_boot_dir = BOOT_DIR / model_key
    c_boot_dir = model_boot_dir / "c_boot"
    c_boot_dir.mkdir(parents=True, exist_ok=True)

    split_input = model_boot_dir / f"{model_key}.split.psmcfa"
    if force or not split_input.exists():
        run_cmd([str(splitfa), str(model_input_path(model_key))], cwd=ROOT, stdout_path=split_input)

    paths = []
    for i in range(1, reps + 1):
        out = c_boot_dir / f"replicate_{i:03d}.psmc"
        paths.append(out)
        if force or not out.exists():
            run_c(split_input, out, n_iter=iters, bootstrap=True)
    return paths


def run_rust_bootstrap(model_key: str, reps: int, iters: int, force: bool = False):
    model_boot_dir = BOOT_DIR / model_key
    rust_boot_dir = model_boot_dir / "rust_boot"
    rust_main_json = model_boot_dir / "rust_bootstrap_main.json"
    rust_boot_dir.mkdir(parents=True, exist_ok=True)

    summary_tsv = rust_boot_dir / "summary.tsv"
    if force or (not summary_tsv.exists()) or (not rust_main_json.exists()):
        extra = [
            "--bootstrap", str(reps),
            "--bootstrap-iters", str(iters),
            "--bootstrap-block-size", str(BOOTSTRAP_BLOCK_SIZE),
            "--bootstrap-seed", str(BOOTSTRAP_SEED),
            "--bootstrap-dir", str(rust_boot_dir),
        ]
        run_rust(model_input_path(model_key), rust_main_json, n_iter=N_ITER, extra=extra)

    summary = pd.read_csv(summary_tsv, sep="	")
    return rust_main_json, summary


def run_figure4_bootstrap(has_c: bool, force: bool = False):
    if not has_c:
        print("[skip] Figure 4 requires C binary for Rust vs C CI comparison")
        return None

    panel_data = {}
    width_rows = []

    for key in BOOT_MODELS:
        print(f"[bootstrap] {key}")
        _, rust_summary = run_rust_bootstrap(key, reps=BOOTSTRAP_REPS, iters=BOOTSTRAP_ITERS, force=force)

        c_main = c_main_psmc_path(key)
        if force or not c_main.exists():
            run_c(model_input_path(key), c_main, n_iter=N_ITER)

        c_rep_paths = run_c_bootstrap_replicates(key, reps=BOOTSTRAP_REPS, iters=BOOTSTRAP_ITERS, force=force)
        c_rep_curves = [load_c_curve(pp) for pp in c_rep_paths]
        c_rep_curves = [cc for cc in c_rep_curves if cc is not None]

        x_grid = rust_summary["x_years"].to_numpy(dtype=float)

        c_q025, c_q500, c_q975 = compute_ci_from_curves(c_rep_curves, x_grid)
        c_main_curve = load_c_curve(c_main)
        c_main_vals = np.asarray([step_value(c_main_curve[0], c_main_curve[1], x) for x in x_grid], dtype=float)

        rust_main_vals = rust_summary["ne_main"].to_numpy(dtype=float)
        rust_q025 = rust_summary["ne_q025"].to_numpy(dtype=float)
        rust_q500 = rust_summary["ne_q500"].to_numpy(dtype=float)
        rust_q975 = rust_summary["ne_q975"].to_numpy(dtype=float)

        true_curve = true_curve_for_model(key)
        true_vals = np.asarray([step_value(true_curve[0], true_curve[1], x) for x in x_grid], dtype=float)

        width_rows.append({
            "model": key,
            "tool": "rust",
            "ci_width_mean": float(np.mean(rust_q975 - rust_q025)),
            "ci_width_median": float(np.median(rust_q975 - rust_q025)),
        })
        width_rows.append({
            "model": key,
            "tool": "c",
            "ci_width_mean": float(np.mean(c_q975 - c_q025)),
            "ci_width_median": float(np.median(c_q975 - c_q025)),
        })

        panel_data[key] = {
            "x": x_grid,
            "true": true_vals,
            "rust_main": rust_main_vals,
            "rust_q025": rust_q025,
            "rust_q500": rust_q500,
            "rust_q975": rust_q975,
            "c_main": c_main_vals,
            "c_q025": c_q025,
            "c_q500": c_q500,
            "c_q975": c_q975,
        }

    width_df = pd.DataFrame(width_rows).sort_values(["model", "tool"])
    save_table_multi(width_df, "table_bootstrap_ci_width")

    fig = plt.figure(figsize=(14.5, 6.8), dpi=220, constrained_layout=True)
    outer = fig.add_gridspec(1, 2, wspace=0.12)
    main_axes = []
    diff_axes = []

    for i, key in enumerate(BOOT_MODELS):
        inner = outer[0, i].subgridspec(2, 1, height_ratios=[4.0, 1.35], hspace=0.06)
        ax = fig.add_subplot(inner[0])
        axd = fig.add_subplot(inner[1], sharex=ax)
        main_axes.append(ax)
        diff_axes.append(axd)

        d = panel_data[key]

        ax.fill_between(
            d["x"], d["rust_q025"], d["rust_q975"], step="post", alpha=0.16, color=COLORS["rust"], label="Rust 95% CI"
        )
        ax.fill_between(
            d["x"], d["c_q025"], d["c_q975"], step="post", alpha=0.16, color=COLORS["c"], label="C 95% CI"
        )

        _step_with_outline(ax, d["x"], d["true"], color=COLORS["true"], ls=(0, (5, 3)), lw=2.2, label="True", zorder=1)
        _step_with_outline(ax, d["x"], d["c_main"], color=COLORS["c"], lw=2.05, label="C main", zorder=2)
        _step_with_outline(ax, d["x"], d["rust_main"], color=COLORS["rust"], lw=2.1, label="Rust main", zorder=3)

        ymax = max(np.max(d["true"]), np.max(d["rust_q975"]), np.max(d["c_q975"]))
        stylize_axis(ax, xlog=True)
        ax.set_xlim(1e3, 1e8)
        ax.set_ylim(0, ymax * 1.14)
        ax.set_title(MODELS[key]["title"], fontsize=12.5, pad=8)
        ax.set_ylabel("Effective population size (Ne)")
        ax.tick_params(labelbottom=False)

        l_true = np.log10(np.maximum(d["true"], 1e-12))
        d_r = np.log10(np.maximum(d["rust_main"], 1e-12)) - l_true
        d_c = np.log10(np.maximum(d["c_main"], 1e-12)) - l_true
        d_rc = np.log10(np.maximum(d["rust_main"], 1e-12)) - np.log10(np.maximum(d["c_main"], 1e-12))

        axd.axhline(0.0, color="#6B7280", lw=1.0, ls=(0, (4, 2)))
        axd.plot(d["x"], d_r, lw=1.8, color=COLORS["rust"], label="Rust - True")
        axd.plot(d["x"], d_c, lw=1.8, color=COLORS["c"], label="C - True")
        axd.plot(d["x"], d_rc, lw=1.2, ls=(0, (3, 2)), color="#374151", label="Rust - C")
        stylize_axis(axd, xlog=True)
        axd.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f"{y:+.2f}"))
        lim = max(0.03, min(0.50, float(np.max(np.abs([d_r, d_c, d_rc]))) * 1.25))
        axd.set_ylim(-lim, lim)
        axd.set_xlabel(f"Years (g={GEN_YEARS}, mu={MU:.1e})")
        axd.set_ylabel("Delta log10(Ne)", fontsize=9)
        axd.tick_params(axis="both", labelsize=8.3)

    panel_labels(main_axes)
    h_main, l_main = main_axes[0].get_legend_handles_labels()
    fig.legend(h_main, l_main, loc="upper center", ncol=5, bbox_to_anchor=(0.5, 1.04))
    h_diff, l_diff = diff_axes[0].get_legend_handles_labels()
    diff_axes[0].legend(h_diff, l_diff, loc="lower left", fontsize=7.9)

    fig.suptitle(
        f"Figure 4. Bootstrap 95% CI (Rust vs C, {BOOTSTRAP_REPS} replicates): zigzag + bottleneck",
        fontsize=14.7,
        fontweight="bold",
        y=1.07,
    )
    save_figure_multi(fig, "figure_4_bootstrap_zigzag_bottleneck")
    plt.show()

    return width_df


In [None]:

def run_all_main_text():
    has_c = ensure_tools()
    print(f"HAS_C={has_c}")

    print("\n[1/6] simulate inputs")
    simulate_inputs(force=FORCE_SIM)

    print("\n[2/6] Figure 1")
    make_figure1_workflow()

    print("\n[3/6] Figure 2 + Table 1")
    table1, curves = run_figure2_and_table1(has_c=has_c, force=FORCE_RUN)
    display(table1)

    print("\n[4/6] Figure 3 + Table 2")
    table2 = run_figure3_and_table2(has_c=has_c, force=FORCE_PERF)
    display(table2)

    print("\n[5/6] Figure 4")
    width_df = run_figure4_bootstrap(has_c=has_c, force=FORCE_BOOTSTRAP)
    if width_df is not None:
        display(width_df)

    print("\n[6/6] output summary")
    figs = sorted(FIG_DIR.glob("*"))
    tabs = sorted(TABLE_DIR.glob("*"))
    print("Figures:")
    for p in figs:
        print(" -", p)
    print("Tables:")
    for p in tabs:
        print(" -", p)
    print("Command log:", LOG_DIR / "commands.jsonl")


run_all_main_text()
