# Imports

In [16]:
import itertools
import os
import typing

from extract import PrefixType
from utils import file_utils
from utils_extraction import load_utils

# Constants

In [30]:
DEFAULT_ENV_VARS = dict(
    MODEL="meta-llama/Llama-2-7b-chat-hf",
    DATASETS="dbpedia-14",
    # LABELED_DATASETS="imdb",
    EVAL_DATASETS="burns",
    PREFIX="normal-bananashed",
    METHOD_LIST=None,  # Must be set.
    MODE="auto",
    LAYER=-1,
    LR=1e-2,
    N_EPOCHS=5000,
    OPT="sgd",
    NUM_SEEDS=1,
    N_TRIES=1,
    SPAN_DIRS_COMBINATION="convex",
    # Saving
    SAVE_PARAMS=False,
    SAVE_FIT_RESULT=True,
    SAVE_FIT_PLOTS=True,
    SAVE_STATES=False,
    SAVE_ORTHOGONAL_DIRECTIONS=False,
)

ALL_DATASETS = [
    "imdb",
    "amazon-polarity",
    "ag-news",
    "dbpedia-14",
    "copa",
    "rte",
    "boolq",
    "qnli",
    "piqa",
]
ALL_DATASET_PAIRS = list(itertools.product(ALL_DATASETS, ALL_DATASETS))

# Utils

In [24]:
MODEL_NAME_TO_TAG = {
    "meta-llama/Llama-2-7b-chat-hf": "llama-2-7b-chat-hf",
    "meta-llama/Llama-2-13b-chat-hf": "llama-2-13b-chat-hf",
    "meta-llama/Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B",
    "mistralai/Mistral-7B-Instruct-v0.2": "mistralai/Mistral-7B-Instruct-v0.2",
}


def make_tag(env_vars: dict) -> str:
    tag = ""

    # Model
    model = env_vars["MODEL"]
    if model not in MODEL_NAME_TO_TAG:
        raise ValueError(f"Unknown model: {model}")
    tag += MODEL_NAME_TO_TAG[model]

    # Prefix
    test_prefix = env_vars.get("TEST_PREFIX", None)
    if test_prefix is None:
        env_vars["TEST_PREFIX"] = env_vars["PREFIX"]
    elif env_vars["PREFIX"] != test_prefix:
        raise NotImplementedError("Different test prefix not supported")
    tag += f"/{env_vars['PREFIX']}"

    # Layer
    tag += f"/layer_{env_vars['LAYER']}"

    # Method
    method = env_vars["METHOD_LIST"]
    if isinstance(method, (list, tuple)):
        if len(method) > 1:
            raise ValueError("Only one method supported at a time.")
        method = method[0]
    if method == "pseudolabel":
        tag += "/pseudolabel"
        tag += f"/rounds_{env_vars['PSEUDOLABEL_N_ROUNDS']}"
        select_fn = env_vars["PSEUDOLABEL_SELECT_FN"]
        tag += f"/select_{select_fn}"
        if select_fn == "high_confidence_consistency":
            tag += f"/prob_thres_{env_vars['PSEUDOLABEL_PROB_THRESHOLD']}"
        elif select_fn != "all":
            raise NotImplementedError(f"Unknown select_fn: {select_fn}")
        tag += f"/label_{env_vars['PSEUDOLABEL_LABEL_FN']}"
    elif method == "CCS+LR":
        mode = env_vars["MODE"]
        if mode == "auto":
            raise ValueError("Set MODE explicitly instead of using 'auto'.")

        tag += f"/ccs_lr/mode_{mode}/sup_weight_{env_vars['SUP_WEIGHT']}/unsup_weight_{env_vars['UNSUP_WEIGHT']}/lr_{env_vars['LR']}/n_epochs_{env_vars['N_EPOCHS']}"
    else:
        raise NotImplementedError(f"Method {method} not supported.")

    return tag


