In [3]:
import os
import glob
from tensorboard.backend.event_processing import event_accumulator
from collections import defaultdict
import pandas as pd
import numpy as np

In [4]:
def load_bias_metrics_from_tensorboard(root_dir):
    """
    Scans root_dir recursively, finds latest TensorBoard event file for each run,
    loads scalar bias metrics, pivots them by step, and removes training loss and wet-day related tags.

    Returns:
        grouped_dfs: dict of {run_id: pd.DataFrame}, pivoted by step with cleaned tags
    """
    latest_event_files = {}

    # Step 1: Find latest event file for each run
    for root, dirs, files in os.walk(root_dir):
        event_files = [f for f in files if f.startswith("events.out.tfevents")]
        if not event_files:
            continue

        run_id = os.path.basename(root)
        full_paths = [os.path.join(root, f) for f in event_files]
        latest_file = max(full_paths, key=os.path.getmtime)
        latest_event_files[run_id] = latest_file

    # Step 2: Load scalars from each file
    run_data = defaultdict(list)

    for run_id, event_path in latest_event_files.items():
        try:
            ea = event_accumulator.EventAccumulator(event_path)
            ea.Reload()
            for tag in ea.Tags().get('scalars', []):
                for s in ea.Scalars(tag):
                    run_data[run_id].append((tag, s.step, s.value))
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to load {event_path}: {e}")

    # Step 3: Convert to cleaned pivoted DataFrames
    grouped_dfs = {}
    drop_tags = {
        'Loss/train',
        'median_adjusted/Wet Days >1mm',
        'median_adjusted/Very Wet Days >10mm',
        'median_adjusted/Very Very Wet Days >20mm',
        'median_adjusted/Dry Days'
    }

    for run_id, records in run_data.items():
        df = pd.DataFrame(records, columns=["tag", "step", "value"])
        pivoted = df.pivot(index='step', columns='tag', values='value').sort_index()
        pivoted = pivoted.drop(columns=[tag for tag in drop_tags if tag in pivoted.columns], errors='ignore')
        grouped_dfs[run_id] = pivoted.dropna()

    return grouped_dfs




In [5]:
root_dir = "runs_revised/conus_gridmet_cnn/access_cm2-gridmet"
grouped_dfs = load_bias_metrics_from_tensorboard(root_dir)

In [None]:
# import os

# base_dir = "/pscratch/sd/k/kas7897/diffDownscale/jobs_revised_pca/access_cm2-gridmet"
# second_level_dirs = []

# for root, dirs, files in os.walk(base_dir):
#     # Only consider first-level subdirectories
#     if os.path.abspath(root) == os.path.abspath(base_dir):
#         for d in dirs:
#             subdir = os.path.join(root, d)
#             # List subdirectories inside each first-level subdirectory
#             for sub_root, sub_dirs, sub_files in os.walk(subdir):
#                 if os.path.abspath(sub_root) == os.path.abspath(subdir):
#                     for sd in sub_dirs:
#                         sd = sd[:8]
#                         second_level_dirs.append(sd)
#         break  # Only need to process the top level

# print(second_level_dirs)

# grouped_dfs = {k: v for k, v in grouped_dfs.items() if any(sub in k for sub in second_level_dirs)}


['a0b8fd4c', 'ab2538cf', '08fb4524', '43705867', 'bfcdd469', 'd6a01914', '18e94be5', '2eba82c4', '4a89eced', '6aab0ccc', '759cef29', 'e91a39c1', '2bac29a2', '51cfede1', '92c02791', 'b3bdb62f', '73a5cbfa', '15950e27', '4f247c2c', 'fe95099e', 'b9905a36', '6dc6b33f', 'b2b3ea84', 'fe7cb7a4', 'e2ce595b']


In [9]:
grouped_dfs['4de68857_1979_2000_2001_2014']

