In [2]:
# set cuda visible devices
def is_notebook() -> bool:
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter

import os
if is_notebook():
    os.environ["CUDA_VISIBLE_DEVICES"] = "" #"1"
    # os.environ['CUDA_LAUNCH_BLOCKING']="1"
    # os.environ['TORCH_USE_CUDA_DSA'] = "1"

import matplotlib 
if not is_notebook():
    matplotlib.use('Agg')

# set directory
os.chdir("/nas/ucb/oliveradk/diverse-gen/")

In [3]:
import json
from functools import partial
from itertools import product
from typing import Optional, Literal, Callable
from tqdm import tqdm
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from copy import deepcopy
from datetime import datetime
from collections import defaultdict

import submitit
from submitit.core.utils import CommandFunction
import nevergrad as ng
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objects as go

from losses.loss_types import LossType
from utils.exp_utils import get_executor, get_executor_local, run_experiments
from utils.utils import conf_to_args

In [4]:
SCRIPT_NAME = "spur_corr_exp.py"
EXP_DIR = Path("output/real_data_exps")
EXP_DIR.mkdir(parents=True, exist_ok=True)

In [5]:
seeds = [1, 2, 3]

# TODO: add aux weights based on tuining
method_configs = {
    "DivDis": {"loss_type": LossType.DIVDIS, "aux_weight": 3},
    "TopK 0.1": {"loss_type": LossType.TOPK, "mix_rate_lower_bound": 0.1, "aux_weight": 1},
    "TopK 0.5": {"loss_type": LossType.TOPK, "mix_rate_lower_bound": 0.5, "aux_weight": 3},
    "DBAT": {"loss_type": LossType.DBAT, "shared_backbone": False, "freeze_heads": True, "binary": True, "batch_size": 16, "target_batch_size": 32, "aux_weight": 0.5},
    "ERM": {"loss_type": LossType.ERM, "aux_weight": 0},
}

dataset_configs = {
    "waterbirds": {"dataset": "waterbirds", "model": "Resnet50", "epochs": 5, "source_cc": False},
}

configs = {
    (ds_name, method_name, seed): {**ds_config, **method_config, "seed": seed} 
    for (ds_name, ds_config), (method_name, method_config) in product(dataset_configs.items(), method_configs.items())
    for seed in seeds
}

def get_conf_dir(conf_name: tuple):
    ds, method, seed = conf_name
    return f"{EXP_DIR}/{ds}_{method}/{seed}"
    # return f"{ds}_{method}/{seed}" # TODO: fix

for (ds_name, method_name, seed), conf in configs.items():
    exp_dir = get_conf_dir((ds_name, method_name, seed))
    conf["exp_dir"] = exp_dir

In [11]:
len(configs)

15

# Run Experiments

In [28]:
executor = get_executor(EXP_DIR, mem_gb=16)
jobs = run_experiments(executor, list(configs.values()), SCRIPT_NAME)

# Plot Results

In [12]:
from typing import Literal
from pathlib import Path
def get_exp_metrics(conf: dict):
    if not (Path(conf["exp_dir"]) / "metrics.json").exists():
        raise FileNotFoundError(f"Metrics file not found for experiment {conf['exp_dir']}")
    with open(Path(conf["exp_dir"]) / "metrics.json", "r") as f:
        exp_metrics = json.load(f)
    return exp_metrics

def get_max_acc(
    exp_metrics: dict,
    acc_metric: Literal["test_acc", "test_worst_acc", "test_acc_alt"]="test_acc",
    model_selection: Literal["acc", "loss", "weighted_loss", "repulsion_loss"]="acc"
):
    max_accs = np.maximum(np.array(exp_metrics[f'{acc_metric}_0']), np.array(exp_metrics[f'{acc_metric}_1']))
    if model_selection == "acc": 
        max_acc_idx= np.argmax(max_accs)
    elif model_selection == "loss":
        max_acc_idx = np.argmin(exp_metrics["val_loss"])
    elif model_selection == "weighted_loss":
        max_acc_idx = np.argmin(exp_metrics["val_weighted_loss"])
    elif model_selection == "repulsion_loss":
        max_acc_idx = np.argmin(exp_metrics["target_val_weighted_repulsion_loss"])
    else: 
        raise ValueError(f"Invalid model selection: {model_selection}")
    max_acc = max_accs[max_acc_idx]
    return max_acc

