In [1]:
# experiments/run_comparison.py
# Compare ALS (matrix-free) vs EM (dense) for low-rank+diag GLS/SUR.
# Assumes `lowrank_gls` is importable (i.e., package root on PYTHONPATH).

import time
import numpy as np

from lowrank_gls import (
    # solvers
    als_gls, em_gls,
    # calibration
    calibrate_alpha_gamma_s_cv3_conservative, finalize_on_dataset,
    # diagnostics & metrics
    coverage_and_mahalanobis, predict_Y, mse, test_nll,
)

# =========================
# Simulation (SUR and GLS)
# =========================

def simulate_sur(N_tr, N_te, K, p, k, seed=0):
    """SUR: X_j share a base + idio noise; latent k factors + idio diag noise."""
    rng = np.random.default_rng(seed)
    N = N_tr + N_te
    base = rng.standard_normal((N, p))
    Xs = [base + 0.50*rng.standard_normal((N, p)) for _ in range(K)]
    B  = [rng.standard_normal((p, 1)) for _ in range(K)]
    F0 = 1.2 * rng.standard_normal((K, k))
    D0 = 0.02 + 0.15 * rng.random(K)
    U  = rng.standard_normal((N, k))
    Y  = predict_Y(Xs, B) + U @ F0.T + rng.standard_normal((N, K)) * np.sqrt(D0)[None, :]
    # split
    Xs_tr = [X[:N_tr] for X in Xs]
    Xs_te = [X[N_tr:] for X in Xs]
    Y_tr  = Y[:N_tr]
    Y_te  = Y[N_tr:]
    return Xs_tr, Y_tr, Xs_te, Y_te

def simulate_gls(N_tr, N_te, p_list, k, seed=0):
    """General GLS: per-equation p_j; latent k factors + idio diag noise."""
    rng = np.random.default_rng(seed)
    K = len(p_list)
    N = N_tr + N_te
    # give all equations a shared base in their own feature space + noise
    Xs = []
    for p in p_list:
        base = rng.standard_normal((N, p))
        Xs.append(base + 0.50*rng.standard_normal((N, p)))
    B  = [rng.standard_normal((p, 1)) for p in p_list]
    F0 = 1.2 * rng.standard_normal((K, k))
    D0 = 0.02 + 0.15 * rng.random(K)
    U  = rng.standard_normal((N, k))
    Y  = predict_Y(Xs, B) + U @ F0.T + rng.standard_normal((N, K)) * np.sqrt(D0)[None, :]
    # split
    Xs_tr = [X[:N_tr] for X in Xs]
    Xs_te = [X[N_tr:] for X in Xs]
    Y_tr  = Y[:N_tr]
    Y_te  = Y[N_tr:]
    return Xs_tr, Y_tr, Xs_te, Y_te

# =========================
# Pretty printing
# =========================

def print_metrics_block(title, metrics, diag_train, diag_test):
    print(title)
    for k in ["mode","solver","K","p","k","p_list_summary","N_tr","N_te",
              "lam_F","lam_B","alpha","gamma","tau","s","Sec_fit","Mem_MB_est",
              "Test_MSE","Test_NLL_perN_cal"]:
        if k in metrics and metrics[k] is not None:
            print(f"{k}: {metrics[k]}")
    if diag_train is not None:
        print("\n[train (pre-calib raw F,D)]  "
              f"K={diag_train['K']}  N={diag_train['N']}")
        _pp_diag(diag_train)
    if diag_test is not None:
        print("\n[test  (calibrated)]  "
              f"K={diag_test['K']}  N={diag_test['N']}")
        _pp_diag(diag_test)

