# PSMC-RS: MHS vs PSMCFA (500Mb, 4 demographic models)

This notebook:
1. Simulates **the same diploid genome** with `msprime` for 4 models (`constant`, `bottleneck`, `expansion`, `zigzag`).
2. Serializes each simulation into **both** input formats:
   - `psmcfa`
   - official `multihetsep/mhs` (`chrom pos nr_called alleles`)
3. Runs `psmc-rs` on both formats with matched inference settings.
4. Plots and quantifies differences (`MHS` vs `PSMCFA`) and recovery vs true curve.

> Default length is **500,000,000 bp** per model. You can override via `SIM_LENGTH_BP` env var before launching Jupyter.


In [None]:
from __future__ import annotations

import json
import math
import os
import platform
import shlex
import subprocess
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import msprime

ROOT = Path.cwd().resolve()
assert (ROOT / 'Cargo.toml').exists(), f'Please run this notebook in psmc-rs root. current={ROOT}'

EXP_DIR = ROOT / 'experiments' / 'mhs_vs_psmcfa_500mb'
INPUT_DIR = EXP_DIR / 'inputs'
OUTPUT_DIR = EXP_DIR / 'outputs'
for d in (EXP_DIR, INPUT_DIR, OUTPUT_DIR):
    d.mkdir(parents=True, exist_ok=True)

PSMC_RS_BIN = Path(os.environ.get('PSMC_RS_BIN', str(ROOT / 'target' / 'release' / 'psmc-rs')))

# Core run config
SIM_LENGTH_BP = int(os.environ.get('SIM_LENGTH_BP', '500000000'))
WINDOW_BP = int(os.environ.get('WINDOW_BP', '100'))
MU = float(os.environ.get('MU', '2.5e-8'))
RECOMB = float(os.environ.get('RECOMB', '1.25e-8'))
GEN_YEARS = float(os.environ.get('GEN_YEARS', '25'))

N_ITER = int(os.environ.get('N_ITER', '20'))
T_MAX = float(os.environ.get('T_MAX', '15'))
N_STEPS = int(os.environ.get('N_STEPS', '64'))
PATTERN = os.environ.get('PATTERN', '4+25*2+4+6')
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '300000'))
THREADS = int(os.environ.get('THREADS', '1'))
SMOOTH_LAMBDA = float(os.environ.get('SMOOTH_LAMBDA', '1e-3'))
FORCE_SIM = os.environ.get('FORCE_SIM', '1').strip() != '0'

print(f'ROOT={ROOT}')
print(f'Python={sys.executable}')
print(f'Platform={platform.platform()}')
print(f'msprime={msprime.__version__}')
print(f'PSMC_RS_BIN={PSMC_RS_BIN}')
print(f'SIM_LENGTH_BP={SIM_LENGTH_BP:,}, WINDOW_BP={WINDOW_BP}, symbols={math.ceil(SIM_LENGTH_BP / WINDOW_BP):,}')
print(f'MU={MU:.3e}, RECOMB={RECOMB:.3e}, GEN_YEARS={GEN_YEARS}')
print(f'N_ITER={N_ITER}, T_MAX={T_MAX}, N_STEPS={N_STEPS}, PATTERN={PATTERN}')
print(f'BATCH_SIZE={BATCH_SIZE}, THREADS={THREADS}, SMOOTH_LAMBDA={SMOOTH_LAMBDA}')
print(f'FORCE_SIM={FORCE_SIM} (set FORCE_SIM=0 to reuse existing inputs)')


In [None]:
def run_cmd(cmd: List[str], cwd: Optional[Path] = None, check: bool = True):
    start = time.perf_counter()
    proc = subprocess.run(
        cmd,
        cwd=str(cwd) if cwd is not None else None,
        text=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
    )
    dt = time.perf_counter() - start
    rec = {
        'cmd': ' '.join(shlex.quote(x) for x in cmd),
        'returncode': proc.returncode,
        'stdout': proc.stdout,
        'stderr': proc.stderr,
        'wall_sec': dt,
    }
    if check and proc.returncode != 0:
        print(rec['cmd'])
        print('--- stdout ---')
        print(proc.stdout)
        print('--- stderr ---')
        print(proc.stderr)
        raise RuntimeError(f'command failed: rc={proc.returncode}')
    return rec


