In [1]:

# -*- coding: utf-8 -*-
"""
STANDALONE Hyb-Adam-UM ALGORITHM ONLY — ALL MISSINGNESS × REPS, NJ*-STYLE OUTPUT TABLES

Minimal-change rewrite of your "Hyb-Adam-UM only (30% rep1)" script so it matches the
output style of the TreePrior code you posted:

What this script does
---------------------
1) Load ORIGINAL full matrix, symmetrize & sanitize.
2) Build & freeze ALL masked matrices for each missingness level × replicate:
     - default: 30/50/65/85% × 5 reps  (match your TreePrior script)
     - same seeding pattern: rng = RandomState(BASE_SEED + (rep-1))
3) For each mask:
     - Run Hyb-Adam-UM (your manual Adam, NumPy-only, central-diff grads on missing LT)
     - Produce a COMPLETED matrix
4) Save:
     - All masked matrices to MASKED_DIR
     - All completed matrices to COMPLETED_DIR
     - Detailed results CSV (all masks)
     - Summary mean±std CSV (by missingness)
     - Numeric summary CSV
     - Training-style delta table CSV (Original + each missingness mean±std)

Notes
-----
- I did NOT change the optimizer math (central diff + Adam). I only wrapped it in an
  "all masks" driver and added the same table outputs as in the TreePrior script.
- This NumPy central-diff implementation is extremely expensive for n=100; it is intended
  for small matrices (e.g., 15×15). If you run n=100 with EPOCHS=30000 you will likely
  wait a very long time.
"""

import os, sys, json, math, warnings, zipfile, time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Dict, Set, Tuple, Union
from dataclasses import dataclass
from itertools import combinations

warnings.filterwarnings("ignore", category=UserWarning)

# ============================================================
# ----------------------- CONFIG -----------------------------
# ============================================================

# Data (align with your TreePrior script)
ORIG_CANDIDATES = [
    "Result_NW_15x15.txt",
    "/mnt/data/Result_NW_100x100.txt",
    "Result_NW_15x15.txt",
    "/mnt/data/Result_NW_15x15.txt",
]

# Missingness × reps (align with your TreePrior script)
MISSING_FRACS = (0.30, 0.50, 0.65, 0.85, 0.90)
REPS          = 5
BASE_SEED     = 55

# Output folders (NJ*-style)
MASKED_DIR    = "hyb_adam_um_missing_matrices"
COMPLETED_DIR = "hyb_adam_um_completed_matrices"

# Preserve observed entries EXACTLY in the final output?
# IMPORTANT: Your original hyb_adam_um_impute() reconstructs the whole matrix from
# observed LT + optimized missing LT, so observed entries are preserved by construction.
PRESERVE_OBSERVED = True

# -----------------------
# Training hyperparameters (keep your original defaults)
# -----------------------
EPOCHS              = 3000
PRINT_EVERY         = 1000
LR_INIT             = 0.04
WEIGHT_DECAY        = 0.0
CLIP_GRAD_NORM      = 5.0
H_CENTRAL_DIFF      = 5e-5
BETA1               = 0.9
BETA2               = 0.999
ADAM_EPS            = 1e-8
TRI_FRAC            = 1.0   # use ALL triplets
SEED_HYB            = 42

# LR schedule: earlier & also fixed milestones
SCHED_PATIENCE_BLOCKS = 7    # with PRINT_EVERY=100 -> ~700 epochs
LR_FACTOR             = 0.5
LR_MIN                = 1e-4
LR_MILESTONES         = [700, 2000]


# ============================================================
# ----------------------- Loading & helpers -------------------
# ============================================================

MISSING_VAL = -1.0

def load_matrix_with_candidates(cands):
    for p in cands:
        if os.path.exists(p):
            M = np.loadtxt(p)
            if np.nanmax(M) > 500:
                M = M / 1000.0
            return M, p
    raise FileNotFoundError(f"File not found. Tried: {cands}")

def symmetrize_full(D: np.ndarray) -> np.ndarray:
    D = 0.5*(D + D.T)
    np.fill_diagonal(D, 0.0)
    return D

def symmetrize_with_missing(D: np.ndarray) -> np.ndarray:
    D = D.copy().astype(float)
    n = D.shape[0]
    for i in range(n):
        for j in range(i+1, n):
            a, b = D[i, j], D[j, i]
            if a >= 0 and b >= 0: v = 0.5*(a+b)
            elif a >= 0:         v = a
            elif b >= 0:         v = b
            else:                v = MISSING_VAL
            D[i, j] = D[j, i] = v
    np.fill_diagonal(D, 0.0)
    return D

def _finite_fill(v, fallback=1.0):
    v = np.asarray(v)
    if np.isfinite(v).any():
        return float(np.nanmedian(v[np.isfinite(v)]))
    return float(fallback)

def sanitize_distance_matrix(D: np.ndarray, name: str="D", force_nonneg: bool=True) -> np.ndarray:
    """Ensure finite, symmetric, optionally nonnegative, 0-diagonal matrix."""
    M = np.array(D, dtype=float)
    n = M.shape[0]
    neg = (M < 0); np.fill_diagonal(neg, False)
    M[neg] = np.nan
    off = ~np.eye(n, dtype=bool)
    med = _finite_fill(M[off], fallback=1.0)
    M = np.nan_to_num(M, nan=med, posinf=med, neginf=med)
    q = np.quantile(M[off], 0.995)
    if np.isfinite(q) and q > 0: M[off] = np.minimum(M[off], q)
    M = 0.5*(M + M.T)
    if force_nonneg: M = np.maximum(M, 0.0)
    np.fill_diagonal(M, 0.0)
    if not np.isfinite(M).all():
        raise ValueError(f"{name} has non-finite entries after sanitize.")
    return M


# ============================================================
# ----------------------- Robust δ / Δ ------------------------
# ============================================================

OMEGA   = 2.0
EPS_NUM = 1.0e-12

def robust_delta_per_triplet_numpy(M: np.ndarray, triplets: np.ndarray) -> np.ndarray:
    """Return δ(i,j,k) for each triplet; M must be symmetric, nonnegative, 0-diagonal."""
    i, j, k = triplets[:, 0], triplets[:, 1], triplets[:, 2]
    a = M[i, j]; b = M[i, k]; c = M[j, k]
    S = np.stack([a, b, c], axis=1)
    S_sorted = -np.sort(-S, axis=1)
    a, b, c = S_sorted[:, 0], S_sorted[:, 1], S_sorted[:, 2]
    viol = a >= (b + c)
    denom_v = np.maximum(b + c, EPS_NUM)
    delta1 = np.maximum(a / denom_v, OMEGA)
    denomA = np.maximum(2.0 * b * c, EPS_NUM)
    denomB = np.maximum(2.0 * a * c, EPS_NUM)
    denomG = np.maximum(2.0 * a * b, EPS_NUM)
    cosA = np.clip((b*b + c*c - a*a) / denomA, -1.0, 1.0)
    cosB = np.clip((a*a + c*c - b*b) / denomB, -1.0, 1.0)
    cosG = np.clip((a*a + b*b - c*c) / denomG, -1.0, 1.0)
    A = np.arccos(cosA); B = np.arccos(cosB); G = np.arccos(cosG)
    Ang = np.stack([A, B, G], axis=1)
    Ang_sorted = -np.sort(-Ang, axis=1)
    A, B, G = Ang_sorted[:, 0], Ang_sorted[:, 1], Ang_sorted[:, 2]
    delta2 = (A - B) / np.maximum(G, EPS_NUM)
    return np.where(viol, delta1, delta2)

def robust_delta_sum_numpy(M: np.ndarray, triplets: np.ndarray) -> float:
    return float(robust_delta_per_triplet_numpy(M, triplets).sum())

def compute_normalized_delta(M: np.ndarray, triplets: np.ndarray) -> float:
    """Log-compressed mean δ in [0,1] (clipped), using max_reasonable_delta=100."""
    delta_vals = robust_delta_per_triplet_numpy(M, triplets)
    max_reasonable_delta = 100.0
    delta_norm_vals = np.log1p(delta_vals) / np.log1p(max_reasonable_delta)
    delta_norm_vals = np.clip(delta_norm_vals, 0.0, 1.0)
    return float(np.mean(delta_norm_vals))


