In [None]:
from __future__ import annotations

import json
import math
import re
from pathlib import Path
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.ticker import FuncFormatter


def find_repo_root(start: Path) -> Path:
    for p in [start, *start.parents]:
        if (p / 'Cargo.toml').exists() and (p / 'src').exists():
            return p
    raise RuntimeError(f'Cannot locate psmc-rs root from {start}')


ROOT = find_repo_root(Path.cwd().resolve())
RUNS = ROOT / 'experiment' / 'runs'
MAIN = RUNS / 'main_text'
SUPP = RUNS / 'supplementary'
REAL = RUNS / 'real_data'

for p in [MAIN / 'figures', MAIN / 'tables', SUPP / 'figures', REAL / 'figures', REAL / 'tables']:
    p.mkdir(parents=True, exist_ok=True)


def save_figure_multi(fig, stem: str, out_dir: Path):
    out = []
    for ext in ('png', 'svg', 'pdf'):
        p = out_dir / f'{stem}.{ext}'
        fig.savefig(p, dpi=340 if ext == 'png' else None, bbox_inches='tight', pad_inches=0.03)
        out.append(str(p))
    print('saved figure:', ', '.join(out))


COL = {
    'true': '#2E6FDD',
    'rust': '#D33F3F',
    'c': '#2E7D32',
    'grid': '#E4E8EF',
    'axis': '#445268',
    'text': '#1F2937',
}

plt.rcParams.update({
    'figure.facecolor': 'white',
    'savefig.facecolor': 'white',
    'axes.facecolor': 'white',
    'axes.edgecolor': COL['axis'],
    'axes.linewidth': 0.9,
    'axes.labelcolor': COL['axis'],
    'axes.labelsize': 10.5,
    'axes.titlesize': 12.0,
    'xtick.color': COL['axis'],
    'ytick.color': COL['axis'],
    'xtick.labelsize': 9.4,
    'ytick.labelsize': 9.4,
    'xtick.major.size': 3.8,
    'ytick.major.size': 3.8,
    'xtick.major.width': 0.9,
    'ytick.major.width': 0.9,
    'grid.color': COL['grid'],
    'grid.linewidth': 0.8,
    'grid.linestyle': '-',
    'grid.alpha': 0.95,
    'lines.linewidth': 2.0,
    'font.family': 'sans-serif',
    'font.sans-serif': ['Arial', 'Helvetica', 'DejaVu Sans'],
    'font.size': 10,
    'legend.frameon': False,
    'legend.borderpad': 0.2,
    'legend.handlelength': 1.8,
    'legend.handletextpad': 0.5,
    'legend.columnspacing': 1.2,
    'pdf.fonttype': 42,
    'ps.fonttype': 42,
    'svg.fonttype': 'none',
})


def stylize(ax, *, xlog: bool = False, yfmt: bool = False):
    if xlog:
        ax.set_xscale('log')
    ax.grid(True, which='major', axis='both')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(direction='out')
    if yfmt:
        ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f'{y:,.0f}'))


def panel_tag(ax, text: str):
    ax.text(
        0.015, 0.985, text,
        transform=ax.transAxes,
        ha='left', va='top',
        fontsize=14, fontweight='bold',
        color='#2D3B52',
        zorder=30,
    )


def fig_legend(fig, handles, labels, *, ncol=3, y=0.985):
    return fig.legend(
        handles,
        labels,
        loc='upper center',
        ncol=ncol,
        bbox_to_anchor=(0.5, y),
        frameon=False,
        borderaxespad=0.0,
    )


def parse_pattern_spec(pattern):
    if pattern is None:
        return None
    out = []
    for part in str(pattern).split('+'):
        part = part.strip()
        if not part:
            continue
        if '*' in part:
            a, b = part.split('*', 1)
            nr = int(a.strip())
            gl = int(b.strip())
        else:
            nr = 1
            gl = int(part)
        out.append((nr, gl))
    return out if out else None


def parse_pattern_spec_legacy(pattern):
    if pattern is None:
        return None
    out = []
    for part in str(pattern).split('+'):
        part = part.strip()
        if not part:
            continue
        if '*' in part:
            a, b = part.split('*', 1)
            ts = int(a.strip())
            gs = int(b.strip())
        else:
            ts = int(part)
            gs = 1
        out.append((ts, gs))
    return out if out else None


