# Figure 4 + Figure 5 (Only)

- Figure 4: Bootstrap dual-panel (zigzag + bottleneck, Rust vs C)
- Figure 5: Format consistency (psmcfa / mhs / vcf)

输出目录：`experiment/runs/figure4_figure5`


In [None]:
from __future__ import annotations

import importlib.util
import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib.patheffects as pe
import numpy as np
import pandas as pd

try:
    from IPython.display import display
except Exception:
    def display(x):
        print(x)


def find_repo_root(start: Path) -> Path:
    for p in [start, *start.parents]:
        if (p / 'Cargo.toml').exists() and (p / 'src').exists():
            return p
    raise RuntimeError(f'Cannot locate psmc-rs root from {start}')


ROOT = find_repo_root(Path.cwd().resolve())
SCRIPT = ROOT / 'experiment' / 'scripts' / 'run_supplementary.py'
assert SCRIPT.exists(), f'missing script: {SCRIPT}'

spec = importlib.util.spec_from_file_location('supp', SCRIPT)
supp = importlib.util.module_from_spec(spec)
assert spec and spec.loader
spec.loader.exec_module(supp)

supp.setup_publication_style()

RUN_DIR = ROOT / 'experiment' / 'runs' / 'figure4_figure5'
FIG_DIR = RUN_DIR / 'figures'
TABLE_DIR = RUN_DIR / 'tables'
BOOT_DIR = RUN_DIR / 'bootstrap'
LOG_DIR = RUN_DIR / 'logs'
for d in (RUN_DIR, FIG_DIR, TABLE_DIR, BOOT_DIR, LOG_DIR):
    d.mkdir(parents=True, exist_ok=True)

print('ROOT =', ROOT)
print('SCRIPT =', SCRIPT)
print('RUN_DIR =', RUN_DIR)
print('PSMC_RS_BIN =', supp.PSMC_RS_BIN)
print('C_PSMC_BIN =', supp.C_PSMC_BIN)


In [None]:
# Parameters
FORCE = False
BOOT_MODELS = ['zigzag', 'bottleneck']
BOOTSTRAP_REPS = int(os.environ.get('BOOTSTRAP_REPS', '100'))
BOOTSTRAP_ITERS = int(os.environ.get('BOOTSTRAP_ITERS', str(supp.N_ITER)))

print('FORCE =', FORCE)
print('BOOT_MODELS =', BOOT_MODELS)
print('BOOTSTRAP_REPS =', BOOTSTRAP_REPS)
print('BOOTSTRAP_ITERS =', BOOTSTRAP_ITERS)


In [None]:
def save_figure_multi(fig, stem: str):
    out_paths = []
    for ext in ('png', 'svg', 'pdf'):
        p = FIG_DIR / f'{stem}.{ext}'
        fig.savefig(p, dpi=320 if ext == 'png' else None, bbox_inches='tight', metadata={'Creator': 'figure4_figure5_only.ipynb'})
        out_paths.append(p)
    print('saved figure:', ', '.join(str(p) for p in out_paths))


def save_table_multi(df: pd.DataFrame, stem: str):
    p_csv = TABLE_DIR / f'{stem}.csv'
    p_tsv = TABLE_DIR / f'{stem}.tsv'
    p_md = TABLE_DIR / f'{stem}.md'
    df.to_csv(p_csv, index=False)
    df.to_csv(p_tsv, index=False, sep='\t')
    p_md.write_text(df.to_markdown(index=False) + '\n', encoding='utf-8')
    print('saved table:', p_csv, p_tsv, p_md)


def compute_ci_from_curves(curves, x_grid):
    vals = np.asarray([[supp.step_value(x, y, qx) for qx in x_grid] for x, y in curves], dtype=float)
    q025 = np.quantile(vals, 0.025, axis=0)
    q500 = np.quantile(vals, 0.5, axis=0)
    q975 = np.quantile(vals, 0.975, axis=0)
    return q025, q500, q975


def run_c_bootstrap_replicates(model_key: str, reps: int, iters: int, force: bool = False):
    splitfa = supp.ensure_splitfa()
    model_dir = BOOT_DIR / model_key
    c_boot_dir = model_dir / 'c_boot'
    c_boot_dir.mkdir(parents=True, exist_ok=True)

    split_input = model_dir / f'{model_key}.split.psmcfa'
    if force or not split_input.exists():
        supp.run_cmd(
            [str(splitfa), str(supp.shared_input_path(model_key, supp.SIM_LENGTH_BP))],
            cwd=supp.ROOT,
            stdout_path=split_input,
        )

    out_paths = []
    for i in range(1, reps + 1):
        out = c_boot_dir / f'replicate_{i:03d}.psmc'
        out_paths.append(out)
        if force or not out.exists():
            supp.run_c(split_input, out, n_iter=iters, bootstrap=True)
    return out_paths