# ============================================================
# ----------------------- RMSE helper -------------------------
# ============================================================

def rmse_on_lt(A: np.ndarray, B: np.ndarray) -> float:
    n = A.shape[0]
    i_lt, j_lt = np.tril_indices(n, k=-1)
    diff = (A - B)[i_lt, j_lt]
    return float(np.sqrt(np.mean(diff*diff)))


# ============================================================
# ----------------------- Mask generation ---------------------
# ============================================================

def simulate_missing(D_full: np.ndarray, frac_missing: float, rng: np.random.RandomState):
    """Mask lower-triangle pairs to -1 (symmetrically) at given fraction."""
    n = D_full.shape[0]
    lower = np.tril(np.ones((n,n), dtype=bool), k=-1)
    I, J = np.where(lower); m = len(I)
    drop = int(round(frac_missing*m))
    keep = np.ones(m, dtype=bool)
    if drop > 0:
        keep[rng.choice(m, size=drop, replace=False)] = False
    D_inc = D_full.copy().astype(float)
    for idx in range(m):
        i, j = I[idx], J[idx]
        if not keep[idx]:
            D_inc[i, j] = D_inc[j, i] = MISSING_VAL
    np.fill_diagonal(D_inc, 0.0)
    D_inc = symmetrize_with_missing(D_inc)
    obs_mask = (D_inc >= 0)
    np.fill_diagonal(obs_mask, True)
    return D_inc, obs_mask


# ============================================================
# ---------------- Hyb-Adam-UM (manual Adam) ------------------
# ============================================================

def setup_problem_for_hyb_adam(M_in: np.ndarray):
    """Prepare parameterization over missing lower-triangle entries."""
    n = M_in.shape[0]
    lower_mask = np.tril(np.ones((n, n), dtype=bool), k=-1)
    given_mask_lower   = lower_mask & (M_in >= 0.0)
    missing_mask_lower = lower_mask & (M_in <  0.0)
    given_pairs   = np.array(np.where(given_mask_lower)).T
    missing_pairs = np.array(np.where(missing_mask_lower)).T
    Ng = len(given_pairs)
    given_vals = M_in[given_mask_lower].astype(np.float64)
    init_val = float(np.mean(given_vals)) if Ng > 0 else 1.0
    x = np.full((len(missing_pairs),), init_val, dtype=np.float64)
    triplets = np.array(list(combinations(range(n), 3)), dtype=np.int32)

    def assemble_full(xvec: np.ndarray) -> np.ndarray:
        # Observed LT preserved EXACTLY because we copy given_vals to those LT positions.
        M = np.zeros((n, n), dtype=np.float64)
        if Ng > 0:
            gi, gj = given_pairs[:, 0], given_pairs[:, 1]
            M[gi, gj] = given_vals
        if len(missing_pairs) > 0:
            mi, mj = missing_pairs[:, 0], missing_pairs[:, 1]
            M[mi, mj] = xvec
        M = M + M.T
        np.fill_diagonal(M, 0.0)
        np.maximum(M, 0.0, out=M)
        return M

    return x, given_pairs, missing_pairs, given_vals, triplets, assemble_full

def central_diff_grad(x: np.ndarray, f, h: float = H_CENTRAL_DIFF) -> np.ndarray:
    g = np.zeros_like(x)
    for k in range(x.size):
        x[k] += h; f1 = f(x)
        x[k] -= 2*h; f2 = f(x)
        x[k] += h
        g[k] = (f1 - f2) / (2*h)
    return g

def hyb_adam_um_impute(D_in: np.ndarray, trip_all: np.ndarray, ntri: int, verbose: bool=True) -> Tuple[np.ndarray, Dict]:
    """
    Optimize missing LT with robust-Δ objective; FULL-BATCH so batch Δ == full Δ.
    Returns (completed_matrix, train_info_dict).
    """
    np.random.seed(SEED_HYB)

    x, _, _, _, all_triplets, assemble_full = setup_problem_for_hyb_adam(D_in)
    T = all_triplets.shape[0]

    if x.size == 0:
        M0 = assemble_full(x)
        info = {"skipped": True, "reason": "no missing pairs", "epochs": 0}
        return M0, info

    def objective_full(xvec: np.ndarray) -> float:
        M = assemble_full(xvec)
        return robust_delta_sum_numpy(M, all_triplets)

    # Initial
    delta0_full = objective_full(x)
    if verbose:
        print(f"Initial robust Δ: {delta0_full:.6f}")

    # Adam state
    m = np.zeros_like(x); v = np.zeros_like(x); t = 0
    lr = LR_INIT

    # Scheduler state
    best_full = float("inf")
    last_block_best = float("inf")
    no_improve_blocks = 0

    best_x = x.copy()
    best_loss = float("inf")

    for epoch in range(1, EPOCHS + 1):
        t += 1

        # FULL batch
        f_batch = lambda xvec: objective_full(xvec)

        # Central-diff gradient
        g = central_diff_grad(x, f_batch, h=H_CENTRAL_DIFF)

        # Weight decay
        if WEIGHT_DECAY > 0.0:
            g = g + WEIGHT_DECAY * x

        # Gradient clipping
        if CLIP_GRAD_NORM is not None:
            g_norm = float(np.linalg.norm(g))
            if g_norm > CLIP_GRAD_NORM and g_norm > 0:
                g = g * (CLIP_GRAD_NORM / g_norm)

        # Adam update
        m = BETA1 * m + (1 - BETA1) * g
        v = BETA2 * v + (1 - BETA2) * (g * g)
        m_hat = m / (1 - (BETA1 ** t))
        v_hat = v / (1 - (BETA2 ** t))
        x -= lr * (m_hat / (np.sqrt(v_hat) + ADAM_EPS))
        x = np.maximum(x, 0.0)

        # Fixed epoch milestones
        if epoch in LR_MILESTONES and lr > LR_MIN + 1e-12:
            new_lr = max(LR_MIN, lr * LR_FACTOR)
            if new_lr < lr - 1e-12 and verbose:
                print(f"Epoch {epoch:05d}: lr milestone {lr:.6f}->{new_lr:.6f}")
            lr = new_lr

        # Logging + plateau schedule
        if epoch % PRINT_EVERY == 0 or epoch == 1 or epoch == EPOCHS:
            full_loss  = objective_full(x)
            if verbose:
                print(f"Epoch {epoch:5d} | full Δ = {full_loss:.6f} | lr={lr:.5f}")

            # plateau-based reduction
            if full_loss < last_block_best - 1e-12:
                no_improve_blocks = 0
                last_block_best = full_loss
            else:
                no_improve_blocks += 1
                if no_improve_blocks >= SCHED_PATIENCE_BLOCKS and lr > LR_MIN + 1e-12:
                    new_lr = max(LR_MIN, lr * LR_FACTOR)
                    if new_lr < lr - 1e-12 and verbose:
                        print(f"Epoch {epoch:05d}: lr plateau {lr:.6f}->{new_lr:.6f}")
                    lr = new_lr
                    no_improve_blocks = 0

        # Track best
        full_loss_now = objective_full(x)
        if full_loss_now < best_loss - 1e-12:
            best_loss = full_loss_now
            best_x = x.copy()

        if full_loss_now < best_full - 1e-12:
            best_full = full_loss_now

    M_best = assemble_full(best_x)
    M_best = 0.5*(M_best + M_best.T)
    np.fill_diagonal(M_best, 0.0)
    M_best = np.maximum(M_best, 0.0)

    info = {
        "skipped": False,
        "epochs": int(EPOCHS),
        "final_lr": float(lr),
        "best_obj": float(best_loss),
        "missing_params": int(best_x.size),
    }
    return M_best, info


# ============================================================
# ------------------------- Tables helpers --------------------
# ============================================================

def fmt_pm(mean_val, std_val, decimals=6) -> str:
    return f"{mean_val:.{decimals}f} ± {std_val:.{decimals}f}"

