# Imports

In [58]:
import itertools
import os

from utils_extraction import load_utils

# Utils

In [59]:
def make_train_command_for_experiment(env_vars: dict) -> str:
    env_vars_str = " ".join(f"{key}={value}" for key, value in sorted(env_vars.items()))
    return f"{env_vars_str} sbatch run_single_extract.sh"

# Constants

In [60]:
DEFAULT_ENV_VARS = dict(
    MODEL="meta-llama/Llama-2-7b-chat-hf",
    DATASETS="dbpedia-14",
    LABELED_DATASETS="imdb",
    EVAL_DATASETS="burns",
    PREFIX="normal-bananashed",
    METHOD_LIST="pseudolabel",
    MODE="concat",
    SUP_WEIGHT=1,
    UNSUP_WEIGHT=0,
    LR=1e-2,
    N_EPOCHS=5000,
    OPT="sgd",
    NUM_SEEDS=1,
    N_TRIES=1,
    SPAN_DIRS_COMBINATION="convex",
    # Pseudolabel
    PSEUDOLABEL_N_ROUNDS=5,
    PSEUDOLABEL_SELECT_FN="high_confidence_consistency",
    PSEUDOLABEL_PROB_THRESHOLD=0.8,
    PSEUDOLABEL_LABEL_FN="argmax",
    # Saving
    SAVE_PARAMS=False,
    SAVE_FIT_RESULT=True,
    SAVE_FIT_PLOTS=True,
    SAVE_STATES=False,
    SAVE_ORTHOGONAL_DIRECTIONS=False,
)

ALL_DATASETS = [
    "imdb",
    "amazon-polarity",
    "ag-news",
    "dbpedia-14",
    "copa",
    "rte",
    "boolq",
    "qnli",
    "piqa",
]
ALL_DATASET_PAIRS = list(itertools.product(ALL_DATASETS, ALL_DATASETS))

# Pseudo-label

In [61]:
prefixes = ["normal", "normal-bananshed"]
experiment_configs = [
    {"PREFIX": prefix, "PSEUDOLABEL_N_ROUNDS": 1, "PSEUDOLABEL_SELECT_FN": "all"}
    for prefix in prefixes
]

In [62]:
env_vars_list = []


def make_tag(env_vars: dict) -> str:
    tag = ""
    model = env_vars["MODEL"]
    if model == "meta-llama/Llama-2-7b-chat-hf":
        tag += "llama-2-7b-chat-hf"
    else:
        raise ValueError(f"Unknown model: {model}")

    test_prefix = env_vars.get("TEST_PREFIX", None)
    if test_prefix is None:
        env_vars["TEST_PREFIX"] = env_vars["PREFIX"]
    elif env_vars["PREFIX"] != test_prefix:
        raise NotImplementedError("Different test prefix not supported")
    tag += f"/{env_vars['PREFIX']}"

    method = env_vars["METHOD_LIST"]
    if method != "pseudolabel" and method != ["pseudolabel"]:
        raise NotImplementedError("Only pseudolabel method supported")
    tag += "/pseudolabel"

    tag += f"/rounds_{env_vars['PSEUDOLABEL_N_ROUNDS']}"
    select_fn = env_vars["PSEUDOLABEL_SELECT_FN"]
    tag += f"/select_{select_fn}"
    if select_fn == "high_confidence_consistency":
        tag += f"/prob_thres_{env_vars['PSEUDOLABEL_PROB_THRESHOLD']}"
    elif select_fn != "all":
        raise NotImplementedError(f"Unknown select_fn: {select_fn}")
    tag += f"/label_{env_vars['PSEUDOLABEL_LABEL_FN']}"

    return tag


for experiment_config in experiment_configs:
    # Iterate over all dataset pairs.
    for labeled_ds, unlabeled_ds in ALL_DATASET_PAIRS:
        env_vars = DEFAULT_ENV_VARS.copy()
        env_vars.update(experiment_config)
        env_vars["NAME"] = make_tag(env_vars)

        env_vars["DATASETS"] = unlabeled_ds
        env_vars["LABELED_DATASETS"] = labeled_ds
        env_vars["EVAL_DATASETS"] = f'"{list(set([labeled_ds, unlabeled_ds]))}"'

        env_vars_list.append(env_vars)

In [63]:
for env_vars in env_vars_list:
    print(make_train_command_for_experiment(env_vars), end="\n\n")

DATASETS=imdb EVAL_DATASETS="['imdb']" LABELED_DATASETS=imdb LR=0.01 METHOD_LIST=pseudolabel MODE=concat MODEL=meta-llama/Llama-2-7b-chat-hf NAME=llama-2-7b-chat-hf/normal/pseudolabel/rounds_1/select_all/label_argmax NUM_SEEDS=1 N_EPOCHS=5000 N_TRIES=1 OPT=sgd PREFIX=normal PSEUDOLABEL_LABEL_FN=argmax PSEUDOLABEL_N_ROUNDS=1 PSEUDOLABEL_PROB_THRESHOLD=0.8 PSEUDOLABEL_SELECT_FN=all SAVE_FIT_PLOTS=True SAVE_FIT_RESULT=True SAVE_ORTHOGONAL_DIRECTIONS=False SAVE_PARAMS=False SAVE_STATES=False SPAN_DIRS_COMBINATION=convex SUP_WEIGHT=1 TEST_PREFIX=normal UNSUP_WEIGHT=0 sbatch run_single_extract.sh