def ensure_psmc_rs_release_binary() -> Path:
    if PSMC_RS_BIN.exists():
        return PSMC_RS_BIN
    print('[build] cargo build --release')
    run_cmd(['cargo', 'build', '--release'], cwd=ROOT)
    if not PSMC_RS_BIN.exists():
        raise FileNotFoundError(f'psmc-rs binary not found after build: {PSMC_RS_BIN}')
    return PSMC_RS_BIN


ensure_psmc_rs_release_binary()


In [None]:
def coalescent_events_to_generations(n0: float, events_4n0: List[Tuple[float, float]]) -> List[Tuple[float, float]]:
    return [(t4 * 4.0 * n0, ratio * n0) for (t4, ratio) in events_4n0]


MODELS: Dict[str, Dict] = {
    'constant': {
        'title': 'Constant',
        'n0': 10_000.0,
        'events_gen': [],
        'seed': 42,
    },
    'bottleneck': {
        'title': 'Bottleneck',
        'n0': 20_000.0,
        'events_gen': coalescent_events_to_generations(20_000.0, [
            (0.01, 0.05),
            (0.015, 0.5),
            (0.05, 0.25),
            (0.5, 0.5),
        ]),
        'seed': 43,
    },
    'expansion': {
        'title': 'Expansion',
        'n0': 10_000.0,
        'events_gen': coalescent_events_to_generations(10_000.0, [
            (0.01, 0.1),
            (0.06, 1.0),
            (0.2, 0.5),
            (1.0, 1.0),
            (2.0, 2.0),
        ]),
        'seed': 44,
    },
    'zigzag': {
        'title': 'Zigzag',
        'n0': 1_000.0,
        'events_gen': coalescent_events_to_generations(1_000.0, [
            (0.1, 5.0),
            (0.6, 20.0),
            (2.0, 5.0),
            (10.0, 10.0),
            (20.0, 5.0),
        ]),
        'seed': 45,
    },
}
MODEL_ORDER = ['constant', 'bottleneck', 'expansion', 'zigzag']


@dataclass
class ModelPaths:
    psmcfa: Path
    mhs: Path
    rust_psmcfa_json: Path
    rust_mhs_json: Path


def model_paths(key: str) -> ModelPaths:
    return ModelPaths(
        psmcfa=INPUT_DIR / f'{key}.psmcfa',
        mhs=INPUT_DIR / f'{key}.multihetsep',
        rust_psmcfa_json=OUTPUT_DIR / f'{key}.psmcfa.rust.json',
        rust_mhs_json=OUTPUT_DIR / f'{key}.mhs.rust.json',
    )


def true_curve_for_model(key: str):
    spec = MODELS[key]
    xs = [1e3]
    ys = [spec['n0']]
    for t_gen, ne in sorted(spec['events_gen'], key=lambda x: x[0]):
        xs.append(max(1e3, t_gen * GEN_YEARS))
        ys.append(ne)
    xs.append(1e8)
    ys.append(ys[-1])
    return np.array(xs, dtype=float), np.array(ys, dtype=float)


In [None]:
def build_demography(n0: float, events_gen: List[Tuple[float, float]]) -> msprime.Demography:
    dem = msprime.Demography()
    dem.add_population(name='pop0', initial_size=float(n0))
    for t, ne in sorted(events_gen, key=lambda x: x[0]):
        dem.add_population_parameters_change(time=float(t), population='pop0', initial_size=float(ne))
    return dem