def _pp_diag(d):
    for k in ["cov@90%","cov@95%","cov@99%"]:
        if k in d["z_coverage_overall"]:
            v = d["z_coverage_overall"][k]
            print(f"  z {k}: overall={v['overall']:.3f}, "
                  f"min_eq={v['min_eq']:.3f}, max_eq={v['max_eq']:.3f}")
    print(f"  m2 mean={d['m2_mean']:.2f} (≈K), var={d['m2_var']:.2f} (≈2K)")
    print(f"  frac m2>χ2_95={d['m2_frac_gt_95']:.3f} (≈0.05), "
          f">χ2_99={d['m2_frac_gt_99']:.3f} (≈0.01)")
    print(f"  whiten frob_dev={d['whiten_frob_dev']:.3f}, "
          f"offdiag_max={d['whiten_offdiag_max']:.3f}")

# =========================
# Pipelines (ALS / EM)
# =========================

def run_sur_pipeline_als_cv31(*,
    N_tr, N_te, K, p, k, seed,
    lam_F=1e-3, lam_B=1e-3, sweeps=12,
    alpha_grid=np.linspace(0.78, 1.10, 17), Kfolds=3, val_frac=0.5,
    tail_target=0.03, cov_target=0.975, r_boost=2, d_floor=1e-6,
    use_cg_beta=True, cg_maxit=800, cg_tol=3e-7, use_diag_precond=True
):
    Xs_tr, Y_tr, Xs_te, Y_te = simulate_sur(N_tr, N_te, K, p, k, seed=seed)

    t0 = time.time()
    B, F, D, mem_mb, _ = als_gls(
        Xs_tr, Y_tr, k,
        lam_F=lam_F, lam_B=lam_B, sweeps=sweeps, d_floor=d_floor,
        use_cg_beta=use_cg_beta, cg_maxit=cg_maxit, cg_tol=cg_tol,
        use_diag_precond=use_diag_precond
    )
    sec_fit = time.time() - t0

    # pre-calibration diagnostics on train
    diag_tr = coverage_and_mahalanobis(Y_tr, Xs_tr, B, F, D)

    # CV calibration on train, finalize on test
    alpha, gamma, tau, s, use_eig = calibrate_alpha_gamma_s_cv3_conservative(
        Xs_tr, Y_tr, B, F, D,
        alpha_grid=alpha_grid, Kfolds=Kfolds, val_frac=val_frac,
        tail_target=tail_target, cov_target=cov_target,
        r_boost=r_boost, d_floor=d_floor
    )
    F_fin, D_fin = finalize_on_dataset(
        Xs_te, Y_te, B, F, D, alpha, gamma, tau, s,
        use_eig=use_eig, r_boost=r_boost, d_floor=d_floor
    )

    # metrics
    Yhat_te = predict_Y(Xs_te, B)
    test_mse = mse(Y_te, Yhat_te)
    test_nll_cal = test_nll(Y_te, Xs_te, B, F_fin, D_fin)
    diag_te = coverage_and_mahalanobis(Y_te, Xs_te, B, F_fin, D_fin)

    metrics = {
        "mode": "SUR", "solver": "ALS",
        "K": K, "p": p, "k": k,
        "N_tr": N_tr, "N_te": N_te,
        "lam_F": lam_F, "lam_B": lam_B,
        "alpha": alpha, "gamma": gamma, "tau": tau, "s": s,
        "Sec_fit": sec_fit, "Mem_MB_est": mem_mb,
        "Test_MSE": test_mse, "Test_NLL_perN_cal": test_nll_cal,
    }
    return metrics, (B, F, D), (F_fin, D_fin), diag_tr, diag_te

