In [None]:
import os 
os.chdir("/nas/ucb/oliveradk/diverse-gen/")
import yaml

In [None]:
CONF_DIR = "configs/spur_corr"

In [None]:
METHODS = ["DivDis", "DBAT", "TopK_0.1", "TopK_0.5", "ERM"]
DATASETS = ["toy_grid", "fmnist_mnist", "cifar_mnist", "waterbirds", "celebA-0", "multi-nli"]


AGGREGATE_MIX_RATE = {
    "toy_grid": False,
    "fmnist_mnist": False,
    "cifar_mnist": False,
    "waterbirds": True,
    "celebA-0": True,
    "multi-nli": True,
}

BATCH_SIZE = 32 
TARGET_BATCH_SIZE = 64

BATCH_SIZES = {
    "toy_grid": {
        "batch_size": 32, 
        "target_batch_size": 128,
    }
}

METHOD_DATASET_CONFIGS = {
    "DivDis": {
        "toy_grid": {
            "optimizer": "adamw",
        }
    }, 
    "TopK_0.1": {
        "toy_grid": {
            "optimizer": "sgd",
        }, 
        "multi-nli": {
            "mix_rate_interval_frac": 0.25,
        }
    },
    "TopK_0.5": {
        "toy_grid": {
            "optimizer": "sgd",
        }, 
        "multi-nli": {
            "mix_rate_interval_frac": 0.25,
        }
    },
}

AUX_WEIGHTS = {
    "DivDis": {
        "toy_grid": 1.0,
        "fmnist_mnist": 1.0,
        "cifar_mnist": 1.5,
        "waterbirds": 8,
        "celebA-0": 2.5,
        "multi-nli": 64,
    },
    "DBAT": {
        "toy_grid": 0.4,
        "fmnist_mnist": 0.05,
        "cifar_mnist": 0.01,
        "waterbirds": 0.01,
        "celebA-0": 0.01,
        "multi-nli": 0.01,
    },
    "TopK_0.1": {
        "toy_grid": 1.5,
        "fmnist_mnist": 1.0,
        "cifar_mnist": 1.5,
        "waterbirds": 8,
        "celebA-0": 2.5,
        "multi-nli": 64,
    },
    "TopK_0.5": {
        "toy_grid": 1.0,
        "fmnist_mnist": 1.0,
        "cifar_mnist": 1.0,
        "waterbirds": 3.0,
        "celebA-0": 1.5,
        "multi-nli": 6,
    },
}

In [None]:
for dataset in DATASETS:
    for method in METHODS:
        conf_updates = {}
        if method != "ERM":
            conf_updates["aux_weight"] = AUX_WEIGHTS[method][dataset]
        if method.startswith("TopK"):
            conf_updates["aggregate_mix_rate"] = AGGREGATE_MIX_RATE[dataset]
        if method == "DBAT": 
            conf_updates["batch_size"] = BATCH_SIZE // 2
            conf_updates["target_batch_size"] =(TARGET_BATCH_SIZE if dataset != "toy_grid" else 128) // 2
        if method in METHOD_DATASET_CONFIGS:
            if dataset in METHOD_DATASET_CONFIGS[method]:
                conf_updates.update(METHOD_DATASET_CONFIGS[method][dataset])
        
        # write to file 
        with open(os.path.join(CONF_DIR, f"{method}_{dataset}.yaml"), "w") as f:
            f.write("defaults:\n")
            f.write(f"  - /dataset/{dataset}@_here_\n")
            f.write(f"  - /method/{method}@_here_\n")
            f.write("  - _self_\n\n")
            # Write configuration updates
            for key, value in conf_updates.items():
                f.write(f"{key}: {value}\n")
        
        


In [None]:
import hydra
from hydra import initialize, compose
from omegaconf import DictConfig, OmegaConf

In [None]:
config_path = "configs/spur_corr"
config_name = "DBAT_waterbirds"

In [None]:
OmegaConf.register_new_resolver("div", lambda x, y: x // y)

In [None]:
with initialize(config_path=f"../{config_path}", version_base=None):
    cfg = compose(config_name=config_name)
    print(OmegaConf.to_yaml(cfg))

In [None]:
out = OmegaConf.resolve(cfg)

In [None]:
cfg['head_1_epochs']

In [None]:
out