def expand_lam(lam_grouped, n_steps, pattern_spec, pattern_raw=None):
    lam_grouped = list(map(float, lam_grouped))
    if pattern_spec is None:
        if len(lam_grouped) != n_steps + 1:
            raise ValueError(f'lam length {len(lam_grouped)} != n_steps+1 ({n_steps+1})')
        return lam_grouped

    expected_c = sum(nr for nr, _ in pattern_spec)
    if len(lam_grouped) == expected_c:
        lam = []
        idx = 0
        for nr, gl in pattern_spec:
            for _ in range(nr):
                lam.extend([lam_grouped[idx]] * gl)
                idx += 1
        if len(lam) != n_steps + 1:
            raise ValueError(f'expanded lam length {len(lam)} != n_steps+1 ({n_steps+1})')
        return lam

    legacy = parse_pattern_spec_legacy(pattern_raw)
    expected_legacy = sum(ts for ts, _ in legacy) + 1 if legacy is not None else None
    if expected_legacy is not None and len(lam_grouped) == expected_legacy:
        lam = []
        idx = 0
        for ts, gs in legacy:
            for _ in range(ts):
                lam.extend([lam_grouped[idx]] * gs)
                idx += 1
        lam.append(lam_grouped[-1])
        if len(lam) != n_steps + 1:
            raise ValueError(f'expanded legacy lam length {len(lam)} != n_steps+1 ({n_steps+1})')
        return lam

    raise ValueError('grouped lam length mismatch with pattern')


def compute_t_grid(n_steps: int, t_max: float, alpha: float = 0.1):
    beta = math.log(1 + t_max / alpha) / n_steps
    t = [alpha * (math.exp(beta * k) - 1.0) for k in range(n_steps)]
    t.append(float(t_max))
    return np.asarray(t, dtype=float)


def curve_from_json(path: Path, mu: float, gen_years: float, bin_size: int = 100):
    params = json.loads(path.read_text())
    theta = float(params['theta'])
    n_steps = int(params['n_steps'])
    t_max = float(params['t_max'])
    pattern_raw = params.get('pattern')
    pattern_spec = parse_pattern_spec(pattern_raw)
    lam = np.asarray(expand_lam(params['lam'], n_steps, pattern_spec, pattern_raw), dtype=float)

    t = compute_t_grid(n_steps, t_max)
    n0 = theta / (4.0 * float(mu) * float(bin_size))
    x = t * 2.0 * float(gen_years) * n0
    y = lam * n0
    x = np.append(x, 1e8)
    y = np.append(y, y[-1])
    return np.asarray(x, dtype=float), np.asarray(y, dtype=float)


def load_c_curve(psmc_path: Path, mu: float, gen_years: float, bin_size: int = 100):
    lines = psmc_path.read_text().splitlines()
    blocks = []
    cur = None
    for ln in lines:
        if ln.startswith('RD\t'):
            if cur is not None:
                blocks.append(cur)
            cur = {'tr': None, 'pa': None, 'rs': []}
        elif cur is not None and ln.startswith('TR\t'):
            _, th, rh = ln.split('\t')[:3]
            cur['tr'] = (float(th), float(rh))
        elif cur is not None and ln.startswith('PA\t'):
            cur['pa'] = ln
        elif cur is not None and ln.startswith('RS\t'):
            t = ln.split('\t')
            cur['rs'].append((int(t[1]), float(t[2]), float(t[3])))
    if cur is not None:
        blocks.append(cur)

    best = None
    for b in blocks[::-1]:
        if b['pa'] and b['tr'] is not None and b['rs']:
            best = b
            break
    if best is None:
        raise ValueError(f'cannot parse valid block in {psmc_path}')

    theta = best['tr'][0]
    n0 = theta / (4.0 * float(mu) * float(bin_size))
    xs = []
    ys = []
    for _, tk, lk in best['rs']:
        xs.append(2.0 * n0 * tk * float(gen_years))
        ys.append(n0 * lk)
    xs.append(1e8)
    ys.append(ys[-1])
    return np.asarray(xs, dtype=float), np.asarray(ys, dtype=float)


def step_value(xs: np.ndarray, ys: np.ndarray, xq: float) -> float:
    idx = int(np.searchsorted(xs, xq, side='right') - 1)
    idx = max(0, min(idx, len(ys) - 1))
    return float(ys[idx])