def run_sur_pipeline_em_cv31(*,
    N_tr, N_te, K, p, k, seed,
    lam_F=1e-3, lam_B=1e-3, iters=45,
    alpha_grid=np.linspace(0.78, 1.10, 17), Kfolds=3, val_frac=0.5,
    tail_target=0.03, cov_target=0.975, r_boost=2, d_floor=1e-6
):
    Xs_tr, Y_tr, Xs_te, Y_te = simulate_sur(N_tr, N_te, K, p, k, seed=seed)

    t0 = time.time()
    B, F, D, mem_mb, _ = em_gls(
        Xs_tr, Y_tr, k, lam_F=lam_F, lam_B=lam_B,
        iters=iters, d_floor=d_floor
    )
    sec_fit = time.time() - t0

    diag_tr = coverage_and_mahalanobis(Y_tr, Xs_tr, B, F, D)

    alpha, gamma, tau, s, use_eig = calibrate_alpha_gamma_s_cv3_conservative(
        Xs_tr, Y_tr, B, F, D,
        alpha_grid=alpha_grid, Kfolds=Kfolds, val_frac=val_frac,
        tail_target=tail_target, cov_target=cov_target,
        r_boost=r_boost, d_floor=d_floor
    )
    F_fin, D_fin = finalize_on_dataset(
        Xs_te, Y_te, B, F, D, alpha, gamma, tau, s,
        use_eig=use_eig, r_boost=r_boost, d_floor=d_floor
    )

    Yhat_te = predict_Y(Xs_te, B)
    test_mse = mse(Y_te, Yhat_te)
    test_nll_cal = test_nll(Y_te, Xs_te, B, F_fin, D_fin)
    diag_te = coverage_and_mahalanobis(Y_te, Xs_te, B, F_fin, D_fin)

    metrics = {
        "mode": "SUR", "solver": "EM",
        "K": K, "p": p, "k": k,
        "N_tr": N_tr, "N_te": N_te,
        "lam_F": lam_F, "lam_B": lam_B,
        "alpha": alpha, "gamma": gamma, "tau": tau, "s": s,
        "Sec_fit": sec_fit, "Mem_MB_est": mem_mb,
        "Test_MSE": test_mse, "Test_NLL_perN_cal": test_nll_cal,
    }
    return metrics, (B, F, D), (F_fin, D_fin), diag_tr, diag_te

def run_gls_pipeline_als_cv31(*,
    N_tr, N_te, p_list, k, seed,
    lam_F=1e-3, lam_B=1e-3, sweeps=12,
    alpha_grid=np.linspace(0.78, 1.10, 17), Kfolds=3, val_frac=0.5,
    tail_target=0.03, cov_target=0.975, r_boost=2, d_floor=1e-6,
    use_cg_beta=True, cg_maxit=800, cg_tol=3e-7, use_diag_precond=True
):
    K = len(p_list)
    Xs_tr, Y_tr, Xs_te, Y_te = simulate_gls(N_tr, N_te, p_list, k, seed=seed)

    t0 = time.time()
    B, F, D, mem_mb, _ = als_gls(
        Xs_tr, Y_tr, k,
        lam_F=lam_F, lam_B=lam_B, sweeps=sweeps, d_floor=d_floor,
        use_cg_beta=use_cg_beta, cg_maxit=cg_maxit, cg_tol=cg_tol,
        use_diag_precond=use_diag_precond
    )
    sec_fit = time.time() - t0

    diag_tr = coverage_and_mahalanobis(Y_tr, Xs_tr, B, F, D)

    alpha, gamma, tau, s, use_eig = calibrate_alpha_gamma_s_cv3_conservative(
        Xs_tr, Y_tr, B, F, D,
        alpha_grid=alpha_grid, Kfolds=Kfolds, val_frac=val_frac,
        tail_target=tail_target, cov_target=cov_target,
        r_boost=r_boost, d_floor=d_floor
    )
    F_fin, D_fin = finalize_on_dataset(
        Xs_te, Y_te, B, F, D, alpha, gamma, tau, s,
        use_eig=use_eig, r_boost=r_boost, d_floor=d_floor
    )

    Yhat_te = predict_Y(Xs_te, B)
    test_mse = mse(Y_te, Yhat_te)
    test_nll_cal = test_nll(Y_te, Xs_te, B, F_fin, D_fin)
    diag_te = coverage_and_mahalanobis(Y_te, Xs_te, B, F_fin, D_fin)

    metrics = {
        "mode": "GLS", "solver": "ALS",
        "K": K, "p_list_summary": (min(p_list), int(np.median(p_list)), max(p_list)),
        "k": k, "N_tr": N_tr, "N_te": N_te,
        "lam_F": lam_F, "lam_B": lam_B,
        "alpha": alpha, "gamma": gamma, "tau": tau, "s": s,
        "Sec_fit": sec_fit, "Mem_MB_est": mem_mb,
        "Test_MSE": test_mse, "Test_NLL_perN_cal": test_nll_cal,
    }
    return metrics, (B, F, D), (F_fin, D_fin), diag_tr, diag_te

