In [1]:
import os
os.chdir("/nas/ucb/oliveradk/diverse-gen/")

In [2]:
from collections import defaultdict
from pathlib import Path
import itertools
from typing import Optional
import json
from itertools import product

import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import numpy as np
import pandas as pd
from omegaconf import OmegaConf

from diverse_gen.utils.exp_utils import get_conf_dir

In [3]:
MAIN_DIR = "output/incomplete_waterbirds/main"
GROUP_LABELS_DIR = "output/incomplete_waterbirds/group_labels"
RESULTS_DIR = "results/incomplete_waterbirds"
os.makedirs(RESULTS_DIR, exist_ok=True)

In [4]:
METHODS = [
    "TopK_0.1", 
    "TopK_0.5", 
    "ERM", 
    "DBAT", 
    "DivDis"
]
SEEDS = [1, 2, 3]
DATASETS = ["waterbirds"]
MIX_RATES = [None]

configs = list(product(DATASETS, METHODS, MIX_RATES, SEEDS))

In [5]:
# TODO: redo this
def get_results(configs: list[tuple[str, str, float, int]], parent_dir: str = MAIN_DIR, select_epoch: bool = True, 
                group_labels: bool = False) -> list[dict]:
    results = []
    for dataset, method, mix_rate, seed in configs:
        exp_dir = get_conf_dir((dataset, method, mix_rate, seed), parent_dir)
        metric_file = os.path.join(exp_dir, "metrics.json")
        with open(metric_file, "r") as f:
            metrics = json.load(f)
        if select_epoch: 
            if dataset == "toy_grid": 
                idxs = [len(metrics["val_loss"]) - 1]
            else: 
                if method == "DBAT":
                    idxs = [np.argmin(metrics["val_source_loss"])]
                else: 
                    idxs = [np.argmin(metrics["val_loss"])]
        else: 
            idxs = range(len(metrics["val_loss"]))
        
        for idx in idxs: 
            min_val_loss_idx = np.argmin(metrics["val_loss"])
            min_val_loss = metrics["val_loss"][idx]
            
            source_acc_0 = metrics["val_source_acc_0"][idx]
            source_acc_1 = metrics["val_source_acc_1"][idx]

            source_loss = metrics["val_source_loss"][idx]

            test_acc_0 = metrics["test_acc_0"][idx]
            test_acc_1 = metrics["test_acc_1"][idx]
            max_test_acc = max(test_acc_0, test_acc_1)

            worst_acc_0 = metrics["test_worst_acc_0"][idx]
            worst_acc_1 = metrics["test_worst_acc_1"][idx]
            max_worst_acc = max(worst_acc_0, worst_acc_1)

            if not group_labels:
                test_acc_alt_0 = metrics["test_acc_alt_0"][idx]
                test_acc_alt_1 = metrics["test_acc_alt_1"][idx]
                max_test_acc_alt = max(test_acc_alt_0, test_acc_alt_1)

            if not group_labels:
                multi_label_test_acc = np.mean([max_test_acc, max_test_acc_alt])

            result_dict = {
                "Dataset": dataset,
                "Method": method,
                "Mix_Rate": mix_rate,
                "Seed": seed,
                "val_loss": min_val_loss,
                "source_acc_0": source_acc_0,
                "source_acc_1": source_acc_1,
                "source_loss": source_loss,
                "acc_0": test_acc_0,
                "acc_1": test_acc_1,
                "worst_acc_0": worst_acc_0,
                "worst_acc_1": worst_acc_1,
                "epoch": idx,
            }
            if not group_labels:
                if method == "DBAT":
                    result_dict.update({
                        "acc": test_acc_0,
                        "acc_alt": test_acc_alt_1,
                        "worst_acc": worst_acc_0,
                        "multi_label_acc": np.mean([test_acc_0, test_acc_alt_1]),
                    })
                else: 
                    result_dict.update({
                        "acc": max_test_acc,
                        "acc_alt": max_test_acc_alt,
                        "worst_acc": max_worst_acc,
                    })
            results.append(result_dict)

    df = pd.DataFrame(results)
    return df