def rmse_log10(curve_a, curve_b, x_min=1e3, x_max=1e8, n=500):
    xa, ya = curve_a
    xb, yb = curve_b
    grid = np.geomspace(x_min, x_max, n)
    va = np.asarray([max(step_value(xa, ya, x), 1e-12) for x in grid], dtype=float)
    vb = np.asarray([max(step_value(xb, yb, x), 1e-12) for x in grid], dtype=float)
    return float(np.sqrt(np.mean((np.log10(va) - np.log10(vb)) ** 2)))


def load_main_perf_summary() -> pd.DataFrame:
    p_summary = MAIN / 'tables' / 'table_2_runtime_memory.csv'
    if p_summary.exists():
        try:
            d = pd.read_csv(p_summary)
            needed = {'model', 'tool', 'wall_sec_mean', 'wall_sec_std', 'peak_rss_mb_mean', 'peak_rss_mb_std'}
            if needed.issubset(set(d.columns)) and not d[['wall_sec_mean', 'peak_rss_mb_mean']].isna().all().all():
                return d.sort_values(['model', 'tool']).reset_index(drop=True)
        except Exception:
            pass

    p_raw = MAIN / 'tables' / 'perf_raw_repeats.csv'
    if p_raw.exists():
        try:
            raw = pd.read_csv(p_raw)
            if {'model', 'tool', 'wall_sec', 'peak_rss_mb'}.issubset(set(raw.columns)) and not raw[['wall_sec', 'peak_rss_mb']].isna().all().all():
                d = (
                    raw.groupby(['model', 'tool'], as_index=False)
                       .agg(
                           wall_sec_mean=('wall_sec', 'mean'),
                           wall_sec_std=('wall_sec', 'std'),
                           peak_rss_mb_mean=('peak_rss_mb', 'mean'),
                           peak_rss_mb_std=('peak_rss_mb', 'std'),
                       )
                       .sort_values(['model', 'tool'])
                )
                return d.reset_index(drop=True)
        except Exception:
            pass

    log = MAIN / 'logs' / 'commands.jsonl'
    rows = []
    if log.exists():
        pat = re.compile(r"/perf/(?P<model>[^/]+)\.(?P<tool>rust|c)\.rep(?P<rep>\d+)\.(?:json|psmc)")
        for line in log.read_text(encoding='utf-8').splitlines():
            if not line.strip():
                continue
            try:
                rec = json.loads(line)
            except Exception:
                continue
            m = pat.search(str(rec.get('cmd', '')))
            if not m:
                continue
            rows.append({
                'model': m.group('model'),
                'tool': m.group('tool'),
                'rep': int(m.group('rep')),
                'wall_sec': float(rec.get('wall_sec', float('nan'))),
                'peak_rss_mb': float(rec.get('peak_rss_mb', float('nan'))),
            })

    if not rows:
        raise FileNotFoundError('No usable runtime/memory data found in main_text tables or logs.')

    raw = pd.DataFrame(rows).drop_duplicates(['model', 'tool', 'rep'], keep='last')
    d = (
        raw.groupby(['model', 'tool'], as_index=False)
           .agg(
               wall_sec_mean=('wall_sec', 'mean'),
               wall_sec_std=('wall_sec', 'std'),
               peak_rss_mb_mean=('peak_rss_mb', 'mean'),
               peak_rss_mb_std=('peak_rss_mb', 'std'),
           )
           .sort_values(['model', 'tool'])
    )
    return d.reset_index(drop=True)


print('ROOT =', ROOT)
print('MAIN =', MAIN)
print('SUPP =', SUPP)
print('REAL =', REAL)


In [None]:
# Figure 2: Accuracy on simulated data (500 Mb)