# data structure: dictionary with keys method types, values dict[mix_rate, list[len(seeds)]] of cifar accuracies (for now ignore case where mix_rate != mix_rate_lower_bound)
def get_acc_results(
    exp_configs: list[dict],
    acc_metric: Literal["test_acc", "test_worst_acc", "test_acc_alt"]="test_acc",
    model_selection: Literal["acc", "loss", "weighted_loss", "repulsion_loss"]="acc",
    verbose: bool=False
):
    results = []
    for conf in exp_configs:
        try:
            exp_metrics = get_exp_metrics(conf)
            max_acc = get_max_acc(exp_metrics, acc_metric, model_selection)
            results.append(max_acc)
        except FileNotFoundError:
            if verbose:
                print(f"Metrics file not found for experiment {conf['exp_dir']}")
            continue
    return results

In [13]:
from collections import defaultdict
exps_by_method = defaultdict(list)
for (ds_name, method_name, seed), conf in configs.items():
    exps_by_method[method_name].append(conf)

results = {
    method_name: get_acc_results(method_exps, model_selection="acc", acc_metric="test_acc")
    for method_name, method_exps in exps_by_method.items()
}

results_alt = {
    method_name: get_acc_results(method_exps, model_selection="acc", acc_metric="test_acc_alt")
    for method_name, method_exps in exps_by_method.items()
}

results_worst = {
    method_name: get_acc_results(method_exps, model_selection="acc", acc_metric="test_worst_acc")
    for method_name, method_exps in exps_by_method.items()
}

In [14]:
def print_stats(results_dict):
    stats = {}
    for method, data in results_dict.items():
        # Get the values for mix_rate 0.0 since that's what we have in the results
        values = data
        mean = np.mean(values)
        std = np.std(values)
        stats[method] = {'mean': mean, 'std': std}
    return stats

print("Average Accuracy:")
for method, stats in print_stats(results).items():
    print(f"{method:10}: {stats['mean']:.3f} ± {stats['std']:.3f}")

print("\nAlternative Accuracy:")
for method, stats in print_stats(results_alt).items():
    print(f"{method:10}: {stats['mean']:.3f} ± {stats['std']:.3f}")

print("\nWorst-Group Accuracy:")
for method, stats in print_stats(results_worst).items():
    print(f"{method:10}: {stats['mean']:.3f} ± {stats['std']:.3f}")

Average Accuracy:
DivDis    : 0.929 ± 0.010
TopK 0.1  : 0.897 ± 0.007
TopK 0.5  : 0.906 ± 0.016
DBAT      : 0.866 ± 0.009
ERM       : 0.891 ± 0.012

Alternative Accuracy:
DivDis    : 0.804 ± 0.052
TopK 0.1  : 0.732 ± 0.019
TopK 0.5  : 0.925 ± 0.003
DBAT      : 0.707 ± 0.009
ERM       : 0.638 ± 0.014

Worst-Group Accuracy:
DivDis    : 0.739 ± 0.047
TopK 0.1  : 0.718 ± 0.058
TopK 0.5  : 0.789 ± 0.042
DBAT      : 0.625 ± 0.020
ERM       : 0.630 ± 0.047


In [25]:
results, results_alt, results_worst

({'DivDis': [0.9373489618301392, 0.9344149231910706, 0.9152571558952332],
  'TopK 0.1': [0.9036934971809387, 0.900414228439331, 0.8874697685241699],
  'TopK 0.5': [0.8929927349090576, 0.8962720036506653, 0.9285467863082886],
  'DBAT': [0.8555402159690857, 0.8765964508056641, 0.8669313192367554]},
 {'DivDis': [0.7402485609054565, 0.8034173250198364, 0.867966890335083],
  'TopK 0.1': [0.7323092818260193, 0.7556092739105225, 0.7093545198440552],
  'TopK 0.5': [0.9221608638763428, 0.9299275279045105, 0.9244045615196228],
  'DBAT': [0.7138419151306152, 0.714359700679779, 0.6939938068389893]},
 {'DivDis': [0.7663551568984985, 0.777258574962616, 0.672897219657898],
  'TopK 0.1': [0.781931459903717, 0.7303769588470459, 0.6417445540428162],
  'TopK 0.5': [0.8302180767059326, 0.7305296063423157, 0.8052959442138672],
  'DBAT': [0.6536585092544556, 0.6106430292129517, 0.6097561120986938]})