def run_gls_pipeline_em_cv31(*,
    N_tr, N_te, p_list, k, seed,
    lam_F=1e-3, lam_B=1e-3, iters=45,
    alpha_grid=np.linspace(0.78, 1.10, 17), Kfolds=3, val_frac=0.5,
    tail_target=0.03, cov_target=0.975, r_boost=2, d_floor=1e-6
):
    K = len(p_list)
    Xs_tr, Y_tr, Xs_te, Y_te = simulate_gls(N_tr, N_te, p_list, k, seed=seed)

    t0 = time.time()
    B, F, D, mem_mb, _ = em_gls(
        Xs_tr, Y_tr, k, lam_F=lam_F, lam_B=lam_B,
        iters=iters, d_floor=d_floor
    )
    sec_fit = time.time() - t0

    diag_tr = coverage_and_mahalanobis(Y_tr, Xs_tr, B, F, D)

    alpha, gamma, tau, s, use_eig = calibrate_alpha_gamma_s_cv3_conservative(
        Xs_tr, Y_tr, B, F, D,
        alpha_grid=alpha_grid, Kfolds=Kfolds, val_frac=val_frac,
        tail_target=tail_target, cov_target=cov_target,
        r_boost=r_boost, d_floor=d_floor
    )
    F_fin, D_fin = finalize_on_dataset(
        Xs_te, Y_te, B, F, D, alpha, gamma, tau, s,
        use_eig=use_eig, r_boost=r_boost, d_floor=d_floor
    )

    Yhat_te = predict_Y(Xs_te, B)
    test_mse = mse(Y_te, Yhat_te)
    test_nll_cal = test_nll(Y_te, Xs_te, B, F_fin, D_fin)
    diag_te = coverage_and_mahalanobis(Y_te, Xs_te, B, F_fin, D_fin)

    metrics = {
        "mode": "GLS", "solver": "EM",
        "K": K, "p_list_summary": (min(p_list), int(np.median(p_list)), max(p_list)),
        "k": k, "N_tr": N_tr, "N_te": N_te,
        "lam_F": lam_F, "lam_B": lam_B,
        "alpha": alpha, "gamma": gamma, "tau": tau, "s": s,
        "Sec_fit": sec_fit, "Mem_MB_est": mem_mb,
        "Test_MSE": test_mse, "Test_NLL_perN_cal": test_nll_cal,
    }
    return metrics, (B, F, D), (F_fin, D_fin), diag_tr, diag_te

# =========================
# Kwarg splitter (avoid passing ALS-only args to EM)
# =========================

def _split_kwargs_for_als_em(kwargs):
    als_kwargs = dict(kwargs)
    em_kwargs  = dict(kwargs)
    for k in ("sweeps", "use_cg_beta", "cg_maxit", "cg_tol", "use_diag_precond"):
        em_kwargs.pop(k, None)
    return als_kwargs, em_kwargs

# =========================
# Comparison entry points
# =========================

def compare_sur_als_vs_em(**kwargs):
    print("=== SUR: ALS (matrix-free) vs EM (dense) ===")
    als_kw, em_kw = _split_kwargs_for_als_em(kwargs)

    als_metrics, _, _, als_tr, als_te = run_sur_pipeline_als_cv31(**als_kw)
    em_metrics,  _, _,  em_tr,  em_te = run_sur_pipeline_em_cv31(**em_kw)

    print_metrics_block("\n--- ALS SUR ---", als_metrics, als_tr, als_te)
    print_metrics_block("\n--- EM  SUR ---", em_metrics,  em_tr,  em_te)

    print("\n--- SUR: Deltas (ALS - EM) ---")
    for k in ["Test_MSE","Test_NLL_perN_cal","Sec_fit","Mem_MB_est"]:
        print(f"{k}: {als_metrics[k] - em_metrics[k]:+.4f}")

