In [None]:
# viz_three_models.py
import os
import re
import h5py
import numpy as np
import matplotlib.pyplot as plt
import csv

# -------------------------------
# User configuration
# -------------------------------
INDEX = 100                                   # Which test sample to visualize
RUN_STEM = "20251014_024804"                  # Must match your training/prediction run folder name
PROJECT_ROOT = "/home/mingyeong/2510_GAL2DM_ASIM_ViT"

TEST_FILE = f"/scratch/adupuy/cosmicweb_asim/ASIM_TSC/samples/test/{INDEX}.hdf5"
CKPT_BASE = os.path.join(PROJECT_ROOT, "results/vit", RUN_STEM)
PRED_BASE = os.path.join(PROJECT_ROOT, "results/vit_predictions", RUN_STEM)

TAGS = ["patch8", "patch4", "patch2"]         # Subfolders for the three experiments
SLICE_AXIS = 2                                # 0: D, 1: H, 2: W
SLICE_INDEX_MODE = "center"                   # "center" or int index

OUT_DIR = os.path.join(PROJECT_ROOT, "figs", "viz_models", RUN_STEM)
os.makedirs(OUT_DIR, exist_ok=True)

# -------------------------------
# Utilities
# -------------------------------
def _safesqueeze(arr):
    a = np.squeeze(arr)
    if a.ndim < 3:
        raise ValueError(f"Unexpected ndim after squeeze: {a.ndim}, shape={a.shape}")
    return a

def _ensure_input_shape(x):
    """
    Accept:
      (2,D,H,W) or (1,2,D,H,W) or (N,2,D,H,W) with N=1
    Return:
      (2,D,H,W)
    """
    arr = x
    if arr.ndim == 4 and arr.shape[0] == 2:
        return arr
    if arr.ndim == 5 and arr.shape[0] == 1 and arr.shape[1] == 2:
        return arr[0]
    if arr.ndim == 5 and arr.shape[1] == 2:
        if arr.shape[0] != 1:
            raise ValueError(f"Got batch N={arr.shape[0]} > 1; unexpected for test visualisation.")
        return arr[0]
    if arr.ndim == 6 and arr.shape[1] == 1 and arr.shape[2] == 2:
        return np.squeeze(arr, axis=0)[0]
    raise ValueError(f"Unsupported input shape: {arr.shape}")

def _get_slice(vol3d, axis=2, idx="center"):
    if idx == "center":
        idx = vol3d.shape[axis] // 2
    if axis == 0:
        return vol3d[idx, :, :]
    elif axis == 1:
        return vol3d[:, idx, :]
    elif axis == 2:
        return vol3d[:, :, idx]
    else:
        raise ValueError("axis must be 0,1,2")

def _log1p_safe(a):
    # Clip negatives to 0 for density-like fields, then log10(1+x)
    return np.log10(1.0 + np.clip(a, 0, None))

def _load_truth_inputs(test_file):
    if not os.path.exists(test_file):
        raise FileNotFoundError(f"TEST_FILE not found: {test_file}")
    with h5py.File(test_file, "r") as f:
        if "input" not in f or "output_rho" not in f:
            raise KeyError("Required datasets 'input' and/or 'output_rho' missing in test file.")
        x = _ensure_input_shape(f["input"][:])     # (2,D,H,W)
        ngal = _safesqueeze(x[0])
        vpec = _safesqueeze(x[1])
        rho_true = _safesqueeze(f["output_rho"][:])
    return ngal, vpec, rho_true

def _load_prediction(pred_file):
    if not os.path.exists(pred_file):
        raise FileNotFoundError(f"Prediction file not found: {pred_file}")
    with h5py.File(pred_file, "r") as fp:
        if "prediction" not in fp:
            raise KeyError(f"'prediction' dataset not found in {pred_file}")
        pred = _safesqueeze(fp["prediction"][:])   # (D,H,W)
    return pred

def _candidate_loss_csvs(ckpt_dir):
    # Common names to try, ordered by preference
    names = [
        "train_log.csv", "history.csv", "loss.csv", "logs.csv",
        "training_history.csv", "metrics.csv"
    ]
    # Also any file matching loss*.csv or *history*.csv
    existing = []
    for n in names:
        p = os.path.join(ckpt_dir, n)
        if os.path.isfile(p):
            existing.append(p)
    for fname in sorted(os.listdir(ckpt_dir)):
        if re.search(r"(loss|history).*\.csv$", fname, re.IGNORECASE):
            p = os.path.join(ckpt_dir, fname)
            if p not in existing:
                existing.append(p)
    return existing

