# Real Data Benchmark: Rust vs C (PSMC)

This notebook runs `psmc-rs` and original C `psmc` on four real diploid `psmcfa` samples,
then generates clean comparison figures (`PNG/SVG/PDF`) for manuscript use.

Fairness settings:
- Same PSMC core settings for Rust and C (`-N`, `-t`, `-r`, `-p`)
- Rust smoothing disabled (`--smooth-lambda 0`)
- Single-thread (`--threads 1`) for Rust-vs-C runtime fairness


In [None]:
from __future__ import annotations

import json
import math
import os
import shlex
import subprocess
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter

try:
    from IPython.display import display
except Exception:
    def display(x):
        print(x)


def find_repo_root(start: Path) -> Path:
    for p in [start, *start.parents]:
        if (p / 'Cargo.toml').exists() and (p / 'src').exists():
            return p
    raise RuntimeError(f'Cannot locate psmc-rs root from {start}')


ROOT = find_repo_root(Path.cwd().resolve())
RUN_DIR = ROOT / 'experiment' / 'runs' / 'real_data'
OUT_DIR = RUN_DIR / 'outputs'
FIG_DIR = RUN_DIR / 'figures'
TAB_DIR = RUN_DIR / 'tables'
LOG_DIR = RUN_DIR / 'logs'

for d in (RUN_DIR, OUT_DIR, FIG_DIR, TAB_DIR, LOG_DIR):
    d.mkdir(parents=True, exist_ok=True)

PSMC_RS_BIN = Path(os.environ.get('PSMC_RS_BIN', str(ROOT / 'target' / 'release' / 'psmc-rs')))
C_PSMC_BIN = Path(os.environ.get('C_PSMC_BIN', str(ROOT.parent / 'psmc-master' / 'psmc')))

plt.rcParams.update({
    'figure.facecolor': 'white',
    'savefig.facecolor': 'white',
    'axes.facecolor': 'white',
    'axes.edgecolor': '#B7C3D0',
    'axes.linewidth': 0.9,
    'axes.titleweight': 'semibold',
    'axes.labelcolor': '#334155',
    'xtick.color': '#475569',
    'ytick.color': '#475569',
    'grid.color': '#E6ECF3',
    'grid.linewidth': 0.7,
    'grid.alpha': 0.9,
    'font.family': 'DejaVu Sans',
    'font.size': 10.5,
    'legend.frameon': False,
    'pdf.fonttype': 42,
    'ps.fonttype': 42,
    'svg.fonttype': 'none',
})

COL = {
    'rust': '#E15759',
    'c': '#59A14F',
    'delta': '#2563EB',
    'axis': '#334155',
}

print('ROOT:', ROOT)
print('RUN_DIR:', RUN_DIR)
print('PSMC_RS_BIN:', PSMC_RS_BIN)
print('C_PSMC_BIN:', C_PSMC_BIN)


In [None]:
@dataclass
class SampleCfg:
    sid: str
    file: Path
    species: str
    group: str
    mu: float
    gen_years: float


SAMPLES: List[SampleCfg] = [
    SampleCfg(
        sid='HLemySub1',
        file=ROOT / 'experiment' / 'real_data' / 'HLemySub1_diploid.psmcfa',
        species='HLemySub1 (likely Emydura subglobosa)',
        group='Turtle',
        mu=7.9e-9,
        gen_years=15.0,
    ),
    SampleCfg(
        sid='Papuan_highlands',
        file=ROOT / 'experiment' / 'real_data' / 'Papuan_highlands_diploid.psmcfa',
        species='Homo sapiens (Papuan Highlands)',
        group='Human',
        mu=1.25e-8,
        gen_years=25.0,
    ),
    SampleCfg(
        sid='PD_0030_A_preussi',
        file=ROOT / 'experiment' / 'real_data' / 'PD_0030_Allochrocebus_preussi.psmcfa',
        species='Allochrocebus preussi',
        group='Primate',
        mu=4.91e-9,
        gen_years=10.0,
    ),
    SampleCfg(
        sid='PD_0032_C_ascanius_schmidti',
        file=ROOT / 'experiment' / 'real_data' / 'PD_0032_Cercopithecus_ascanius_schmidti.psmcfa',
        species='Cercopithecus ascanius schmidti',
        group='Primate',
        mu=4.82e-9,
        gen_years=12.0,
    ),
]