tag,Loss/validation,median_adjusted/CDD (Yearly),median_adjusted/CWD (Yearly),median_adjusted/R10mm,median_adjusted/R20mm,median_adjusted/R95pTOT,median_adjusted/R99pTOT,median_adjusted/Rx1day,median_adjusted/Rx5day,median_adjusted/SDII (Monthly)
step,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
0,19458.689453,-7.190606,11.618481,172.647736,550.0,2312.738281,2312.738281,2275.392578,2083.965332,1271.777588
10,296.030273,123.769463,-55.113121,-25.100586,7.898946,-17.15498,-17.15498,72.489929,42.600544,127.159988
20,83.359314,761.545776,-86.889351,-98.000145,-99.074608,-95.373024,-95.373024,-90.816208,-92.650383,-83.032486
30,55.154156,1190.998291,-98.571426,-100.0,-100.0,-100.0,-100.0,-99.504242,-99.666145,-98.919708
40,51.236343,1253.336182,-99.603172,-100.0,-100.0,-100.0,-100.0,-99.956123,-99.976555,-99.906738
50,30.533264,-64.312965,369.839844,-82.140617,-100.0,4.382234,4.382234,-26.120935,15.688938,-51.944843
60,13.804272,-10.167025,27.525837,2.498233,-77.833069,-13.489286,-13.489286,11.286108,38.667217,0.812907
70,12.780326,-5.115696,13.935658,1.406875,-76.275917,-22.942089,-22.942089,6.728521,31.89596,2.800834
80,10.871817,-2.767333,2.185374,-3.519678,-54.340134,-23.396389,-23.396389,11.722526,26.634626,3.980069
90,7.98353,-3.148599,-9.832252,-6.080954,-22.738094,-14.517997,-14.517997,23.930775,28.268358,12.486053


In [42]:
# root_dir1 = "runs_revised/conus_pca/access_cm2-gridmet"
# grouped_dfs_pca = load_bias_metrics_from_tensorboard(root_dir1)

# grouped_dfs = grouped_dfs | grouped_dfs_pca

In [9]:
import pandas as pd

def find_best_experiment_and_epoch(exp_dict, agg_method='median'):
    """
    Args:
        exp_dict: dict of {exp_name: pd.DataFrame} with index=step, columns=indices (bias %)
        agg_method: 'median', 'mean', or 'sum' to aggregate bias across indices
    
    Returns:
        best_overall: (exp, step, score)
        best_per_index: {index: (exp, step, bias)}
        score_df: dataframe with all scores
    """
    rows = []

    for exp, df in exp_dict.items():
        for step, row in df.iterrows():
            bias_vals = row.dropna()
            if agg_method == 'median':
                score = bias_vals.abs().median()
            elif agg_method == 'mean':
                score = bias_vals.abs().mean()
            elif agg_method == 'sum':
                score = bias_vals.abs().sum()
            else:
                raise ValueError("agg_method must be 'median', 'mean', or 'sum'")

            rows.append({
                'exp': exp,
                'step': step,
                'score': score,
                **row.to_dict()
            })

    score_df = pd.DataFrame(rows)

    # Best overall (lowest aggregated score)
    best_overall_row = score_df.loc[score_df['score'].idxmin()]
    best_overall = (best_overall_row['exp'], best_overall_row['step'], best_overall_row['score'])

    # Best for each index (closest to 0 bias)
    indices = [col for col in score_df.columns if col not in ['exp', 'step', 'score']]
    best_per_index = {}
    for ind in indices:
        best_row = score_df.loc[score_df[ind].abs().idxmin()]
        best_per_index[ind] = (best_row['exp'], best_row['step'], best_row[ind])

    return best_overall, best_per_index, score_df


In [10]:
best_overall, best_per_index, scores = find_best_experiment_and_epoch(grouped_dfs, agg_method='median')


In [11]:
best_overall

('4de68857_1979_2000_2001_2014', 220, 7.9664692878723145)

In [12]:
best_per_index