MODELS = {
    'constant': {'mu': 2.5e-8, 'gen_years': 25.0, 'true_kind': 'constant', 'true_params': {'ne': 10000.0}, 'title': 'Constant'},
    'bottleneck': {
        'mu': 2.5e-8,
        'gen_years': 25.0,
        'true_kind': 'piecewise',
        'true_params': {'ne0': 20000.0, 'events': [(0.01, 0.05), (0.015, 0.5), (0.05, 0.25), (0.5, 0.5)]},
        'title': 'Bottleneck',
    },
    'expansion': {
        'mu': 2.5e-8,
        'gen_years': 25.0,
        'true_kind': 'piecewise',
        'true_params': {'ne0': 10000.0, 'events': [(0.01, 0.1), (0.06, 1.0), (0.2, 0.5), (1.0, 1.0), (2.0, 2.0)]},
        'title': 'Expansion',
    },
    'zigzag': {
        'mu': 2.5e-8,
        'gen_years': 25.0,
        'true_kind': 'piecewise',
        'true_params': {'ne0': 1000.0, 'events': [(0.1, 5.0), (0.6, 20.0), (2.0, 5.0), (10.0, 10.0), (20.0, 5.0)]},
        'title': 'Zigzag',
    },
}
MODEL_ORDER = ['constant', 'bottleneck', 'expansion', 'zigzag']


def true_curve_constant(ne: float):
    return np.asarray([1e3, 1e8], dtype=float), np.asarray([ne, ne], dtype=float)


def true_curve_piecewise(ne0: float, events: List[Tuple[float, float]], gen_years: float):
    xs = [1e3]
    ys = [ne0]
    for t_4n0, ratio in sorted(events, key=lambda x: x[0]):
        t_gen = t_4n0 * 4.0 * ne0
        xs.append(max(1e3, t_gen * gen_years))
        ys.append(ratio * ne0)
    xs.append(1e8)
    ys.append(ys[-1])
    return np.asarray(xs, dtype=float), np.asarray(ys, dtype=float)


def true_curve(model_key: str):
    spec = MODELS[model_key]
    if spec['true_kind'] == 'constant':
        return true_curve_constant(**spec['true_params'])
    return true_curve_piecewise(gen_years=spec['gen_years'], **spec['true_params'])


rows = []
curves = {}
for k in MODEL_ORDER:
    rust_json = MAIN / 'outputs' / f'{k}.rust.main.json'
    c_psmc = MAIN / 'outputs' / f'{k}.c.main.psmc'
    if not rust_json.exists() or not c_psmc.exists():
        raise FileNotFoundError(f'missing main output for {k}: {rust_json} / {c_psmc}')

    tr = true_curve(k)
    rc = curve_from_json(rust_json, mu=MODELS[k]['mu'], gen_years=MODELS[k]['gen_years'])
    cc = load_c_curve(c_psmc, mu=MODELS[k]['mu'], gen_years=MODELS[k]['gen_years'])
    curves[k] = {'true': tr, 'rust': rc, 'c': cc}
    rows.append({'model': k, 'rmse_log10_ne_rust': rmse_log10(tr, rc), 'rmse_log10_ne_c': rmse_log10(tr, cc)})


table1 = pd.DataFrame(rows)
table1.to_csv(MAIN / 'tables' / 'table_1_rmse_notitle.csv', index=False)
print(table1)

fig, axes = plt.subplots(2, 2, figsize=(13.0, 8.8), dpi=280)
fig.subplots_adjust(left=0.07, right=0.99, bottom=0.08, top=0.92, wspace=0.20, hspace=0.26)
axes = axes.flatten()

for i, k in enumerate(MODEL_ORDER):
    ax = axes[i]
    tx, ty = curves[k]['true']
    rx, ry = curves[k]['rust']
    cx, cy = curves[k]['c']

    ax.step(tx, ty, where='post', color=COL['true'], lw=2.2, ls=(0, (5, 3)), label='True history', zorder=1)
    ax.step(cx, cy, where='post', color=COL['c'], lw=2.6, alpha=0.85, label='C baseline', zorder=2)
    ax.step(rx, ry, where='post', color=COL['rust'], lw=1.9, label='PSMC-RS', zorder=3)

    stylize(ax, xlog=True, yfmt=True)
    ax.set_xlim(1e3, 1e8)
    ymax = max(np.max(ty), np.max(ry), np.max(cy))
    ax.set_ylim(0, ymax * 1.12)

    ax.set_title(MODELS[k]['title'], pad=5)
    if i % 2 == 0:
        ax.set_ylabel('Effective population size (Ne)')
    if i >= 2:
        ax.set_xlabel('Years (g=25, μ=2.5e-08)')
    panel_tag(ax, 'ABCD'[i])

h, l = axes[0].get_legend_handles_labels()
fig_legend(fig, h, l, ncol=3, y=0.985)

