In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# set directory
os.chdir("/nas/ucb/oliveradk/diverse-gen/")

In [2]:
import json
from functools import partial
from itertools import product
from typing import Optional, Literal, Callable
from tqdm import tqdm
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from copy import deepcopy
from datetime import datetime
from collections import defaultdict

import submitit
from submitit.core.utils import CommandFunction
import nevergrad as ng
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objects as go

from losses.loss_types import LossType
from utils.exp_utils import get_executor, get_executor_local, run_experiments
from utils.utils import conf_to_args

In [3]:
SCRIPT_NAME = "spur_corr_exp.py"
EXP_DIR = Path("output/real_data_group_labels_exps")
EXP_DIR.mkdir(parents=True, exist_ok=True)

In [31]:
seeds = [1, 2, 3]

# TODO: add aux weights based on tuining
method_configs = {
    "DivDis": {"loss_type": LossType.DIVDIS, "aux_weight": 2.5},
    "TopK 0.1": {"loss_type": LossType.TOPK, "mix_rate_lower_bound": 0.1, "aux_weight": 2.5},
    "TopK 0.5": {"loss_type": LossType.TOPK, "mix_rate_lower_bound": 0.5, "aux_weight": 2.5},
    "DBAT": {"loss_type": LossType.DBAT, "shared_backbone": False, "freeze_heads": True, "binary": True, "batch_size": 16, "target_batch_size": 32},
}

dataset_configs = {
    "waterbirds": {"dataset": "waterbirds", "model": "Resnet50", "epochs": 5, "source_cc": False, "use_group_labels": True},
}

configs = {
    (ds_name, method_name, seed): {**ds_config, **method_config, "seed": seed} 
    for (ds_name, ds_config), (method_name, method_config) in product(dataset_configs.items(), method_configs.items())
    for seed in seeds
}

def get_conf_dir(ds_name, method_name, seed):
    return f"{EXP_DIR}/{ds_name}_{method_name}/{seed}"

for (ds_name, method_name, seed), conf in configs.items():
    exp_dir = get_conf_dir(ds_name, method_name, seed)
    conf["exp_dir"] = exp_dir
    conf["plot_activations"] = False

# Run Experiments

In [9]:
executor = get_executor(EXP_DIR, mem_gb=16)
jobs = run_experiments(executor, list(configs.values()), SCRIPT_NAME)

# Process Results

In [58]:
dbat_conf = [conf for conf in configs.values() if conf["loss_type"] == LossType.DBAT][0]
exp_metrics = get_exp_metrics(dbat_conf)
exp_metrics["test_acc_0"], exp_metrics["test_acc_1"]
# hmm I might be processing dbat results wrong
# either validation loss is 0 to selecting wrong ac, or mismatch between val loss or idx labels
# should go back and check, recompute results

([0.4295823276042938,
  0.4325163960456848,
  0.7816706895828247,
  0.7221263647079468,
  0.8864342570304871],
 [0.8983430862426758,
  0.9176734685897827,
  0.8935105204582214,
  0.8940283060073853,
  0.9038660526275635])

In [54]:
from typing import Literal
from pathlib import Path
def get_exp_metrics(conf: dict):
    if not (Path(conf["exp_dir"]) / "metrics.json").exists():
        raise FileNotFoundError(f"Metrics file not found for experiment {conf['exp_dir']}")
    with open(Path(conf["exp_dir"]) / "metrics.json", "r") as f:
        exp_metrics = json.load(f)
    return exp_metrics

def get_max_acc(
    exp_metrics: dict,
    acc_metric: Literal["test_acc", "test_worst_acc", "test_acc_alt"]="test_acc",
    model_selection: Literal["acc", "loss", "weighted_loss", "repulsion_loss"]="acc", 
    head_idx: int = 0, 
    head_1_epochs: Optional[int] = None
):
    if head_1_epochs is not None:
        exp_metrics = {k: v[head_1_epochs:] for k, v in exp_metrics.items()}
    accs = np.array(exp_metrics[f'{acc_metric}_{head_idx}'])
    if model_selection == "acc": 
        max_acc_idx= np.argmax(accs)
    elif model_selection == "loss":
        max_acc_idx = np.argmin(exp_metrics["val_loss"])
    elif model_selection == "weighted_loss":
        max_acc_idx = np.argmin(exp_metrics["val_weighted_loss"])
    elif model_selection == "repulsion_loss":
        max_acc_idx = np.argmin(exp_metrics["target_val_weighted_repulsion_loss"])
    else: 
        raise ValueError(f"Invalid model selection: {model_selection}")
    accs = accs[max_acc_idx]
    return accs