{'Loss/validation': ('40970740_1979_2000_2001_2014', 380, 2.3028457164764404),
 'median_adjusted/CDD (Yearly)': ('74cc5d77_1979_2000_2001_2014',
  30,
  0.18543754518032074),
 'median_adjusted/CWD (Yearly)': ('cd2c368c_1979_2000_2001_2014',
  80,
  -0.0371057502925396),
 'median_adjusted/R10mm': ('cd2c368c_1979_2000_2001_2014',
  70,
  -0.5277726054191589),
 'median_adjusted/R20mm': ('6816184f_1979_2000_2001_2014', 130, 0.0),
 'median_adjusted/R95pTOT': ('74cc5d77_1979_2000_2001_2014',
  250,
  -0.13641822338104248),
 'median_adjusted/R99pTOT': ('74cc5d77_1979_2000_2001_2014',
  250,
  -0.13641822338104248),
 'median_adjusted/Rx1day': ('cd2c368c_1979_2000_2001_2014',
  70,
  5.9045257568359375),
 'median_adjusted/Rx5day': ('40970740_1979_2000_2001_2014',
  0,
  3.2955329418182373),
 'median_adjusted/SDII (Monthly)': ('4de68857_1979_2000_2001_2014',
  60,
  0.8129074573516846)}

In [13]:
from collections import Counter

def count_best_indices(best_per_index):
    exp_counts = Counter()
    for idx, (exp, step, bias) in best_per_index.items():
        exp_counts[exp] += 1
    return dict(exp_counts)

counts = count_best_indices(best_per_index)
print("üèÜ Best Index Counts Per Experiment:")
for exp, count in counts.items():
    print(f"{exp}: {count} indices")

üèÜ Best Index Counts Per Experiment:
40970740_1979_2000_2001_2014: 2 indices
74cc5d77_1979_2000_2001_2014: 3 indices
cd2c368c_1979_2000_2001_2014: 3 indices
6816184f_1979_2000_2001_2014: 1 indices
4de68857_1979_2000_2001_2014: 1 indices


In [None]:
def find_best_experiment_with_stability(exp_dict, agg_method='median', 
                                       stability_window=10, min_epochs=50,
                                       loss_weight=0.3, bias_weight=0.7):
    """
    Find best experiment considering both bias performance AND training stability.
    
    Args:
        stability_window: number of recent epochs to check for stability
        min_epochs: minimum training epochs before considering a model
        loss_weight, bias_weight: relative importance of loss vs bias (should sum to 1)
    """
    results = []
    
    for exp, df in exp_dict.items():
        if len(df) < min_epochs:
            continue
            
        # Get loss data (you'll need to load this separately)
        loss_data = load_loss_data(exp)  # You'll need to implement this
        
        for step in df.index[min_epochs:]:  # Only consider after min_epochs
            row = df.loc[step]
            bias_vals = row.dropna()
            
            if len(bias_vals) == 0:
                continue
                
            # 1. Calculate bias score
            if agg_method == 'median':
                bias_score = bias_vals.abs().median()
            elif agg_method == 'mean':
                bias_score = bias_vals.abs().mean()
            else:
                bias_score = bias_vals.abs().sum()
            
            # 2. Check training stability
            stability_metrics = calculate_stability(loss_data, step, stability_window)
            
            # 3. Combined score
            combined_score = (loss_weight * stability_metrics['loss_score'] + 
                            bias_weight * bias_score)
            
            results.append({
                'exp': exp,
                'step': step,
                'bias_score': bias_score,
                'loss_trend': stability_metrics['loss_trend'],
                'loss_variance': stability_metrics['loss_variance'],
                'is_stable': stability_metrics['is_stable'],
                'epochs_since_improvement': stability_metrics['epochs_since_improvement'],
                'combined_score': combined_score,
                **row.to_dict()
            })
    
    results_df = pd.DataFrame(results)
    
    # Filter to only stable models
    stable_results = results_df[results_df['is_stable']].copy()
    
    if len(stable_results) == 0:
        print("‚ö†Ô∏è No stable models found! Relaxing stability criteria...")
        stable_results = results_df
    
    # Best stable model
    best_stable = stable_results.loc[stable_results['combined_score'].idxmin()]
    
    return best_stable, stable_results