# Matched settings for Rust vs C
N_ITER = int(os.environ.get('N_ITER', '20'))
T_MAX = float(os.environ.get('T_MAX', '15'))
N_STEPS = int(os.environ.get('N_STEPS', '64'))
PATTERN = os.environ.get('PATTERN', '4+25*2+4+6')
RHO_T_RATIO = int(os.environ.get('RHO_T_RATIO', '5'))
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '300000'))
THREADS = int(os.environ.get('THREADS', '1'))
SMOOTH_LAMBDA = float(os.environ.get('SMOOTH_LAMBDA', '0'))
BIN_SIZE = int(os.environ.get('BIN_SIZE', '100'))
FORCE_RUN = os.environ.get('FORCE_RUN', '0') == '1'

meta = pd.DataFrame([
    {
        'sample_id': s.sid,
        'species': s.species,
        'group': s.group,
        'mu': s.mu,
        'gen_years': s.gen_years,
        'input_file': str(s.file),
        'exists': s.file.exists(),
    }
    for s in SAMPLES
])
display(meta)

missing = [str(s.file) for s in SAMPLES if not s.file.exists()]
if missing:
    raise FileNotFoundError('Missing input files:\n' + '\n'.join(missing))


In [None]:
def save_figure_multi(fig, stem: str):
    paths = []
    for ext in ('png', 'svg', 'pdf'):
        p = FIG_DIR / f'{stem}.{ext}'
        fig.savefig(p, dpi=320 if ext == 'png' else None, bbox_inches='tight')
        paths.append(p)
    print('saved figure:', ', '.join(str(p) for p in paths))


def save_table_multi(df: pd.DataFrame, stem: str):
    csv_p = TAB_DIR / f'{stem}.csv'
    tsv_p = TAB_DIR / f'{stem}.tsv'
    md_p = TAB_DIR / f'{stem}.md'
    df.to_csv(csv_p, index=False)
    df.to_csv(tsv_p, index=False, sep='	')
    md_p.write_text(df.to_markdown(index=False) + '\n', encoding='utf-8')
    print('saved table:', csv_p, tsv_p, md_p)


def run_cmd(cmd: List[str], cwd: Optional[Path] = None, check: bool = True):
    t0 = time.perf_counter()
    proc = subprocess.run(
        cmd,
        cwd=str(cwd) if cwd is not None else None,
        capture_output=True,
        text=True,
    )
    dt = time.perf_counter() - t0
    rec = {
        'cmd': ' '.join(shlex.quote(x) for x in cmd),
        'returncode': proc.returncode,
        'wall_sec': dt,
        'stdout': proc.stdout,
        'stderr': proc.stderr,
    }
    with (LOG_DIR / 'commands.jsonl').open('a', encoding='utf-8') as f:
        f.write(json.dumps(rec, ensure_ascii=False) + '\n')

    if check and proc.returncode != 0:
        print(rec['cmd'])
        print('--- stdout ---')
        print(proc.stdout)
        print('--- stderr ---')
        print(proc.stderr)
        raise RuntimeError(f'command failed: rc={proc.returncode}')
    return rec


def rust_out_json(s: SampleCfg) -> Path:
    return OUT_DIR / f'{s.sid}.rust.json'


def c_out_psmc(s: SampleCfg) -> Path:
    return OUT_DIR / f'{s.sid}.c.psmc'


def run_rust(s: SampleCfg):
    out_json = rust_out_json(s)
    cmd = [
        str(PSMC_RS_BIN),
        str(s.file),
        str(out_json),
        str(N_ITER),
        '--t-max', str(T_MAX),
        '--n-steps', str(N_STEPS),
        '--pattern', PATTERN,
        '--mu', str(s.mu),
        '--smooth-lambda', str(SMOOTH_LAMBDA),
        '--batch-size', str(BATCH_SIZE),
        '--threads', str(THREADS),
        '--no-progress',
    ]
    rec = run_cmd(cmd, cwd=ROOT)
    return out_json, rec