def validate_env_vars(env_vars: dict) -> None:
    prefix = env_vars["PREFIX"]
    if prefix not in typing.get_args(PrefixType):
        raise ValueError(f"Unknown prefix: {prefix}")

    test_prefix = env_vars.get("TEST_PREFIX")
    if test_prefix is not None and test_prefix not in typing.get_args(PrefixType):
        raise ValueError(f"Unknown test prefix: {test_prefix}")

    if not env_vars.get("METHOD_LIST"):
        raise ValueError("METHOD_LIST must be set.")


def make_env_vars_for_experiment(experiment_config: dict) -> dict:
    env_vars = DEFAULT_ENV_VARS.copy()
    env_vars.update(experiment_config)
    env_vars["NAME"] = make_tag(env_vars)

    validate_env_vars(env_vars)
    return env_vars


MODEL_TO_SLURM_MEM = {
    "meta-llama/Llama-2-7b-chat-hf": 16,
    "meta-llama/Llama-2-13b-chat-hf": 16,
}


def make_slurm_args(env_vars: dict) -> str:
    args = []
    mem = MODEL_TO_SLURM_MEM.get(env_vars["MODEL"])
    if mem is not None:
        args.append(f"--mem={mem}gb")

    return " ".join(args)


def make_train_command_for_experiment(env_vars: dict, slurm: bool = True) -> str:
    env_vars_str = " ".join(f"{key}={value}" for key, value in sorted(env_vars.items()))
    if slurm:
        slurm_args = make_slurm_args(env_vars)
        cmd = f"sbatch {slurm_args} slurm_extract.sh"
    else:
        cmd = "./extract.sh"

    return f"{env_vars_str} {cmd}"


def print_train_commands_for_experiments_all_datasets(
    experiment_configs: list[dict], slurm: bool = True
):
    env_vars_list = []
    for experiment_config in experiment_configs:
        for ds in ALL_DATASETS:
            ds_experiment_config = experiment_config.copy()
            if "DATASETS" in ds_experiment_config:
                raise ValueError("DATASETS should not be set in ds_experiment_config.")
            if "LABELED_DATASETS" in ds_experiment_config:
                raise ValueError(
                    "LABELED_DATASETS should not be set in ds_experiment_config."
                )
            if "EVAL_DATASETS" in ds_experiment_config:
                raise ValueError(
                    "EVAL_DATASETS should not be set in ds_experiment_config."
                )

            ds_experiment_config["DATASETS"] = ds
            ds_experiment_config["EVAL_DATASETS"] = "burns"
            env_vars_list.append(make_env_vars_for_experiment(ds_experiment_config))

    for env_vars in env_vars_list:
        print(make_train_command_for_experiment(env_vars, slurm=slurm), end="\n\n")


def print_train_commands_for_experiments_all_dataset_pairs(
    experiment_configs: list[dict], slurm: bool = True
):
    env_vars_list = []
    for experiment_config in experiment_configs:
        # Iterate over all dataset pairs.
        for labeled_ds, unlabeled_ds in ALL_DATASET_PAIRS:
            ds_experiment_config = experiment_config.copy()
            if "DATASETS" in ds_experiment_config:
                raise ValueError("DATASETS should not be set in ds_experiment_config.")
            if "LABELED_DATASETS" in ds_experiment_config:
                raise ValueError(
                    "LABELED_DATASETS should not be set in ds_experiment_config."
                )
            if "EVAL_DATASETS" in ds_experiment_config:
                raise ValueError(
                    "EVAL_DATASETS should not be set in ds_experiment_config."
                )

            ds_experiment_config["DATASETS"] = unlabeled_ds
            ds_experiment_config["LABELED_DATASETS"] = labeled_ds
            ds_experiment_config["EVAL_DATASETS"] = (
                f'"{list(set([labeled_ds, unlabeled_ds]))}"'
            )
            env_vars_list.append(make_env_vars_for_experiment(ds_experiment_config))

    for env_vars in env_vars_list:
        print(make_train_command_for_experiment(env_vars, slurm=slurm), end="\n\n")

