In [1]:
import os 
os.chdir("/nas/ucb/oliveradk/diverse-gen/")

In [2]:
from tqdm import tqdm
from pathlib import Path
from datetime import datetime
from itertools import product
import json
import optuna

from submitit.helpers import CommandFunction

from losses.loss_types import LossType
from utils.exp_utils import get_executor, get_conf_dir, run_experiments
from utils.utils import conf_to_args
from run_study import get_storage_path


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
N_TRIALS = 32
N_STARTUP_TRIALS = 8
NODES_PER_STUDY = 8
STUDY_SCRIPT_NAME = "run_study.py"

SCRIPT_NAME = "spur_corr_exp.py"
HPARM_PARENT_DIR = Path("output/cc_aux_weight_sweep")
HPARAM_DIR_NAME = "2025-01-31_14-53-41"

if HPARAM_DIR_NAME is None:
    hparam_dir_name = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    hparam_dir = Path(HPARM_PARENT_DIR, hparam_dir_name)
    hparam_dir.mkdir(exist_ok=True, parents=True)
else:
    hparam_dir = Path(HPARM_PARENT_DIR, HPARAM_DIR_NAME)

### Configs
loss_configs = {
    "DivDis": {"loss_type": LossType.DIVDIS},
    "TopK 0.1": {"loss_type": LossType.TOPK, "mix_rate_lower_bound": 0.1},
    "TopK 0.5": {"loss_type": LossType.TOPK, "mix_rate_lower_bound": 0.5},
    "DBAT": {"loss_type": LossType.DBAT, "shared_backbone": False, "freeze_heads": True, "binary": True, "batch_size": 16, "target_batch_size": 32},
}

# TODO: shorten number of epochs
env_configs = {
    # "toy_grid": {"dataset": "toy_grid", "model": "toy_model", "epochs": 100, "batch_size": 32, "target_batch_size": 128, "lr": 1e-3, "optimizer": "sgd"},
    # "fmnist_mnist": {"dataset": "fmnist_mnist", "model": "Resnet50", "epochs": 5},
    # "cifar_mnist": {"dataset": "cifar_mnist", "model": "Resnet50", "epochs": 5},
    # "waterbirds": {"dataset": "waterbirds", "model": "Resnet50", "epochs": 5},
    # "celebA-0": {"dataset": "celebA-0", "model": "Resnet50", "epochs": 5},
    "multi-nli": {"dataset": "multi-nli", "model": "bert", "epochs": 1, "lr": 1e-5, "combine_neut_entail": True, "contra_no_neg": True},
}

method_hparam_map = {
    "DivDis": {"aux_weight": {"type": "float", "range": [1e0, 1e2], "log": True}},
    "TopK 0.1": {"aux_weight": {"type": "float", "range": [1e0, 1e1], "log": True}},
    "TopK 0.5": {"aux_weight": {"type": "float", "range": [1e0, 1e1], "log": True}},
    "DBAT": {"aux_weight": {"type": "float", "range": [1e-2, 3e0], "log": True}},
}

mix_rates = [None]

configs = list(product(env_configs.items(), loss_configs.items(), mix_rates))

dataset_to_mem_gb = {
    "toy_grid": 16,
    "fmnist_mnist": 16,
    "cifar_mnist": 16,
    "waterbirds": 16,
    "celebA-0": 32,
    "multi-nli": 32,
}

In [4]:
def get_study_args_dict(args: list[str], script_name: str, hparams: dict, n_trials: int, n_startup_trials: int, study_name: str, study_dir: Path):
    cmd_args = {
        "args": conf_to_args(args),
        "script_name": script_name,
        "hparams": hparams,
        "n_trials": n_trials,
        "n_startup_trials": n_startup_trials,
        "study_name": study_name,
        "study_dir": study_dir,
    }
    return cmd_args

In [5]:
def get_study_name(env_name, loss_name, mix_rate):
    return f"{env_name}_{loss_name}_{mix_rate}"


# Run Studies

In [16]:
x = {"foo": "bar"}.update({"baz": "biz"})
print(x)

