In [19]:
from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any, Dict, List, Optional


def to_float(x: Any) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None


def safe_mean(values: List[Optional[float]]) -> Optional[float]:
    vals = [v for v in values if v is not None]
    return (sum(vals) / len(vals)) if vals else None


def load_metrics(mfile: Path) -> Dict[str, Optional[float]]:
    try:
        data = json.loads(mfile.read_text())
    except Exception as e:
        raise RuntimeError(f"Failed to read {mfile}: {e}") from e

    return {
        "psi_mse": to_float(data.get("psi_mse")),
        "phi_mse": to_float(data.get("phi_mse")),
        "psi_rmse": to_float(data.get("psi_rmse")),
        "phi_rmse": to_float(data.get("phi_rmse")),
        # optional plain keys if present
        "loss": to_float(data.get("loss")),
        "mse": to_float(data.get("mse")),
        "rmse": to_float(data.get("rmse")),
    }


def find_metric_dirs(base_dir: Path) -> List[Path]:
    # Prefer explicit patterns, but also accept any immediate subdir that has metrics.json
    candidates = set(base_dir.glob("example_*")) | set(base_dir.glob("sample_*"))
    # Add any subdir containing metrics.json
    for child in base_dir.iterdir():
        if child.is_dir() and (child / "metrics.json").is_file():
            candidates.add(child)
    return sorted([p for p in candidates if (p / "metrics.json").is_file()])


def compute_summary(base_dir: Path) -> Dict[str, Any]:
    subdirs = find_metric_dirs(base_dir)
    records: List[Dict[str, Optional[float]]] = []
    for d in subdirs:
        mfile = d / "metrics.json"
        try:
            rec = load_metrics(mfile)
        except Exception as e:
            print(f"Warning: {e}")
            continue
        rec["example"] = d.name  # type: ignore[index]
        records.append(rec)

    psi_mse_mean = safe_mean([r.get("psi_mse") for r in records])
    phi_mse_mean = safe_mean([r.get("phi_mse") for r in records])
    psi_rmse_mean = safe_mean([r.get("psi_rmse") for r in records])
    phi_rmse_mean = safe_mean([r.get("phi_rmse") for r in records])

    loss_mean = safe_mean([r.get("loss") for r in records])
    print(records[0].keys())
    rmspe_mean = safe_mean([r.get("rmse") for r in records])

    # overall means: average across psi and phi values altogether (ignore None)
    mse_pool = [r.get("psi_mse") for r in records] + [r.get("phi_mse") for r in records]
    rmse_pool = [r.get("psi_rmse") for r in records] + [
        r.get("phi_rmse") for r in records
    ]
    overall_mse_mean = safe_mean(mse_pool)
    overall_rmse_mean = safe_mean(rmse_pool)

    plain_mse_mean = safe_mean([r.get("mse") for r in records])
    plain_rmse_mean = safe_mean([r.get("rmse") for r in records])

    summary: Dict[str, Any] = {
        "base_dir": str(base_dir),
        "num_examples": len(records),
        "psi_mse_mean": psi_mse_mean,
        "phi_mse_mean": phi_mse_mean,
        "overall_mse_mean": overall_mse_mean,
        "psi_rmse_mean": psi_rmse_mean,
        "phi_rmse_mean": phi_rmse_mean,
        "overall_rmse_mean": overall_rmse_mean,
        "plain_mse_mean": plain_mse_mean,
        "plain_rmse_mean": plain_rmse_mean,
        "loss_mean": loss_mean,
        "rmspe_mean": rmspe_mean,
    }
    return summary

In [23]:
json_path = "/workspaces/deeprte/output/train/L/results/L-test"

In [24]:
s = compute_summary(Path(json_path))

dict_keys(['psi_mse', 'phi_mse', 'psi_rmse', 'phi_rmse', 'loss', 'mse', 'rmse', 'example'])


In [25]:
s

{'base_dir': '/workspaces/deeprte/output/train/L/results/L-test',
 'num_examples': 100,
 'psi_mse_mean': None,
 'phi_mse_mean': None,
 'overall_mse_mean': None,
 'psi_rmse_mean': None,
 'phi_rmse_mean': None,
 'overall_rmse_mean': None,
 'plain_mse_mean': None,
 'plain_rmse_mean': 0.25350243479013446,
 'loss_mean': 0.012238887450657786,
 'rmspe_mean': 0.25350243479013446}