save_figure_multi(fig, 'figure_2_simulated_equivalence', MAIN / 'figures')
plt.show()


In [None]:
# Figure 3: Empirical genomes (PSMC-RS vs C baseline)

fit = pd.read_csv(REAL / 'tables' / 'real_data_rust_vs_c_fit_summary.csv')

sample_order = ['HomoSapiens', 'HLemySub1', 'HLhydTec1', 'HLpelCas1']
label_map = {
    'HomoSapiens': 'Homo sapiens (Papuan Highlands)',
    'HLemySub1': 'Emydura subglobosa',
    'HLhydTec1': 'Hydromedusa tectifera',
    'HLpelCas1': 'Pelusios castaneus',
}

fit = fit[fit['sample_id'].isin(sample_order)].copy()
fit['sample_id'] = pd.Categorical(fit['sample_id'], categories=sample_order, ordered=True)
fit = fit.sort_values('sample_id').reset_index(drop=True)

rows = []
fig, axes = plt.subplots(2, 2, figsize=(13.0, 8.8), dpi=280)
fig.subplots_adjust(left=0.07, right=0.99, bottom=0.08, top=0.92, wspace=0.20, hspace=0.26)
axes = axes.flatten()

for i, sid in enumerate(sample_order):
    row = fit[fit['sample_id'] == sid]
    if row.empty:
        raise ValueError(f'missing sample in fit summary: {sid}')
    row = row.iloc[0]

    mu = float(row['mu'])
    gen_years = float(row['gen_years'])
    rust_json = Path(row['rust_json'])
    c_psmc = Path(row['c_psmc'])

    xr, yr = curve_from_json(rust_json, mu=mu, gen_years=gen_years)
    xc, yc = load_c_curve(c_psmc, mu=mu, gen_years=gen_years)
    rmse = rmse_log10((xc, yc), (xr, yr))
    rows.append({'sample_id': sid, 'species': label_map[sid], 'rmse_log10_ne_rs_vs_c': rmse})

    ax = axes[i]
    ax.step(xc, yc, where='post', color=COL['c'], lw=2.6, alpha=0.86, label='C baseline', zorder=2)
    ax.step(xr, yr, where='post', color=COL['rust'], lw=1.9, label='PSMC-RS', zorder=3)

    stylize(ax, xlog=True, yfmt=True)
    ax.set_xlim(1e3, 1e8)
    ax.set_ylim(0, max(np.max(yc), np.max(yr)) * 1.12)
    ax.set_title(label_map[sid], pad=5)

    if i % 2 == 0:
        ax.set_ylabel('Effective population size (Ne)')
    if i >= 2:
        ax.set_xlabel('Years')

    panel_tag(ax, 'ABCD'[i])

rmse_emp = pd.DataFrame(rows)
rmse_emp.to_csv(REAL / 'tables' / 'figure_3_empirical_rmse_rs_vs_c.csv', index=False)
print(rmse_emp)

h, l = axes[0].get_legend_handles_labels()
fig_legend(fig, h, l, ncol=2, y=0.985)
save_figure_multi(fig, 'figure_3_empirical_genomes_rust_vs_c', REAL / 'figures')
plt.show()


In [None]:
# Figure 4: Computational performance and scalability

from matplotlib.lines import Line2D
from matplotlib.patches import Patch

perf = load_main_perf_summary()
s3 = pd.read_csv(SUPP / 'tables' / 'S3_memory_scaling_summary.csv')
s4 = pd.read_csv(SUPP / 'tables' / 'S4_thread_scaling_summary.csv')

order = ['constant', 'bottleneck', 'expansion', 'zigzag']
model_labels = ['Constant', 'Bottleneck', 'Expansion', 'Zigzag']

# Save a clean Table 2 snapshot used by this figure.
table2 = (
    perf.copy()
    .assign(model=lambda d: pd.Categorical(d['model'], categories=order, ordered=True))
    .sort_values(['model', 'tool'])
    .reset_index(drop=True)
)
table2.to_csv(MAIN / 'tables' / 'table_2_runtime_memory_notitle.csv', index=False)
print(table2)

fig, axes = plt.subplots(2, 2, figsize=(13.1, 8.9), dpi=280)
fig.subplots_adjust(left=0.07, right=0.995, bottom=0.08, top=0.91, wspace=0.23, hspace=0.30)