None


In [6]:
for (env_name, env_config), (loss_name, loss_config), mix_rate in configs: 
    # get configs
    conf = {**env_config, **loss_config, "mix_rate": mix_rate}
    study_name = get_study_name(env_name, loss_name, mix_rate)
    study_dir = Path(hparam_dir, study_name)
    study_dir.mkdir(exist_ok=True, parents=True)
    hparams = method_hparam_map[loss_name]
    
    # create study (must create it here to nodes don't conflict)
    study = optuna.create_study(study_name=study_name, storage=get_storage_path(study_dir), direction="minimize", load_if_exists=True)  
    
    # run study
    n_trials_per_node = N_TRIALS // NODES_PER_STUDY
    n_startup_trials_per_node = N_STARTUP_TRIALS // NODES_PER_STUDY
    executor = get_executor(study_dir, mem_gb=dataset_to_mem_gb[env_name], slurm_array_parallelism=NODES_PER_STUDY)
    cmd = get_study_args_dict(conf, SCRIPT_NAME, hparams, n_trials_per_node, n_startup_trials_per_node, study_name, study_dir)
    cmds = [{**cmd, "sampler_seed": i} for i in range(NODES_PER_STUDY)]
    jobs = run_experiments(executor, cmds, STUDY_SCRIPT_NAME)
    
    


[I 2025-01-31 14:53:49,541] A new study created in RDB with name: multi-nli_DivDis_None
[I 2025-01-31 14:53:49,807] A new study created in RDB with name: multi-nli_TopK 0.1_None
[I 2025-01-31 14:53:49,961] A new study created in RDB with name: multi-nli_TopK 0.5_None
[I 2025-01-31 14:53:50,114] A new study created in RDB with name: multi-nli_DBAT_None


In [12]:
print(jobs[0].stdout())

submitit INFO (2025-01-31 14:53:54,009) - Starting with JobEnvironment(job_id=806357_0, hostname=gan.ist.berkeley.edu, local_rank=0(1), node=0(1), global_rank=0(1))
submitit INFO (2025-01-31 14:53:54,009) - Loading pickle: /nas/ucb/oliveradk/diverse-gen/output/cc_aux_weight_sweep/2025-01-31_14-53-41/multi-nli_DBAT_None/806357_0_submitted.pkl
The following command is sent: "python run_study.py args=['dataset=multi-nli', 'model=bert', 'epochs=1', 'lr=1e-05', 'combine_neut_entail=True', 'contra_no_neg=True', 'loss_type=DBAT', 'shared_backbone=False', 'freeze_heads=True', 'binary=True', 'batch_size=16', 'target_batch_size=32', 'mix_rate=null'] script_name=spur_corr_exp.py hparams={'aux_weight': {'type': 'float', 'range': [0.01, 3.0], 'log': True}} n_trials=4 study_name=multi-nli_DBAT_None study_dir=output/cc_aux_weight_sweep/2025-01-31_14-53-41/multi-nli_DBAT_None"



# Process Results

In [15]:
from optuna.visualization import plot_slice
# load study 
hparam_dir = "output/cc_aux_weight_sweep/2025-01-31_14-53-41"
ENV_NAME = "multi-nli"
LOSS_NAME = "DBAT"
MIX_RATE = None

study_name = get_study_name(ENV_NAME, LOSS_NAME, MIX_RATE)
study_dir = Path(hparam_dir, study_name)
storage_path = get_storage_path(study_dir)
study = optuna.load_study(study_name=study_name, storage=storage_path)

# plot results
from optuna.visualization import plot_optimization_history, plot_param_importances, plot_contour

# Plot optimization history
fig_history = plot_optimization_history(study)
fig_history.show()

# Plot parameter importances
fig_importance = plot_param_importances(study)
fig_importance.show()

# plot slice 
fig_slice = plot_slice(study)
fig_slice.show()

# Plot contour of parameters (if you have 2+ parameters)
# try:
#     fig_contour = plot_contour(study)
#     fig_contour.show()
# except:
#     print("Contour plot requires 2+ parameters")