def run_rust_bootstrap(model_key: str, reps: int, iters: int, force: bool = False):
    model_dir = BOOT_DIR / model_key
    rust_boot_dir = model_dir / 'rust_boot'
    rust_main_json = model_dir / 'rust_bootstrap_main.json'
    rust_boot_dir.mkdir(parents=True, exist_ok=True)

    summary_tsv = rust_boot_dir / 'summary.tsv'
    if force or (not summary_tsv.exists()) or (not rust_main_json.exists()):
        extra = [
            '--bootstrap', str(reps),
            '--bootstrap-iters', str(iters),
            '--bootstrap-block-size', str(supp.BOOTSTRAP_BLOCK_SIZE),
            '--bootstrap-seed', str(supp.BOOTSTRAP_SEED),
            '--bootstrap-dir', str(rust_boot_dir),
        ]
        supp.run_rust(
            supp.shared_input_path(model_key, supp.SIM_LENGTH_BP),
            rust_main_json,
            n_iter=supp.N_ITER,
            threads=supp.RUST_THREADS_BASE,
            smooth_lambda=None,
            extra=extra,
        )

    summary = pd.read_csv(summary_tsv, sep='\t')
    return rust_main_json, summary


In [None]:

def _step_with_outline(ax, x, y, *, color, lw=2.1, ls='-', label=None, zorder=3, alpha=1.0):
    line = ax.step(x, y, where='post', color=color, lw=lw, ls=ls, label=label, zorder=zorder, alpha=alpha)[0]
    line.set_path_effects([
        pe.Stroke(linewidth=lw + 1.5, foreground='white', alpha=0.92),
        pe.Normal(),
    ])
    return line