def run_c(s: SampleCfg):
    out_psmc = c_out_psmc(s)
    cmd = [
        str(C_PSMC_BIN),
        f'-N{N_ITER}',
        f'-t{T_MAX}',
        f'-r{RHO_T_RATIO}',
        '-p', PATTERN,
        '-o', str(out_psmc),
        str(s.file),
    ]
    rec = run_cmd(cmd, cwd=ROOT)
    return out_psmc, rec


if not PSMC_RS_BIN.exists():
    raise FileNotFoundError(f'psmc-rs binary not found: {PSMC_RS_BIN}')
if not C_PSMC_BIN.exists():
    raise FileNotFoundError(f'C psmc binary not found: {C_PSMC_BIN}')


In [None]:
def parse_pattern_spec(pattern):
    if pattern is None:
        return None
    out = []
    for part in str(pattern).split('+'):
        part = part.strip()
        if not part:
            continue
        if '*' in part:
            a, b = part.split('*', 1)
            nr = int(a.strip())
            gl = int(b.strip())
        else:
            nr = 1
            gl = int(part)
        out.append((nr, gl))
    return out if out else None


def parse_pattern_spec_legacy(pattern):
    if pattern is None:
        return None
    out = []
    for part in str(pattern).split('+'):
        part = part.strip()
        if not part:
            continue
        if '*' in part:
            a, b = part.split('*', 1)
            ts = int(a.strip())
            gs = int(b.strip())
        else:
            ts = int(part)
            gs = 1
        out.append((ts, gs))
    return out if out else None


def expand_lam(lam_grouped, n_steps, pattern_spec, pattern_raw=None):
    lam_grouped = list(map(float, lam_grouped))
    if pattern_spec is None:
        if len(lam_grouped) != n_steps + 1:
            raise ValueError(f'lam length {len(lam_grouped)} != n_steps+1 ({n_steps+1})')
        return lam_grouped

    expected_c = sum(nr for nr, _ in pattern_spec)
    if len(lam_grouped) == expected_c:
        lam = []
        idx = 0
        for nr, gl in pattern_spec:
            for _ in range(nr):
                lam.extend([lam_grouped[idx]] * gl)
                idx += 1
        if len(lam) != n_steps + 1:
            raise ValueError(f'expanded lam length {len(lam)} != n_steps+1 ({n_steps+1})')
        return lam

    legacy = parse_pattern_spec_legacy(pattern_raw)
    expected_legacy = sum(ts for ts, _ in legacy) + 1 if legacy is not None else None
    if expected_legacy is not None and len(lam_grouped) == expected_legacy:
        lam = []
        idx = 0
        for ts, gs in legacy:
            for _ in range(ts):
                lam.extend([lam_grouped[idx]] * gs)
                idx += 1
        lam.append(lam_grouped[-1])
        if len(lam) != n_steps + 1:
            raise ValueError(f'expanded legacy lam length {len(lam)} != n_steps+1 ({n_steps+1})')
        return lam

    raise ValueError('grouped lam length mismatch with pattern')


def compute_t_grid(n_steps: int, t_max: float, alpha: float = 0.1):
    beta = math.log(1 + t_max / alpha) / n_steps
    t = [alpha * (math.exp(beta * k) - 1.0) for k in range(n_steps)]
    t.append(float(t_max))
    return np.asarray(t, dtype=float)


def curve_from_json(path: Path, mu: float, gen_years: float, bin_size: int = BIN_SIZE):
    params = json.loads(path.read_text())
    theta = float(params['theta'])
    n_steps = int(params['n_steps'])
    t_max = float(params['t_max'])
    pattern_raw = params.get('pattern')
    pattern_spec = parse_pattern_spec(pattern_raw)
    lam = np.asarray(expand_lam(params['lam'], n_steps, pattern_spec, pattern_raw), dtype=float)

    t = compute_t_grid(n_steps, t_max)
    n0 = theta / (4.0 * float(mu) * float(bin_size))
    x = t * 2.0 * float(gen_years) * n0
    y = lam * n0
    x = np.append(x, 1e8)
    y = np.append(y, y[-1])
    return np.asarray(x, dtype=float), np.asarray(y, dtype=float)