def calculate_stability(loss_data, current_step, window=10):
    """Calculate training stability metrics"""
    if current_step < window:
        return {'is_stable': False, 'loss_trend': float('inf'), 
                'loss_variance': float('inf'), 'loss_score': 1.0,
                'epochs_since_improvement': 0}
    
    recent_loss = loss_data[current_step-window:current_step+1]
    
    # 1. Loss trend (should be flat or slightly decreasing)
    loss_trend = np.polyfit(range(len(recent_loss)), recent_loss, 1)[0]
    
    # 2. Loss variance (should be low)
    loss_variance = np.var(recent_loss)
    
    # 3. Epochs since last significant improvement
    best_loss_idx = np.argmin(loss_data[:current_step+1])
    epochs_since_improvement = current_step - best_loss_idx
    
    # 4. Stability criteria
    is_stable = (
        abs(loss_trend) < 0.001 and  # Trend is nearly flat
        loss_variance < 0.01 and     # Low variance
        epochs_since_improvement < window * 2  # Recent improvement
    )
    
    # 5. Loss score (normalized, lower is better)
    loss_score = min(recent_loss) / max(loss_data)  # Relative to worst loss
    
    return {
        'is_stable': is_stable,
        'loss_trend': loss_trend,
        'loss_variance': loss_variance,
        'loss_score': loss_score,
        'epochs_since_improvement': epochs_since_improvement
    }

def load_loss_data(exp_name):
    """Load training loss for a specific experiment"""
    # You'll need to implement this based on your TensorBoard data
    # Return array of loss values by epoch
    pass

In [None]:
# auto_select.py
import json, math, re
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Dict, List, Any, Tuple, Optional

############################
# 1) Configurable settings #
############################
# Your metrics keys (all are "lower is better" AFTER transform in _compute_J)
METRIC_WEIGHTS = {
    "pdf_gap": 0.25,          # 1 - PDF overlap (we'll convert inside)
    "wasserstein": 0.15,
    "tail_q95": 0.10,
    "tail_q99": 0.10,
    "tail_rx1": 0.075,
    "tail_rx5": 0.075,
    "wet_sdii": 0.05,
    "wet_cdd": 0.05,
    "wet_cwd": 0.05,
    "wet_r10": 0.05,
    "wet_r20": 0.05,
    "trend_gap": 0.10,
}

# Optionally weight across temporal aggregation scales
SCALE_WEIGHTS = {"daily": 0.5, "monthly": 0.3, "seasonal": 0.2}

EMA_ALPHA = 0.3            # smoothing for J across epochs
MIN_EPOCHS = 10            # require at least this many epochs to evaluate a run
IMPROVEMENT_FLOOR = 0.02   # require >=2% improvement vs the run's early J, else mark as "stagnant"
NONINCR_TOL = 0.01         # allow tiny non-monotonicity (1%) in the tail of the curve

# Filenames expected inside each trial directory
VAL_LOG = "val_metrics.jsonl"    # one json per line: {"epoch": int, "loss": float, "metrics": {...}}
BASELINE_FILE = "baseline.json"  # raw-vs-obs metrics dict used for normalization (same keys as metrics)

#######################
# 2) Helper utilities #
#######################
def _ema(series: List[float], alpha: float) -> List[float]:
    out = []
    s = None
    for x in series:
        s = x if s is None else alpha * x + (1 - alpha) * s
        out.append(s)
    return out

def _safe_div(a: float, b: float, eps: float = 1e-8) -> float:
    return a / (b + eps)

def _normalize_metric(value: float, baseline: float, lower_better=True) -> float:
    """
    Map metric to [0,1] where 0 ~ perfect, 1 ~ as bad as raw baseline.
    If lower_better is False (e.g., PDF overlap), we invert appropriately.
    """
    if not lower_better:
        # convert to a "gap" first: 1 - overlap
        value = 1.0 - value
        baseline = 1.0 - baseline
    # Normalize relative to baseline (clip to [0,1.5] but we'll clamp later)
    norm = _safe_div(value, baseline if baseline > 0 else 1.0)
    return max(0.0, min(1.0, norm))