def mean_std(x: np.ndarray) -> Tuple[float, float]:
    m = float(np.mean(x))
    s = float(np.std(x, ddof=1)) if len(x) > 1 else 0.0
    return m, s


# ============================================================
# ------------------------- MAIN ------------------------------
# ============================================================

def main():
    # Load ORIGINAL
    D_orig, used_orig_path = load_matrix_with_candidates(ORIG_CANDIDATES)
    D0 = sanitize_distance_matrix(symmetrize_full(D_orig), "D_orig")

    n = D0.shape[0]
    labels = [f"T{i+1}" for i in range(n)]
    trip_all = np.array(list(combinations(range(n), 3)), dtype=np.int32)
    ntri = n*(n-1)*(n-2)//6

    # ORIGINAL Δ
    print("=== Calculating Δ for original matrix ===")
    Δ_original = robust_delta_sum_numpy(D0, trip_all)
    Δ_normalized_original = compute_normalized_delta(D0, trip_all)
    print(f"Original matrix file: {used_orig_path}")
    print(f"Number of triplets (n choose 3) = {ntri}")
    print(f"Original matrix Δ_total = {Δ_original:.6f}")
    print(f"Original Δ_normalized = {Δ_normalized_original:.6f}")
    print(f"Original Δ_per_triangle (Δ/ntri) = {Δ_original/ntri:.6f}")
    print()

    # Build & freeze ALL masked matrices (same seed pattern as TreePrior)
    masked_registry = []
    for frac in MISSING_FRACS:
        for r in range(1, REPS + 1):
            rng = np.random.RandomState(BASE_SEED + (r - 1))
            D_inc, obs_mask = simulate_missing(D0, frac, rng)
            masked_registry.append({"frac": frac, "rep": r, "D_inc": D_inc, "obs_mask": obs_mask})

    # Save all masked matrices
    os.makedirs(MASKED_DIR, exist_ok=True)
    for rec in masked_registry:
        pct = int(round(100 * rec["frac"]))
        fn = f"missing_p{pct}_rep{rec['rep']}.csv"
        pd.DataFrame(rec["D_inc"], index=labels, columns=labels).to_csv(os.path.join(MASKED_DIR, fn))

    # Run Hyb-Adam-UM on ALL masks
    print("=== Running Hyb-Adam-UM on all masks ===")
    os.makedirs(COMPLETED_DIR, exist_ok=True)

    results = []
    runtimes_50 = []

    for rec in masked_registry:
        frac = rec["frac"]; rep = rec["rep"]
        D_inc = rec["D_inc"]
        pct = int(round(100 * frac))
        print(f"\n  Processing {pct}% missing, replicate {rep}...")

        t0_50 = time.perf_counter() if abs(frac - 0.50) < 1e-12 else None
        t_all = time.perf_counter()

        # Train
        D_hyb, info = hyb_adam_um_impute(D_inc, trip_all=trip_all, ntri=ntri, verbose=True)

        runtime_total = time.perf_counter() - t_all
        runtime_50 = None
        if t0_50 is not None:
            runtime_50 = time.perf_counter() - t0_50
            runtimes_50.append(runtime_50)
            print(f"    Runtime for 50% missing, replicate {rep}: {runtime_50:.2f} sec ({runtime_50/60:.2f} min)")
        print(f"    End-to-end runtime: {runtime_total:.2f} sec ({runtime_total/60:.2f} min)")

        # Sanitize (consistent with your earlier script)
        D_hyb_san = sanitize_distance_matrix(D_hyb, f"Hyb-Adam-UM_p{pct}_rep{rep}")

        # Metrics vs original
        rmse_val = rmse_on_lt(D_hyb_san, D0)
        Δ_total  = robust_delta_sum_numpy(D_hyb_san, trip_all)
        Δ_norm   = compute_normalized_delta(D_hyb_san, trip_all)
        Δ_per_triangle = Δ_total / ntri
        Δ_relative = (Δ_total / Δ_original) if (Δ_original != 0) else float("nan")

        # Save completed matrix
        out_fn = f"HybAdamUM_completed_p{pct}_rep{rep}.csv"
        pd.DataFrame(D_hyb_san, index=labels, columns=labels).to_csv(os.path.join(COMPLETED_DIR, out_fn))

        results.append({
            "pct_missing": pct,
            "replicate": rep,
            "RMSE_vs_ORIG": rmse_val,
            "Δ_total": Δ_total,
            "Δ_normalized": Δ_norm,
            "Δ_per_triangle": Δ_per_triangle,
            "Δ_relative": Δ_relative,
            "runtime_seconds": runtime_50,   # keep TreePrior convention: only filled for 50%
            "completed_file": out_fn,
            "train_info": str(info),
        })

    # runtime summary for 50%
    if runtimes_50:
        avg_runtime_50 = float(np.mean(runtimes_50))
        min_runtime_50 = float(np.min(runtimes_50))
        max_runtime_50 = float(np.max(runtimes_50))
        print("\n=== Runtime summary for 50% missing matrices ===")
        print(f"Count: {len(runtimes_50)}")
        print(f"Average: {avg_runtime_50:.2f} sec ({avg_runtime_50/60:.2f} min)")
        print(f"Range:   {min_runtime_50:.2f} - {max_runtime_50:.2f} sec ({min_runtime_50/60:.2f} - {max_runtime_50/60:.2f} min)")

    # detailed results table
    results_df = pd.DataFrame(results)
    print("\n=== Hyb-Adam-UM results for all masks ===")
    try:
        from IPython.display import display
        display(results_df)
    except Exception:
        print(results_df.to_string(index=False))
    results_df.to_csv("hyb_adam_um_all_masks_detailed.csv", index=False)

    # summary (mean ± std over reps per missingness)
    summary_rows = []
    for frac in MISSING_FRACS:
        pct = int(round(100 * frac))
        mask_results = results_df[results_df["pct_missing"] == pct]

        rmse_vals = mask_results["RMSE_vs_ORIG"].astype(float).to_numpy()
        dt_vals   = mask_results["Δ_total"].astype(float).to_numpy()
        dn_vals   = mask_results["Δ_normalized"].astype(float).to_numpy()
        dpt_vals  = mask_results["Δ_per_triangle"].astype(float).to_numpy()
        dr_vals   = mask_results["Δ_relative"].astype(float).to_numpy()

        rmse_m, rmse_s = mean_std(rmse_vals)
        dt_m, dt_s     = mean_std(dt_vals)
        dn_m, dn_s     = mean_std(dn_vals)
        dpt_m, dpt_s   = mean_std(dpt_vals)
        dr_m, dr_s     = mean_std(dr_vals)

        summary_rows.append({
            "% Missing": f"{pct}%",
            "RMSE_mean": rmse_m, "RMSE_std": rmse_s,
            "Δ_total_mean": dt_m, "Δ_total_std": dt_s,
            "Δ_normalized_mean": dn_m, "Δ_normalized_std": dn_s,
            "Δ_per_triangle_mean": dpt_m, "Δ_per_triangle_std": dpt_s,
            "Δ_relative_mean": dr_m, "Δ_relative_std": dr_s
        })

    summary_df = pd.DataFrame(summary_rows)
    print("\n=== Hyb-Adam-UM summary (mean ± std over replicates) ===")
    try:
        from IPython.display import display
        display(summary_df)
    except Exception:
        print(summary_df.to_string(index=False))

    # formatted summary (NJ*-style)
    formatted_summary_df = summary_df.copy()
    formatted_summary_df["RMSE"] = [
        fmt_pm(row["RMSE_mean"], row["RMSE_std"], 6) for _, row in summary_df.iterrows()
    ]
    formatted_summary_df["Δ_total"] = [
        fmt_pm(row["Δ_total_mean"], row["Δ_total_std"], 4) for _, row in summary_df.iterrows()
    ]
    formatted_summary_df["Δ_normalized"] = [
        fmt_pm(row["Δ_normalized_mean"], row["Δ_normalized_std"], 6) for _, row in summary_df.iterrows()
    ]
    formatted_summary_df["Δ̄ (per triangle)"] = [
        fmt_pm(row["Δ_per_triangle_mean"], row["Δ_per_triangle_std"], 6) for _, row in summary_df.iterrows()
    ]
    formatted_summary_df["Δ_relative (to original)"] = [
        fmt_pm(row["Δ_relative_mean"], row["Δ_relative_std"], 6) for _, row in summary_df.iterrows()
    ]
    formatted_summary_df = formatted_summary_df[
        ["% Missing", "RMSE", "Δ_total", "Δ_normalized", "Δ̄ (per triangle)", "Δ_relative (to original)"]
    ]

    print("\n=== Hyb-Adam-UM formatted summary ===")
    try:
        from IPython.display import display
        display(formatted_summary_df)
    except Exception:
        print(formatted_summary_df.to_string(index=False))

    formatted_summary_df.to_csv("hyb_adam_um_summary_formatted.csv", index=False)
    pd.concat({
        "mean": summary_df[["RMSE_mean", "Δ_total_mean", "Δ_normalized_mean", "Δ_per_triangle_mean", "Δ_relative_mean"]],
        "std":  summary_df[["RMSE_std",  "Δ_total_std",  "Δ_normalized_std",  "Δ_per_triangle_std",  "Δ_relative_std"]]
    }, axis=1).to_csv("hyb_adam_um_summary_numeric.csv")

    # training-style delta table (final matrices)
    print(f"\n=== Delta Sum Table (training-style; COMPLETED matrices) for Hyb-Adam-UM ===")
    print(f"Original Δ_total = {Δ_original:.6f}")
    print(f"Original Δ_normalized = {Δ_normalized_original:.6f}")
    print(f"Number of triangles = {ntri}")
    print(f"Original Δ_per_triangle = {Δ_original/ntri:.6f}")

    delta_table_rows = [
        {"% Missing": "Original",
         "Hyb-Adam-UM Δ_total": f"{Δ_original:.4f}",
         "Δ_normalized": f"{Δ_normalized_original:.6f}",
         "Δ̄ (per triangle)": f"{Δ_original/ntri:.6f}",
         "Δ_relative (to original)": "1.0000"}
    ]

    for frac in MISSING_FRACS:
        pct = int(round(100 * frac))
        mask_results = results_df[results_df["pct_missing"] == pct]

        dt_vals  = mask_results["Δ_total"].astype(float).to_numpy()
        dn_vals  = mask_results["Δ_normalized"].astype(float).to_numpy()
        dpt_vals = mask_results["Δ_per_triangle"].astype(float).to_numpy()
        dr_vals  = mask_results["Δ_relative"].astype(float).to_numpy()

        dt_m, dt_s   = mean_std(dt_vals)
        dn_m, dn_s   = mean_std(dn_vals)
        dpt_m, dpt_s = mean_std(dpt_vals)
        dr_m, dr_s   = mean_std(dr_vals)

        delta_table_rows.append({
            "% Missing": f"{pct}%",
            "Hyb-Adam-UM Δ_total": f"{dt_m:.4f} ± {dt_s:.4f}",
            "Δ_normalized": f"{dn_m:.6f} ± {dn_s:.6f}",
            "Δ̄ (per triangle)": f"{dpt_m:.6f} ± {dpt_s:.6f}",
            "Δ_relative (to original)": f"{dr_m:.4f} ± {dr_s:.4f}"
        })

    delta_table_df = pd.DataFrame(delta_table_rows)
    try:
        from IPython.display import display
        display(delta_table_df)
    except Exception:
        print(delta_table_df.to_string(index=False))

    delta_table_df.to_csv("hyb_adam_um_delta_sum_training_style.csv", index=False)

    print("\n=== ALL PROCESSING COMPLETE! ===")
    print("Output files:")
    print(f"  - Masked matrices saved to: {MASKED_DIR}/")
    print(f"  - Completed matrices saved to: {COMPLETED_DIR}/")
    print("  - Detailed results: hyb_adam_um_all_masks_detailed.csv")
    print("  - Formatted summary: hyb_adam_um_summary_formatted.csv")
    print("  - Numeric summary: hyb_adam_um_summary_numeric.csv")
    print("  - Training-style delta table: hyb_adam_um_delta_sum_training_style.csv")
    print("\nNote: Δ_normalized uses log scaling: δ_norm = log(1+δ)/log(101), then clipped to [0,1].")


