In [30]:
import os 
os.chdir("/nas/ucb/oliveradk/diverse-gen/")

import yaml
from pathlib import Path
import json
from diverse_gen.utils.exp_utils import get_conf_dir
import pandas as pd


In [37]:
EXP_DIR = Path("output/random_network_baseline/2025-02-24_11-24-29")

In [38]:
DATASETS = [
    "toy_grid", 
    "fmnist_mnist", 
    "cifar_mnist", 
    "waterbirds", 
    "celebA-0", 
    # "multi-nli"
]
SEEDS = [1, 2, 3]


In [39]:
results = []
for ds_name in DATASETS:
    for seed in SEEDS:
        exp_dir = get_conf_dir((ds_name, "Random_Network", 0.0, seed), EXP_DIR)
        # load metrics
        metrics_path = Path(exp_dir) / "metrics.json"
        if not metrics_path.exists():
            print(f"Metrics file not found for {ds_name} {seed}")
            continue
        with open(metrics_path, "r") as f:
            metrics = json.load(f)
        with open(Path(exp_dir) / "config.yaml", "r") as f:
            config = yaml.safe_load(f)
        source_loss = float(metrics["val_source_loss"][0])
        source_acc = float(metrics["val_source_acc_0"][0])
        train = config["train"]
        results.append({"dataset": ds_name, "seed": seed, "source_loss": source_loss, "source_acc": source_acc, "train": train})

df = pd.DataFrame(results)



In [40]:
df

Unnamed: 0,dataset,seed,source_loss,source_acc,train
0,toy_grid,1,1.411544,0.5,False
1,toy_grid,2,1.409556,0.5,False
2,toy_grid,3,1.386931,0.5,False
3,fmnist_mnist,1,1.465185,0.410156,False
4,fmnist_mnist,2,1.471454,0.25,False
5,fmnist_mnist,3,1.541012,0.496094,False
6,cifar_mnist,1,1.451757,0.503906,False
7,cifar_mnist,2,1.369113,0.523438,False
8,cifar_mnist,3,1.485752,0.496094,False
9,waterbirds,1,1.56014,0.478448,False


In [41]:
summary = df.groupby('dataset').agg({
    'source_loss': ['mean', 'std']
}).round(4)
summary

Unnamed: 0_level_0,source_loss,source_loss
Unnamed: 0_level_1,mean,std
dataset,Unnamed: 1_level_2,Unnamed: 2_level_2
celebA-0,1.5306,0.1444
cifar_mnist,1.4355,0.06
fmnist_mnist,1.4926,0.0421
toy_grid,1.4027,0.0137
waterbirds,1.5974,0.0539