# Pseudo-label

In [25]:
DEFAULT_PSEUDOLABEL_ENV_VARS = dict(
    METHOD_LIST="pseudolabel",
    MODE="concat",
    SUP_WEIGHT=1,
    UNSUP_WEIGHT=0,
    LR=1e-2,
    N_EPOCHS=5000,
    # Pseudolabel
    PSEUDOLABEL_N_ROUNDS=5,
    PSEUDOLABEL_SELECT_FN="high_confidence_consistency",
    PSEUDOLABEL_PROB_THRESHOLD=0.7,
    PSEUDOLABEL_LABEL_FN="argmax",
)

## select_fn=high_confidence_consistency label_fn=argmax

In [26]:
prefixes = ["normal"]
models = [
    "meta-llama/Llama-2-13b-chat-hf",
    "meta-llama/Meta-Llama-3-8B",
    "mistralai/Mistral-7B-Instruct-v0.2",
]
layers = [-3, -5, -7, -9]

# Iterate over product of parameters.
experiment_configs = []
for model, layer, prefix in itertools.product(models, layers, prefixes):
    experiment_configs.append(
        dict(DEFAULT_PSEUDOLABEL_ENV_VARS, MODEL=model, LAYER=layer, PREFIX=prefix)
    )

In [27]:
# experiment_configs for all dataset pairs
slurm = True
print_train_commands_for_experiments_all_dataset_pairs(experiment_configs, slurm=slurm)

DATASETS=imdb EVAL_DATASETS="['imdb']" LABELED_DATASETS=imdb LAYER=-3 LR=0.01 METHOD_LIST=pseudolabel MODE=concat MODEL=meta-llama/Llama-2-13b-chat-hf NAME=llama-2-13b-chat-hf/normal/layer_-3/pseudolabel/rounds_5/select_high_confidence_consistency/prob_thres_0.7/label_argmax NUM_SEEDS=1 N_EPOCHS=5000 N_TRIES=1 OPT=sgd PREFIX=normal PSEUDOLABEL_LABEL_FN=argmax PSEUDOLABEL_N_ROUNDS=5 PSEUDOLABEL_PROB_THRESHOLD=0.7 PSEUDOLABEL_SELECT_FN=high_confidence_consistency SAVE_FIT_PLOTS=True SAVE_FIT_RESULT=True SAVE_ORTHOGONAL_DIRECTIONS=False SAVE_PARAMS=False SAVE_STATES=False SPAN_DIRS_COMBINATION=convex SUP_WEIGHT=1 TEST_PREFIX=normal UNSUP_WEIGHT=0 sbatch --mem=16gb slurm_extract.sh

DATASETS=amazon-polarity EVAL_DATASETS="['amazon-polarity', 'imdb']" LABELED_DATASETS=imdb LAYER=-3 LR=0.01 METHOD_LIST=pseudolabel MODE=concat MODEL=meta-llama/Llama-2-13b-chat-hf NAME=llama-2-13b-chat-hf/normal/layer_-3/pseudolabel/rounds_5/select_high_confidence_consistency/prob_thres_0.7/label_argmax NUM_SE

In [8]:
# Only for experiment_configs
env_vars_list = []
for experiment_config in experiment_configs:
    env_vars_list.append(make_env_vars_for_experiment(experiment_config))

for env_vars in env_vars_list:
    print(make_train_command_for_experiment(env_vars), end="\n\n")