DATASETS=amazon-polarity EVAL_DATASETS="['amazon-polarity', 'imdb']" LABELED_DATASETS=imdb LR=0.01 METHOD_LIST=pseudolabel MODE=concat MODEL=meta-llama/Llama-2-7b-chat-hf NAME=llama-2-7b-chat-hf/normal/pseudolabel/rounds_1/select_all/label_argmax NUM_SEEDS=1 N_EPOCHS=5000 N_TRIES=1 OPT=sgd PREFIX=normal PSEUDOLABEL_LABEL_FN=argmax PSEUDOLABEL_N_ROUNDS=1 PSEUDOLABEL_PROB_THRESHOLD=0.8 PSEUDOLABEL_SE

# LR in CCS span

In [64]:
train_test_datasets = list(itertools.product(ALL_DATASETS, ALL_DATASETS))

# load_orthogonal_directions_base_dir = None
load_orthogonal_directions_base_dir = "/nas/ucb/ebronstein/Exhaustive-CCS/extraction_results/Llama-2-7b-chat-hf_normal-bananashed_CCS-in-CCS-span-convex_100_orth_dirs/meta-llama-Llama-2-7b-chat-hf"

env_vars_list = []
# Iterate over num_orth_dirs and datasets.
num_orth_dirs_list = [20]
for num_orth_dirs in num_orth_dirs_list:
    name = f"Llama-2-7b-chat-hf_normal-bananashed_LR-in-CCS-span-convex_{num_orth_dirs}_orth_dirs"

    # Setup: CCS orthogonal directions are from train_ds. Train oracle LR on
    # test_ds over the span of these directions.
    for train_ds, test_ds in train_test_datasets:
        # Load CCS orthogonal directions from train set.
        datasets_str = load_utils.get_combined_datasets_str(
            [train_ds], labeled_datasets=[train_ds]
        )
        load_orthogonal_directions_dir = os.path.join(
            load_orthogonal_directions_base_dir, datasets_str
        )
        if not os.path.exists(load_orthogonal_directions_dir):
            raise ValueError(
                f"Could not find orthogonal directions for {train_ds} in {load_orthogonal_directions_dir}"
            )

        # Arbitrarily use test set as the unlabeled set. This setting does not
        # matter since the oracle only uses the labeled set.
        exp_env_vars = dict(
            NAME=name,
            DATASETS=test_ds,
            LABELED_DATASETS=test_ds,  # Oracle: use labels from test set.
            EVAL_DATASETS=f'"{list(set([train_ds, test_ds]))}"',
            LOAD_ORTHOGONAL_DIRECTIONS_DIR=load_orthogonal_directions_dir,
            NUM_ORTHOGONAL_DIRECTIONS=num_orth_dirs,
        )

        env_vars = dict(DEFAULT_ENV_VARS, **exp_env_vars)
        env_vars_list.append(env_vars)

In [65]:
for env_vars in env_vars_list:
    print(make_train_command_for_experiment(env_vars), end="\n\n")

DATASETS=imdb EVAL_DATASETS="['imdb']" LABELED_DATASETS=imdb LOAD_ORTHOGONAL_DIRECTIONS_DIR=/nas/ucb/ebronstein/Exhaustive-CCS/extraction_results/Llama-2-7b-chat-hf_normal-bananashed_CCS-in-CCS-span-convex_100_orth_dirs/meta-llama-Llama-2-7b-chat-hf/nolabel_imdb-label_imdb LR=0.01 METHOD_LIST=pseudolabel MODE=concat MODEL=meta-llama/Llama-2-7b-chat-hf NAME=Llama-2-7b-chat-hf_normal-bananashed_LR-in-CCS-span-convex_20_orth_dirs NUM_ORTHOGONAL_DIRECTIONS=20 NUM_SEEDS=1 N_EPOCHS=5000 N_TRIES=1 OPT=sgd PREFIX=normal-bananashed PSEUDOLABEL_LABEL_FN=argmax PSEUDOLABEL_N_ROUNDS=5 PSEUDOLABEL_PROB_THRESHOLD=0.8 PSEUDOLABEL_SELECT_FN=high_confidence_consistency SAVE_FIT_PLOTS=True SAVE_FIT_RESULT=True SAVE_ORTHOGONAL_DIRECTIONS=False SAVE_PARAMS=False SAVE_STATES=False SPAN_DIRS_COMBINATION=convex SUP_WEIGHT=1 UNSUP_WEIGHT=0 sbatch run_single_extract.sh

DATASETS=amazon-polarity EVAL_DATASETS="['amazon-polarity', 'imdb']" LABELED_DATASETS=amazon-polarity LOAD_ORTHOGONAL_DIRECTIONS_DIR=/nas/ucb/