In [3]:
import os
import glob
from tensorboard.backend.event_processing import event_accumulator
from collections import defaultdict
import pandas as pd

In [1]:
def load_bias_metrics_from_tensorboard(root_dir):
    """
    Scans root_dir recursively, finds latest TensorBoard event file for each run,
    loads scalar bias metrics, pivots them by step, and removes training loss and wet-day related tags.

    Returns:
        grouped_dfs: dict of {run_id: pd.DataFrame}, pivoted by step with cleaned tags
    """
    latest_event_files = {}

    # Step 1: Find latest event file for each run
    for root, dirs, files in os.walk(root_dir):
        event_files = [f for f in files if f.startswith("events.out.tfevents")]
        if not event_files:
            continue

        run_id = os.path.basename(root)
        full_paths = [os.path.join(root, f) for f in event_files]
        latest_file = max(full_paths, key=os.path.getmtime)
        latest_event_files[run_id] = latest_file

    # Step 2: Load scalars from each file
    run_data = defaultdict(list)

    for run_id, event_path in latest_event_files.items():
        try:
            ea = event_accumulator.EventAccumulator(event_path)
            ea.Reload()
            for tag in ea.Tags().get('scalars', []):
                for s in ea.Scalars(tag):
                    run_data[run_id].append((tag, s.step, s.value))
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to load {event_path}: {e}")

    # Step 3: Convert to cleaned pivoted DataFrames
    grouped_dfs = {}
    drop_tags = {
        'Loss/train',
        'median_adjusted/Wet Days >1mm',
        'median_adjusted/Very Wet Days >10mm',
        'median_adjusted/Very Very Wet Days >20mm',
        'median_adjusted/Dry Days'
    }

    for run_id, records in run_data.items():
        df = pd.DataFrame(records, columns=["tag", "step", "value"])
        pivoted = df.pivot(index='step', columns='tag', values='value').sort_index()
        pivoted = pivoted.drop(columns=[tag for tag in drop_tags if tag in pivoted.columns], errors='ignore')
        grouped_dfs[run_id] = pivoted.dropna()

    return grouped_dfs




In [None]:
root_dir = "runs_LSTM/conus_gridmet_new/gfdl_esm4-gridmet"
grouped_dfs = load_bias_metrics_from_tensorboard(root_dir)

In [28]:
import os

base_dir = "/pscratch/sd/k/kas7897/diffDownscale/jobs_revised_pca/access_cm2-gridmet"
second_level_dirs = []

for root, dirs, files in os.walk(base_dir):
    # Only consider first-level subdirectories
    if os.path.abspath(root) == os.path.abspath(base_dir):
        for d in dirs:
            subdir = os.path.join(root, d)
            # List subdirectories inside each first-level subdirectory
            for sub_root, sub_dirs, sub_files in os.walk(subdir):
                if os.path.abspath(sub_root) == os.path.abspath(subdir):
                    for sd in sub_dirs:
                        sd = sd[:8]
                        second_level_dirs.append(sd)
        break  # Only need to process the top level

print(second_level_dirs)

grouped_dfs = {k: v for k, v in grouped_dfs.items() if any(sub in k for sub in second_level_dirs)}


['a0b8fd4c', 'ab2538cf', '08fb4524', '43705867', 'bfcdd469', 'd6a01914', '18e94be5', '2eba82c4', '4a89eced', '6aab0ccc', '759cef29', 'e91a39c1', '2bac29a2', '51cfede1', '92c02791', 'b3bdb62f', '73a5cbfa', '15950e27', '4f247c2c', 'fe95099e', 'b9905a36', '6dc6b33f', 'b2b3ea84', 'fe7cb7a4', 'e2ce595b']


In [30]:
len(grouped_dfs)

8

In [42]:
# root_dir1 = "runs_revised/conus_pca/access_cm2-gridmet"
# grouped_dfs_pca = load_bias_metrics_from_tensorboard(root_dir1)

# grouped_dfs = grouped_dfs | grouped_dfs_pca

In [31]:
import pandas as pd