x = np.arange(len(order))
width = 0.34

# Panel A: runtime bars
ax = axes[0, 0]
for tool, color, offset in [('rust', COL['rust'], -width/2), ('c', COL['c'], width/2)]:
    g = table2[table2.tool == tool].set_index('model').reindex(order)
    vals = g['wall_sec_mean'].to_numpy(dtype=float)
    errs = g['wall_sec_std'].fillna(0.0).to_numpy(dtype=float)
    ax.bar(
        x + offset,
        vals,
        width=width,
        color=color,
        alpha=0.92,
        yerr=errs,
        error_kw=dict(ecolor='#111827', lw=1.0, capsize=3.2, capthick=1.0),
        linewidth=0,
    )
ax.set_xticks(x)
ax.set_xticklabels(model_labels, rotation=14, ha='right')
ax.set_ylabel('Wall-clock time (s)')
stylize(ax)
ax.grid(True, axis='y')
ax.grid(False, axis='x')
panel_tag(ax, 'A')

# Panel B: peak RSS bars
ax = axes[0, 1]
for tool, color, offset in [('rust', COL['rust'], -width/2), ('c', COL['c'], width/2)]:
    g = table2[table2.tool == tool].set_index('model').reindex(order)
    vals = g['peak_rss_mb_mean'].to_numpy(dtype=float)
    errs = g['peak_rss_mb_std'].fillna(0.0).to_numpy(dtype=float)
    ax.bar(
        x + offset,
        vals,
        width=width,
        color=color,
        alpha=0.92,
        yerr=errs,
        error_kw=dict(ecolor='#111827', lw=1.0, capsize=3.2, capthick=1.0),
        linewidth=0,
    )
ax.set_xticks(x)
ax.set_xticklabels(model_labels, rotation=14, ha='right')
ax.set_ylabel('Peak RSS (MB)')
stylize(ax)
ax.grid(True, axis='y')
ax.grid(False, axis='x')
panel_tag(ax, 'B')

# Panel C: memory scaling
ax = axes[1, 0]
for tool, color, marker in [('rust', COL['rust'], 'o'), ('c', COL['c'], 's')]:
    g = s3[s3['tool'] == tool].sort_values('length_mb')
    xx = g['length_mb'].to_numpy(dtype=float)
    yy = g['peak_rss_mb_mean'].to_numpy(dtype=float)
    sd = g['peak_rss_mb_std'].fillna(0.0).to_numpy(dtype=float)
    ax.fill_between(xx, yy - sd, yy + sd, color=color, alpha=0.08, linewidth=0)
    ax.plot(xx, yy, marker=marker, ms=4.8, lw=2.0, color=color)
ax.set_xlabel('Sequence length (Mb)')
ax.set_ylabel('Peak RSS (MB)')
stylize(ax)
panel_tag(ax, 'C')

# Panel D: thread scaling
ax = axes[1, 1]
ss = s4.sort_values('threads').copy()
threads = ss['threads'].to_numpy(dtype=float)
wall = ss['wall_sec_mean'].to_numpy(dtype=float)
wall_sd = ss['wall_sec_std'].fillna(0.0).to_numpy(dtype=float)
base_idx = int(np.argmin(threads))
base_wall = float(wall[base_idx])
base_sd = float(wall_sd[base_idx])

speedup_obs = base_wall / np.maximum(wall, 1e-12)
speedup_sd = speedup_obs * np.sqrt((base_sd / max(base_wall, 1e-12))**2 + (wall_sd / np.maximum(wall, 1e-12))**2)
ideal = threads / float(np.min(threads))

ax.errorbar(threads, speedup_obs, yerr=speedup_sd, color=COL['rust'], marker='o', lw=2.0, capsize=3.2)
ax.plot(threads, ideal, color=COL['true'], lw=1.9, ls=(0, (5, 3)))

ax.set_xlabel('Threads')
ax.set_ylabel('Speedup (x)')
stylize(ax)
panel_tag(ax, 'D')