def compare_gls_als_vs_em(**kwargs):
    print("=== GLS: ALS (matrix-free) vs EM (dense) ===")
    als_kw, em_kw = _split_kwargs_for_als_em(kwargs)

    als_metrics, _, _, als_tr, als_te = run_gls_pipeline_als_cv31(**als_kw)
    em_metrics,  _, _,  em_tr,  em_te = run_gls_pipeline_em_cv31(**em_kw)

    print_metrics_block("\n--- ALS GLS ---", als_metrics, als_tr, als_te)
    print_metrics_block("\n--- EM  GLS ---", em_metrics,  em_tr,  em_te)

    print("\n--- GLS: Deltas (ALS - EM) ---")
    for k in ["Test_MSE","Test_NLL_perN_cal","Sec_fit","Mem_MB_est"]:
        print(f"{k}: {als_metrics[k] - em_metrics[k]:+.4f}")

# =========================
# Script entry
# =========================

if __name__ == "__main__":
    # SUR comparison (same settings for both solvers)
    compare_sur_als_vs_em(
        N_tr=240, N_te=120, K=120, p=3, k=8, seed=12345,
        lam_F=1e-3, lam_B=1e-3,
        sweeps=12,                    # ALS-only
        use_cg_beta=True,            # ALS-only
        cg_maxit=800, cg_tol=3e-7,   # ALS-only
        use_diag_precond=True,       # ALS-only
        alpha_grid=np.linspace(0.78, 1.10, 17),
        Kfolds=3, val_frac=0.5,
        tail_target=0.03, cov_target=0.975,
        r_boost=3, d_floor=1e-6
    )

    # GLS comparison
    compare_gls_als_vs_em(
        N_tr=240, N_te=120, p_list=[2,6,4,8,3]*24, k=8, seed=2027,
        lam_F=1e-3, lam_B=1e-3,
        sweeps=12,                    # ALS-only
        use_cg_beta=True,            # ALS-only
        cg_maxit=800, cg_tol=3e-7,   # ALS-only
        use_diag_precond=True,       # ALS-only
        alpha_grid=np.linspace(0.78, 1.10, 17),
        Kfolds=3, val_frac=0.5,
        tail_target=0.03, cov_target=0.975,
        r_boost=3, d_floor=1e-6
    )

=== SUR: ALS (matrix-free) vs EM (dense) ===

--- ALS SUR ---
mode: SUR
solver: ALS
K: 120
p: 3
k: 8
N_tr: 240
N_te: 120
lam_F: 0.001
lam_B: 0.001
alpha: 1.06
gamma: 3.1151434705040384e-05
tau: 3.5380366157200536
s: 1.03515625
Sec_fit: 0.10996603965759277
Mem_MB_est: 0.00864
Test_MSE: 11.541302021105702
Test_NLL_perN_cal: -48.965705275738294

[train (pre-calib raw F,D)]  K=120  N=240
  z cov@90%: overall=0.999, min_eq=0.992, max_eq=1.000
  z cov@95%: overall=1.000, min_eq=1.000, max_eq=1.000
  z cov@99%: overall=1.000, min_eq=1.000, max_eq=1.000
  m2 mean=119.98 (≈K), var=201.35 (≈2K)
  frac m2>χ2_95=0.033 (≈0.05), >χ2_99=0.004 (≈0.01)
  whiten frob_dev=0.813, offdiag_max=0.564

[test  (calibrated)]  K=120  N=120
  z cov@90%: overall=0.999, min_eq=0.983, max_eq=1.000
  z cov@95%: overall=1.000, min_eq=0.992, max_eq=1.000
  z cov@99%: overall=1.000, min_eq=1.000, max_eq=1.000
  m2 mean=129.60 (≈K), var=317.14 (≈2K)
  frac m2>χ2_95=0.150 (≈0.05), >χ2_99=0.050 (≈0.01)
  whiten frob_dev=12