if __name__ == "__main__":
    main()


=== Calculating Δ for original matrix ===
Original matrix file: Result_NW_15x15.txt
Number of triplets (n choose 3) = 455
Original matrix Δ_total = 105.474106
Original Δ_normalized = 0.037533
Original Δ_per_triangle (Δ/ntri) = 0.231811

=== Running Hyb-Adam-UM on all masks ===

  Processing 30% missing, replicate 1...
Initial robust Δ: 183.917145
Epoch     1 | full Δ = 327.061732 | lr=0.04000
Epoch 00700: lr milestone 0.040000->0.020000
Epoch  1000 | full Δ = 79.330071 | lr=0.02000
Epoch 02000: lr milestone 0.020000->0.010000
Epoch  2000 | full Δ = 78.471793 | lr=0.01000
Epoch  3000 | full Δ = 74.334304 | lr=0.01000
    End-to-end runtime: 12.34 sec (0.21 min)

  Processing 30% missing, replicate 2...
Initial robust Δ: 295.173674
Epoch     1 | full Δ = 415.442741 | lr=0.04000
Epoch 00700: lr milestone 0.040000->0.020000
Epoch  1000 | full Δ = 86.874481 | lr=0.02000
Epoch 02000: lr milestone 0.020000->0.010000
Epoch  2000 | full Δ = 83.946653 | lr=0.01000
Epoch  3000 | full Δ = 80.45949

Unnamed: 0,pct_missing,replicate,RMSE_vs_ORIG,Δ_total,Δ_normalized,Δ_per_triangle,Δ_relative,runtime_seconds,completed_file,train_info
0,30,1,0.011035,73.220974,0.028634,0.160925,0.694208,,HybAdamUM_completed_p30_rep1.csv,"{'skipped': False, 'epochs': 3000, 'final_lr':..."
1,30,2,0.013289,78.935976,0.029021,0.173486,0.748392,,HybAdamUM_completed_p30_rep2.csv,"{'skipped': False, 'epochs': 3000, 'final_lr':..."
2,30,3,0.004644,65.305514,0.02605,0.143529,0.619162,,HybAdamUM_completed_p30_rep3.csv,"{'skipped': False, 'epochs': 3000, 'final_lr':..."
3,30,4,0.006316,86.801353,0.031384,0.190772,0.822964,,HybAdamUM_completed_p30_rep4.csv,"{'skipped': False, 'epochs': 3000, 'final_lr':..."
4,30,5,0.002205,94.290873,0.033872,0.207233,0.893972,,HybAdamUM_completed_p30_rep5.csv,"{'skipped': False, 'epochs': 3000, 'final_lr':..."
5,50,1,0.013534,62.2195,0.024175,0.136746,0.589903,19.745462,HybAdamUM_completed_p50_rep1.csv,"{'skipped': False, 'epochs': 3000, 'final_lr':..."
6,50,2,0.014419,53.947371,0.020481,0.118566,0.511475,19.835471,HybAdamUM_completed_p50_rep2.csv,"{'skipped': False, 'epochs': 3000, 'final_lr':..."
7,50,3,0.008292,49.147614,0.02006,0.108017,0.465969,19.788318,HybAdamUM_completed_p50_rep3.csv,"{'skipped': False, 'epochs': 3000, 'final_lr':..."
8,50,4,0.014044,46.24362,0.018767,0.101634,0.438436,19.62431,HybAdamUM_completed_p50_rep4.csv,"{'skipped': False, 'epochs': 3000, 'final_lr':..."
9,50,5,0.004616,61.339516,0.024655,0.134812,0.58156,19.686655,HybAdamUM_completed_p50_rep5.csv,"{'skipped': False, 'epochs': 3000, 'final_lr':..."



=== Hyb-Adam-UM summary (mean ± std over replicates) ===