DATASETS=dbpedia-14 EVAL_DATASETS=burns LAYER=-1 LR=0.01 METHOD_LIST=pseudolabel MODE=concat MODEL=meta-llama/Llama-2-13b-chat-hf NAME=llama-2-13b-chat-hf/normal/layer_-1/pseudolabel/rounds_5/select_high_confidence_consistency/prob_thres_0.7/label_argmax NUM_SEEDS=1 N_EPOCHS=5000 N_TRIES=1 OPT=sgd PREFIX=normal PSEUDOLABEL_LABEL_FN=argmax PSEUDOLABEL_N_ROUNDS=5 PSEUDOLABEL_PROB_THRESHOLD=0.7 PSEUDOLABEL_SELECT_FN=high_confidence_consistency SAVE_FIT_PLOTS=True SAVE_FIT_RESULT=True SAVE_ORTHOGONAL_DIRECTIONS=False SAVE_PARAMS=False SAVE_STATES=False SPAN_DIRS_COMBINATION=convex SUP_WEIGHT=1 TEST_PREFIX=normal UNSUP_WEIGHT=0 sbatch --mem=16gb slurm_extract.sh

DATASETS=dbpedia-14 EVAL_DATASETS=burns LAYER=-1 LR=0.01 METHOD_LIST=pseudolabel MODE=concat MODEL=meta-llama/Meta-Llama-3-8B NAME=meta-llama/Meta-Llama-3-8B/normal/layer_-1/pseudolabel/rounds_5/select_high_confidence_consistency/prob_thres_0.7/label_argmax NUM_SEEDS=1 N_EPOCHS=5000 N_TRIES=1 OPT=sgd PREFIX=normal PSEUDOLABEL_LABEL

## select_fn=all label_fn=argmax

In [28]:
DEFAULT_PSEUDOLABEL_SELECT_ALL_LABEL_ARGMAX_ENV_VARS = (
    DEFAULT_PSEUDOLABEL_ENV_VARS.copy()
)
DEFAULT_PSEUDOLABEL_SELECT_ALL_LABEL_ARGMAX_ENV_VARS.update(
    PSEUDOLABEL_N_ROUNDS=1, PSEUDOLABEL_SELECT_FN="all"
)

prefixes = ["normal"]
models = [
    "meta-llama/Llama-2-13b-chat-hf",
    "meta-llama/Meta-Llama-3-8B",
    "mistralai/Mistral-7B-Instruct-v0.2",
]
layers = [-1, -3, -5, -7, -9]

# Iterate over product of parameters.
experiment_configs = []
for model, layer, prefix in itertools.product(models, layers, prefixes):
    experiment_configs.append(
        dict(
            DEFAULT_PSEUDOLABEL_SELECT_ALL_LABEL_ARGMAX_ENV_VARS,
            MODEL=model,
            LAYER=layer,
            PREFIX=prefix,
        )
    )

In [29]:
print_train_commands_for_experiments_all_dataset_pairs(experiment_configs)

DATASETS=imdb EVAL_DATASETS="['imdb']" LABELED_DATASETS=imdb LAYER=-1 LR=0.01 METHOD_LIST=pseudolabel MODE=concat MODEL=meta-llama/Llama-2-13b-chat-hf NAME=llama-2-13b-chat-hf/normal/layer_-1/pseudolabel/rounds_1/select_all/label_argmax NUM_SEEDS=1 N_EPOCHS=5000 N_TRIES=1 OPT=sgd PREFIX=normal PSEUDOLABEL_LABEL_FN=argmax PSEUDOLABEL_N_ROUNDS=1 PSEUDOLABEL_PROB_THRESHOLD=0.7 PSEUDOLABEL_SELECT_FN=all SAVE_FIT_PLOTS=True SAVE_FIT_RESULT=True SAVE_ORTHOGONAL_DIRECTIONS=False SAVE_PARAMS=False SAVE_STATES=False SPAN_DIRS_COMBINATION=convex SUP_WEIGHT=1 TEST_PREFIX=normal UNSUP_WEIGHT=0 sbatch --mem=16gb slurm_extract.sh