def load_c_curve(psmc_path: Path, mu: float, gen_years: float, bin_size: int = BIN_SIZE):
    lines = psmc_path.read_text().splitlines()
    blocks = []
    cur = None
    for ln in lines:
        if ln.startswith('RD	'):
            if cur is not None:
                blocks.append(cur)
            cur = {'tr': None, 'pa': None, 'rs': []}
        elif cur is not None and ln.startswith('TR	'):
            _, th, rh = ln.split('	')[:3]
            cur['tr'] = (float(th), float(rh))
        elif cur is not None and ln.startswith('PA	'):
            cur['pa'] = ln
        elif cur is not None and ln.startswith('RS	'):
            t = ln.split('	')
            cur['rs'].append((int(t[1]), float(t[2]), float(t[3])))
    if cur is not None:
        blocks.append(cur)

    best = None
    for b in blocks[::-1]:
        if b['pa'] and b['tr'] is not None and b['rs']:
            best = b
            break
    if best is None:
        raise ValueError(f'cannot parse valid block in {psmc_path}')

    theta = best['tr'][0]
    n0 = theta / (4.0 * float(mu) * float(bin_size))
    xs = []
    ys = []
    for _, tk, lk in best['rs']:
        xs.append(2.0 * n0 * tk * float(gen_years))
        ys.append(n0 * lk)
    xs.append(1e8)
    ys.append(ys[-1])
    return np.asarray(xs, dtype=float), np.asarray(ys, dtype=float)


def step_value(xs: np.ndarray, ys: np.ndarray, xq: float) -> float:
    idx = int(np.searchsorted(xs, xq, side='right') - 1)
    idx = max(0, min(idx, len(ys) - 1))
    return float(ys[idx])


In [None]:
records = []
curves: Dict[str, Dict[str, np.ndarray]] = {}

for s in SAMPLES:
    print(f'\n[run] {s.sid}')

    r_json = rust_out_json(s)
    c_psmc = c_out_psmc(s)

    rust_rec = {'wall_sec': np.nan}
    c_rec = {'wall_sec': np.nan}

    if FORCE_RUN or not r_json.exists():
        _, rust_rec = run_rust(s)
    if FORCE_RUN or not c_psmc.exists():
        _, c_rec = run_c(s)

    xr, yr = curve_from_json(r_json, mu=s.mu, gen_years=s.gen_years)
    xc, yc = load_c_curve(c_psmc, mu=s.mu, gen_years=s.gen_years)

    curves[s.sid] = {
        'x_rust': xr,
        'y_rust': yr,
        'x_c': xc,
        'y_c': yc,
    }

    xg = np.geomspace(max(1e3, min(xr.min(), xc.min())), min(1e8, max(xr.max(), xc.max())), 600)
    lr = np.log10(np.asarray([max(step_value(xr, yr, x), 1e-12) for x in xg]))
    lc = np.log10(np.asarray([max(step_value(xc, yc, x), 1e-12) for x in xg]))
    rmse = float(np.sqrt(np.mean((lr - lc) ** 2)))

    records.append({
        'sample_id': s.sid,
        'species': s.species,
        'group': s.group,
        'mu': s.mu,
        'gen_years': s.gen_years,
        'rust_json': str(r_json),
        'c_psmc': str(c_psmc),
        'rmse_log10_ne_rust_vs_c': rmse,
        'rust_wall_sec': rust_rec['wall_sec'],
        'c_wall_sec': c_rec['wall_sec'],
    })

summary = pd.DataFrame(records).sort_values('sample_id').reset_index(drop=True)
save_table_multi(summary, 'real_data_rust_vs_c_summary')
display(summary)


In [None]:
def style_axis(ax, yfmt=True):
    ax.set_xscale('log')
    ax.grid(True, which='major')
    ax.grid(False, which='minor')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    if yfmt:
        ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f'{y:,.0f}'))


fig, axes = plt.subplots(2, 2, figsize=(14, 9.2), dpi=250, constrained_layout=True)
axes = axes.flatten()

