In [1]:
# ok maybe we just quasi-randomly sample mix rates, then can plot countor 

# do this for all the toy datasets and waterbirds 

# later we can do it for celebA-0 and multi-nli with mix rate 0.5 

In [2]:
import os 
os.chdir("/nas/ucb/oliveradk/diverse-gen/")

In [4]:
from itertools import product
from pathlib import Path
from datetime import datetime

import optuna

from losses.loss_types import LossType
from utils.exp_utils import get_study_args_dict, get_executor, run_experiments
from run_study import get_storage_path

In [11]:
N_TRIALS = 32
N_STARTUP_TRIALS = 8
NODES_PER_STUDY = 8
SAMPLER = "quasi-random"
STUDY_SCRIPT_NAME = "run_study.py"

SCRIPT_NAME = "spur_corr_exp.py"
HPARM_PARENT_DIR = Path("output/cc_mix_rate_sweep")
HPARAM_DIR_NAME = None

if HPARAM_DIR_NAME is None:
    hparam_dir_name = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    hparam_dir = Path(HPARM_PARENT_DIR, hparam_dir_name)
    hparam_dir.mkdir(exist_ok=True, parents=True)
else:
    hparam_dir = Path(HPARM_PARENT_DIR, HPARAM_DIR_NAME)

In [6]:
MIX_RATES = [0.1, 0.25, 0.5, 0.75, 1.0]

# using aux weight 1.0 for now

env_configs = {
    "toy_grid": {"dataset": "toy_grid", "model": "toy_model", "epochs": 100, "batch_size": 32, "target_batch_size": 128, "lr": 1e-3, "optimizer": "sgd"},
    "fmnist_mnist": {"dataset": "fmnist_mnist", "model": "Resnet50", "epochs": 5},
    "cifar_mnist": {"dataset": "cifar_mnist", "model": "Resnet50", "epochs": 5},
    "waterbirds": {"dataset": "waterbirds", "model": "Resnet50", "epochs": 5},
#     "celebA-0": {"dataset": "celebA-0", "model": "Resnet50", "epochs": 5},
#     "multi-nli": {"dataset": "multi-nli", "model": "bert", "epochs": 1, "lr": 1e-5, "combine_neut_entail": True, "contra_no_neg": True},
}

hparam_map = {
    "mix_rate_lower_bound_01": {"type": "float", "range": (0, 1), "log": False},
    "mix_rate_lower_bound_10": {"type": "float", "range": (0, 1), "log": False},
}

def partition_ranges(hparam_map, n_partitions):
    for i in range(n_partitions):
        new_hparam_map = {}
        for param_name, param_info in hparam_map.items():
            new_param_info = {**param_info}
            param_range = param_info["range"]
            param_step = (param_range[1] - param_range[0]) / n_partitions
            param_values = (param_range[0] + i * param_step, param_range[0] + (i + 1) * param_step)
            new_param_info["range"] = param_values
            new_hparam_map[param_name] = new_param_info
        yield new_hparam_map

configs = list(product(env_configs.items(), MIX_RATES))

dataset_to_mem_gb = {
    "toy_grid": 16,
    "fmnist_mnist": 16,
    "cifar_mnist": 16,
    "waterbirds": 16,
    "celebA-0": 32,
    "multi-nli": 32,
}

def get_study_name(env_name, mix_rate):
    return f"{env_name}_{mix_rate}"


In [12]:
for (env_name, env_config), mix_rate in configs: 
    # get configs
    conf = {**env_config, "mix_rate": mix_rate}
    study_name = get_study_name(env_name, mix_rate)
    study_dir = Path(hparam_dir, study_name)
    study_dir.mkdir(exist_ok=True, parents=True)
    
    # # create study (must create it here to nodes don't conflict)
    # study = optuna.create_study(study_name=study_name, storage=get_storage_path(study_dir), direction="minimize", load_if_exists=True)  
    
    # run study
    n_trials_per_node = N_TRIALS // NODES_PER_STUDY
    n_startup_trials_per_node = N_STARTUP_TRIALS // NODES_PER_STUDY
    cmds = [
        {
            **get_study_args_dict(conf, SCRIPT_NAME, hparams, n_trials_per_node, n_startup_trials_per_node, study_name, study_dir), 
            "sampler_seed": i, 
            "sampler_type": SAMPLER
        } for i, hparams in zip(range(NODES_PER_STUDY), partition_ranges(hparam_map, NODES_PER_STUDY))
    ]
    executor = get_executor(study_dir, mem_gb=dataset_to_mem_gb[env_name], slurm_array_parallelism=NODES_PER_STUDY)

    jobs = run_experiments(executor, cmds, STUDY_SCRIPT_NAME)

In [None]:
# ok I think I have everything, just need to think about target of final output

# plot source acc and test acc in two columns for each method
# plot peformance of selected mix rate topk compared to mix rates and other methods
# plot correlation between source acc dif (with ERM) and max(assumed mix rate - mix_rate, 0) (expecting positive)

# ok seems good