DATASETS=amazon-polarity EVAL_DATASETS="['amazon-polarity', 'imdb']" LABELED_DATASETS=imdb LAYER=-1 LR=0.01 METHOD_LIST=pseudolabel MODE=concat MODEL=meta-llama/Llama-2-13b-chat-hf NAME=llama-2-13b-chat-hf/normal/layer_-1/pseudolabel/rounds_1/select_all/label_argmax NUM_SEEDS=1 N_EPOCHS=5000 N_TRIES=1 OPT=sgd PREFIX=normal PSEUDOLABEL_LABEL_FN=argmax PSEUDOLABEL_N_ROUNDS=1

# LR (CCS+LR impl)

In [19]:
DEFAULT_LR_ENV_VARS = dict(
    METHOD_LIST="CCS+LR",
    MODE="concat",
    SUP_WEIGHT=1,
    UNSUP_WEIGHT=0,
    LR=1e-2,
    N_EPOCHS=5000,
)

prefixes = ["normal"]
models = [
    "meta-llama/Llama-2-13b-chat-hf",
    "meta-llama/Meta-Llama-3-8B",
    "mistralai/Mistral-7B-Instruct-v0.2",
]
layers = [-3, -5, -7, -9]

# Iterate over product of parameters.
experiment_configs = []
for model, layer, prefix in itertools.product(models, layers, prefixes):
    if layer == -3 and model == "meta-llama/Llama-2-13b-chat-hf":
        continue
    experiment_configs.append(
        dict(DEFAULT_LR_ENV_VARS, MODEL=model, LAYER=layer, PREFIX=prefix)
    )

In [20]:
slurm = True

env_vars_list = []
for experiment_config in experiment_configs:
    for ds in ALL_DATASETS:
        ds_experiment_config = experiment_config.copy()
        if "DATASETS" in ds_experiment_config:
            raise ValueError("DATASETS should not be set in ds_experiment_config.")
        if "LABELED_DATASETS" in ds_experiment_config:
            raise ValueError(
                "LABELED_DATASETS should not be set in ds_experiment_config."
            )
        if "EVAL_DATASETS" in ds_experiment_config:
            raise ValueError("EVAL_DATASETS should not be set in ds_experiment_config.")

        ds_experiment_config["DATASETS"] = ds
        ds_experiment_config["LABELED_DATASETS"] = ds
        ds_experiment_config["EVAL_DATASETS"] = "burns"
        env_vars_list.append(make_env_vars_for_experiment(ds_experiment_config))

for env_vars in env_vars_list:
    print(make_train_command_for_experiment(env_vars, slurm=slurm), end="\n\n")



DATASETS=imdb EVAL_DATASETS=burns LABELED_DATASETS=imdb LAYER=-5 LR=0.01 METHOD_LIST=CCS+LR MODE=concat MODEL=meta-llama/Llama-2-13b-chat-hf NAME=llama-2-13b-chat-hf/normal/layer_-5/ccs_lr/mode_concat/sup_weight_1/unsup_weight_0/lr_0.01/n_epochs_5000 NUM_SEEDS=1 N_EPOCHS=5000 N_TRIES=1 OPT=sgd PREFIX=normal SAVE_FIT_PLOTS=True SAVE_FIT_RESULT=True SAVE_ORTHOGONAL_DIRECTIONS=False SAVE_PARAMS=False SAVE_STATES=False SPAN_DIRS_COMBINATION=convex SUP_WEIGHT=1 TEST_PREFIX=normal UNSUP_WEIGHT=0 sbatch --mem=16gb slurm_extract.sh

DATASETS=amazon-polarity EVAL_DATASETS=burns LABELED_DATASETS=amazon-polarity LAYER=-5 LR=0.01 METHOD_LIST=CCS+LR MODE=concat MODEL=meta-llama/Llama-2-13b-chat-hf NAME=llama-2-13b-chat-hf/normal/layer_-5/ccs_lr/mode_concat/sup_weight_1/unsup_weight_0/lr_0.01/n_epochs_5000 NUM_SEEDS=1 N_EPOCHS=5000 N_TRIES=1 OPT=sgd PREFIX=normal SAVE_FIT_PLOTS=True SAVE_FIT_RESULT=True SAVE_ORTHOGONAL_DIRECTIONS=False SAVE_PARAMS=False SAVE_STATES=False SPAN_DIRS_COMBINATION=convex