def _compute_J(metrics: Dict[str, float], baseline: Dict[str, float]) -> float:
    """
    Compute single scalar J from (possibly multi-scale) metrics.
    Expect keys like "daily/pdf_overlap", "monthly/rx1_err", etc.
    """
    accum = 0.0
    wsum = 0.0
    for scale, sw in SCALE_WEIGHTS.items():
        scale_contrib = 0.0
        scale_wsum = 0.0

        # materialize a per-scale dict if present
        def get(key_suffix: str) -> Optional[float]:
            # prefer "<scale>/<key>", else plain "<key>"
            if f"{scale}/{key_suffix}" in metrics:
                return metrics[f"{scale}/{key_suffix}"]
            return metrics.get(key_suffix, None)

        # build normalized components
        parts = {
            "pdf_gap": _normalize_metric(
                get("pdf_overlap") if get("pdf_overlap") is not None else 1.0,  # overlap in [0,1]
                baseline.get(f"{scale}/pdf_overlap", baseline.get("pdf_overlap", 0.0)),
                lower_better=False
            ),
            "wasserstein": _normalize_metric(
                get("wasserstein") or 0.0,
                baseline.get(f"{scale}/wasserstein", baseline.get("wasserstein", 1.0)),
                lower_better=True
            ),
            "tail_q95": _normalize_metric(get("q95_err") or 0.0, baseline.get(f"{scale}/q95_err", 1.0)),
            "tail_q99": _normalize_metric(get("q99_err") or 0.0, baseline.get(f"{scale}/q99_err", 1.0)),
            "tail_rx1": _normalize_metric(get("rx1_err") or 0.0, baseline.get(f"{scale}/rx1_err", 1.0)),
            "tail_rx5": _normalize_metric(get("rx5_err") or 0.0, baseline.get(f"{scale}/rx5_err", 1.0)),
            "wet_sdii": _normalize_metric(get("sdii_err") or 0.0, baseline.get(f"{scale}/sdii_err", 1.0)),
            "wet_cdd":  _normalize_metric(get("cdd_err") or 0.0,  baseline.get(f"{scale}/cdd_err", 1.0)),
            "wet_cwd":  _normalize_metric(get("cwd_err") or 0.0,  baseline.get(f"{scale}/cwd_err", 1.0)),
            "wet_r10":  _normalize_metric(get("r10_err") or 0.0,  baseline.get(f"{scale}/r10_err", 1.0)),
            "wet_r20":  _normalize_metric(get("r20_err") or 0.0,  baseline.get(f"{scale}/r20_err", 1.0)),
            "trend_gap": _normalize_metric(get("trend_gap") or 0.0, baseline.get(f"{scale}/trend_gap", 1.0)),
        }

        for k, v in parts.items():
            w = METRIC_WEIGHTS.get(k, 0.0)
            scale_contrib += w * v
            scale_wsum += w

        if scale_wsum > 0:
            accum += sw * (scale_contrib / scale_wsum)
            wsum += sw

    return accum / wsum if wsum > 0 else 1.0

@dataclass
class EpochEval:
    epoch: int
    loss: float
    J: float
    J_ema: float

@dataclass
class TrialResult:
    trial_dir: str
    best_epoch: int
    best_J: float
    best_J_ema: float
    final_loss: float
    improved: bool
    good_tail: bool
    config_summary: Dict[str, Any]  # model_type, layers, parameter_scale, etc.

########################
# 3) Core evaluation   #
########################
def load_jsonl(path: Path) -> List[Dict[str, Any]]:
    rows = []
    with path.open() as f:
        for line in f:
            line = line.strip()
            if not line: continue
            rows.append(json.loads(line))
    return rows

def read_config_summary(trial_dir: Path) -> Dict[str, Any]:
    # Try a few common config filenames
    out = {"model_type": None, "layers": None, "hidden": None,
           "parameter_scale": None, "epochs": None, "trial": trial_dir.name}
    # Lightweight parse from path name (fallback)
    m = re.search(r"(MLP|CNN|LSTM)", trial_dir.name, re.I)
    if m: out["model_type"] = m.group(1).upper()
    for fname in ["config.json", "config.yaml", "train_config.yaml"]:
        p = trial_dir / fname
        if p.exists():
            try:
                if p.suffix == ".json":
                    cfg = json.loads(p.read_text())
                else:
                    # naive YAML loader to avoid pyyaml dependency
                    import yaml  # comment out if you truly can't have it
                    cfg = yaml.safe_load(p.read_text())
                out.update({k: cfg.get(k) for k in ["model_type","layers","hidden","parameter_scale","epochs"] if k in cfg})
            except Exception:
                pass
    return out