Unnamed: 0,% Missing,RMSE_mean,RMSE_std,Δ_total_mean,Δ_total_std,Δ_normalized_mean,Δ_normalized_std,Δ_per_triangle_mean,Δ_per_triangle_std,Δ_relative_mean,Δ_relative_std
0,30%,0.007498,0.004572,79.710938,11.325689,0.029792,0.002963,0.175189,0.024892,0.755739,0.107379
1,50%,0.010981,0.004344,54.579524,7.131958,0.021628,0.002627,0.119955,0.015675,0.517468,0.067618
2,65%,0.019182,0.011137,46.558869,26.477558,0.017669,0.007561,0.102327,0.058192,0.441425,0.251034
3,85%,0.038451,0.004286,10.392753,3.186969,0.004444,0.001173,0.022841,0.007004,0.098534,0.030216
4,90%,0.052576,0.008726,341.764896,159.181034,0.069814,0.033886,0.751132,0.349848,3.240273,1.509195



=== Hyb-Adam-UM formatted summary ===


Unnamed: 0,% Missing,RMSE,Δ_total,Δ_normalized,Δ̄ (per triangle),Δ_relative (to original)
0,30%,0.007498 ± 0.004572,79.7109 ± 11.3257,0.029792 ± 0.002963,0.175189 ± 0.024892,0.755739 ± 0.107379
1,50%,0.010981 ± 0.004344,54.5795 ± 7.1320,0.021628 ± 0.002627,0.119955 ± 0.015675,0.517468 ± 0.067618
2,65%,0.019182 ± 0.011137,46.5589 ± 26.4776,0.017669 ± 0.007561,0.102327 ± 0.058192,0.441425 ± 0.251034
3,85%,0.038451 ± 0.004286,10.3928 ± 3.1870,0.004444 ± 0.001173,0.022841 ± 0.007004,0.098534 ± 0.030216
4,90%,0.052576 ± 0.008726,341.7649 ± 159.1810,0.069814 ± 0.033886,0.751132 ± 0.349848,3.240273 ± 1.509195



=== Delta Sum Table (training-style; COMPLETED matrices) for Hyb-Adam-UM ===
Original Δ_total = 105.474106
Original Δ_normalized = 0.037533
Number of triangles = 455
Original Δ_per_triangle = 0.231811


Unnamed: 0,% Missing,Hyb-Adam-UM Δ_total,Δ_normalized,Δ̄ (per triangle),Δ_relative (to original)
0,Original,105.4741,0.037533,0.231811,1.0000
1,30%,79.7109 ± 11.3257,0.029792 ± 0.002963,0.175189 ± 0.024892,0.7557 ± 0.1074
2,50%,54.5795 ± 7.1320,0.021628 ± 0.002627,0.119955 ± 0.015675,0.5175 ± 0.0676
3,65%,46.5589 ± 26.4776,0.017669 ± 0.007561,0.102327 ± 0.058192,0.4414 ± 0.2510
4,85%,10.3928 ± 3.1870,0.004444 ± 0.001173,0.022841 ± 0.007004,0.0985 ± 0.0302
5,90%,341.7649 ± 159.1810,0.069814 ± 0.033886,0.751132 ± 0.349848,3.2403 ± 1.5092



=== ALL PROCESSING COMPLETE! ===
Output files:
  - Masked matrices saved to: hyb_adam_um_missing_matrices/
  - Completed matrices saved to: hyb_adam_um_completed_matrices/
  - Detailed results: hyb_adam_um_all_masks_detailed.csv
  - Formatted summary: hyb_adam_um_summary_formatted.csv
  - Numeric summary: hyb_adam_um_summary_numeric.csv
  - Training-style delta table: hyb_adam_um_delta_sum_training_style.csv

Note: Δ_normalized uses log scaling: δ_norm = log(1+δ)/log(101), then clipped to [0,1].


In [2]:
# Cell — PICTURES OF TREES + HEATMAPS (ALL masks) for Hyb-Adam-UM (NO TreePrior)
#
# Compatible with the "Hyb-Adam-UM only — all missingness × reps" script above.
#
# What this does:
#   • Auto-finds ALL completed matrices produced by that script:
#       HybAdamUM_completed_p{30,50,65,85}_rep{1..5}.csv   (default 20 files)
#   • Loads ORIGINAL matrix
#   • Builds NJ tree for Original (once)
#   • Builds NJ trees for ALL completed matrices
#   • Saves:
#       - tree PNGs + Newicks (Original + all completed)
#       - benchmark CSVs (tree + matrix), detailed + mean±std by missingness
#       - heatmap grid PNG (Original + all completed; auto grid size)
#
# Output folder (single directory, minimal files by default):
#   trees_hyb_adam_um_all/
#     tree_Original.png
#     tree_Original.newick
#     tree_p30_rep1.png, tree_p30_rep1.newick, ... (all found)
#     benchmark_tree_all.csv
#     benchmark_matrix_all.csv
#     benchmark_tree_by_missingness_meanstd.csv
#     benchmark_matrix_by_missingness_meanstd.csv
#     heatmap_grid_original_plus_all.png
#
# Notes:
#   • Heatmap = plt.imshow(matrix), saved as PNG.
#   • Grid size is automatic (works for 16, 20, etc).
#   • This NJ implementation is a simple nonnegative NJ for visualization/benchmarking.
#     It is NOT NJ* for incomplete matrices (we are reading COMPLETED matrices).
#
# ------------------------------------------------------------

import os
import re
import math
import warnings
from dataclasses import dataclass
from typing import Dict, List, Set, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")

# ============================================================
# -------------------------- CONFIG ---------------------------
# ============================================================

COMPLETED_DIR = "hyb_adam_um_completed_matrices"   # from the code above
OUT_DIR       = "trees_hyb_adam_um_all"

# Original matrix candidates (edit if needed)
ORIG_CANDIDATES = [
    "Result_NW_15x15.txt",
    "./Result_NW_15x15.txt",
    "/mnt/data/Result_NW_15x15.txt",
    "Result_NW_100x100.txt",
    "./Result_NW_100x100.txt",
    "/mnt/data/Result_NW_100x100.txt",
    "ta41_orig.txt",
    "./ta41_orig.txt",
    "/mnt/data/ta41_orig.txt",
]

# Heatmap options (keep minimal by default)
SAVE_HEATMAP_GRID          = True
SAVE_INDIVIDUAL_HEATMAPS   = False   # if True: saves (1 + N) PNGs
SAVE_PATRISTIC_HEATMAPS    = False   # if True: saves patristic grid too
SHOW_PLOTS_INLINE          = False

HEATMAP_CMAP = "viridis"
HEATMAP_DPI  = 220
TREE_DPI     = 180

# If your completed folder contains extra files, you can limit to first K
MAX_COMPLETED_TO_PROCESS = None   # e.g. 15, 20, or None for all found

# ============================================================
# ----------------------- IO HELPERS --------------------------
# ============================================================

def load_matrix_with_candidates(cands: List[str]) -> Tuple[np.ndarray, str]:
    for p in cands:
        if os.path.exists(p):
            M = np.loadtxt(p)
            # heuristic: if distances are in ~1000s, rescale to ~units
            if np.nanmax(M) > 500:
                M = M / 1000.0
            return M, p
    raise FileNotFoundError(f"File not found. Tried: {cands}")

def load_completed_csv(path_csv: str) -> Tuple[np.ndarray, List[str]]:
    if not os.path.exists(path_csv):
        raise FileNotFoundError(path_csv)
    df = pd.read_csv(path_csv, index_col=0)
    labels = list(df.columns)
    return df.values.astype(float), labels

# ============================================================
# ----------------------- SANITIZATION ------------------------
# ============================================================

def _finite_fill(v, fallback: float = 1.0) -> float:
    v = np.asarray(v, dtype=float)
    vv = v[np.isfinite(v)]
    if vv.size > 0:
        return float(np.nanmedian(vv))
    return float(fallback)