# CCS (CCS+LR impl)

In [31]:
DEFAULT_CCS_ENV_VARS = dict(
    METHOD_LIST="CCS+LR",
    MODE="concat",
    SUP_WEIGHT=0,
    UNSUP_WEIGHT=1,
    LR=1e-2,
    N_EPOCHS=5000,
)

prefixes = ["normal"]
models = [
    "meta-llama/Llama-2-7b-chat-hf",
    "meta-llama/Llama-2-13b-chat-hf",
    "meta-llama/Meta-Llama-3-8B",
    "mistralai/Mistral-7B-Instruct-v0.2",
]
layers = [-3, -5, -7, -9]

# Iterate over product of parameters.
experiment_configs = []
for model, layer, prefix in itertools.product(models, layers, prefixes):
    experiment_configs.append(
        dict(DEFAULT_CCS_ENV_VARS, MODEL=model, LAYER=layer, PREFIX=prefix)
    )

In [32]:
slurm = True

env_vars_list = []
for experiment_config in experiment_configs:
    for ds in ALL_DATASETS:
        ds_experiment_config = experiment_config.copy()
        if "DATASETS" in ds_experiment_config:
            raise ValueError("DATASETS should not be set in ds_experiment_config.")
        if "LABELED_DATASETS" in ds_experiment_config:
            raise ValueError(
                "LABELED_DATASETS should not be set in ds_experiment_config."
            )
        if "EVAL_DATASETS" in ds_experiment_config:
            raise ValueError("EVAL_DATASETS should not be set in ds_experiment_config.")

        ds_experiment_config["DATASETS"] = ds
        ds_experiment_config["LABELED_DATASETS"] = ds
        ds_experiment_config["EVAL_DATASETS"] = "burns"
        env_vars_list.append(make_env_vars_for_experiment(ds_experiment_config))

for env_vars in env_vars_list:
    print(make_train_command_for_experiment(env_vars, slurm=slurm), end="\n\n")

DATASETS=imdb EVAL_DATASETS=burns LABELED_DATASETS=imdb LAYER=-3 LR=0.01 METHOD_LIST=CCS+LR MODE=concat MODEL=meta-llama/Llama-2-7b-chat-hf NAME=llama-2-7b-chat-hf/normal/layer_-3/ccs_lr/mode_concat/sup_weight_0/unsup_weight_1/lr_0.01/n_epochs_5000 NUM_SEEDS=1 N_EPOCHS=5000 N_TRIES=1 OPT=sgd PREFIX=normal SAVE_FIT_PLOTS=True SAVE_FIT_RESULT=True SAVE_ORTHOGONAL_DIRECTIONS=False SAVE_PARAMS=False SAVE_STATES=False SPAN_DIRS_COMBINATION=convex SUP_WEIGHT=0 TEST_PREFIX=normal UNSUP_WEIGHT=1 sbatch --mem=16gb slurm_extract.sh

DATASETS=amazon-polarity EVAL_DATASETS=burns LABELED_DATASETS=amazon-polarity LAYER=-3 LR=0.01 METHOD_LIST=CCS+LR MODE=concat MODEL=meta-llama/Llama-2-7b-chat-hf NAME=llama-2-7b-chat-hf/normal/layer_-3/ccs_lr/mode_concat/sup_weight_0/unsup_weight_1/lr_0.01/n_epochs_5000 NUM_SEEDS=1 N_EPOCHS=5000 N_TRIES=1 OPT=sgd PREFIX=normal SAVE_FIT_PLOTS=True SAVE_FIT_RESULT=True SAVE_ORTHOGONAL_DIRECTIONS=False SAVE_PARAMS=False SAVE_STATES=False SPAN_DIRS_COMBINATION=convex SUP

# LR in CCS span