def evaluate_trial(trial_dir: Path) -> Optional[TrialResult]:
    val_path = trial_dir / VAL_LOG
    base_path = trial_dir / BASELINE_FILE
    if not val_path.exists() or not base_path.exists():
        return None

    logs = load_jsonl(val_path)
    if len(logs) < MIN_EPOCHS:
        return None

    baseline = json.loads(base_path.read_text())
    epochs, losses, Js = [], [], []

    for row in logs:
        ep = int(row.get("epoch", len(epochs)))
        met = row.get("metrics", {})
        loss = float(row.get("loss", math.nan))
        J = _compute_J(met, baseline)
        epochs.append(ep); losses.append(loss); Js.append(J)

    J_ema = _ema(Js, EMA_ALPHA)
    # choose best by raw J (not EMA), but keep EMA for stability diagnostics
    best_idx = int(min(range(len(Js)), key=lambda i: Js[i]))
    best_epoch = epochs[best_idx]
    best_J, best_J_ema = Js[best_idx], J_ema[best_idx]

    # diagnostics: did J meaningfully improve vs early training?
    early = Js[min(5, len(Js)-1)]
    improved = (early - best_J) / max(early, 1e-6) >= IMPROVEMENT_FLOOR

    # diagnostics: tail monotonic-ish check over last 20% epochs (EMA)
    tail_start = int(0.8 * len(J_ema))
    tail = J_ema[tail_start:]
    good_tail = True
    for i in range(1, len(tail)):
        if tail[i] - tail[i-1] > NONINCR_TOL * max(tail[i-1], 1e-6):
            good_tail = False; break

    cfg = read_config_summary(trial_dir)
    return TrialResult(
        trial_dir=str(trial_dir),
        best_epoch=best_epoch,
        best_J=best_J,
        best_J_ema=best_J_ema,
        final_loss=float(losses[-1]),
        improved=improved,
        good_tail=good_tail,
        config_summary=cfg
    )

########################
# 4) Batch orchestration
########################
def scan_and_rank(root: str) -> Dict[str, Any]:
    rootp = Path(root)
    results: List[TrialResult] = []
    for trial in sorted([p for p in rootp.glob("**/") if (p/VAL_LOG).exists()]):
        r = evaluate_trial(trial)
        if r is not None:
            results.append(r)

    # Filter out unstable runs first
    stable = [r for r in results if r.improved and r.good_tail]
    finalists = stable if stable else results  # if nothing stable, fall back

    # Rank by best_J, then tail-extremes proxy if you log it (we already folded into J)
    finalists.sort(key=lambda r: (r.best_J, r.best_J_ema, r.final_loss))

    # Summaries
    table = []
    for r in finalists:
        row = {
            "trial": Path(r.trial_dir).name,
            "best_epoch": r.best_epoch,
            "best_J": round(r.best_J, 6),
            "best_J_ema": round(r.best_J_ema, 6),
            "final_loss": round(r.final_loss, 6),
            "improved": r.improved,
            "good_tail": r.good_tail,
            **{f"cfg_{k}": v for k, v in r.config_summary.items()}
        }
        table.append(row)

    best = finalists[0] if finalists else None
    return {
        "best": asdict(best) if best else None,
        "ranked": table,
        "n_total": len(results),
        "n_stable": len(stable),
    }



In [None]:
if __name__ == "__main__":
    import argparse, csv, sys
    ap = argparse.ArgumentParser()
    ap.add_argument("--exp_root", required=True, help="Directory containing many trial subfolders")
    ap.add_argument("--out_csv", default="auto_select_results.csv")
    ap.add_argument("--out_json", default="auto_select_best.json")
    args = ap.parse_args()

    summary = scan_and_rank(args.exp_root)

    # CSV table for quick viewing
    rows = summary["ranked"]
    if rows:
        keys = list(rows[0].keys())
        with open(args.out_csv, "w", newline="") as f:
            w = csv.DictWriter(f, fieldnames=keys)
            w.writeheader(); w.writerows(rows)

    # JSON with winner & counts
    with open(args.out_json, "w") as f:
        json.dump(summary, f, indent=2)

    print(f"[auto_select] scanned={summary['n_total']} stable={summary['n_stable']}")
    if summary["best"]:
        print(f"[auto_select] BEST trial={Path(summary['best']['trial_dir']).name} epoch={summary['best']['best_epoch']} J={summary['best']['best_J']:.4f}")
    else:
        print("[auto_select] No valid trials found.")