# One compact global legend (no overlap with panels)
legend_handles = [
    Patch(facecolor=COL['rust'], edgecolor='none', label='PSMC-RS'),
    Patch(facecolor=COL['c'], edgecolor='none', label='C baseline'),
    Line2D([0], [0], color=COL['true'], lw=1.9, ls=(0, (5, 3)), label='Ideal linear'),
    Patch(facecolor='#6B7280', alpha=0.10, edgecolor='none', label='±1 SD ribbon'),
]
fig.legend(
    handles=legend_handles,
    loc='upper center',
    bbox_to_anchor=(0.5, 0.985),
    ncol=4,
    frameon=False,
    handlelength=1.7,
    columnspacing=1.0,
    handletextpad=0.4,
)

save_figure_multi(fig, 'figure_4_performance_scalability', MAIN / 'figures')
plt.show()


In [None]:
# Supplementary Figure S1: EM convergence diagnostics

s2 = pd.read_csv(SUPP / 'tables' / 'S2_em_convergence_trace.csv')
models = ['constant', 'bottleneck', 'expansion', 'zigzag']
title_map = {'constant': 'Constant', 'bottleneck': 'Bottleneck', 'expansion': 'Expansion', 'zigzag': 'Zigzag'}

fig, axes = plt.subplots(2, 2, figsize=(13.0, 8.8), dpi=280)
fig.subplots_adjust(left=0.08, right=0.99, bottom=0.09, top=0.92, wspace=0.22, hspace=0.30)
axes = axes.flatten()

for i, m in enumerate(models):
    ax = axes[i]
    g = s2[s2['model'] == m].sort_values(['tool', 'iter'])
    gc = g[g['tool'] == 'c']
    gr = g[g['tool'] == 'rust']

    if len(gc):
        ax.plot(gc['iter'], gc['loglike'], color=COL['c'], lw=2.4, alpha=0.92, label='C baseline')
    if len(gr):
        ax.plot(gr['iter'], gr['loglike'], color=COL['rust'], lw=1.9, marker='o', ms=3.2, label='PSMC-RS')

    ax.set_title(title_map[m], pad=5)
    ax.set_xlabel('EM iteration')
    if i % 2 == 0:
        ax.set_ylabel('Log-likelihood')
    stylize(ax)
    panel_tag(ax, 'ABCD'[i])

h, l = axes[0].get_legend_handles_labels()
fig_legend(fig, h, l, ncol=2, y=0.985)

save_figure_multi(fig, 'supplementary_figure_S1_em_convergence', SUPP / 'figures')
plt.show()


In [None]:
# Supplementary Figure S2: format-level inference consistency (psmcfa / mhs / vcf)

s5 = pd.read_csv(SUPP / 'tables' / 'S5_format_consistency.csv')
order = ['constant', 'bottleneck', 'expansion', 'zigzag']
model_labels = ['Constant', 'Bottleneck', 'Expansion', 'Zigzag']

s5['model'] = pd.Categorical(s5['model'], categories=order, ordered=True)
s5 = s5.sort_values('model').reset_index(drop=True)

fig, axes = plt.subplots(1, 2, figsize=(13.4, 5.1), dpi=280)
fig.subplots_adjust(left=0.08, right=0.88, bottom=0.16, top=0.90, wspace=0.24)

# Panel A: curve-level RMSE
ax = axes[0]
x = np.arange(len(s5), dtype=float)
width = 0.34
ax.bar(x - width/2, s5['rmse_curve_log10_mhs_vs_psmcfa'], width=width, color=COL['c'], alpha=0.90, label='MHS vs PSMCFA')
ax.bar(x + width/2, s5['rmse_curve_log10_vcf_vs_psmcfa'], width=width, color=COL['rust'], alpha=0.90, label='VCF vs PSMCFA')
ax.set_xticks(x)
ax.set_xticklabels(model_labels, rotation=14, ha='right')
ax.set_ylabel('Curve RMSE (log10 Ne)')
stylize(ax)
ax.legend(loc='upper right', frameon=False, handlelength=1.4, borderaxespad=0.2)
panel_tag(ax, 'A')