def run_figure4(force=False):
    supp.ensure_pydeps()
    supp.ensure_tools(require_c=True)
    supp.ensure_shared_inputs(supp.SIM_LENGTH_BP, force=force)

    panel_data = {}
    width_rows = []

    for key in BOOT_MODELS:
        print(f'[Figure 4] {key}')
        _, rust_summary = run_rust_bootstrap(key, reps=BOOTSTRAP_REPS, iters=BOOTSTRAP_ITERS, force=force)

        c_main = BOOT_DIR / key / f'{key}.c.main.psmc'
        if force or not c_main.exists():
            supp.run_c(supp.shared_input_path(key, supp.SIM_LENGTH_BP), c_main, n_iter=supp.N_ITER)

        c_rep_paths = run_c_bootstrap_replicates(key, reps=BOOTSTRAP_REPS, iters=BOOTSTRAP_ITERS, force=force)
        c_rep_curves = [supp.load_c_curve(pp) for pp in c_rep_paths]
        c_rep_curves = [cc for cc in c_rep_curves if cc is not None]

        x_grid = rust_summary['x_years'].to_numpy(dtype=float)
        c_q025, c_q500, c_q975 = compute_ci_from_curves(c_rep_curves, x_grid)

        c_main_curve = supp.load_c_curve(c_main)
        c_main_vals = np.asarray([supp.step_value(c_main_curve[0], c_main_curve[1], x) for x in x_grid], dtype=float)

        rust_main_vals = rust_summary['ne_main'].to_numpy(dtype=float)
        rust_q025 = rust_summary['ne_q025'].to_numpy(dtype=float)
        rust_q500 = rust_summary['ne_q500'].to_numpy(dtype=float)
        rust_q975 = rust_summary['ne_q975'].to_numpy(dtype=float)

        true_curve = supp.true_curve_for_model(key)
        true_vals = np.asarray([supp.step_value(true_curve[0], true_curve[1], x) for x in x_grid], dtype=float)

        width_rows.append({
            'model': key,
            'tool': 'rust',
            'ci_width_mean': float(np.mean(rust_q975 - rust_q025)),
            'ci_width_median': float(np.median(rust_q975 - rust_q025)),
        })
        width_rows.append({
            'model': key,
            'tool': 'c',
            'ci_width_mean': float(np.mean(c_q975 - c_q025)),
            'ci_width_median': float(np.median(c_q975 - c_q025)),
        })

        panel_data[key] = {
            'x': x_grid,
            'true': true_vals,
            'rust_main': rust_main_vals,
            'rust_q025': rust_q025,
            'rust_q500': rust_q500,
            'rust_q975': rust_q975,
            'c_main': c_main_vals,
            'c_q025': c_q025,
            'c_q500': c_q500,
            'c_q975': c_q975,
        }

    width_df = pd.DataFrame(width_rows).sort_values(['model', 'tool'])
    save_table_multi(width_df, 'figure_4_bootstrap_ci_width')

    fig = plt.figure(figsize=(14.5, 6.8), dpi=220, constrained_layout=True)
    outer = fig.add_gridspec(1, 2, wspace=0.12)
    main_axes = []
    diff_axes = []

    for i, key in enumerate(BOOT_MODELS):
        inner = outer[0, i].subgridspec(2, 1, height_ratios=[4.0, 1.35], hspace=0.06)
        ax = fig.add_subplot(inner[0])
        axd = fig.add_subplot(inner[1], sharex=ax)
        main_axes.append(ax)
        diff_axes.append(axd)

        d = panel_data[key]

        ax.fill_between(d['x'], d['rust_q025'], d['rust_q975'], step='post', alpha=0.16, color=supp.COLORS['rust'], label='Rust 95% CI')
        ax.fill_between(d['x'], d['c_q025'], d['c_q975'], step='post', alpha=0.16, color=supp.COLORS['c'], label='C 95% CI')
        _step_with_outline(ax, d['x'], d['true'], color=supp.COLORS['true'], ls=(0, (5, 3)), lw=2.2, label='True', zorder=1)
        _step_with_outline(ax, d['x'], d['c_main'], color=supp.COLORS['c'], lw=2.0, label='C main', zorder=2)
        _step_with_outline(ax, d['x'], d['rust_main'], color=supp.COLORS['rust'], lw=2.1, label='Rust main', zorder=3)

        ymax = max(np.max(d['true']), np.max(d['rust_q975']), np.max(d['c_q975']))
        supp.stylize_axis(ax, xlog=True)
        ax.set_xlim(1e3, 1e8)
        ax.set_ylim(0, ymax * 1.14)
        ax.set_title(supp.MODELS[key]['title'], fontsize=12.5, pad=8)
        ax.set_ylabel('Effective population size (Ne)')
        ax.tick_params(labelbottom=False)

        l_true = np.log10(np.maximum(d['true'], 1e-12))
        d_r = np.log10(np.maximum(d['rust_main'], 1e-12)) - l_true
        d_c = np.log10(np.maximum(d['c_main'], 1e-12)) - l_true
        d_rc = np.log10(np.maximum(d['rust_main'], 1e-12)) - np.log10(np.maximum(d['c_main'], 1e-12))

        axd.axhline(0.0, color='#6B7280', lw=1.0, ls=(0, (4, 2)))
        axd.plot(d['x'], d_r, lw=1.8, color=supp.COLORS['rust'], label='Rust - True')
        axd.plot(d['x'], d_c, lw=1.8, color=supp.COLORS['c'], label='C - True')
        axd.plot(d['x'], d_rc, lw=1.2, ls=(0, (3, 2)), color='#374151', label='Rust - C')

        supp.stylize_axis(axd, xlog=True, yfmt=None)
        axd.yaxis.set_major_formatter(supp.FuncFormatter(lambda y, _: f"{y:+.2f}"))
        lim = max(0.03, min(0.50, float(np.max(np.abs([d_r, d_c, d_rc]))) * 1.25))
        axd.set_ylim(-lim, lim)
        axd.set_xlabel(f'Years (g={supp.GEN_YEARS}, mu={supp.MU:.1e})')
        axd.set_ylabel('Delta log10(Ne)', fontsize=9)
        axd.tick_params(axis='both', labelsize=8.3)

    supp.panel_labels(main_axes)
    handles, labels = main_axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', ncol=5, bbox_to_anchor=(0.5, 1.04))
    h2, l2 = diff_axes[0].get_legend_handles_labels()
    diff_axes[0].legend(h2, l2, loc='lower left', fontsize=8.0)

    fig.suptitle(
        f'Figure 4. Bootstrap 95% CI (Rust vs C, {BOOTSTRAP_REPS} replicates): zigzag + bottleneck',
        fontsize=14.7,
        fontweight='bold',
        y=1.07,
    )
    save_figure_multi(fig, 'figure_4_bootstrap_zigzag_bottleneck')
    plt.show()

    return width_df


In [None]:
def run_figure5(force=False):
    # Reuse supplementary S5 pipeline and re-export as Figure 5
    tab = supp.run_s5_format_consistency(force=force)

    for ext in ('png', 'svg', 'pdf'):
        src = supp.FIG_DIR / f'S5_format_consistency.{ext}'
        dst = FIG_DIR / f'figure_5_format_consistency.{ext}'
        if src.exists():
            shutil.copy2(src, dst)

    for ext in ('csv', 'tsv', 'md'):
        src = supp.TABLE_DIR / f'S5_format_consistency.{ext}'
        dst = TABLE_DIR / f'figure_5_format_consistency.{ext}'
        if src.exists():
            shutil.copy2(src, dst)

    print('Figure 5 exported to', FIG_DIR)
    return tab


In [None]:
figure4_table = run_figure4(force=FORCE)
display(figure4_table)

figure5_table = run_figure5(force=FORCE)
display(figure5_table)


In [None]:
print('Figures:')
for p in sorted(FIG_DIR.glob('*')):
    print(' -', p)

print('Tables:')
for p in sorted(TABLE_DIR.glob('*')):
    print(' -', p)