def write_psmcfa(path: Path, seq: str, header: str = 'chr1'):
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open('w', encoding='utf-8') as f:
        f.write(f'> {header}\n')
        wrap = 60
        for i in range(0, len(seq), wrap):
            f.write(seq[i:i+wrap] + '\n')


def simulate_model_to_both_formats(key: str, force: bool = False):
    spec = MODELS[key]
    p = model_paths(key)
    if (not force) and p.psmcfa.exists() and p.mhs.exists():
        return {
            'model': key,
            'sim_skipped': True,
            'psmcfa': str(p.psmcfa),
            'mhs': str(p.mhs),
        }

    dem = build_demography(spec['n0'], spec['events_gen'])
    ts = msprime.sim_ancestry(
        samples={'pop0': 1},
        ploidy=2,
        demography=dem,
        sequence_length=float(SIM_LENGTH_BP),
        recombination_rate=float(RECOMB),
        random_seed=int(spec['seed']),
    )
    mts = msprime.sim_mutations(
        ts,
        rate=float(MU),
        random_seed=int(spec['seed']) + 1,
    )

    n_bins = math.ceil(SIM_LENGTH_BP / WINDOW_BP)
    has_het = np.zeros(n_bins, dtype=bool)

    mhs_rows = []
    prev_emit_pos = 0
    het_sites = 0
    all_sites = 0

    for var in mts.variants():
        pos = int(var.site.position) + 1  # 1-based
        if pos < 1 or pos > SIM_LENGTH_BP:
            continue
        all_sites += 1

        g0 = int(var.genotypes[0])
        g1 = int(var.genotypes[1])
        a0 = var.alleles[g0]
        a1 = var.alleles[g1]

        if a0 != a1:
            het_sites += 1
            idx = (pos - 1) // WINDOW_BP
            has_het[idx] = True

            nr_called = pos - prev_emit_pos
            if nr_called <= 0:
                continue
            mhs_rows.append((pos, nr_called, f'{a0}{a1}'))
            prev_emit_pos = pos

    seq = ''.join('K' if x else 'T' for x in has_het)
    write_psmcfa(p.psmcfa, seq, header='chr1')

    p.mhs.parent.mkdir(parents=True, exist_ok=True)
    with p.mhs.open('w', encoding='utf-8') as f:
        for pos, nr_called, alleles in mhs_rows:
            f.write(f'chr1\t{pos}\t{nr_called}\t{alleles}\n')

    return {
        'model': key,
        'sim_skipped': False,
        'psmcfa': str(p.psmcfa),
        'mhs': str(p.mhs),
        'all_sites': all_sites,
        'het_sites': het_sites,
        'mhs_rows': len(mhs_rows),
        'psmcfa_symbols': len(seq),
        'psmcfa_mb': p.psmcfa.stat().st_size / 1e6,
        'mhs_mb': p.mhs.stat().st_size / 1e6,
    }


In [None]:
sim_rows = []
for key in MODEL_ORDER:
    print(f'[simulate] {key} (length={SIM_LENGTH_BP:,} bp)')
    sim_rows.append(simulate_model_to_both_formats(key, force=FORCE_SIM))

sim_df = pd.DataFrame(sim_rows)
display(sim_df)


In [None]:
def run_psmc_rs(input_path: Path, input_format: str, output_json: Path):
    cmd = [
        str(PSMC_RS_BIN),
        str(input_path),
        str(output_json),
        str(N_ITER),
        '--input-format', input_format,
        '--t-max', str(T_MAX),
        '--n-steps', str(N_STEPS),
        '--pattern', PATTERN,
        '--mu', str(MU),
        '--smooth-lambda', str(SMOOTH_LAMBDA),
        '--batch-size', str(BATCH_SIZE),
        '--threads', str(THREADS),
        '--no-progress',
    ]
    if input_format == 'mhs':
        cmd += ['--mhs-bin-size', str(WINDOW_BP)]
    return run_cmd(cmd, cwd=ROOT)