def sanitize_distance_matrix(D: np.ndarray, name: str = "D", force_nonneg: bool = True) -> np.ndarray:
    M = np.array(D, dtype=float)
    if M.ndim != 2 or M.shape[0] != M.shape[1]:
        raise ValueError(f"{name} must be square. Got {M.shape}.")
    n = M.shape[0]

    neg = (M < 0)
    np.fill_diagonal(neg, False)
    M[neg] = np.nan

    off = ~np.eye(n, dtype=bool)
    med = _finite_fill(M[off], fallback=1.0)
    M = np.nan_to_num(M, nan=med, posinf=med, neginf=med)

    try:
        q = np.quantile(M[off], 0.995)
    except Exception:
        q = np.nan
    if np.isfinite(q) and q > 0:
        M[off] = np.minimum(M[off], q)

    M = 0.5 * (M + M.T)
    if force_nonneg:
        M = np.maximum(M, 0.0)
    np.fill_diagonal(M, 0.0)

    if not np.isfinite(M).all():
        raise ValueError(f"{name} has non-finite entries after sanitize.")
    return M

# ============================================================
# -------------------------- NJ TREE --------------------------
# ============================================================

@dataclass
class NJTree:
    newick: str
    patristic: np.ndarray
    splits: Set[frozenset]
    adj: Dict[int, Dict[int, float]]
    root: int

def neighbor_joining_nonneg(D_full: np.ndarray, labels: List[str]) -> NJTree:
    D = sanitize_distance_matrix(D_full, "NJ_input", force_nonneg=True)
    n = len(labels)
    if D.shape != (n, n):
        raise ValueError(f"D shape {D.shape} != (n,n) with n={n}")

    adj: Dict[int, Dict[int, float]] = {}

    def add_edge(u: int, v: int, w: float):
        w = float(max(w, 1e-9))
        adj.setdefault(u, {})
        adj.setdefault(v, {})
        adj[u][v] = w
        adj[v][u] = w

    next_id = n
    idx2node = {i: i for i in range(n)}
    act = list(range(n))
    Dv = D.copy()

    while len(act) > 2:
        m = len(act)
        r = np.sum(Dv, axis=1)
        r = np.nan_to_num(r, nan=_finite_fill(r, 1.0), posinf=_finite_fill(r, 1.0), neginf=_finite_fill(r, 1.0))

        Q = (m - 2) * Dv - r[:, None] - r[None, :]
        Q = np.nan_to_num(Q, nan=np.inf, posinf=np.inf, neginf=np.inf)
        np.fill_diagonal(Q, np.inf)

        if not np.isfinite(Q).any():
            Q = Dv.copy()
            np.fill_diagonal(Q, np.inf)

        a_idx, b_idx = np.unravel_index(np.argmin(Q), Q.shape)
        if a_idx > b_idx:
            a_idx, b_idx = b_idx, a_idx

        i, j = act[a_idx], act[b_idx]
        dij = float(Dv[a_idx, b_idx])

        li = 0.5 * dij + (r[a_idx] - r[b_idx]) / (2 * (m - 2))
        lj = dij - li
        if not np.isfinite(li): li = 0.5 * dij
        if not np.isfinite(lj): lj = 0.5 * dij

        u = next_id
        next_id += 1

        add_edge(u, idx2node[i], li)
        add_edge(u, idx2node[j], lj)

        duk = {}
        for k in range(m):
            if k in (a_idx, b_idx):
                continue
            val = 0.5 * (Dv[a_idx, k] + Dv[b_idx, k] - dij)
            if not np.isfinite(val):
                val = _finite_fill([Dv[a_idx, k], Dv[b_idx, k], dij], fallback=0.0)
            duk[k] = float(val)

        mask = np.ones(m, dtype=bool)
        mask[b_idx] = False
        Dv = Dv[np.ix_(mask, mask)]

        new_a = a_idx
        for t in range(m - 1):
            if t == new_a:
                Dv[new_a, t] = 0.0
                continue
            old_t = t if t < b_idx else t + 1
            Dv[new_a, t] = Dv[t, new_a] = duk.get(old_t, 0.0)

        Dv = np.maximum(0.5 * (Dv + Dv.T), 0.0)
        np.fill_diagonal(Dv, 0.0)

        idx2node[i] = u
        act.pop(b_idx)

    i, j = act[0], act[1]
    add_edge(idx2node[i], idx2node[j], float(Dv[0, 1]))
    root = idx2node[i]

    def to_newick(x: int, parent: int = -1) -> str:
        if x < n:
            return labels[x]
        parts = []
        for v, w in adj.get(x, {}).items():
            if v == parent:
                continue
            parts.append(f"{to_newick(v, x)}:{w:.6f}")
        return "(" + ",".join(parts) + ")"

    newick = to_newick(root) + ";"

    def path_len(a: int, b: int) -> float:
        stack = [(a, -1, 0.0)]
        seen = set()
        while stack:
            x, p, acc = stack.pop()
            if x == b:
                return acc
            seen.add(x)
            for y, w in adj.get(x, {}).items():
                if y == p or y in seen:
                    continue
                stack.append((y, x, acc + w))
        return np.nan

    P = np.zeros((n, n), dtype=float)
    for a in range(n):
        for b in range(a + 1, n):
            d = path_len(a, b)
            P[a, b] = P[b, a] = d

    def compute_splits() -> Set[frozenset]:
        splits = set()
        seen_edges = set()
        for u in adj:
            for v in adj[u]:
                if (v, u) in seen_edges:
                    continue
                seen_edges.add((u, v))

                stack = [u]
                blocked = v
                visited = set([blocked])
                leafset = set()

                while stack:
                    x = stack.pop()
                    if x in visited:
                        continue
                    visited.add(x)
                    if x < n:
                        leafset.add(labels[x])
                    for y in adj.get(x, {}):
                        if y not in visited:
                            stack.append(y)

                if 1 < len(leafset) < n - 1:
                    splits.add(frozenset(sorted(leafset)))
        return splits

    return NJTree(newick=newick, patristic=P, splits=compute_splits(), adj=adj, root=root)

# ============================================================
# ---------------------- TREE DRAWING -------------------------
# ============================================================

def draw_nj_tree(nj: NJTree, labels: List[str], title: str, out_path: str, dpi: int = 160):
    n = len(labels)
    x_pos = {i: i for i in range(n)}
    y_pos: Dict[int, float] = {}

    def dfs(u: int, p: int = -1, y: float = 0.0):
        y_pos[u] = -y
        for v, w in nj.adj.get(u, {}).items():
            if v == p:
                continue
            dfs(v, u, y + w)

    dfs(nj.root, -1, 0.0)

    def leaf_span(node: int) -> float:
        if node < n:
            return float(x_pos[node])
        seen = set()
        stack = [node]
        leaves = []
        while stack:
            x = stack.pop()
            if x in seen:
                continue
            seen.add(x)
            if x < n:
                leaves.append(x_pos[x])
            else:
                for y in nj.adj.get(x, {}):
                    if y not in seen:
                        stack.append(y)
        return float(np.mean(leaves)) if leaves else 0.0

    plt.figure(figsize=(9, 4.2))
    drawn = set()
    for u in nj.adj:
        for v, w in nj.adj[u].items():
            if (v, u) in drawn:
                continue
            drawn.add((u, v))
            plt.plot([leaf_span(u), leaf_span(v)], [y_pos.get(u, 0.0), y_pos.get(v, 0.0)], linewidth=1.4)

    for i in range(n):
        plt.text(leaf_span(i), y_pos.get(i, 0.0), labels[i], ha="center", va="bottom", fontsize=8)

    plt.title(title)
    plt.xlabel("leaves (order = label order)")
    plt.ylabel("− branch length")
    plt.tight_layout()

    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.savefig(out_path, dpi=dpi, bbox_inches="tight")
    if SHOW_PLOTS_INLINE:
        plt.show()
    plt.close()

# ============================================================
# ------------------------ METRICS ----------------------------
# ============================================================

def rf_distance(t1: NJTree, t2: NJTree) -> int:
    return len(t1.splits - t2.splits) + len(t2.splits - t1.splits)

def rf_normalized(t1: NJTree, t2: NJTree, n_leaves: int) -> float:
    if n_leaves < 4:
        return 0.0
    rf = rf_distance(t1, t2)
    rf_max = 2 * (n_leaves - 3)
    return float(rf) / float(rf_max) if rf_max > 0 else 0.0