def _load_loss_csv(loss_csv_path):
    """
    Return dict: {"epoch": [...], "train": [...], "val": [...]}
    Column names accepted (case-insensitive):
      epoch, train_loss/train/TrainLoss, val_loss/val/ValLoss/valid/validation
    """
    keys = {"epoch": None, "train": None, "val": None}
    epochs, train, val = [], [], []
    with open(loss_csv_path, "r", newline="") as f:
        reader = csv.reader(f)
        header = next(reader)
        # map indices
        col_map = {c.lower(): i for i, c in enumerate(header)}
        # helper to find a column among aliases
        def find_idx(*aliases):
            for a in aliases:
                if a in col_map:
                    return col_map[a]
            return None
        ie = find_idx("epoch", "epochs")
        it = find_idx("train_loss", "train", "trainloss", "loss_train", "loss")
        iv = find_idx("val_loss", "val", "valloss", "valid", "validation", "loss_val")
        for row in reader:
            try:
                epochs.append(int(float(row[ie])) if ie is not None else len(epochs)+1)
                train.append(float(row[it]) if it is not None else np.nan)
                val.append(float(row[iv]) if iv is not None else np.nan)
            except Exception:
                # skip malformed lines
                continue
    return {"epoch": epochs, "train": train, "val": val}

def _plot_one_model(tag, ngal, rho_true, pred, loss_dict, out_png):
    # choose a consistent slice index across all models
    ngal_s = _get_slice(ngal, axis=SLICE_AXIS, idx=SLICE_INDEX_MODE)
    rho_t_s = _get_slice(rho_true, axis=SLICE_AXIS, idx=SLICE_INDEX_MODE)
    rho_p_s = _get_slice(pred, axis=SLICE_AXIS, idx=SLICE_INDEX_MODE)

    fig, axes = plt.subplots(1, 3, figsize=(18, 5.5))
    # 1) Loss curve
    ax0 = axes[0]
    if loss_dict is not None and len(loss_dict.get("epoch", [])) > 0:
        ax0.plot(loss_dict["epoch"], loss_dict["train"], label="Train", lw=1.8)
        if not np.all(np.isnan(loss_dict["val"])):
            ax0.plot(loss_dict["epoch"], loss_dict["val"], label="Val", lw=1.8)
        ax0.set_xlabel("Epoch")
        ax0.set_ylabel("Loss")
        ax0.set_title(f"[{tag}] Training Curve")
        ax0.grid(True, ls="--", alpha=0.4)
        ax0.legend()
    else:
        ax0.text(0.5, 0.5, "No loss CSV found", ha="center", va="center", fontsize=12)
        ax0.set_axis_off()

    # 2) True slice
    im1 = axes[1].imshow(_log1p_safe(rho_t_s), origin="lower", cmap="inferno")
    axes[1].set_title(f"[{tag}] True ρ (log10(1+ρ))")
    cb1 = plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
    cb1.ax.set_ylabel("log10(1+ρ)")

    # 3) Pred slice
    im2 = axes[2].imshow(_log1p_safe(rho_p_s), origin="lower", cmap="inferno")
    axes[2].set_title(f"[{tag}] Predicted ρ̂ (log10(1+ρ̂))")
    cb2 = plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)
    cb2.ax.set_ylabel("log10(1+ρ̂)")

    plt.tight_layout()
    plt.savefig(out_png, dpi=160)
    plt.close(fig)
    print(f"[SAVE] {out_png}")

# -------------------------------
# Load once: truth & inputs
# -------------------------------
ngal, vpec, rho_true = _load_truth_inputs(TEST_FILE)

# -------------------------------
# Iterate over three models
# -------------------------------
for tag in TAGS:
    ckpt_dir = os.path.join(CKPT_BASE, tag)
    pred_file = os.path.join(PRED_BASE, tag, f"{INDEX}.hdf5")

    # Load prediction
    try:
        pred = _load_prediction(pred_file)
    except Exception as e:
        print(f"[WARN] Skip {tag}: cannot load prediction: {e}")
        continue

    # Load loss curve if available
    loss_dict = None
    if os.path.isdir(ckpt_dir):
        cands = _candidate_loss_csvs(ckpt_dir)
        for cand in cands:
            try:
                ld = _load_loss_csv(cand)
                if len(ld.get("epoch", [])) > 0:
                    loss_dict = ld
                    break
            except Exception:
                continue
    else:
        print(f"[INFO] No ckpt dir for {tag}: {ckpt_dir}")

    # Output figure
    out_png = os.path.join(OUT_DIR, f"viz_{tag}_idx{INDEX}.png")
    _plot_one_model(tag, ngal, rho_true, pred, loss_dict, out_png)

print("[DONE] All available models processed.")