run_rows = []
for key in MODEL_ORDER:
    p = model_paths(key)

    print(f'[run] {key} / psmcfa')
    r1 = run_psmc_rs(p.psmcfa, 'psmcfa', p.rust_psmcfa_json)
    run_rows.append({
        'model': key,
        'format': 'psmcfa',
        'wall_sec': r1['wall_sec'],
        'output_json': str(p.rust_psmcfa_json),
    })

    print(f'[run] {key} / mhs')
    r2 = run_psmc_rs(p.mhs, 'mhs', p.rust_mhs_json)
    run_rows.append({
        'model': key,
        'format': 'mhs',
        'wall_sec': r2['wall_sec'],
        'output_json': str(p.rust_mhs_json),
    })

run_df = pd.DataFrame(run_rows)
display(run_df)


In [None]:
def parse_pattern_spec(pattern):
    if pattern is None:
        return None
    out = []
    for part in str(pattern).split('+'):
        part = part.strip()
        if not part:
            continue
        if '*' in part:
            a, b = part.split('*', 1)
            nr = int(a.strip())
            gl = int(b.strip())
        else:
            nr = 1
            gl = int(part)
        if nr <= 0 or gl <= 0:
            raise ValueError(f'invalid pattern token: {part}')
        out.append((nr, gl))
    return out if out else None


def parse_pattern_spec_legacy(pattern):
    if pattern is None:
        return None
    out = []
    for part in str(pattern).split('+'):
        part = part.strip()
        if not part:
            continue
        if '*' in part:
            a, b = part.split('*', 1)
            ts = int(a.strip())
            gs = int(b.strip())
        else:
            ts = int(part)
            gs = 1
        if ts <= 0 or gs <= 0:
            raise ValueError(f'invalid legacy token: {part}')
        out.append((ts, gs))
    return out if out else None


def expand_lam(lam_grouped, n_steps, pattern_spec, pattern_raw=None):
    lam_grouped = list(map(float, lam_grouped))
    if pattern_spec is None:
        if len(lam_grouped) != n_steps + 1:
            raise ValueError(f'lam length {len(lam_grouped)} != n_steps+1 ({n_steps+1})')
        return lam_grouped

    expected_c = sum(nr for nr, _ in pattern_spec)
    if len(lam_grouped) == expected_c:
        lam = []
        idx = 0
        for nr, gl in pattern_spec:
            for _ in range(nr):
                lam.extend([lam_grouped[idx]] * gl)
                idx += 1
        if len(lam) != n_steps + 1:
            raise ValueError(f'expanded lam length {len(lam)} != n_steps+1 ({n_steps+1})')
        return lam

    legacy_spec = parse_pattern_spec_legacy(pattern_raw)
    expected_legacy = sum(ts for ts, _ in legacy_spec) + 1 if legacy_spec is not None else None
    if expected_legacy is not None and len(lam_grouped) == expected_legacy:
        lam = []
        idx = 0
        for ts, gs in legacy_spec:
            for _ in range(ts):
                lam.extend([lam_grouped[idx]] * gs)
                idx += 1
        lam.append(lam_grouped[-1])
        if len(lam) != n_steps + 1:
            raise ValueError(f'expanded legacy lam length {len(lam)} != n_steps+1 ({n_steps+1})')
        return lam

    raise ValueError(
        f'grouped lam length {len(lam_grouped)} != expected_c {expected_c}'
        + (f' and != expected_legacy {expected_legacy}' if expected_legacy is not None else '')
    )


def compute_t_grid(n_steps: int, t_max: float, alpha: float = 0.1):
    beta = math.log(1 + t_max / alpha) / n_steps
    t = [alpha * (math.exp(beta * k) - 1.0) for k in range(n_steps)]
    t.append(float(t_max))
    return np.asarray(t, dtype=float)