def _upper_vec(M: np.ndarray) -> np.ndarray:
    M = np.asarray(M, dtype=float)
    iu = np.triu_indices(M.shape[0], k=1)
    v = M[iu]
    v = v[np.isfinite(v)]
    return v

def _pearson(a: np.ndarray, b: np.ndarray) -> float:
    a = np.asarray(a, float)
    b = np.asarray(b, float)
    if a.size == 0 or b.size == 0:
        return float("nan")
    if np.std(a) == 0 or np.std(b) == 0:
        return float("nan")
    return float(np.corrcoef(a, b)[0, 1])

def _spearman(a: np.ndarray, b: np.ndarray) -> float:
    a = np.asarray(a, float)
    b = np.asarray(b, float)
    if a.size == 0 or b.size == 0:
        return float("nan")
    ra = pd.Series(a).rank(method="average").to_numpy()
    rb = pd.Series(b).rank(method="average").to_numpy()
    return _pearson(ra, rb)

def patristic_metrics(t: NJTree, tref: NJTree) -> Dict[str, float]:
    v = _upper_vec(t.patristic)
    vr = _upper_vec(tref.patristic)
    m = min(v.size, vr.size)
    if m == 0:
        return {"pat_MAE": np.nan, "pat_RMSE": np.nan, "pat_Pearson": np.nan, "pat_Spearman": np.nan}
    v = v[:m]; vr = vr[:m]
    diff = v - vr
    return {
        "pat_MAE": float(np.mean(np.abs(diff))),
        "pat_RMSE": float(np.sqrt(np.mean(diff * diff))),
        "pat_Pearson": _pearson(v, vr),
        "pat_Spearman": _spearman(v, vr),
    }

def matrix_metrics(D: np.ndarray, Dref: np.ndarray) -> Dict[str, float]:
    v = _upper_vec(D)
    vr = _upper_vec(Dref)
    m = min(v.size, vr.size)
    if m == 0:
        return {"mat_MAE": np.nan, "mat_RMSE": np.nan, "mat_Pearson": np.nan, "mat_Spearman": np.nan}
    v = v[:m]; vr = vr[:m]
    diff = v - vr
    return {
        "mat_MAE": float(np.mean(np.abs(diff))),
        "mat_RMSE": float(np.sqrt(np.mean(diff * diff))),
        "mat_Pearson": _pearson(v, vr),
        "mat_Spearman": _spearman(v, vr),
    }

def fmt_pm(mean_val: float, std_val: float, decimals: int = 6) -> str:
    return f"{mean_val:.{decimals}f} ± {std_val:.{decimals}f}"

def mean_std_by_group(df: pd.DataFrame, group_col: str, metrics: List[str]) -> pd.DataFrame:
    rows = []
    for g, sub in df.groupby(group_col):
        row = {group_col: g}
        for m in metrics:
            vals = sub[m].astype(float).to_numpy()
            mu = float(np.mean(vals))
            sd = float(np.std(vals, ddof=1)) if len(vals) > 1 else 0.0
            row[m] = fmt_pm(mu, sd, decimals=6)
            row[m + "_mean"] = mu
            row[m + "_std"]  = sd
        rows.append(row)
    out = pd.DataFrame(rows).sort_values(group_col)
    return out

# ============================================================
# ------------------------- HEATMAPS --------------------------
# ============================================================

def plot_heatmap(M: np.ndarray, labels: List[str], title: str, out_png: str):
    plt.figure(figsize=(8.2, 6.6))
    im = plt.imshow(M, interpolation="nearest", cmap=HEATMAP_CMAP)
    plt.title(title, fontsize=12, pad=12)
    plt.xticks(range(len(labels)), labels, rotation=90, fontsize=7)
    plt.yticks(range(len(labels)), labels, fontsize=7)
    plt.colorbar(im, shrink=0.8)
    plt.tight_layout()
    plt.savefig(out_png, dpi=HEATMAP_DPI, bbox_inches="tight")
    if SHOW_PLOTS_INLINE:
        plt.show()
    plt.close()

def plot_heatmap_grid(mats: List[np.ndarray], titles: List[str], labels: List[str], out_png: str):
    assert len(mats) == len(titles)
    k = len(mats)
    ncol = int(math.ceil(math.sqrt(k)))
    nrow = int(math.ceil(k / ncol))
    fig, axes = plt.subplots(nrow, ncol, figsize=(5.3*ncol, 4.6*nrow))
    axes = np.atleast_1d(axes).ravel()

    im = None
    for idx in range(nrow * ncol):
        ax = axes[idx]
        if idx < k:
            im = ax.imshow(mats[idx], interpolation="nearest", cmap=HEATMAP_CMAP)
            ax.set_title(titles[idx], fontsize=9)
            ax.set_xticks(range(len(labels)))
            ax.set_xticklabels(labels, rotation=90, fontsize=5)
            ax.set_yticks(range(len(labels)))
            ax.set_yticklabels(labels, fontsize=5)
        else:
            ax.axis("off")

    fig.tight_layout()
    # colorbar on the right
    cbar_ax = fig.add_axes([0.92, 0.12, 0.015, 0.76])
    fig.colorbar(im, cax=cbar_ax, label="Distance")
    plt.savefig(out_png, dpi=HEATMAP_DPI, bbox_inches="tight")
    if SHOW_PLOTS_INLINE:
        plt.show()
    plt.close()

# ============================================================
# -------------------------- MAIN -----------------------------
# ============================================================

# 1) Discover ALL completed matrices for Hyb-Adam-UM (no TreePrior)
if not os.path.exists(COMPLETED_DIR):
    raise FileNotFoundError(f"Completed directory not found: {COMPLETED_DIR}")

# Matches filenames created by the Hyb-Adam-UM-only code above:
#   HybAdamUM_completed_p{pct}_rep{rep}.csv
pat = re.compile(r"^HybAdamUM_completed_p(\d+)_rep(\d+)\.csv$")
found = []
for fn in sorted(os.listdir(COMPLETED_DIR)):
    m = pat.match(fn)
    if m:
        pct = int(m.group(1))
        rep = int(m.group(2))
        found.append((pct, rep, fn))

if len(found) == 0:
    raise RuntimeError(
        f"No files matched pattern in {COMPLETED_DIR}: HybAdamUM_completed_p##_rep#.csv"
    )

found = sorted(found, key=lambda x: (x[0], x[1]))
if MAX_COMPLETED_TO_PROCESS is not None:
    found = found[:int(MAX_COMPLETED_TO_PROCESS)]

print(f"Found {len(found)} completed matrices in: {os.path.abspath(COMPLETED_DIR)}")
for pct, rep, fn in found[:10]:
    print(f"  p{pct} rep{rep}: {fn}")
if len(found) > 10:
    print("  ...")

os.makedirs(OUT_DIR, exist_ok=True)
print(f"\nOutput directory: {os.path.abspath(OUT_DIR)}")

# 2) Load ORIGINAL matrix
print("\nLoading original matrix...")
D0, D0_path = load_matrix_with_candidates(ORIG_CANDIDATES)
D0 = sanitize_distance_matrix(D0, "D_orig", force_nonneg=True)
print(f"  Original matrix loaded: {D0_path}")
print(f"  Original matrix shape:  {D0.shape}")

# 3) Load first completed matrix to get labels (authoritative)
first_path = os.path.join(COMPLETED_DIR, found[0][2])
D_first, labels = load_completed_csv(first_path)
if D_first.shape != D0.shape:
    raise ValueError(
        f"Original shape {D0.shape} != completed shape {D_first.shape}. "
        f"Fix ORIG_CANDIDATES or use matching original."
    )

# 4) Build Original NJ tree once
print("\nBuilding NJ tree for original matrix...")
tree_orig = neighbor_joining_nonneg(D0, labels)
draw_nj_tree(tree_orig, labels, "NJ (Original full matrix)", os.path.join(OUT_DIR, "tree_Original.png"), dpi=TREE_DPI)
with open(os.path.join(OUT_DIR, "tree_Original.newick"), "w") as f:
    f.write(tree_orig.newick + "\n")
print("  Saved: tree_Original.png, tree_Original.newick")

