In [4]:
import os
import matplotlib 
# set directory
os.chdir("/nas/ucb/oliveradk/diverse-gen/")

In [5]:
from tqdm import tqdm
from pathlib import Path
from datetime import datetime
from itertools import product
import json
import copy

import numpy as np

from losses.loss_types import LossType
from utils.exp_utils import get_executor, run_experiments

In [6]:
SCRIPT_NAME = "spur_corr_exp.py"
EXP_DIR = "output/cc_dbat_aux_weight_sweep"
n_trials = 32

In [12]:
method_configs = {
    "DBAT": {
        "loss_type": LossType.DBAT, 
        "shared_backbone": False, 
        "freeze_heads": True, 
        "binary": True, 
    },
}

dataset_configs = {
    "toy_grid": {"dataset": "toy_grid", "model": "toy_model", "epochs": 100, "batch_size": 32, "target_batch_size": 128, "lr": 1e-3, "optimizer": "sgd"},
    "fmnist_mnist": {"dataset": "fmnist_mnist", "model": "Resnet50", "epochs": 5, "batch_size": 32, "target_batch_size": 64},
    "cifar_mnist": {"dataset": "cifar_mnist", "model": "Resnet50", "epochs": 5, "batch_size": 32, "target_batch_size": 64},
    "waterbirds": {"dataset": "waterbirds", "model": "Resnet50", "epochs": 5, "batch_size": 32, "target_batch_size": 64},
    "celebA-0": {"dataset": "celebA-0", "model": "Resnet50", "epochs": 2, "batch_size": 32, "target_batch_size": 64},
    # "multi-nli": {"dataset": "multi-nli", "model": "bert", "epochs": 1, "lr": 1e-5, "combine_neut_entail": True, "contra_no_neg": True},
}

method_ranges = {
    "DBAT": {"aux_weight": [-1, 1]},
}

configs = {}
for (ds_name, ds_config), (method_name, method_config) in product(dataset_configs.items(), method_configs.items()):
    configs[(ds_name, method_name)] = {**ds_config, **method_config}
    if "batch_size" in ds_config:
        configs[(ds_name, method_name)]["batch_size"] = int(ds_config["batch_size"] / 2)
        configs[(ds_name, method_name)]["target_batch_size"] = int(ds_config["target_batch_size"] / 2)



def get_conf_exp_dir(ds_name, method_name, i):
    return Path(EXP_DIR, f"{ds_name}_{method_name}/{i}")

sampled_configs = []
for (ds_name, method_name), conf in configs.items():
    for i in range(n_trials):
        sample_conf = copy.deepcopy(conf)
        sample_range = method_ranges[method_name]["aux_weight"]
        aux_weight = 10**(np.random.uniform(sample_range[0], sample_range[1]))
        seed = np.random.randint(0, 1000000)
        sample_conf["aux_weight"] = aux_weight
        sample_conf["seed"] = seed
        sample_conf["exp_dir"] = get_conf_exp_dir(ds_name, method_name, i)
        sample_conf["plot_activations"] = False
        sampled_configs.append(sample_conf)

In [13]:
low_mem_configs = [conf for conf in sampled_configs if conf["dataset"] not in ["celebA-0", "multi-nli"]]
high_mem_configs = [conf for conf in sampled_configs if conf["dataset"] in ["celebA-0", "multi-nli"]]


In [18]:
executor = get_executor(EXP_DIR, mem_gb=16)
jobs = run_experiments(executor, low_mem_configs, SCRIPT_NAME)


In [19]:
executor = get_executor(EXP_DIR, mem_gb=32)
jobs = run_experiments(executor, high_mem_configs, SCRIPT_NAME)