def curve_from_json(path: Path):
    d = json.loads(path.read_text())
    theta = float(d['theta'])
    mu = float(d.get('mu', MU))
    n_steps = int(d['n_steps'])
    t_max = float(d['t_max'])
    bin_size = float(d.get('bin_size', WINDOW_BP))

    pattern_raw = d.get('pattern')
    pattern_spec = parse_pattern_spec(pattern_raw)
    lam = np.asarray(expand_lam(d['lam'], n_steps, pattern_spec, pattern_raw=pattern_raw), dtype=float)

    n0 = theta / (4.0 * mu * bin_size)
    t = compute_t_grid(n_steps, t_max)

    x = t * 2.0 * GEN_YEARS * n0
    y = lam * n0

    x = np.append(x, 1e8)
    y = np.append(y, y[-1])
    return x, y


def step_value(xs: np.ndarray, ys: np.ndarray, xq: float) -> float:
    idx = np.searchsorted(xs, xq, side='right') - 1
    idx = max(0, min(idx, len(ys) - 1))
    return float(ys[idx])


def rmse_log10(curve_a, curve_b, x_min=1e3, x_max=1e8, n=400):
    xa, ya = curve_a
    xb, yb = curve_b
    grid = np.geomspace(x_min, x_max, n)
    va = np.array([step_value(xa, ya, x) for x in grid])
    vb = np.array([step_value(xb, yb, x) for x in grid])
    m = (va > 0) & (vb > 0) & np.isfinite(va) & np.isfinite(vb)
    if m.sum() == 0:
        return float('nan')
    return float(np.sqrt(np.mean((np.log10(va[m]) - np.log10(vb[m])) ** 2)))


In [None]:
metrics = []

fig, axes = plt.subplots(2, 2, figsize=(14, 10), dpi=140)
axes = axes.ravel()

for i, key in enumerate(MODEL_ORDER):
    ax = axes[i]
    p = model_paths(key)

    true_curve = true_curve_for_model(key)
    psmcfa_curve = curve_from_json(p.rust_psmcfa_json)
    mhs_curve = curve_from_json(p.rust_mhs_json)

    tx, ty = true_curve
    ax.step(tx, ty, where='post', lw=2.1, ls='--', color='dodgerblue', label='True')

    x1, y1 = psmcfa_curve
    x2, y2 = mhs_curve
    ax.step(x1, y1, where='post', lw=2.0, color='tomato', label='PSMCFA input')
    ax.step(x2, y2, where='post', lw=2.0, color='seagreen', label='MHS input')

    ymax = max(float(np.max(ty)), float(np.max(y1)), float(np.max(y2)))
    ax.set_title(MODELS[key]['title'])
    ax.set_xscale('log')
    ax.set_xlim(1e3, 1e8)
    ax.set_ylim(0, ymax * 1.15)
    ax.set_xlabel(f'Years (g={GEN_YEARS}, mu={MU:.1e})')
    ax.set_ylabel('Effective population size (Ne)')
    ax.grid(alpha=0.25)

    metrics.append({
        'model': key,
        'rmse_log10_ne_vs_true_psmcfa': rmse_log10(true_curve, psmcfa_curve),
        'rmse_log10_ne_vs_true_mhs': rmse_log10(true_curve, mhs_curve),
        'rmse_log10_ne_mhs_vs_psmcfa': rmse_log10(mhs_curve, psmcfa_curve),
    })

handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', ncol=3, frameon=False)
fig.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

metrics_df = pd.DataFrame(metrics).sort_values('model')
run_pivot = run_df.pivot(index='model', columns='format', values='wall_sec').reset_index()
summary_df = metrics_df.merge(run_pivot, on='model', how='left')
summary_df['speed_ratio_mhs_over_psmcfa'] = summary_df['mhs'] / summary_df['psmcfa']

print('Summary metrics (lower RMSE is better)')
display(summary_df)


## Notes

- This is a **format comparison** under matched simulation and inference settings.
- `MHS` and `PSMCFA` are generated from the **same simulated diploid genome** per model.
- If runtime is too long on your machine, reduce `SIM_LENGTH_BP` (e.g., `5e7`) and rerun.