# 5) Process all completed matrices
tree_rows = []
mat_rows  = []

heat_mats   = [D0]
heat_titles = ["Original"]

if SAVE_PATRISTIC_HEATMAPS:
    pat_mats   = [tree_orig.patristic]
    pat_titles = ["Patristic (Original)"]

for pct, rep, fn in found:
    path = os.path.join(COMPLETED_DIR, fn)
    print(f"\nProcessing: {fn} (p{pct}, rep{rep})")

    D_comp, labels2 = load_completed_csv(path)
    labels_use = labels2 if labels2 != labels else labels

    if D_comp.shape != D0.shape:
        print(f"  Skip: shape mismatch {D_comp.shape} vs original {D0.shape}")
        continue

    D_comp = sanitize_distance_matrix(D_comp, f"D_comp_p{pct}_rep{rep}", force_nonneg=True)

    # build NJ tree
    tree_comp = neighbor_joining_nonneg(D_comp, labels_use)

    # draw + save tree
    tree_png = os.path.join(OUT_DIR, f"tree_p{pct}_rep{rep}.png")
    draw_nj_tree(tree_comp, labels_use, f"NJ (Hyb-Adam-UM — {pct}% missing, rep{rep})", tree_png, dpi=TREE_DPI)

    # save newick
    newick_path = os.path.join(OUT_DIR, f"tree_p{pct}_rep{rep}.newick")
    with open(newick_path, "w") as f:
        f.write(tree_comp.newick + "\n")

    # benchmarks (tree-level)
    tr = {
        "pct_missing": pct,
        "replicate": rep,
        "file": fn,
        "RF": int(rf_distance(tree_comp, tree_orig)),
        "RF_norm": float(rf_normalized(tree_comp, tree_orig, len(labels_use))),
        "n_splits": int(len(tree_comp.splits)),
    }
    tr.update(patristic_metrics(tree_comp, tree_orig))
    tree_rows.append(tr)

    # benchmarks (matrix-level)
    mr = {"pct_missing": pct, "replicate": rep, "file": fn}
    mr.update(matrix_metrics(D_comp, D0))
    mat_rows.append(mr)

    print(f"  Done: RF_norm = {tr['RF_norm']:.6f} | saved -> {os.path.basename(tree_png)}")

    # collect for heatmap grid
    heat_mats.append(D_comp)
    heat_titles.append(f"p{pct} r{rep} (RF={tr['RF_norm']:.3f})")

    if SAVE_PATRISTIC_HEATMAPS:
        pat_mats.append(tree_comp.patristic)
        pat_titles.append(f"Patristic p{pct} r{rep}")

    # optional individual heatmaps
    if SAVE_INDIVIDUAL_HEATMAPS:
        plot_heatmap(
            D_comp, labels_use,
            f"Hyb-Adam-UM completed (p{pct}, rep{rep})",
            os.path.join(OUT_DIR, f"heatmap_completed_p{pct}_rep{rep}.png")
        )

# 6) Save detailed tables
df_tree = pd.DataFrame(tree_rows).sort_values(["pct_missing", "replicate"])
df_mat  = pd.DataFrame(mat_rows).sort_values(["pct_missing", "replicate"])

tree_csv = os.path.join(OUT_DIR, "benchmark_tree_all.csv")
mat_csv  = os.path.join(OUT_DIR, "benchmark_matrix_all.csv")
df_tree.to_csv(tree_csv, index=False)
df_mat.to_csv(mat_csv, index=False)

print("\n" + "="*70)
print("Detailed Tree Benchmark Results (ALL)")
print("="*70)
with pd.option_context("display.max_rows", 300, "display.max_columns", 300, "display.width", 240):
    print(df_tree.to_string(index=False))

print("\n" + "="*70)
print("Detailed Matrix Benchmark Results (ALL)")
print("="*70)
with pd.option_context("display.max_rows", 300, "display.max_columns", 300, "display.width", 240):
    print(df_mat.to_string(index=False))

# 7) Mean ± SD by missingness
tree_metrics_list = ["RF", "RF_norm", "pat_MAE", "pat_RMSE", "pat_Pearson", "pat_Spearman"]
mat_metrics_list  = ["mat_MAE", "mat_RMSE", "mat_Pearson", "mat_Spearman"]

df_tree_ms = mean_std_by_group(df_tree, "pct_missing", tree_metrics_list)
df_mat_ms  = mean_std_by_group(df_mat,  "pct_missing", mat_metrics_list)

tree_ms_csv = os.path.join(OUT_DIR, "benchmark_tree_by_missingness_meanstd.csv")
mat_ms_csv  = os.path.join(OUT_DIR, "benchmark_matrix_by_missingness_meanstd.csv")
df_tree_ms.to_csv(tree_ms_csv, index=False)
df_mat_ms.to_csv(mat_ms_csv, index=False)

print("\n" + "="*70)
print("Tree metrics: mean ± SD by missingness")
print("="*70)
print(df_tree_ms[["pct_missing"] + tree_metrics_list].to_string(index=False))

print("\n" + "="*70)
print("Matrix metrics: mean ± SD by missingness")
print("="*70)
print(df_mat_ms[["pct_missing"] + mat_metrics_list].to_string(index=False))

# 8) Heatmaps
if SAVE_HEATMAP_GRID:
    out_grid = os.path.join(OUT_DIR, "heatmap_grid_original_plus_all.png")
    plot_heatmap_grid(heat_mats, heat_titles, labels, out_grid)
    print(f"\nSaved heatmap grid: {os.path.abspath(out_grid)}")

if SAVE_PATRISTIC_HEATMAPS:
    out_grid_pat = os.path.join(OUT_DIR, "heatmap_grid_patristic_original_plus_all.png")
    plot_heatmap_grid(pat_mats, pat_titles, labels, out_grid_pat)
    print(f"Saved patristic heatmap grid: {os.path.abspath(out_grid_pat)}")

print("\n" + "="*70)
print("ALL PROCESSING COMPLETE!")
print("="*70)
print(f"Output directory: {os.path.abspath(OUT_DIR)}")
print("\nSaved:")
print(f"  - Trees: 1 Original + {len(df_tree)} completed PNGs + Newicks")
print(f"  - Detailed CSVs: {os.path.basename(tree_csv)}, {os.path.basename(mat_csv)}")
print(f"  - Mean±SD CSVs:  {os.path.basename(tree_ms_csv)}, {os.path.basename(mat_ms_csv)}")
if SAVE_HEATMAP_GRID:
    print(f"  - Heatmap grid: heatmap_grid_original_plus_all.png")
if SAVE_INDIVIDUAL_HEATMAPS:
    print(f"  - Individual heatmaps: {len(df_mat)} PNGs (plus original if you add it)")


Found 25 completed matrices in: /home/user/bioinformatics/try match 2/Hyb-Adam-UM/hyb_adam_um_completed_matrices
  p30 rep1: HybAdamUM_completed_p30_rep1.csv
  p30 rep2: HybAdamUM_completed_p30_rep2.csv
  p30 rep3: HybAdamUM_completed_p30_rep3.csv
  p30 rep4: HybAdamUM_completed_p30_rep4.csv
  p30 rep5: HybAdamUM_completed_p30_rep5.csv
  p50 rep1: HybAdamUM_completed_p50_rep1.csv
  p50 rep2: HybAdamUM_completed_p50_rep2.csv
  p50 rep3: HybAdamUM_completed_p50_rep3.csv
  p50 rep4: HybAdamUM_completed_p50_rep4.csv
  p50 rep5: HybAdamUM_completed_p50_rep5.csv
  ...

Output directory: /home/user/bioinformatics/try match 2/Hyb-Adam-UM/trees_hyb_adam_um_all

Loading original matrix...
  Original matrix loaded: Result_NW_15x15.txt
  Original matrix shape:  (15, 15)

Building NJ tree for original matrix...
  Saved: tree_Original.png, tree_Original.newick

Processing: HybAdamUM_completed_p30_rep1.csv (p30, rep1)
  Done: RF_norm = 0.583333 | saved -> tree_p30_rep1.png

Processing: HybAdamUM_comp