In [6]:
df = get_results(configs, MAIN_DIR)
df.to_csv(os.path.join(RESULTS_DIR, "results.csv"), index=False)
acc_df = df.copy().rename(columns={"acc": "Accuracy"})
alt_acc_df = df.copy().rename(columns={"acc_alt": "Accuracy"})
worst_acc_df = df.copy().rename(columns={"worst_acc": "Accuracy"})

In [7]:
df_gl = get_results([cfg for cfg in configs if cfg[1] != "ERM"], GROUP_LABELS_DIR, group_labels=True)
df_gl.to_csv(os.path.join(RESULTS_DIR, "df_gl.csv"), index=False)

acc_gl_df = df_gl.copy().rename(columns={"acc_0": "Accuracy"})
alt_acc_gl_df = df_gl.copy().rename(columns={"acc_1": "Accuracy"})
worst_acc_gl_df = df_gl.copy().rename(columns={"worst_acc_0": "Accuracy"})


In [8]:
# metrics = json.load(open("output/incomplete_waterbirds/group_labels/waterbirds_DivDis_None/1/metrics.json"))

In [9]:
# Create DataFrame with all metrics
def print_latex_table(acc_df, alt_acc_df, worst_acc_df):
    df = pd.DataFrame({
        'Method': [],
        'Average Acc': [],
        'Alternative Acc': [],
        'Worst-Group Acc': []
    })

    for method in METHODS:
        avg_acc = f"{acc_df[acc_df['Method'] == method]['Accuracy'].mean()*100:.1f} ± {acc_df[acc_df['Method'] == method]['Accuracy'].std()*100:.1f}"
        alt_acc = f"{alt_acc_df[alt_acc_df['Method'] == method]['Accuracy'].mean()*100:.1f} ± {alt_acc_df[alt_acc_df['Method'] == method]['Accuracy'].std()*100:.1f}"
        worst_acc = f"{worst_acc_df[worst_acc_df['Method'] == method]['Accuracy'].mean()*100:.1f} ± {worst_acc_df[worst_acc_df['Method'] == method]['Accuracy'].std()*100:.1f}"
        
        df = pd.concat([df, pd.DataFrame({
            'Method': [method.replace("_", " ")],
            'Average Acc': [avg_acc],
            'Alternative Acc': [alt_acc],
            'Worst-Group Acc': [worst_acc]
        })], ignore_index=True)

    # Print LaTeX table
    print(df.to_latex(index=False, escape=True))

In [10]:
print_latex_table(acc_df, alt_acc_df, worst_acc_df)

\begin{tabular}{llll}
\toprule
Method & Average Acc & Alternative Acc & Worst-Group Acc \\
\midrule
TopK 0.1 & 88.9 ± 1.3 & 74.0 ± 3.6 & 53.9 ± 12.8 \\
TopK 0.5 & 92.8 ± 0.4 & 90.7 ± 1.1 & 70.9 ± 2.3 \\
ERM & 84.7 ± 2.9 & 63.1 ± 2.9 & 51.1 ± 10.2 \\
DBAT & 60.2 ± 19.4 & 68.7 ± 1.5 & 33.1 ± 23.4 \\
DivDis & 91.0 ± 3.2 & 72.9 ± 1.3 & 67.7 ± 8.1 \\
\bottomrule
\end{tabular}



In [11]:
print_latex_table(acc_gl_df, alt_acc_gl_df, worst_acc_gl_df)

\begin{tabular}{llll}
\toprule
Method & Average Acc & Alternative Acc & Worst-Group Acc \\
\midrule
TopK 0.1 & 87.9 ± 0.7 & 89.7 ± 2.6 & 54.3 ± 8.8 \\
TopK 0.5 & 92.5 ± 0.9 & 94.8 ± 0.5 & 75.6 ± 7.0 \\
ERM & nan ± nan & nan ± nan & nan ± nan \\
DBAT & 89.6 ± 0.8 & 91.8 ± 0.7 & 50.3 ± 5.4 \\
DivDis & 91.7 ± 1.4 & 92.4 ± 0.9 & 66.9 ± 3.2 \\
\bottomrule
\end{tabular}