# data structure: dictionary with keys method types, values dict[mix_rate, list[len(seeds)]] of cifar accuracies (for now ignore case where mix_rate != mix_rate_lower_bound)
def get_acc_results(
    exp_configs: list[dict],
    acc_metric: Literal["test_acc", "test_worst_acc", "test_acc_alt"]="test_acc",
    model_selection: Literal["acc", "loss", "weighted_loss", "repulsion_loss"]="acc",
    verbose: bool=False, 
    head_idx: int= 0
):
    results = []
    for conf in exp_configs:
        try:
            exp_metrics = get_exp_metrics(conf)
            head_1_epochs = 2 if conf["loss_type"] == LossType.DBAT else None
            max_acc = get_max_acc(exp_metrics, acc_metric, model_selection, head_idx, head_1_epochs)
            results.append(max_acc)
        except FileNotFoundError:
            if verbose:
                print(f"Metrics file not found for experiment {conf['exp_dir']}")
            continue
    return results

In [59]:
from collections import defaultdict
exps_by_method = defaultdict(list)
for (ds_name, method_name, seed), conf in configs.items():
    exps_by_method[method_name].append(conf)

results = {
    method_name: get_acc_results(
        method_exps, model_selection="acc", acc_metric="test_acc", head_idx=0 if method_name != "DBAT" else 1, verbose=True
    )
    for method_name, method_exps in exps_by_method.items()
}

results_alt = {
    method_name: get_acc_results(method_exps, model_selection="acc", acc_metric="test_acc", head_idx=1 if method_name != "DBAT" else 0, verbose=True)
    for method_name, method_exps in exps_by_method.items()
}

results_worst = {
    method_name: get_acc_results(method_exps, model_selection="acc", acc_metric="test_worst_acc", head_idx=0 if method_name != "DBAT" else 1, verbose=True)
    for method_name, method_exps in exps_by_method.items()
}

In [60]:
results, results_alt, results_worst

({'DivDis': [0.9207801222801208, 0.9190542101860046, 0.9328615665435791],
  'TopK 0.1': [0.8973075747489929, 0.8545046448707581, 0.8917846083641052],
  'TopK 0.5': [0.934242308139801, 0.9092164039611816, 0.9092164039611816],
  'DBAT': [0.9038660526275635, 0.9290645718574524, 0.9190542101860046]},
 {'DivDis': [0.9288919568061829, 0.9290645718574524, 0.9154297709465027],
  'TopK 0.1': [0.9057645797729492, 0.9335519671440125, 0.912150502204895],
  'TopK 0.5': [0.9537452459335327, 0.9376941919326782, 0.9473593235015869],
  'DBAT': [0.8864342570304871, 0.9223334193229675, 0.9192267656326294]},
 {'DivDis': [0.6806853413581848, 0.6495327353477478, 0.704049825668335],
  'TopK 0.1': [0.6510903239250183, 0.559190034866333, 0.722741425037384],
  'TopK 0.5': [0.8520249128341675, 0.7476718425750732, 0.8052959442138672],
  'DBAT': [0.8150776028633118, 0.8709534406661987, 0.8541020154953003]})

In [61]:
df = pd.DataFrame({
    'Method': [],
    'Average Acc': [],
    'Alternative Acc': [],
    'Worst-Group Acc': []
})

for method in results.keys():
    avg_acc = f"{np.mean(results[method])*100:.1f} ± {np.std(results[method])*100:.1f}"
    alt_acc = f"{np.mean(results_alt[method])*100:.1f} ± {np.std(results_alt[method])*100:.1f}"
    worst_acc = f"{np.mean(results_worst[method])*100:.1f} ± {np.std(results_worst[method])*100:.1f}"
    
    df = pd.concat([df, pd.DataFrame({
        'Method': [method],
        'Average Acc': [avg_acc],
        'Alternative Acc': [alt_acc],
        'Worst-Group Acc': [worst_acc]
    })], ignore_index=True)

# Print LaTeX table
print(df.to_latex(index=False, escape=True))

\begin{tabular}{llll}
\toprule
Method & Average Acc & Alternative Acc & Worst-Group Acc \\
\midrule
DivDis & 92.4 ± 0.6 & 92.4 ± 0.6 & 67.8 ± 2.2 \\
TopK 0.1 & 88.1 ± 1.9 & 91.7 ± 1.2 & 64.4 ± 6.7 \\
TopK 0.5 & 91.8 ± 1.2 & 94.6 ± 0.7 & 80.2 ± 4.3 \\
DBAT & 91.7 ± 1.0 & 90.9 ± 1.6 & 84.7 ± 2.3 \\
\bottomrule
\end{tabular}