for i, s in enumerate(SAMPLES):
    ax = axes[i]
    c = curves[s.sid]

    ax.step(c['x_c'], c['y_c'], where='post', color=COL['c'], lw=2.0, label='C', zorder=2)
    ax.step(c['x_rust'], c['y_rust'], where='post', color=COL['rust'], lw=2.0, label='Rust', zorder=3)

    style_axis(ax)
    ax.set_xlim(1e3, 1e8)
    ymax = max(float(np.max(c['y_c'])), float(np.max(c['y_rust'])))
    ax.set_ylim(0, ymax * 1.12)

    ax.set_title(s.species, fontsize=12.7, pad=7)
    if i % 2 == 0:
        ax.set_ylabel('Effective population size (Ne)')
    if i >= 2:
        ax.set_xlabel('Years')

    rmse = summary.loc[summary.sample_id == s.sid, 'rmse_log10_ne_rust_vs_c'].iloc[0]
    ax.text(
        0.985, 0.965,
        f'RMSE(log10 Ne)={rmse:.3f}\nμ={s.mu:.2e}, g={s.gen_years:.0f}',
        transform=ax.transAxes,
        ha='right', va='top', fontsize=8.4, color='#425466',
        bbox={'boxstyle': 'round,pad=0.22', 'fc': 'white', 'ec': '#D5DFEA', 'lw': 0.75, 'alpha': 0.95},
    )

    ax.text(0.012, 0.985, chr(ord('A') + i), transform=ax.transAxes,
            ha='left', va='top', fontsize=12, color='#44576B', fontweight='semibold')

h, l = axes[0].get_legend_handles_labels()
fig.legend(h, l, loc='upper center', ncol=2, bbox_to_anchor=(0.5, 1.01))
fig.suptitle('Real Data: PSMC-RS vs Original C PSMC', fontsize=17, y=1.04, fontweight='semibold')

save_figure_multi(fig, 'real_data_rust_vs_c')
plt.show()


In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 7.8), dpi=250, constrained_layout=True)
axes = axes.flatten()

for i, s in enumerate(SAMPLES):
    ax = axes[i]
    c = curves[s.sid]

    x_min = max(1e3, min(float(c['x_rust'].min()), float(c['x_c'].min())))
    x_max = min(1e8, max(float(c['x_rust'].max()), float(c['x_c'].max())))
    xg = np.geomspace(x_min, x_max, 650)

    lr = np.log10(np.asarray([max(step_value(c['x_rust'], c['y_rust'], x), 1e-12) for x in xg]))
    lc = np.log10(np.asarray([max(step_value(c['x_c'], c['y_c'], x), 1e-12) for x in xg]))
    d = lr - lc

    ax.axhline(0.0, color='#64748B', lw=1.0, ls=(0, (4, 2)))
    ax.plot(xg, d, color=COL['delta'], lw=1.8)
    ax.fill_between(xg, 0, d, where=d>=0, color=COL['delta'], alpha=0.10, linewidth=0)
    ax.fill_between(xg, 0, d, where=d<0, color=COL['delta'], alpha=0.08, linewidth=0)

    style_axis(ax, yfmt=False)
    ax.set_xlim(1e3, 1e8)
    lim = max(0.04, min(1.0, float(np.max(np.abs(d))) * 1.22))
    ax.set_ylim(-lim, lim)
    ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f'{y:+.2f}'))

    ax.set_title(s.species, fontsize=12.2, pad=6)
    if i % 2 == 0:
        ax.set_ylabel('Δ log10(Ne)  (Rust - C)')
    if i >= 2:
        ax.set_xlabel('Years')

    ax.text(0.012, 0.985, chr(ord('A') + i), transform=ax.transAxes,
            ha='left', va='top', fontsize=12, color='#44576B', fontweight='semibold')

fig.suptitle('Real Data: Difference Curve (Rust - C)', fontsize=16, y=1.03, fontweight='semibold')
save_figure_multi(fig, 'real_data_rust_minus_c_delta')
plt.show()


In [None]:
print('Notebook outputs:')
print(' - Figures:', FIG_DIR)
print(' - Tables :', TAB_DIR)
print(' - Logs   :', LOG_DIR / 'commands.jsonl')
print('\nIf you want to re-run all inference from scratch in this notebook:')
print('  set FORCE_RUN=1 in environment before launching Jupyter, or edit FORCE_RUN=True in config cell.')