In [12]:
train_test_datasets = list(itertools.product(ALL_DATASETS, ALL_DATASETS))

# load_orthogonal_directions_base_dir = None
load_orthogonal_directions_base_dir = "/nas/ucb/ebronstein/Exhaustive-CCS/extraction_results/Llama-2-7b-chat-hf_normal-bananashed_CCS-in-CCS-span-convex_100_orth_dirs/meta-llama-Llama-2-7b-chat-hf"

env_vars_list = []
# Iterate over num_orth_dirs and datasets.
num_orth_dirs_list = [20]
for num_orth_dirs in num_orth_dirs_list:
    name = f"Llama-2-7b-chat-hf_normal-bananashed_LR-in-CCS-span-convex_{num_orth_dirs}_orth_dirs"

    # Setup: CCS orthogonal directions are from train_ds. Train oracle LR on
    # test_ds over the span of these directions.
    for train_ds, test_ds in train_test_datasets:
        # Load CCS orthogonal directions from train set.
        datasets_str = load_utils.get_combined_datasets_str(
            [train_ds], labeled_datasets=[train_ds]
        )
        load_orthogonal_directions_dir = os.path.join(
            load_orthogonal_directions_base_dir, datasets_str
        )
        if not os.path.exists(load_orthogonal_directions_dir):
            raise ValueError(
                f"Could not find orthogonal directions for {train_ds} in {load_orthogonal_directions_dir}"
            )

        # Arbitrarily use test set as the unlabeled set. This setting does not
        # matter since the oracle only uses the labeled set.
        exp_env_vars = dict(
            NAME=name,
            DATASETS=test_ds,
            LABELED_DATASETS=test_ds,  # Oracle: use labels from test set.
            EVAL_DATASETS=f'"{list(set([train_ds, test_ds]))}"',
            LOAD_ORTHOGONAL_DIRECTIONS_DIR=load_orthogonal_directions_dir,
            NUM_ORTHOGONAL_DIRECTIONS=num_orth_dirs,
        )

        env_vars = dict(DEFAULT_ENV_VARS, **exp_env_vars)
        env_vars_list.append(env_vars)

In [13]:
for env_vars in env_vars_list:
    print(make_train_command_for_experiment(env_vars), end="\n\n")

DATASETS=imdb EVAL_DATASETS="['imdb']" LABELED_DATASETS=imdb LOAD_ORTHOGONAL_DIRECTIONS_DIR=/nas/ucb/ebronstein/Exhaustive-CCS/extraction_results/Llama-2-7b-chat-hf_normal-bananashed_CCS-in-CCS-span-convex_100_orth_dirs/meta-llama-Llama-2-7b-chat-hf/nolabel_imdb-label_imdb LR=0.01 METHOD_LIST=None MODE=auto MODEL=meta-llama/Llama-2-7b-chat-hf NAME=Llama-2-7b-chat-hf_normal-bananashed_LR-in-CCS-span-convex_20_orth_dirs NUM_ORTHOGONAL_DIRECTIONS=20 NUM_SEEDS=1 N_EPOCHS=5000 N_TRIES=1 OPT=sgd PREFIX=normal-bananashed SAVE_FIT_PLOTS=True SAVE_FIT_RESULT=True SAVE_ORTHOGONAL_DIRECTIONS=False SAVE_PARAMS=False SAVE_STATES=False SPAN_DIRS_COMBINATION=convex sbatch run_single_extract.sh

DATASETS=amazon-polarity EVAL_DATASETS="['imdb', 'amazon-polarity']" LABELED_DATASETS=amazon-polarity LOAD_ORTHOGONAL_DIRECTIONS_DIR=/nas/ucb/ebronstein/Exhaustive-CCS/extraction_results/Llama-2-7b-chat-hf_normal-bananashed_CCS-in-CCS-span-convex_100_orth_dirs/meta-llama-Llama-2-7b-chat-hf/nolabel_imdb-label_i