In [2]:
import itertools
import os

from utils_extraction import load_utils

In [3]:
def make_train_command_for_experiment(env_vars: dict) -> str:
    env_vars_str = " ".join(f"{key}={value}" for key, value in sorted(env_vars.items()))
    return f"{env_vars_str} sbatch run_single_extract.sh"

In [6]:
experiment_configs = [dict()]

default_env_vars = dict(
    MODEL="meta-llama/Llama-2-7b-chat-hf",
    # EVAL_DATASETS="burns",
    PREFIX="normal-bananashed",
    TEST_PREFIX="normal-bananashed",
    METHOD_LIST="CCS+LR-in-span",
    MODE="concat",
    SUP_WEIGHT=1,
    UNSUP_WEIGHT=0,
    LR=1,
    N_EPOCHS=10000,
    OPT="sgd",
    NUM_SEEDS=1,
    N_TRIES=1,
    SPAN_DIRS_COMBINATION="convex",
    SAVE_PARAMS=True,
    SAVE_FIT_RESULT=True,
    SAVE_FIT_PLOTS=True,
    SAVE_ORTHOGONAL_DIRECTIONS=False,
)

all_datasets = [
    "imdb",
    "amazon-polarity",
    "ag-news",
    "dbpedia-14",
    "copa",
    "rte",
    "boolq",
    "qnli",
    "piqa",
]

train_test_datasets = list(itertools.product(all_datasets, all_datasets))

# load_orthogonal_directions_base_dir = None
load_orthogonal_directions_base_dir = "/nas/ucb/ebronstein/Exhaustive-CCS/extraction_results/Llama-2-7b-chat-hf_normal-bananashed_CCS-in-CCS-span-convex_100_orth_dirs/meta-llama-Llama-2-7b-chat-hf"

env_vars_list = []
# Iterate over num_orth_dirs and datasets.
num_orth_dirs_list = [20]
for num_orth_dirs in num_orth_dirs_list:
    name = f"Llama-2-7b-chat-hf_normal-bananashed_LR-in-CCS-span-convex_{num_orth_dirs}_orth_dirs"

    # Setup: CCS orthogonal directions are from train_ds. Train oracle LR on
    # test_ds over the span of these directions.
    for train_ds, test_ds in train_test_datasets:
        # Load CCS orthogonal directions from train set.
        datasets_str = load_utils.get_combined_datasets_str(
            [train_ds], labeled_datasets=[train_ds]
        )
        load_orthogonal_directions_dir = os.path.join(
            load_orthogonal_directions_base_dir, datasets_str
        )
        if not os.path.exists(load_orthogonal_directions_dir):
            raise ValueError(
                f"Could not find orthogonal directions for {train_ds} in {load_orthogonal_directions_dir}"
            )

        # Arbitrarily use test set as the unlabeled set. This setting does not
        # matter since the oracle only uses the labeled set.
        exp_env_vars = dict(
            NAME=name,
            DATASETS=test_ds,
            LABELED_DATASETS=test_ds,  # Oracle: use labels from test set.
            EVAL_DATASETS=f'"{list(set([train_ds, test_ds]))}"',
            LOAD_ORTHOGONAL_DIRECTIONS_DIR=load_orthogonal_directions_dir,
            NUM_ORTHOGONAL_DIRECTIONS=num_orth_dirs,
        )

        env_vars = dict(default_env_vars, **exp_env_vars)
        env_vars_list.append(env_vars)

In [7]:
for env_vars in env_vars_list:
    print(make_train_command_for_experiment(env_vars), end="\n\n")

DATASETS=imdb EVAL_DATASETS="['imdb']" LABELED_DATASETS=imdb LOAD_ORTHOGONAL_DIRECTIONS_DIR=/nas/ucb/ebronstein/Exhaustive-CCS/extraction_results/Llama-2-7b-chat-hf_normal-bananashed_CCS-in-CCS-span-convex_100_orth_dirs/meta-llama-Llama-2-7b-chat-hf/nolabel_imdb-label_imdb LR=1 METHOD_LIST=CCS+LR-in-span MODE=concat MODEL=meta-llama/Llama-2-7b-chat-hf NAME=Llama-2-7b-chat-hf_normal-bananashed_LR-in-CCS-span-convex_20_orth_dirs NUM_ORTHOGONAL_DIRECTIONS=20 NUM_SEEDS=1 N_EPOCHS=10000 N_TRIES=1 OPT=sgd PREFIX=normal-bananashed SAVE_FIT_PLOTS=True SAVE_FIT_RESULT=True SAVE_ORTHOGONAL_DIRECTIONS=False SAVE_PARAMS=True SPAN_DIRS_COMBINATION=convex SUP_WEIGHT=1 TEST_PREFIX=normal-bananashed UNSUP_WEIGHT=0 sbatch run_single_extract.sh

DATASETS=amazon-polarity EVAL_DATASETS="['amazon-polarity', 'imdb']" LABELED_DATASETS=amazon-polarity LOAD_ORTHOGONAL_DIRECTIONS_DIR=/nas/ucb/ebronstein/Exhaustive-CCS/extraction_results/Llama-2-7b-chat-hf_normal-bananashed_CCS-in-CCS-span-convex_100_orth_dirs/m