In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.chdir("/nas/ucb/oliveradk/diverse-gen/")

In [None]:
import json
from functools import partial
from itertools import product
from typing import Optional, Literal, Callable
from tqdm import tqdm
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from copy import deepcopy
from datetime import datetime
from collections import defaultdict

import submitit
from submitit.core.utils import CommandFunction
import numpy as np

from losses.loss_types import LossType
from utils.exp_utils import get_executor, get_executor_local, run_experiments, get_conf_dir
from utils.proc_data_utils import get_exp_metrics, get_max_acc, get_acc_results
from utils.utils import conf_to_args

In [None]:
SCRIPT_NAME = "exp_scripts/spur_corr_exp.py"
EXP_DIR = Path("output/cc_mix_rate")
SUB_DIR = None
if SUB_DIR is None:
    SUB_DIR = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
EXP_DIR = Path(EXP_DIR, SUB_DIR)
EXP_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
# shared configs
BATCH_SIZE = 32 
TARGET_BATCH_SIZE = 64
EPOCHS = 5
LR = 1e-4
OPTIMIZER = "adamw"

# seeds
SEEDS = [1, 2, 3]

# mix rates
MIX_RATES = [0.1, 0.25, 0.5, 0.75, 1.0]

# methods 
methods = {
    "DivDis": {"loss_type": LossType.DIVDIS},
    "TopK 0.1": {
        "loss_type": LossType.TOPK, 
        "mix_rate_lower_bound": 0.1, 
        "mix_rate_schedule": "linear"
    }, 
    "TopK 0.5": {
        "loss_type": LossType.TOPK, 
        "mix_rate_lower_bound": 0.5, 
        "mix_rate_schedule": "linear"
    }, 
    "DBAT": {
        "loss_type": LossType.DBAT, "shared_backbone": False, "freeze_heads": True, "binary": True
    },
    "ERM": {"loss_type": LossType.ERM}
}
# datasets 
datasets = {
    "toy_grid": {
        "dataset": "toy_grid", 
        "model": "toy_model", 
        "epochs": 100, 
        "batch_size": BATCH_SIZE, 
        "target_batch_size": 128, 
        "plot_activations": False
    },
    "fmnist_mnist": {
        "dataset": "fmnist_mnist", 
        "epochs": EPOCHS, 
        "lr": LR, 
        "optimizer": OPTIMIZER, "batch_size": BATCH_SIZE, "target_batch_size": TARGET_BATCH_SIZE},
    "cifar_mnist": {"dataset": "cifar_mnist", "epochs": EPOCHS, "lr": LR, "optimizer": OPTIMIZER, "batch_size": BATCH_SIZE, "target_batch_size": TARGET_BATCH_SIZE},
    "waterbirds": {"dataset": "waterbirds", "epochs": EPOCHS, "lr": LR, "optimizer": OPTIMIZER, "batch_size": BATCH_SIZE, "target_batch_size": TARGET_BATCH_SIZE},
    "celebA-0": {"dataset": "celebA-0", "epochs": EPOCHS, "lr": LR, "optimizer": OPTIMIZER, "batch_size": BATCH_SIZE, "target_batch_size": TARGET_BATCH_SIZE},
    "multi-nli":{"dataset": "multi-nli", "model": "bert", "epochs": 2, "lr": 1e-5, "optimizer": OPTIMIZER, "lr_scheduler": "cosine", "batch_size": BATCH_SIZE, "target_batch_size": TARGET_BATCH_SIZE, "combine_neut_entail": True, "contra_no_neg": False},
}

# tuned according to total validatin loss
aux_weight_map = {
    "TopK 0.1": {
        "toy_grid": 1.5,
        "fmnist_mnist": 1.0,
        "cifar_mnist": 1.5,
        "waterbirds": 8,
        "celebA-0": 2.5,
        "multi-nli": 6
    },
    "TopK 0.5": {
        "toy_grid": 1.0,
        "fmnist_mnist": 1.0,
        "cifar_mnist": 1.0,
        "waterbirds": 3.0,
        "celebA-0": 1.5,
        "multi-nli": 6
    },
    "TopK 1.0": {
        "toy_grid": 1.0,
        "fmnist_mnist": 1.0,
        "cifar_mnist": 1.0,
        "waterbirds": 5.0,
        "celebA-0": 2.5,
        "multi-nli": 16
    }, 
    "DivDis": {
        "toy_grid": 1.0,
        "fmnist_mnist": 1.5,
        "cifar_mnist": 1.5,
        "waterbirds": 6,
        "celebA-0": 1.5,
        "multi-nli": 64
    }, 
    "DBAT": {
        "toy_grid": 1.0,
        "fmnist_mnist": 1.0,
        "cifar_mnist": 1.0,
        "waterbirds": 1.0,
        "celebA-0": 1.0,
        "multi-nli": 0.1
    }, 
    "ERM": {
        "toy_grid": 1.0,
        "fmnist_mnist": 1.0,
        "cifar_mnist": 1.0,
        "waterbirds": 1.0,
        "celebA-0": 1.0,
        "multi-nli": 1.0
    }
}

lr_optim_map = {
    "TopK 0.1": {
        "toy_grid": (1e-3, "sgd")
    }, 
    "DivDis": {
        "toy_grid": (1e-3, "adamw")
    }
}

configs = {
    (ds_name, method_name, mix_rate, seed): {**ds, **method, "mix_rate": mix_rate, "seed": seed} 
    for (ds_name, ds), (method_name, method), mix_rate, seed in product(datasets.items(), methods.items(), MIX_RATES, SEEDS)
    if not (ds_name == "multi-nli" and method_name == "ERM")
}
for (ds_name, ds), seed in product(datasets.items(), SEEDS):
    if ds_name == "multi-nli": # TODO: remove this
        configs[(ds_name, "ERM", 0.0, seed)] = {**ds, **methods["ERM"], "seed": seed}

###  dataset x method adjustments
# aux weight
for ((ds_name, method_name, mix_rate, seed), conf) in configs.items():
    conf["aux_weight"] = aux_weight_map[method_name][ds_name]
# optimizer and lr
for ((ds_name, method_name, mix_rate, seed), conf) in configs.items():
    if method_name in lr_optim_map and ds_name in lr_optim_map[method_name]:
        conf["lr"], conf["optimizer"] = lr_optim_map[method_name][ds_name]


for conf in configs.values():
    if conf["loss_type"] == LossType.DBAT: 
        conf["batch_size"] = int(conf["batch_size"] / 2)
        conf["target_batch_size"] = int(conf["target_batch_size"] / 2)


for conf_name, conf in configs.items():
    conf["exp_dir"] = get_conf_dir(conf_name, EXP_DIR)