In [None]:
# set cuda visible devices
def is_notebook() -> bool:
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter

import os
if is_notebook():
    os.environ["CUDA_VISIBLE_DEVICES"] = "" #"1"
    # os.environ['CUDA_LAUNCH_BLOCKING']="1"
    # os.environ['TORCH_USE_CUDA_DSA'] = "1"

import matplotlib 
if not is_notebook():
    matplotlib.use('Agg')

# set directory
os.chdir("/nas/ucb/oliveradk/diverse-gen/")


In [None]:
import json
from functools import partial
from itertools import product
from typing import Optional, Literal, Callable
from tqdm import tqdm
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from copy import deepcopy
from datetime import datetime

import submitit
from submitit.core.utils import CommandFunction
import nevergrad as ng
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objects as go

from losses.loss_types import LossType
from utils.exp_utils import get_executor, get_executor_local, run_experiments
from utils.utils import conf_to_args

In [None]:
SCRIPT_NAME = "spur_corr_exp.py"
EXP_DIR = Path("output/cc_mix_rate_sweep")
EXP_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
# shared configs
BATCH_SIZE = 32 
TARGET_BATCH_SIZE = 64
EPOCHS = 5
LR = 1e-4
OPTIMIZER = "adamw"

# seeds
seeds = [1, 2, 3]

# mix rates
mix_rates = [0.1, 0.25, 0.5, 0.75, 1.0]

# methods 
methods = {
    "DivDis": {"loss_type": LossType.DIVDIS},
    "TopK 0.1": {"loss_type": LossType.TOPK, "mix_rate_lower_bound": 0.1}, 
    "TopK 0.5": {"loss_type": LossType.TOPK, "mix_rate_lower_bound": 0.5}, 
    "TopK 1.0": {"loss_type": LossType.TOPK, "mix_rate_lower_bound": 1.0}, 
    "DBAT": {"loss_type": LossType.DBAT, "shared_backbone": False, "freeze_heads": True},
    "ERM": {"loss_type": LossType.ERM}
}
# datasets 
datasets = {
    "toy_grid": {"dataset": "toy_grid", "model": "toy_model", "epochs": 100, "lr": 1e-3, "optimizer": "sgd", "batch_size": BATCH_SIZE, "target_batch_size": 128},
    "fmnist_mnist": {"dataset": "fmnist_mnist", "epochs": EPOCHS, "lr": LR, "optimizer": OPTIMIZER, "batch_size": BATCH_SIZE, "target_batch_size": TARGET_BATCH_SIZE},
    "cifar_mnist": {"dataset": "cifar_mnist", "epochs": EPOCHS, "lr": LR, "optimizer": OPTIMIZER, "batch_size": BATCH_SIZE, "target_batch_size": TARGET_BATCH_SIZE},
    "waterbirds": {"dataset": "waterbirds", "epochs": EPOCHS, "lr": LR, "optimizer": OPTIMIZER, "batch_size": BATCH_SIZE, "target_batch_size": TARGET_BATCH_SIZE},
    "celebA-0": {"dataset": "celebA-0", "epochs": EPOCHS, "lr": LR, "optimizer": OPTIMIZER, "batch_size": BATCH_SIZE, "target_batch_size": TARGET_BATCH_SIZE},
    "multi-nli":{"dataset": "multi-nli", "model": "bert", "epochs": 2, "lr": 1e-5, "optimizer": OPTIMIZER, "lr_scheduler": "cosine", "batch_size": BATCH_SIZE, "target_batch_size": TARGET_BATCH_SIZE, "combine_neut_entail": True, "contra_no_neg": True},
}

# tuned according to total validatin loss
aux_weight_map = {
    "TopK 0.1": {
        "toy_grid": 1.5,
        "fmnist_mnist": 1.0,
        "cifar_mnist": 1.5,
        "waterbirds": 8,
        "celebA-0": 2.5,
        "multi-nli": 16
    },
    "TopK 0.5": {
        "toy_grid": 1.0,
        "fmnist_mnist": 1.0,
        "cifar_mnist": 1.0,
        "waterbirds": 3.0,
        "celebA-0": 1.5,
        "multi-nli": 1.0
    },
    "TopK 1.0": {
        "toy_grid": 1.0,
        "fmnist_mnist": 1.0,
        "cifar_mnist": 1.0,
        "waterbirds": 5.0,
        "celebA-0": 2.5,
        "multi-nli": 16
    }, 
    "DivDis": {
        "toy_grid": 1.0,
        "fmnist_mnist": 1.5,
        "cifar_mnist": 1.5,
        "waterbirds": 6,
        "celebA-0": 1.5,
        "multi-nli": 64
    }, 
    "DBAT": {
        "toy_grid": 1.0,
        "fmnist_mnist": 1.0,
        "cifar_mnist": 1.0,
        "waterbirds": 1.0,
        "celebA-0": 1.0,
        "multi-nli": 1.0
    }, 
    "ERM": {
        "toy_grid": 1.0,
        "fmnist_mnist": 1.0,
        "cifar_mnist": 1.0,
        "waterbirds": 1.0,
        "celebA-0": 1.0,
        "multi-nli": 1.0
    }
}

configs = {
    (ds_name, method_name, mix_rate, seed): {**ds, **method, "mix_rate": mix_rate, "seed": seed} 
    for (ds_name, ds), (method_name, method), mix_rate, seed in product(datasets.items(), methods.items(), mix_rates, seeds)
}
for ((ds_name, method_name, mix_rate, seed), conf) in configs.items():
    conf["aux_weight"] = aux_weight_map[method_name][ds_name]

for conf in configs.values():
    if conf["loss_type"] == LossType.DBAT: 
        conf["batch_size"] = int(conf["batch_size"] / 2)
        conf["target_batch_size"] = int(conf["target_batch_size"] / 2)

def get_conf_dir(conf_name: tuple):
    ds, method, mix_rate, seed = conf_name
    return f"{ds}_{method}_{mix_rate}/{seed}"

for conf_name, conf in configs.items():
    conf["exp_dir"] = EXP_DIR / get_conf_dir(conf_name)


In [None]:
high_mem_ds = ["multi-nli", "celebA-0"]
low_mem_configs = {k: v for k, v in configs.items() if v["dataset"] not in high_mem_ds}
high_mem_configs = {k: v for k, v in configs.items() if v["dataset"] in high_mem_ds}


In [None]:
executor = get_executor(EXP_DIR, mem_gb=16)
jobs = run_experiments(executor, list(low_mem_configs.values()), SCRIPT_NAME)


In [None]:
executor = get_executor(EXP_DIR, mem_gb=32)
jobs = run_experiments(executor, list(high_mem_configs.values()), SCRIPT_NAME)                