def find_best_experiment_and_epoch(exp_dict, agg_method='median'):
    """
    Args:
        exp_dict: dict of {exp_name: pd.DataFrame} with index=step, columns=indices (bias %)
        agg_method: 'median', 'mean', or 'sum' to aggregate bias across indices
    
    Returns:
        best_overall: (exp, step, score)
        best_per_index: {index: (exp, step, bias)}
        score_df: dataframe with all scores
    """
    rows = []

    for exp, df in exp_dict.items():
        for step, row in df.iterrows():
            bias_vals = row.dropna()
            if agg_method == 'median':
                score = bias_vals.abs().median()
            elif agg_method == 'mean':
                score = bias_vals.abs().mean()
            elif agg_method == 'sum':
                score = bias_vals.abs().sum()
            else:
                raise ValueError("agg_method must be 'median', 'mean', or 'sum'")

            rows.append({
                'exp': exp,
                'step': step,
                'score': score,
                **row.to_dict()
            })

    score_df = pd.DataFrame(rows)

    # Best overall (lowest aggregated score)
    best_overall_row = score_df.loc[score_df['score'].idxmin()]
    best_overall = (best_overall_row['exp'], best_overall_row['step'], best_overall_row['score'])

    # Best for each index (closest to 0 bias)
    indices = [col for col in score_df.columns if col not in ['exp', 'step', 'score']]
    best_per_index = {}
    for ind in indices:
        best_row = score_df.loc[score_df[ind].abs().idxmin()]
        best_per_index[ind] = (best_row['exp'], best_row['step'], best_row[ind])

    return best_overall, best_per_index, score_df


In [32]:
best_overall, best_per_index, scores = find_best_experiment_and_epoch(grouped_dfs, agg_method='median')


In [33]:
best_overall

('b9905a36_1979_2000_2001_2014', 200, 12.997745513916016)

In [46]:
best_per_index

{'Loss/validation': ('a466b5af_1979_2000_2001_2014', 400, 1.7437440156936646),
 'median_adjusted/CDD (Yearly)': ('aba3352c_1979_2000_2001_2014',
  30,
  -0.10065137594938278),
 'median_adjusted/CWD (Yearly)': ('71c3c134_1979_2000_2001_2014',
  170,
  2.6077096462249756),
 'median_adjusted/R10mm': ('6197c63f_1979_2000_2001_2014',
  30,
  0.05280762165784836),
 'median_adjusted/R20mm': ('22601d24_1979_2000_2001_2014',
  0,
  -4.671151161193848),
 'median_adjusted/R95pTOT': ('d0594b3b_1979_2000_2001_2014',
  20,
  5.493357181549072),
 'median_adjusted/R99pTOT': ('d0594b3b_1979_2000_2001_2014',
  20,
  5.493357181549072),
 'median_adjusted/Rx1day': ('35bcd105_1979_2000_2001_2014',
  0,
  22.253765106201172),
 'median_adjusted/Rx5day': ('35bcd105_1979_2000_2001_2014',
  0,
  36.42094039916992),
 'median_adjusted/SDII (Monthly)': ('ef45008b_1979_2000_2001_2014',
  0,
  1.352627158164978)}

In [47]:
from collections import Counter

def count_best_indices(best_per_index):
    exp_counts = Counter()
    for idx, (exp, step, bias) in best_per_index.items():
        exp_counts[exp] += 1
    return dict(exp_counts)

counts = count_best_indices(best_per_index)
print("üèÜ Best Index Counts Per Experiment:")
for exp, count in counts.items():
    print(f"{exp}: {count} indices")

üèÜ Best Index Counts Per Experiment:
a466b5af_1979_2000_2001_2014: 1 indices
aba3352c_1979_2000_2001_2014: 1 indices
71c3c134_1979_2000_2001_2014: 1 indices
6197c63f_1979_2000_2001_2014: 1 indices
22601d24_1979_2000_2001_2014: 1 indices
d0594b3b_1979_2000_2001_2014: 2 indices
35bcd105_1979_2000_2001_2014: 2 indices
ef45008b_1979_2000_2001_2014: 1 indices