# Panel B: parameter-level relative differences (theta and rho)
ax = axes[1]
xx = np.arange(len(s5), dtype=float)
dx = 0.04
ax.plot(xx - dx, s5['delta_theta_rel_mhs'], color=COL['c'], lw=1.9, marker='o', ms=4.4, label='θ rel diff (MHS)')
ax.plot(xx + dx, s5['delta_theta_rel_vcf'], color=COL['rust'], lw=1.9, marker='o', ms=4.4, label='θ rel diff (VCF)')
ax.plot(xx - dx, s5['delta_rho_rel_mhs'], color=COL['c'], lw=1.6, ls=(0, (4, 2)), marker='s', ms=4.0, label='ρ rel diff (MHS)')
ax.plot(xx + dx, s5['delta_rho_rel_vcf'], color=COL['rust'], lw=1.6, ls=(0, (4, 2)), marker='s', ms=4.0, label='ρ rel diff (VCF)')
ax.set_xticks(xx)
ax.set_xticklabels(model_labels, rotation=14, ha='right')
ax.set_ylabel('Relative difference')
ax.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))
stylize(ax)
ax.legend(loc='center left', bbox_to_anchor=(1.02, 0.5), frameon=False)
panel_tag(ax, 'B')

save_figure_multi(fig, 'supplementary_figure_S2_format_consistency', SUPP / 'figures')
plt.show()


In [None]:
# Supplementary Figure S3: modernized user interfaces (TUI + interactive HTML report)

import matplotlib.image as mpimg

ui_dir = ROOT / 'experiment' / 'assets' / 'ui'
path_tui = ui_dir / 'tui_run.png'
path_html = ui_dir / 'html_report.png'

items = [
    ('A', path_tui, 'Place TUI screenshot at\nexperiment/assets/ui/tui_run.png'),
    ('B', path_html, 'Place HTML report screenshot at\nexperiment/assets/ui/html_report.png'),
]

loaded = []
width_ratios = []
for _, p, _ in items:
    if p.exists():
        img = mpimg.imread(p)
        h, w = img.shape[:2]
        loaded.append(img)
        width_ratios.append(max(0.7, float(w) / float(h)))
    else:
        loaded.append(None)
        width_ratios.append(1.0)

fig, axes = plt.subplots(
    1,
    2,
    figsize=(14.0, 5.5),
    dpi=280,
    gridspec_kw={'width_ratios': width_ratios},
)
fig.subplots_adjust(left=0.03, right=0.995, bottom=0.06, top=0.93, wspace=0.04)

for ax, (tag, path, fallback), img in zip(axes, items, loaded):
    ax.set_xticks([])
    ax.set_yticks([])
    for sp in ax.spines.values():
        sp.set_visible(True)
        sp.set_linewidth(0.8)
        sp.set_edgecolor('#D6DEE8')

    if img is not None:
        ax.imshow(img, interpolation='antialiased')
        # Keep original image aspect ratio (no stretching)
        ax.set_aspect('equal', adjustable='box')
        ax.set_facecolor('white')
    else:
        ax.set_facecolor('#F6F8FC')
        ax.text(
            0.5,
            0.5,
            fallback,
            ha='center',
            va='center',
            fontsize=11,
            color='#475569',
            linespacing=1.4,
            transform=ax.transAxes,
        )
        ax.text(
            0.5,
            0.16,
            'Figure will render automatically after screenshots are added.',
            ha='center',
            va='center',
            fontsize=9.2,
            color='#64748B',
            transform=ax.transAxes,
        )

for tag, ax in zip(['A', 'B'], axes):
    bb = ax.get_position()
    fig.text(
        bb.x0 - 0.006,
        bb.y1 + 0.004,
        tag,
        ha='left',
        va='bottom',
        fontsize=13.5,
        fontweight='bold',
        color='#2D3B52',
        bbox=dict(facecolor='white', edgecolor='none', pad=0.16, alpha=0.96),
    )

save_figure_multi(fig, 'supplementary_figure_S3_interfaces', SUPP / 'figures')
plt.show()



In [None]:
print('Done.')
print('Figure 2:', MAIN / 'figures' / 'figure_2_simulated_equivalence.png')
print('Figure 3:', REAL / 'figures' / 'figure_3_empirical_genomes_rust_vs_c.png')
print('Figure 4:', MAIN / 'figures' / 'figure_4_performance_scalability.png')
print('Supplementary S1:', SUPP / 'figures' / 'supplementary_figure_S1_em_convergence.png')
print('Supplementary S2:', SUPP / 'figures' / 'supplementary_figure_S2_format_consistency.png')
print('Supplementary S3:', SUPP / 'figures' / 'supplementary_figure_S3_interfaces.png')
