# Imports

In [1]:
import itertools
import os
import typing
from typing import Optional

from extract import PrefixType
from utils import file_utils
from utils.types import Milestones, PiecewiseLinearSchedule
from utils_extraction import load_utils
from utils_generation.construct_prompts import prompt_name_to_index

# Constants

In [2]:
DEFAULT_ENV_VARS = dict(
    MODEL="meta-llama/Llama-2-7b-chat-hf",
    DATASETS="dbpedia-14",
    # LABELED_DATASETS="imdb",
    EVAL_DATASETS="burns",
    PREFIX="normal",
    TEST_PREFIX="normal",
    METHOD_LIST=None,  # Must be set.
    MODE="auto",
    LAYER=-1,
    LR=1e-2,
    N_EPOCHS=5000,
    WEIGHT_DECAY=0.0,
    OPT="sgd",
    NUM_SEEDS=1,
    N_TRIES=1,
    SPAN_DIRS_COMBINATION="convex",
    # Saving
    SAVE_PARAMS=False,
    SAVE_FIT_RESULT=True,
    SAVE_FIT_PLOTS=False,
    SAVE_STATES=False,
    SAVE_ORTHOGONAL_DIRECTIONS=False,
)

ALL_DATASETS = [
    "imdb",
    "amazon-polarity",
    "ag-news",
    "dbpedia-14",
    "copa",
    "rte",
    "boolq",
    "qnli",
    "piqa",
]
ALL_DATASET_PAIRS = list(itertools.product(ALL_DATASETS, ALL_DATASETS))

# Datasets that support the Alice explicit opinion prompt. This should eventually
# be ALL_DATASETS.
ALICE_PROMPT_DATASETS = ["imdb", "amazon-polarity", "ag-news", "dbpedia-14", "rte", "boolq", "qnli"]
ALICE_EXPLICIT_OPINION_PROMPT_NAMES = {
    "imdb": [f"alice_explicit_opinion_{i}" for i in range(1, 3)],
    "amazon-polarity": [f"alice_explicit_opinion_{i}" for i in range(1, 3)],
    "ag-news": [f"alice_explicit_opinion_{i}" for i in range(1, 9)],
}

# Utils

In [3]:
MODEL_NAME_TO_TAG = {
    "meta-llama/Llama-2-7b-chat-hf": "llama-2-7b-chat-hf",
    "meta-llama/Llama-2-13b-chat-hf": "llama-2-13b-chat-hf",
}


# TODO: import this from the relevant load_utils module.
def make_loss_weight_schedule_tag(loss_weight: PiecewiseLinearSchedule) -> str:
    if isinstance(loss_weight, (float, int)):
        return str(loss_weight)

    if isinstance(loss_weight, (list, tuple)):
        loss_weight = sorted(
            [(int(epoch), float(value)) for epoch, value in loss_weight]
        )
        return "_".join(f"{epoch}_{value}" for epoch, value in loss_weight)

    raise TypeError(
        "Loss weight schedule must be a float, or a list or tuple of (int, float) "
        f"tuples/lists, got {loss_weight}"
    )


def make_ccs_lr_sub_tag(env_vars: dict) -> str:
    mode = env_vars["MODE"]
    if mode == "auto":
        raise ValueError("Set MODE explicitly instead of using 'auto'.")

    unsup_weight = env_vars["UNSUP_WEIGHT"]
    unsup_weight_str = make_loss_weight_schedule_tag(unsup_weight)
    sup_weight_str = make_loss_weight_schedule_tag(env_vars["SUP_WEIGHT"])
    tag = f"/ccs_lr/mode_{mode}/sup_weight_{sup_weight_str}/unsup_weight_{unsup_weight_str}"

    # Unsupervised loss term weights. Only add these if the unsupervised loss is
    # used.
    if unsup_weight != 0:
        for weight_name in ("CONSISTENCY_WEIGHT", "CONFIDENCE_WEIGHT"):
            weight = env_vars.get(weight_name, 1)
            if weight != 1:
                tag += f"/{weight_name.lower()}_{make_loss_weight_schedule_tag(weight)}"

    tag += f"/lr_{env_vars['LR']}/n_epochs_{env_vars['N_EPOCHS']}"

    weight_decay = env_vars["WEIGHT_DECAY"]
    if weight_decay > 0:
        tag += f"/weight_decay_{weight_decay}"

    return tag


def make_pseudolabel_sub_tag(env_vars: dict) -> str:
    tag = f"/pseudolabel/rounds_{env_vars['PSEUDOLABEL_N_ROUNDS']}"

    # select_fn
    select_fn = env_vars["PSEUDOLABEL_SELECT_FN"]
    tag += f"/select_{select_fn}"
    prob_threshold = env_vars.get("PSEUDOLABEL_PROB_THRESHOLD")
    consistency_err_threshold = env_vars.get("PSEUDOLABEL_CONSISTENCY_ERR_THRESHOLD")
    if select_fn == "confidence":
        tag += f"/prob_thres_{prob_threshold}"
    elif select_fn == "consistency":
        tag += f"/cons_thres_{consistency_err_threshold}"
    elif select_fn == "confidence_consistency":
        tag += f"/prob_thres_{prob_threshold}/cons_thres_{consistency_err_threshold}"
    elif select_fn == "high_confidence_consistency":
        tag += f"/prob_thres_{prob_threshold}"
    elif select_fn != "all":
        raise NotImplementedError(f"Unknown select_fn: {select_fn}")

    # label_fn
    label_fn = env_vars["PSEUDOLABEL_LABEL_FN"]
    tag += f"/label_{label_fn}"
    if label_fn == "softmax":
        tag += f"/temp_{env_vars['PSEUDOLABEL_SOFTMAX_TEMP']}"

    tag += make_ccs_lr_sub_tag(env_vars)
    return tag


def make_tag(env_vars: dict) -> str:
    tag = ""

    # Model
    model = env_vars["MODEL"]
    tag += MODEL_NAME_TO_TAG.get(model, model)

    # Prefix
    prefix = env_vars["PREFIX"]
    labeled_prefix = env_vars.get("LABELED_PREFIX")
    if labeled_prefix is None:
        tag += f"/{prefix}"
    else:
        tag += f"/label_{labeled_prefix}-nolabel_{prefix}"
    tag += f"/test_{env_vars['TEST_PREFIX']}"

    # Prompt indices and subsets.
    for label, idx_key, subset_key, ds_name in [
        ("prompts", "PROMPT_IDX", "PROMPT_SUBSET", "DATASETS"),
        (
            "labeled_prompts",
            "LABELED_PROMPT_IDX",
            "LABELED_PROMPT_SUBSET",
            "LABELED_DATASETS",
        ),
        ("test_prompts", "TEST_PROMPT_IDX", "TEST_PROMPT_SUBSET", "EVAL_DATASETS"),
    ]:
        prompt_index = env_vars.get(idx_key)
        prompt_subset = env_vars.get(subset_key)
        if prompt_subset is not None and prompt_index is not None:
            raise ValueError(f"Both {subset_key} and {idx_key} are set.")
        # Add prompt index to tag.
        if prompt_index is not None:
            datasets = env_vars.get(ds_name)
            if isinstance(datasets, str):
                datasets = [datasets]

            prompt_index_str = []
            for ds in sorted(set(prompt_index.keys())):
                # Skip datasets that are not used.
                if ds not in datasets:
                    continue
                prompt_index_str.append(ds)
                idxs = prompt_index[ds]
                for idx_or_name in idxs:
                    if isinstance(idx_or_name, int):
                        idx = idx_or_name
                    else:
                        idx = prompt_name_to_index(idx_or_name, ds)
                        if idx is None:
                            raise ValueError(
                                f"Unknown prompt name: {idx_or_name} for dataset {ds}"
                            )
                    prompt_index_str.append(str(idx))
            prompt_index_str = "_".join(prompt_index_str)
            tag += f"/{label}_{prompt_index_str}"
        # Add prompt subset to tag f=if it is not the default.
        elif prompt_subset is not None and prompt_subset != "default":
            tag += f"/{label}_{prompt_subset}"

    # Layer
    tag += f"/layer_{env_vars['LAYER']}"

    # Projection
    projection_method = env_vars.get("PROJECTION_METHOD")
    if projection_method is not None:
        tag += f"/proj_{projection_method}/n_comp_{env_vars['PROJECTION_N_COMPONENTS']}"

    # Method
    method = env_vars["METHOD_LIST"]
    if isinstance(method, (list, tuple)):
        if len(method) > 1:
            raise ValueError("Only one method supported at a time.")
        method = method[0]
    if method == "pseudolabel":
        tag += make_pseudolabel_sub_tag(env_vars)
    elif method == "CCS+LR":
        tag += make_ccs_lr_sub_tag(env_vars)
    elif method == "CCS+LR-in-span":
        mode = env_vars["MODE"]
        if mode == "auto":
            raise ValueError("Set MODE explicitly instead of using 'auto'.")

        tag += f"/ccs_lr_in_span/mode_{mode}/sup_weight_{env_vars['SUP_WEIGHT']}/unsup_weight_{env_vars['UNSUP_WEIGHT']}/lr_{env_vars['LR']}/n_epochs_{env_vars['N_EPOCHS']}/n_orth_dirs_{env_vars['NUM_ORTHOGONAL_DIRECTIONS']}/span_dirs_combo_{env_vars['SPAN_DIRS_COMBINATION']}"
    else:
        raise NotImplementedError(f"Method {method} not supported.")

    # Optimizer
    opt = env_vars["OPT"]
    if opt != "sgd":
        tag += f"/opt_{opt}"

    return tag


def validate_env_vars(env_vars: dict) -> None:
    for key in ["PREFIX", "LABELED_PREFIX", "TEST_PREFIX"]:
        prefix = env_vars.get("TEST_PREFIX")
        if prefix is not None and prefix not in typing.get_args(PrefixType):
            raise ValueError(f"Unknown {key}: {prefix}")

    if not env_vars.get("METHOD_LIST"):
        raise ValueError("METHOD_LIST must be set.")


def make_env_vars_for_experiment(experiment_config: dict) -> dict:
    env_vars = DEFAULT_ENV_VARS.copy()
    env_vars.update(experiment_config)
    env_vars["NAME"] = make_tag(env_vars)

    validate_env_vars(env_vars)
    return env_vars


MODEL_TO_SLURM_MEM = {
    "meta-llama/Llama-2-7b-chat-hf": 16,
    "meta-llama/Llama-2-13b-chat-hf": 16,
}


def make_slurm_args(env_vars: dict, slurm_args: dict = {}) -> str:
    args = []
    mem = MODEL_TO_SLURM_MEM.get(env_vars["MODEL"])
    if mem is not None:
        args.append(f"--mem={mem}gb")

    for key, value in slurm_args.items():
        args.append(f"--{key}={value}")

    return " ".join(args)


def make_train_command_for_experiment(
    env_vars: dict, slurm: bool = True, slurm_args: dict = {}
) -> str:
    env_vars_str = " ".join(
        f'{key}="{value}"' for key, value in sorted(env_vars.items())
    )
    if slurm:
        slurm_args = make_slurm_args(env_vars, slurm_args=slurm_args)
        cmd = f"sbatch {slurm_args} slurm_extract.sh"
    else:
        cmd = "./slurm_extract.sh"

    return f"{env_vars_str} {cmd}"


def print_train_commands_for_experiments_all_datasets(
    experiment_configs: list[dict], slurm: bool = True
):
    env_vars_list = []
    for experiment_config in experiment_configs:
        for ds in ALL_DATASETS:
            ds_experiment_config = experiment_config.copy()
            if "DATASETS" in ds_experiment_config:
                raise ValueError("DATASETS should not be set in ds_experiment_config.")
            if "LABELED_DATASETS" in ds_experiment_config:
                raise ValueError(
                    "LABELED_DATASETS should not be set in ds_experiment_config."
                )
            if "EVAL_DATASETS" in ds_experiment_config:
                raise ValueError(
                    "EVAL_DATASETS should not be set in ds_experiment_config."
                )

            ds_experiment_config["DATASETS"] = ds
            ds_experiment_config["EVAL_DATASETS"] = "burns"
            env_vars_list.append(make_env_vars_for_experiment(ds_experiment_config))

    for env_vars in env_vars_list:
        print(make_train_command_for_experiment(env_vars, slurm=slurm), end="\n\n")


def print_train_commands_for_experiments_all_dataset_pairs(
    experiment_configs: list[dict],
    labeled_datasets: Optional[list[str]] = None,
    unlabeled_datasets: Optional[list[str]] = None,
    slurm: bool = True,
    slurm_args: dict = {},
):
    if labeled_datasets is None:
        labeled_datasets = ALL_DATASETS
    if unlabeled_datasets is None:
        unlabeled_datasets = ALL_DATASETS
    dataset_pairs = list(itertools.product(labeled_datasets, unlabeled_datasets))
    env_vars_list = []
    for experiment_config in experiment_configs:
        # Iterate over dataset pairs.
        for labeled_ds, unlabeled_ds in dataset_pairs:
            ds_experiment_config = experiment_config.copy()
            if "DATASETS" in ds_experiment_config:
                raise ValueError("DATASETS should not be set in ds_experiment_config.")
            if "LABELED_DATASETS" in ds_experiment_config:
                raise ValueError(
                    "LABELED_DATASETS should not be set in ds_experiment_config."
                )
            if "EVAL_DATASETS" in ds_experiment_config:
                raise ValueError(
                    "EVAL_DATASETS should not be set in ds_experiment_config."
                )

            all_train_ds = set()
            for datasets in [labeled_ds, unlabeled_ds]:
                if isinstance(datasets, (list, tuple)):
                    all_train_ds.update(datasets)
                elif isinstance(datasets, str):
                    all_train_ds.add(datasets)
                else:
                    raise ValueError(
                        f"Datasets must be a string, list, or tuple, got {datasets}"
                    )

            ds_experiment_config["DATASETS"] = unlabeled_ds
            ds_experiment_config["LABELED_DATASETS"] = labeled_ds
            ds_experiment_config["EVAL_DATASETS"] = f"{list(all_train_ds)}"
            env_vars_list.append(make_env_vars_for_experiment(ds_experiment_config))

    for env_vars in env_vars_list:
        print(
            make_train_command_for_experiment(
                env_vars, slurm=slurm, slurm_args=slurm_args
            ),
            end="\n\n",
        )


def print_train_commands_for_experiments_single_datasets(
    experiment_configs: list[dict],
    datasets: Optional[list[str]] = None,
    eval_datasets: Optional[list[str]] = None,
    slurm: bool = True,
    slurm_args: dict = {},
):
    datasets = datasets or ALL_DATASETS
    eval_datasets = eval_datasets or ALL_DATASETS
    env_vars_list = []

    for experiment_config in experiment_configs:
        for ds in datasets:
            ds_experiment_config = experiment_config.copy()
            if "DATASETS" in ds_experiment_config:
                raise ValueError("DATASETS should not be set in ds_experiment_config.")
            if "LABELED_DATASETS" in ds_experiment_config:
                raise ValueError(
                    "LABELED_DATASETS should not be set in ds_experiment_config."
                )
            if "EVAL_DATASETS" in ds_experiment_config:
                raise ValueError(
                    "EVAL_DATASETS should not be set in ds_experiment_config."
                )

            ds_experiment_config["DATASETS"] = ds
            ds_experiment_config["LABELED_DATASETS"] = ds
            ds_experiment_config["EVAL_DATASETS"] = eval_datasets
            env_vars_list.append(make_env_vars_for_experiment(ds_experiment_config))

    for env_vars in env_vars_list:
        print(
            make_train_command_for_experiment(
                env_vars, slurm=slurm, slurm_args=slurm_args
            ),
            end="\n\n",
        )

# LR (CCS+LR impl)

In [16]:
DEFAULT_LR_ENV_VARS = dict(
    METHOD_LIST="CCS+LR",
    MODE="concat",
    SUP_WEIGHT=1,
    UNSUP_WEIGHT=0,
    LR=1e-3,
    N_EPOCHS=5000,
    WEIGHT_DECAY=1,
    OPT="adam",
    LAYER=-1,
)

# labeled=unlabeled=test=normal
LR_NORMAL_ENV_VARS = dict(
    DEFAULT_LR_ENV_VARS,
    PREFIX="normal",
    TEST_PREFIX="normal",
    TEST_PROMPT_SUBSET="default",
)

# labeled=normal, test=bananashed prefix
LR_LABEL_NORMAL_TEST_BANANASHED_PREFIX_ENV_VARS = dict(
    DEFAULT_LR_ENV_VARS,
    PREFIX="normal",
    TEST_PREFIX="bananashed",
    TEST_PROMPT_SUBSET="default",
)

# labeled=normal, test=bananashed suffix
LR_LABEL_NORMAL_TEST_BANANASHED_SUFFIX_ENV_VARS = dict(
    DEFAULT_LR_ENV_VARS,
    PREFIX="normal",  # labeled (LR)
    TEST_PREFIX="normal-bananashed",
    TEST_PROMPT_SUBSET="default",
)

# labeled=normal, test=all
LR_LABEL_NORMAL_TEST_ALICE_ENV_VARS = dict(
    DEFAULT_LR_ENV_VARS,
    PREFIX="normal",
    TEST_PREFIX="normal",
    PROMPT_SUBSET="default",  # train on default prompts
    TEST_PROMPT_SUBSET="alice_explicit_opinion",  # test on alice prompts
)


hyperparams = dict(
    MODEL=[
        "meta-llama/Meta-Llama-3-8B-Instruct",
        "meta-llama/Llama-2-13b-chat-hf",
        "mistralai/Mistral-7B-Instruct-v0.2",
        # "meta-llama/Llama-2-7b-chat-hf",
        # "openai-community/gpt2-xl",
        # "EleutherAI/gpt-j-6b",
    ],
    # LR=[1e-1, 1e-2],
    # WEIGHT_DECAY=[0.1],
    # PROJECTION_METHOD=["gaussian_random"],
    # N_EPOCHS=[1000],
    # PROJECTION_N_COMPONENTS=[400],
)
hyperparam_configs = [
    dict(zip(hyperparams, values))
    for values in itertools.product(*hyperparams.values())
]

# Iterate over product of parameters.
experiment_configs = []
for config in hyperparam_configs:
    # TODO: choose which default env vars to use.
    experiment_configs.append(dict(LR_LABEL_NORMAL_TEST_BANANASHED_PREFIX_ENV_VARS, **config))

print_train_commands_for_experiments_single_datasets(experiment_configs)

DATASETS="imdb" EVAL_DATASETS="['imdb', 'amazon-polarity', 'ag-news', 'dbpedia-14', 'copa', 'rte', 'boolq', 'qnli', 'piqa']" LABELED_DATASETS="imdb" LAYER="-1" LR="0.001" METHOD_LIST="CCS+LR" MODE="concat" MODEL="meta-llama/Meta-Llama-3-8B-Instruct" NAME="meta-llama/Meta-Llama-3-8B-Instruct/normal/test_bananashed/layer_-1/ccs_lr/mode_concat/sup_weight_1/unsup_weight_0/lr_0.001/n_epochs_5000/weight_decay_1/opt_adam" NUM_SEEDS="1" N_EPOCHS="5000" N_TRIES="1" OPT="adam" PREFIX="normal" SAVE_FIT_PLOTS="False" SAVE_FIT_RESULT="True" SAVE_ORTHOGONAL_DIRECTIONS="False" SAVE_PARAMS="False" SAVE_STATES="False" SPAN_DIRS_COMBINATION="convex" SUP_WEIGHT="1" TEST_PREFIX="bananashed" TEST_PROMPT_SUBSET="default" UNSUP_WEIGHT="0" WEIGHT_DECAY="1" sbatch  slurm_extract.sh

DATASETS="amazon-polarity" EVAL_DATASETS="['imdb', 'amazon-polarity', 'ag-news', 'dbpedia-14', 'copa', 'rte', 'boolq', 'qnli', 'piqa']" LABELED_DATASETS="amazon-polarity" LAYER="-1" LR="0.001" METHOD_LIST="CCS+LR" MODE="concat" MOD

# CCS (CCS+LR impl)

In [17]:
DEFAULT_CCS_ENV_VARS = dict(
    METHOD_LIST="CCS+LR",
    MODE="concat",
    SUP_WEIGHT=0,
    UNSUP_WEIGHT=1,
    LR=1e-3,
    N_EPOCHS=5000,
    OPT="adam",
    LAYER=-1,
)

# unlabeled=test=normal
CCS_NORMAL_ENV_VARS = dict(
    DEFAULT_CCS_ENV_VARS,
    PREFIX="normal",
    TEST_PREFIX="normal",
    TEST_PROMPT_SUBSET="default",
)

# unlabeled=test=bananashed prefix
CCS_NOLABEL_BANANASHED_PREFIX_TEST_BANANASHED_PREFIX_ENV_VARS = dict(
    DEFAULT_CCS_ENV_VARS,
    PREFIX="bananashed",
    TEST_PREFIX="bananashed",
    TEST_PROMPT_SUBSET="default",
)

# unlabeled=test=bananashed suffix
CCS_NOLABEL_BANANASHED_SUFFIX_TEST_BANANASHED_SUFFIX_ENV_VARS = dict(
    DEFAULT_CCS_ENV_VARS,
    PREFIX="normal-bananashed",
    TEST_PREFIX="normal-bananashed",
    TEST_PROMPT_SUBSET="default",
)

# unlabeled=test=alice
CCS_NOLABEL_ALICE_TEST_ALL_ENV_VARS = dict(
    DEFAULT_CCS_ENV_VARS,
    PREFIX="normal",
    TEST_PREFIX="normal",
    PROMPT_SUBSET="alice_explicit_opinion",
    TEST_PROMPT_SUBSET="alice_explicit_opinion",
)

hyperparams = dict(
    MODEL=[
        # "meta-llama/Meta-Llama-3-8B-Instruct",
        # "meta-llama/Llama-2-13b-chat-hf",
        # "mistralai/Mistral-7B-Instruct-v0.2",
        # "meta-llama/Llama-2-7b-chat-hf",
        "openai-community/gpt2-xl",
        "EleutherAI/gpt-j-6b",
    ],
    # PREFIX=["normal"],
    # TEST_PREFIX=["normal-bananashed"],
    # PROMPT_SUBSET=["alice_explicit_opinion"],
    # LR=[1e-1, 1e-2],
    # WEIGHT_DECAY=[0, 0.1, 1],
    # PROJECTION_METHOD=["gaussian_random"],
    # N_EPOCHS=[1000],
    # PROJECTION_N_COMPONENTS=[400],
)
hyperparam_configs = [
    dict(zip(hyperparams, values))
    for values in itertools.product(*hyperparams.values())
]

# Iterate over product of parameters.
experiment_configs = []
for config in hyperparam_configs:
    experiment_configs.append(dict(CCS_TRAIN_NORMAL_TEST_NORMAL_ENV_VARS, **config))

print_train_commands_for_experiments_single_datasets(experiment_configs)

DATASETS="imdb" EVAL_DATASETS="['imdb', 'amazon-polarity', 'ag-news', 'dbpedia-14', 'copa', 'rte', 'boolq', 'qnli', 'piqa']" LABELED_DATASETS="imdb" LAYER="-1" LR="0.001" METHOD_LIST="CCS+LR" MODE="concat" MODEL="openai-community/gpt2-xl" NAME="openai-community/gpt2-xl/normal/test_normal/layer_-1/ccs_lr/mode_concat/sup_weight_0/unsup_weight_1/lr_0.001/n_epochs_5000/opt_adam" NUM_SEEDS="1" N_EPOCHS="5000" N_TRIES="1" OPT="adam" PREFIX="normal" SAVE_FIT_PLOTS="False" SAVE_FIT_RESULT="True" SAVE_ORTHOGONAL_DIRECTIONS="False" SAVE_PARAMS="False" SAVE_STATES="False" SPAN_DIRS_COMBINATION="convex" SUP_WEIGHT="0" TEST_PREFIX="normal" TEST_PROMPT_SUBSET="default" UNSUP_WEIGHT="1" WEIGHT_DECAY="0.0" sbatch  slurm_extract.sh

DATASETS="amazon-polarity" EVAL_DATASETS="['imdb', 'amazon-polarity', 'ag-news', 'dbpedia-14', 'copa', 'rte', 'boolq', 'qnli', 'piqa']" LABELED_DATASETS="amazon-polarity" LAYER="-1" LR="0.001" METHOD_LIST="CCS+LR" MODE="concat" MODEL="openai-community/gpt2-xl" NAME="openai-

# CCS+LR

In [14]:
DEFAULT_CCS_LR_ENV_VARS = dict(
    METHOD_LIST="CCS+LR",
    MODE="concat",
    SUP_WEIGHT=3,
    UNSUP_WEIGHT=1,
    CONSISTENCY_WEIGHT=1,
    CONFIDENCE_WEIGHT=0.1,
    WEIGHT_DECAY=10,
    LR=1e-3,
    N_EPOCHS=10000,
    LAYER=-1,
    OPT="adam",
)

# labeled=unlabeled=test=normal
CCS_LR_NORMAL_ENV_VARS = dict(
    DEFAULT_CCS_LR_ENV_VARS,
    LABELED_PREFIX="normal",
    PREFIX="normal",
    TEST_PREFIX="normal",
    TEST_PROMPT_SUBSET="default",
)

# labeled=normal, unlabeled=test=bananashed prefix
CCS_LR_LABEL_NORMAL_NOLABEL_BANANASHED_PREFIX_TEST_BANANASHED_PREFIX_ENV_VARS = dict(
    DEFAULT_CCS_LR_ENV_VARS,
    LABELED_PREFIX="normal",
    PREFIX="bananashed",
    TEST_PREFIX="bananashed",
    TEST_PROMPT_SUBSET="default",
)

# labeled=normal, unlabeled=test=bananashed suffix
CCS_LR_LABEL_NORMAL_NOLABEL_BANANASHED_SUFFIX_TEST_BANANASHED_SUFFIX_ENV_VARS = dict(
    DEFAULT_CCS_LR_ENV_VARS,
    LABELED_PREFIX="normal",  # labeled (LR)
    PREFIX="normal-bananashed",  # unlabeled (CCS)
    TEST_PREFIX="normal-bananashed",
    TEST_PROMPT_SUBSET="default",
)

# labeled=normal, unlabeled=alice, test=alice
CCS_LR_LABEL_NORMAL_NOLABEL_ALICE_TEST_ALL_ENV_VARS = dict(
    DEFAULT_CCS_LR_ENV_VARS,
    PREFIX="normal",  # unlabeled
    LABELED_PREFIX="normal",  # labeled
    TEST_PREFIX="normal",
    LABELED_PROMPT_SUBSET="default",
    PROMPT_SUBSET="alice_explicit_opinion",  # unlabeled
    TEST_PROMPT_SUBSET="alice_explicit_opinion",  # test on alice prompts
)

hyperparams = dict(
    MODEL=[
        "meta-llama/Meta-Llama-3-8B-Instruct",
        "mistralai/Mistral-7B-Instruct-v0.2",
        "meta-llama/Llama-2-13b-chat-hf",
        # "meta-llama/Llama-2-7b-chat-hf",
        # "openai-community/gpt2-xl",
        # "EleutherAI/gpt-j-6b",
    ],
    # LABELED_PREFIX=["normal"],
    # PREFIX=["normal-bananashed"],
    # TEST_PREFIX=["normal"],
    # PROMPT_SUBSET=["alice_explicit_opinion"],
    # PROJECTION_METHOD=["gaussian_random"],
    # PROJECTION_N_COMPONENTS=[10, 50],
    # WEIGHT_DECAY=[0, 1],
    # LR=[1e-3, 1e-2],
    # SUP_WEIGHT=[3, 10],
    # UNSUP_WEIGHT=[
    #     [(0, 0.0), (1000, 1.0)],
    #     [(0, 0.0), (10000, 1.0)],
    #     [(0, 0.0), (99, 0.0), (100, 1.0)],
    #     [(0, 0.0), (99, 0.0), (1000, 1.0)],
    # ],
)
hyperparam_configs = [
    dict(zip(hyperparams, values))
    for values in itertools.product(*hyperparams.values())
]


# Iterate over product of parameters.
experiment_configs = []
for config in hyperparam_configs:
    # TODO: choose which default env vars to use.
    experiment_configs.append(
        dict(
            CCS_LR_LABEL_NORMAL_NOLABEL_ALICE_TEST_ALL_ENV_VARS,
            **config
        )
    )

labeled_datasets = ALL_DATASETS
unlabeled_datasets = ALL_DATASETS
slurm_args = {}  # {"qos": "high"}

print_train_commands_for_experiments_all_dataset_pairs(
    experiment_configs,
    labeled_datasets=labeled_datasets,
    unlabeled_datasets=unlabeled_datasets,
    slurm=True,
    slurm_args=slurm_args,
)

CONFIDENCE_WEIGHT="0.1" CONSISTENCY_WEIGHT="1" DATASETS="imdb" EVAL_DATASETS="['imdb']" LABELED_DATASETS="imdb" LABELED_PREFIX="normal" LABELED_PROMPT_SUBSET="default" LAYER="-1" LR="0.001" METHOD_LIST="CCS+LR" MODE="concat" MODEL="meta-llama/Meta-Llama-3-8B-Instruct" NAME="meta-llama/Meta-Llama-3-8B-Instruct/label_normal-nolabel_normal/test_normal/prompts_alice_explicit_opinion/test_prompts_alice_explicit_opinion/layer_-1/ccs_lr/mode_concat/sup_weight_3/unsup_weight_1/confidence_weight_0.1/lr_0.001/n_epochs_10000/weight_decay_10/opt_adam" NUM_SEEDS="1" N_EPOCHS="10000" N_TRIES="1" OPT="adam" PREFIX="normal" PROMPT_SUBSET="alice_explicit_opinion" SAVE_FIT_PLOTS="False" SAVE_FIT_RESULT="True" SAVE_ORTHOGONAL_DIRECTIONS="False" SAVE_PARAMS="False" SAVE_STATES="False" SPAN_DIRS_COMBINATION="convex" SUP_WEIGHT="3" TEST_PREFIX="normal" TEST_PROMPT_SUBSET="alice_explicit_opinion" UNSUP_WEIGHT="1" WEIGHT_DECAY="10" sbatch  slurm_extract.sh

CONFIDENCE_WEIGHT="0.1" CONSISTENCY_WEIGHT="1" DATAS

## Train = all except test ds

In [9]:
# Labeled datasets: all datasets except for the target domain, which is used for
# the CCS loss.

hyperparams = dict(
    MODEL=[
        "meta-llama/Meta-Llama-3-8B-Instruct",
        # "meta-llama/Llama-2-13b-chat-hf",
        # "mistralai/Mistral-7B-Instruct-v0.2",
        # "meta-llama/Llama-2-7b-chat-hf",
        # "openai-community/gpt2-xl",
        # "EleutherAI/gpt-j-6b",
    ],
)
hyperparam_configs = [
    dict(zip(hyperparams, values))
    for values in itertools.product(*hyperparams.values())
]


# Iterate over product of parameters.
experiment_configs = []
for config in hyperparam_configs:
    # TODO: choose which default env vars to use.
    experiment_configs.append(
        dict(CCS_LR_LABEL_NORMAL_NOLABEL_BANANASHED_SUFFIX_TEST_BANANASHED_SUFFIX_ENV_VARS, **config)
    )

labeled_datasets = [
    [ds for ds in ALL_DATASETS if ds != nolabel_ds] for nolabel_ds in ALL_DATASETS
]
unlabeled_datasets = ALL_DATASETS
slurm_args = {}  # {"qos": "high"}

env_vars_list = []
for experiment_config in experiment_configs:
    # Iterate over dataset pairs.
    for labeled_ds, unlabeled_ds in zip(labeled_datasets, unlabeled_datasets):
        assert unlabeled_ds not in labeled_ds
        assert len(labeled_ds) == len(ALL_DATASETS) - 1
        ds_experiment_config = experiment_config.copy()
        if "DATASETS" in ds_experiment_config:
            raise ValueError("DATASETS should not be set in ds_experiment_config.")
        if "LABELED_DATASETS" in ds_experiment_config:
            raise ValueError(
                "LABELED_DATASETS should not be set in ds_experiment_config."
            )
        if "EVAL_DATASETS" in ds_experiment_config:
            raise ValueError("EVAL_DATASETS should not be set in ds_experiment_config.")

        all_train_ds = set()
        for datasets in [labeled_ds, unlabeled_ds]:
            if isinstance(datasets, (list, tuple)):
                all_train_ds.update(datasets)
            elif isinstance(datasets, str):
                all_train_ds.add(datasets)
            else:
                raise ValueError(
                    f"Datasets must be a string, list, or tuple, got {datasets}"
                )

        ds_experiment_config["DATASETS"] = unlabeled_ds
        ds_experiment_config["LABELED_DATASETS"] = labeled_ds
        ds_experiment_config["EVAL_DATASETS"] = f"{list(all_train_ds)}"
        env_vars_list.append(make_env_vars_for_experiment(ds_experiment_config))

for env_vars in env_vars_list:
    print(
        make_train_command_for_experiment(env_vars, slurm=True, slurm_args=slurm_args),
        end="\n\n",
    )

CONFIDENCE_WEIGHT="1" CONSISTENCY_WEIGHT="1" DATASETS="imdb" EVAL_DATASETS="['ag-news', 'amazon-polarity', 'rte', 'boolq', 'copa', 'dbpedia-14', 'imdb', 'piqa', 'qnli']" LABELED_DATASETS="['amazon-polarity', 'ag-news', 'dbpedia-14', 'copa', 'rte', 'boolq', 'qnli', 'piqa']" LABELED_PREFIX="normal" LAYER="-1" LR="0.001" METHOD_LIST="CCS+LR" MODE="concat" MODEL="meta-llama/Meta-Llama-3-8B-Instruct" NAME="meta-llama/Meta-Llama-3-8B-Instruct/label_normal-nolabel_normal-bananashed/layer_-1/ccs_lr/mode_concat/sup_weight_10/unsup_weight_1/lr_0.001/n_epochs_10000/weight_decay_10/opt_adam" NUM_SEEDS="1" N_EPOCHS="10000" N_TRIES="1" OPT="adam" PREFIX="normal-bananashed" SAVE_FIT_PLOTS="False" SAVE_FIT_RESULT="True" SAVE_ORTHOGONAL_DIRECTIONS="False" SAVE_PARAMS="False" SAVE_STATES="False" SPAN_DIRS_COMBINATION="convex" SUP_WEIGHT="10" TEST_PREFIX="normal-bananashed" TEST_PROMPT_SUBSET="default" UNSUP_WEIGHT="1" WEIGHT_DECAY="10" sbatch  slurm_extract.sh

CONFIDENCE_WEIGHT="1" CONSISTENCY_WEIGHT="

# NeurIPS experiments

## Constants

In [4]:
DEFAULT_LR_ENV_VARS = dict(
    METHOD_LIST="CCS+LR",
    MODE="concat",
    SUP_WEIGHT=1,
    UNSUP_WEIGHT=0,
    LR=1e-3,
    N_EPOCHS=5000,
    WEIGHT_DECAY=1,
    OPT="adam",
    LAYER=-1,
)

# labeled=test=normal
LR_NORMAL_ENV_VARS = dict(
    DEFAULT_LR_ENV_VARS,
    PREFIX="normal",
    TEST_PREFIX="normal",
    TEST_PROMPT_SUBSET="default",
)

# labeled=normal, test=bananashed prefix
LR_LABEL_NORMAL_TEST_BANANASHED_PREFIX_ENV_VARS = dict(
    DEFAULT_LR_ENV_VARS,
    PREFIX="normal",
    TEST_PREFIX="bananashed",
    TEST_PROMPT_SUBSET="default",
)

# labeled=normal, test=bananashed suffix
LR_LABEL_NORMAL_TEST_BANANASHED_SUFFIX_ENV_VARS = dict(
    DEFAULT_LR_ENV_VARS,
    PREFIX="normal",  # labeled (LR)
    TEST_PREFIX="normal-bananashed",
    TEST_PROMPT_SUBSET="default",
)

# labeled=normal, test=alice
LR_LABEL_NORMAL_TEST_ALICE_ENV_VARS = dict(
    DEFAULT_LR_ENV_VARS,
    PREFIX="normal",
    TEST_PREFIX="normal",
    PROMPT_SUBSET="default",  # train on default prompts
    TEST_PROMPT_SUBSET="alice_explicit_opinion",  # test on alice prompts
)

DEFAULT_CCS_ENV_VARS = dict(
    METHOD_LIST="CCS+LR",
    MODE="concat",
    SUP_WEIGHT=0,
    UNSUP_WEIGHT=1,
    LR=1e-3,
    N_EPOCHS=5000,
    OPT="adam",
    LAYER=-1,
)

# unlabeled=test=normal
CCS_NORMAL_ENV_VARS = dict(
    DEFAULT_CCS_ENV_VARS,
    PREFIX="normal",
    TEST_PREFIX="normal",
    TEST_PROMPT_SUBSET="default",
)

# unlabeled=test=bananashed prefix
CCS_NOLABEL_BANANASHED_PREFIX_TEST_BANANASHED_PREFIX_ENV_VARS = dict(
    DEFAULT_CCS_ENV_VARS,
    PREFIX="bananashed",
    TEST_PREFIX="bananashed",
    TEST_PROMPT_SUBSET="default",
)

# unlabeled=test=bananashed suffix
CCS_NOLABEL_BANANASHED_SUFFIX_TEST_BANANASHED_SUFFIX_ENV_VARS = dict(
    DEFAULT_CCS_ENV_VARS,
    PREFIX="normal-bananashed",
    TEST_PREFIX="normal-bananashed",
    TEST_PROMPT_SUBSET="default",
)

# unlabeled=test=alice
CCS_NOLABEL_ALICE_TEST_ALL_ENV_VARS = dict(
    DEFAULT_CCS_ENV_VARS,
    PREFIX="normal",
    TEST_PREFIX="normal",
    PROMPT_SUBSET="alice_explicit_opinion",
    TEST_PROMPT_SUBSET="alice_explicit_opinion",
)

DEFAULT_CCS_LR_ENV_VARS = dict(
    METHOD_LIST="CCS+LR",
    MODE="concat",
    SUP_WEIGHT=3,
    UNSUP_WEIGHT=1,
    CONSISTENCY_WEIGHT=1,
    CONFIDENCE_WEIGHT=0.1,
    WEIGHT_DECAY=10,
    LR=1e-3,
    N_EPOCHS=10000,
    LAYER=-1,
    OPT="adam",
)

# labeled=unlabeled=test=normal
CCS_LR_NORMAL_ENV_VARS = dict(
    DEFAULT_CCS_LR_ENV_VARS,
    LABELED_PREFIX="normal",
    PREFIX="normal",
    TEST_PREFIX="normal",
    TEST_PROMPT_SUBSET="default",
)

# labeled=normal, unlabeled=test=bananashed prefix
CCS_LR_LABEL_NORMAL_NOLABEL_BANANASHED_PREFIX_TEST_BANANASHED_PREFIX_ENV_VARS = dict(
    DEFAULT_CCS_LR_ENV_VARS,
    LABELED_PREFIX="normal",
    PREFIX="bananashed",
    TEST_PREFIX="bananashed",
    TEST_PROMPT_SUBSET="default",
)

# labeled=normal, unlabeled=test=bananashed suffix
CCS_LR_LABEL_NORMAL_NOLABEL_BANANASHED_SUFFIX_TEST_BANANASHED_SUFFIX_ENV_VARS = dict(
    DEFAULT_CCS_LR_ENV_VARS,
    LABELED_PREFIX="normal",  # labeled (LR)
    PREFIX="normal-bananashed",  # unlabeled (CCS)
    TEST_PREFIX="normal-bananashed",
    TEST_PROMPT_SUBSET="default",
)

# labeled=normal, unlabeled=alice, test=alice
CCS_LR_LABEL_NORMAL_NOLABEL_ALICE_TEST_ALL_ENV_VARS = dict(
    DEFAULT_CCS_LR_ENV_VARS,
    PREFIX="normal",  # unlabeled
    LABELED_PREFIX="normal",  # labeled
    TEST_PREFIX="normal",
    LABELED_PROMPT_SUBSET="default",
    PROMPT_SUBSET="alice_explicit_opinion",  # unlabeled
    TEST_PROMPT_SUBSET="alice_explicit_opinion",  # test on alice prompts
)

In [5]:
MAIN_MODELS = [
    "meta-llama/Meta-Llama-3-8B-Instruct",
    "mistralai/Mistral-7B-Instruct-v0.2",
    "meta-llama/Llama-2-13b-chat-hf",
]

ALL_MODELS = [
    "meta-llama/Meta-Llama-3-8B-Instruct",
    "mistralai/Mistral-7B-Instruct-v0.2",
    "meta-llama/Llama-2-13b-chat-hf",
    "meta-llama/Llama-2-7b-chat-hf",
    "openai-community/gpt2-xl",
    "EleutherAI/gpt-j-6b",
]

FINAL_SAVE_DIR = "extraction_results_final"

## Make experiment configs

In [6]:
experiment_configs = []

# CCS+LR
for env_vars, models in zip(
    [
        CCS_LR_NORMAL_ENV_VARS,
        CCS_LR_LABEL_NORMAL_NOLABEL_BANANASHED_PREFIX_TEST_BANANASHED_PREFIX_ENV_VARS,
        CCS_LR_LABEL_NORMAL_NOLABEL_BANANASHED_SUFFIX_TEST_BANANASHED_SUFFIX_ENV_VARS,
        CCS_LR_LABEL_NORMAL_NOLABEL_ALICE_TEST_ALL_ENV_VARS,
    ],
    [ALL_MODELS, MAIN_MODELS, MAIN_MODELS, MAIN_MODELS],
):
    for model in models:
        experiment_configs.append(
            dict(
                env_vars,
                MODEL=model,
                SAVE_DIR=FINAL_SAVE_DIR,
            )
        )

print_train_commands_for_experiments_all_dataset_pairs(experiment_configs)

CONFIDENCE_WEIGHT="0.1" CONSISTENCY_WEIGHT="1" DATASETS="imdb" EVAL_DATASETS="['imdb']" LABELED_DATASETS="imdb" LABELED_PREFIX="normal" LAYER="-1" LR="0.001" METHOD_LIST="CCS+LR" MODE="concat" MODEL="meta-llama/Meta-Llama-3-8B-Instruct" NAME="meta-llama/Meta-Llama-3-8B-Instruct/label_normal-nolabel_normal/test_normal/layer_-1/ccs_lr/mode_concat/sup_weight_3/unsup_weight_1/confidence_weight_0.1/lr_0.001/n_epochs_10000/weight_decay_10/opt_adam" NUM_SEEDS="1" N_EPOCHS="10000" N_TRIES="1" OPT="adam" PREFIX="normal" SAVE_DIR="extraction_results_final" SAVE_FIT_PLOTS="False" SAVE_FIT_RESULT="True" SAVE_ORTHOGONAL_DIRECTIONS="False" SAVE_PARAMS="False" SAVE_STATES="False" SPAN_DIRS_COMBINATION="convex" SUP_WEIGHT="3" TEST_PREFIX="normal" TEST_PROMPT_SUBSET="default" UNSUP_WEIGHT="1" WEIGHT_DECAY="10" sbatch  slurm_extract.sh

CONFIDENCE_WEIGHT="0.1" CONSISTENCY_WEIGHT="1" DATASETS="amazon-polarity" EVAL_DATASETS="['amazon-polarity', 'imdb']" LABELED_DATASETS="imdb" LABELED_PREFIX="normal" LAY

In [7]:
experiment_configs = []

# LR
for env_vars, models in zip(
    [
        LR_NORMAL_ENV_VARS,
        LR_LABEL_NORMAL_TEST_BANANASHED_PREFIX_ENV_VARS,
        LR_LABEL_NORMAL_TEST_BANANASHED_SUFFIX_ENV_VARS,
        LR_LABEL_NORMAL_TEST_ALICE_ENV_VARS,
    ],
    [ALL_MODELS, MAIN_MODELS, MAIN_MODELS, MAIN_MODELS],
):
    for model in models:
        experiment_configs.append(
            dict(
                env_vars,
                MODEL=model,
                SAVE_DIR=FINAL_SAVE_DIR,
            )
        )

# CCS
for env_vars, models in zip(
    [
        CCS_NORMAL_ENV_VARS,
        CCS_NOLABEL_BANANASHED_PREFIX_TEST_BANANASHED_PREFIX_ENV_VARS,
        CCS_NOLABEL_BANANASHED_SUFFIX_TEST_BANANASHED_SUFFIX_ENV_VARS,
        CCS_NOLABEL_ALICE_TEST_ALL_ENV_VARS,
    ],
    [ALL_MODELS, MAIN_MODELS, MAIN_MODELS, MAIN_MODELS],
):
    for model in models:
        experiment_configs.append(
            dict(
                env_vars,
                MODEL=model,
                SAVE_DIR=FINAL_SAVE_DIR,
            )
        )


print_train_commands_for_experiments_single_datasets(experiment_configs)

DATASETS="imdb" EVAL_DATASETS="['imdb', 'amazon-polarity', 'ag-news', 'dbpedia-14', 'copa', 'rte', 'boolq', 'qnli', 'piqa']" LABELED_DATASETS="imdb" LAYER="-1" LR="0.001" METHOD_LIST="CCS+LR" MODE="concat" MODEL="meta-llama/Meta-Llama-3-8B-Instruct" NAME="meta-llama/Meta-Llama-3-8B-Instruct/normal/test_normal/layer_-1/ccs_lr/mode_concat/sup_weight_1/unsup_weight_0/lr_0.001/n_epochs_5000/weight_decay_1/opt_adam" NUM_SEEDS="1" N_EPOCHS="5000" N_TRIES="1" OPT="adam" PREFIX="normal" SAVE_DIR="extraction_results_final" SAVE_FIT_PLOTS="False" SAVE_FIT_RESULT="True" SAVE_ORTHOGONAL_DIRECTIONS="False" SAVE_PARAMS="False" SAVE_STATES="False" SPAN_DIRS_COMBINATION="convex" SUP_WEIGHT="1" TEST_PREFIX="normal" TEST_PROMPT_SUBSET="default" UNSUP_WEIGHT="0" WEIGHT_DECAY="1" sbatch  slurm_extract.sh

DATASETS="amazon-polarity" EVAL_DATASETS="['imdb', 'amazon-polarity', 'ag-news', 'dbpedia-14', 'copa', 'rte', 'boolq', 'qnli', 'piqa']" LABELED_DATASETS="amazon-polarity" LAYER="-1" LR="0.001" METHOD_LIS

# Pseudo-label + LR

In [None]:
DEFAULT_PSEUDOLABEL_ENV_VARS = dict(
    METHOD_LIST="pseudolabel",
    MODE="concat",
    SUP_WEIGHT=1,
    UNSUP_WEIGHT=0,
    LR=1e-3,
    N_EPOCHS=5000,
    # Pseudolabel
    PSEUDOLABEL_N_ROUNDS=5,
    PSEUDOLABEL_SELECT_FN="high_confidence_consistency",
    PSEUDOLABEL_PROB_THRESHOLD=0.7,
    PSEUDOLABEL_LABEL_FN="argmax",
)

## different select_fn and label_fn

In [None]:
hyperparams = dict(
    MODEL=[
        "meta-llama/Meta-Llama-3-8B-Instruct"
        # "meta-llama/Llama-2-13b-chat-hf",
        # "mistralai/Mistral-7B-Instruct-v0.2",
        # "meta-llama/Llama-2-7b-chat-hf",
    ],
    LABELED_PREFIX=["normal"],
    PREFIX=["normal-bananashed"],
    TEST_PREFIX=["normal-bananashed"],
    OPT=["adam"],
    LR=[1e-3],
    WEIGHT_DECAY=[1],
    PSEUDOLABEL_SELECT_FN=["high_confidence_consistency"],
    PSEUDOLABEL_PROB_THRESHOLD=[0.5],
    PSEUDOLABEL_CONSISTENCY_ERR_THRESHOLD=[0.1],
    PSEUDOLABEL_LABEL_FN=["softmax"],
    PSEUDOLABEL_SOFTMAX_TEMP=[0.1, 0.3],
)
hyperparam_configs = [
    dict(zip(hyperparams, values))
    for values in itertools.product(*hyperparams.values())
]

# Iterate over product of parameters.
experiment_configs = []
for config in hyperparam_configs:
    experiment_configs.append(dict(DEFAULT_PSEUDOLABEL_ENV_VARS, **config))

print_train_commands_for_experiments_all_dataset_pairs(experiment_configs, slurm=True)

## select_fn=high_confidence_consistency label_fn=argmax

In [None]:
hyperparams = dict(
    MODEL=[
        "meta-llama/Meta-Llama-3-8B-Instruct"
        # "meta-llama/Llama-2-13b-chat-hf",
        # "mistralai/Mistral-7B-Instruct-v0.2",
        # "meta-llama/Llama-2-7b-chat-hf",
    ],
    LABELED_PREFIX=["normal"],
    PREFIX=["normal-bananashed"],
    TEST_PREFIX=["normal-bananashed"],
    OPT=["adam"],
    LR=[1e-3],
    WEIGHT_DECAY=[1],
    PSEUDOLABEL_PROB_THRESHOLD=[0.5],
)
hyperparam_configs = [
    dict(zip(hyperparams, values))
    for values in itertools.product(*hyperparams.values())
]

# Iterate over product of parameters.
experiment_configs = []
for config in hyperparam_configs:
    experiment_configs.append(dict(DEFAULT_PSEUDOLABEL_ENV_VARS, **config))

print_train_commands_for_experiments_all_dataset_pairs(experiment_configs, slurm=True)

## select_fn=all label_fn=argmax

In [None]:
DEFAULT_PSEUDOLABEL_SELECT_ALL_LABEL_ARGMAX_ENV_VARS = dict(
    DEFAULT_PSEUDOLABEL_ENV_VARS, PSEUDOLABEL_N_ROUNDS=1, PSEUDOLABEL_SELECT_FN="all"
)

hyperparams = dict(
    MODEL=[
        "meta-llama/Meta-Llama-3-8B-Instruct"
        # "meta-llama/Llama-2-13b-chat-hf",
        # "mistralai/Mistral-7B-Instruct-v0.2",
        # "meta-llama/Llama-2-7b-chat-hf",
    ],
    LABELED_PREFIX=["normal"],
    PREFIX=["normal-bananashed"],
    TEST_PREFIX=["normal-bananashed"],
    OPT=["adam"],
    LR=[1e-3],
    WEIGHT_DECAY=[1],
)
hyperparam_configs = [
    dict(zip(hyperparams, values))
    for values in itertools.product(*hyperparams.values())
]

# Iterate over product of parameters.
experiment_configs = []
for config in hyperparam_configs:
    experiment_configs.append(
        dict(DEFAULT_PSEUDOLABEL_SELECT_ALL_LABEL_ARGMAX_ENV_VARS, **config)
    )

print_train_commands_for_experiments_all_dataset_pairs(experiment_configs, slurm=True)

# Pseudo-label + CCS+LR

In [None]:
DEFAULT_PSEUDOLABEL_CCS_LR_ENV_VARS = dict(
    METHOD_LIST="pseudolabel",
    MODE="concat",
    SUP_WEIGHT=10,
    UNSUP_WEIGHT=1,
    LR=1e-3,
    N_EPOCHS=10000,
    # Pseudolabel
    PSEUDOLABEL_N_ROUNDS=5,
    PSEUDOLABEL_SELECT_FN="high_confidence_consistency",
    PSEUDOLABEL_PROB_THRESHOLD=0.7,
    PSEUDOLABEL_LABEL_FN="argmax",
)

## different select_fn and label_fn

In [None]:
hyperparams = dict(
    MODEL=[
        "meta-llama/Meta-Llama-3-8B-Instruct"
        # "meta-llama/Llama-2-13b-chat-hf",
        # "mistralai/Mistral-7B-Instruct-v0.2",
        # "meta-llama/Llama-2-7b-chat-hf",
    ],
    LABELED_PREFIX=["normal"],
    PREFIX=["normal-bananashed"],
    TEST_PREFIX=["normal-bananashed"],
    OPT=["adam"],
    N_EPOCHS=[10000],  # Decrease from 10k to 5k mostly to make it run faster.
    LR=[1e-3],
    WEIGHT_DECAY=[10],
    PSEUDOLABEL_N_ROUNDS=[1],
    PSEUDOLABEL_SELECT_FN=["all"],
    # PSEUDOLABEL_PROB_THRESHOLD=[0.5],
    # PSEUDOLABEL_CONSISTENCY_ERR_THRESHOLD=[0.1],
    PSEUDOLABEL_LABEL_FN=["argmax"],
    # PSEUDOLABEL_SOFTMAX_TEMP=[0.3],
)
hyperparam_configs = [
    dict(zip(hyperparams, values))
    for values in itertools.product(*hyperparams.values())
]

# Iterate over product of parameters.
experiment_configs = []
for config in hyperparam_configs:
    experiment_configs.append(dict(DEFAULT_PSEUDOLABEL_CCS_LR_ENV_VARS, **config))

print_train_commands_for_experiments_all_dataset_pairs(experiment_configs, slurm=True)

## select_fn=high_confidence_consistency label_fn=argmax

In [None]:
prefixes = ["normal-bananashed"]
models = [
    "meta-llama/Meta-Llama-3-8B-Instruct",
    # "meta-llama/Llama-2-13b-chat-hf",
    # "mistralai/Mistral-7B-Instruct-v0.2",
    # "meta-llama/Llama-2-7b-chat-hf",
]
layers = [-1]

# Iterate over product of parameters.
experiment_configs = []
for model, layer, prefix in itertools.product(models, layers, prefixes):
    experiment_configs.append(
        dict(
            DEFAULT_PSEUDOLABEL_CCS_LR_ENV_VARS, MODEL=model, LAYER=layer, PREFIX=prefix
        )
    )

print_train_commands_for_experiments_all_dataset_pairs(experiment_configs, slurm=True)

In [None]:
# # Only for experiment_configs
# env_vars_list = []
# for experiment_config in experiment_configs:
#     env_vars_list.append(make_env_vars_for_experiment(experiment_config))

# for env_vars in env_vars_list:
#     print(make_train_command_for_experiment(env_vars), end="\n\n")

## select_fn=all label_fn=argmax

In [None]:
DEFAULT_PSEUDOLABEL_CCS_LR_SELECT_ALL_LABEL_ARGMAX_ENV_VARS = (
    DEFAULT_PSEUDOLABEL_CCS_LR_ENV_VARS.copy()
)
DEFAULT_PSEUDOLABEL_CCS_LR_SELECT_ALL_LABEL_ARGMAX_ENV_VARS.update(
    PSEUDOLABEL_N_ROUNDS=1, PSEUDOLABEL_SELECT_FN="all"
)

prefixes = ["bananashed"]
models = [
    "meta-llama/Llama-2-13b-chat-hf",
    "meta-llama/Meta-Llama-3-8B-Instruct",
    "mistralai/Mistral-7B-Instruct-v0.2",
    "meta-llama/Llama-2-7b-chat-hf",
]
layers = [-1, -5, -9]

# Iterate over product of parameters.
experiment_configs = []
for model, layer, prefix in itertools.product(models, layers, prefixes):
    experiment_configs.append(
        dict(
            DEFAULT_PSEUDOLABEL_CCS_LR_SELECT_ALL_LABEL_ARGMAX_ENV_VARS,
            MODEL=model,
            LAYER=layer,
            PREFIX=prefix,
        )
    )

print_train_commands_for_experiments_all_dataset_pairs(experiment_configs, slurm=True)

# CCS in LR span

## Train orthogonal probes (LR span)

In [None]:
DEFAULT_LR_SPAN_ENV_VARS = dict(
    METHOD_LIST="CCS+LR-in-span",
    MODE="concat",
    SUP_WEIGHT=1,
    UNSUP_WEIGHT=0,
    LR=1e-2,
    N_EPOCHS=5000,
    NUM_ORTHOGONAL_DIRECTIONS=100,
    SPAN_DIRS_COMBINATION="convex",
)

prefixes = ["normal"]
models = [
    # "meta-llama/Llama-2-7b-chat-hf",
    "meta-llama/Llama-2-13b-chat-hf",
    # "meta-llama/Meta-Llama-3-8B-Instruct",
    "mistralai/Mistral-7B-Instruct-v0.2",
]
layers = [-1, -3, -5, -7, -9]

# Iterate over product of parameters.
experiment_configs = []
for model, layer, prefix in itertools.product(models, layers, prefixes):
    experiment_configs.append(
        dict(
            DEFAULT_LR_SPAN_ENV_VARS,
            MODEL=model,
            LAYER=layer,
            PREFIX=prefix,
            SAVE_ORTHOGONAL_DIRECTIONS=True,
            SAVE_FIT_PLOTS=False,
        )
    )

# Make a train command for each individual dataset for each experiment config.
env_vars_list = []
for experiment_config in experiment_configs:
    for ds in ALL_DATASETS:
        ds_experiment_config = experiment_config.copy()
        if "DATASETS" in ds_experiment_config:
            raise ValueError("DATASETS should not be set in ds_experiment_config.")
        if "LABELED_DATASETS" in ds_experiment_config:
            raise ValueError(
                "LABELED_DATASETS should not be set in ds_experiment_config."
            )
        if "EVAL_DATASETS" in ds_experiment_config:
            raise ValueError("EVAL_DATASETS should not be set in ds_experiment_config.")

        ds_experiment_config["DATASETS"] = ds
        # LABELED_DATASETS is unused by LR, so this is arbitrary.
        ds_experiment_config["LABELED_DATASETS"] = ds
        # We don't care about the eval datasets because we're just generating
        # the LR span, but eval on all datasets just to have the results.
        ds_experiment_config["EVAL_DATASETS"] = "burns"
        env_vars_list.append(make_env_vars_for_experiment(ds_experiment_config))

slurm = True
for env_vars in env_vars_list:
    print(make_train_command_for_experiment(env_vars, slurm=slurm), end="\n\n")

## Train CCS in LR span

In [None]:
DEFAULT_CCS_IN_LR_SPAN_ENV_VARS = dict(
    METHOD_LIST="CCS+LR-in-span",
    MODE="concat",
    SUP_WEIGHT=0,
    UNSUP_WEIGHT=1,
    LR=1e-2,
    N_EPOCHS=1000,
    SPAN_DIRS_COMBINATION="convex",
)

model_to_orthogonal_lr_directions_dir_template = {
    "meta-llama/Llama-2-13b-chat-hf": "/nas/ucb/ebronstein/Exhaustive-CCS/extraction_results/llama-2-13b-chat-hf/{prefix}/layer_{layer}/ccs_lr_in_span/mode_concat/sup_weight_1/unsup_weight_0/lr_0.01/n_epochs_5000/n_orth_dirs_100/span_dirs_combo_convex/meta-llama-Llama-2-13b-chat-hf",
    "mistralai/Mistral-7B-Instruct-v0.2": "/nas/ucb/ebronstein/Exhaustive-CCS/extraction_results/mistralai/Mistral-7B-Instruct-v0.2/{prefix}/layer_{layer}/ccs_lr_in_span/mode_concat/sup_weight_1/unsup_weight_0/lr_0.01/n_epochs_5000/n_orth_dirs_100/span_dirs_combo_convex/mistralai-Mistral-7B-Instruct-v0.2",
}


def get_orthogonal_lr_directions_dir(
    model: str, layer: int, prefix: str, dataset
) -> str:
    base_dir_template = model_to_orthogonal_lr_directions_dir_template[model]
    base_dir = base_dir_template.format(prefix=prefix, layer=layer)
    # Use the same dataset for the unlabeled and labeled datasets because the
    # LR span generation only used the unlabeled dataset.
    datasets_str = load_utils.get_combined_datasets_str(
        dataset, labeled_datasets=dataset
    )
    load_orthogonal_directions_dir = os.path.join(base_dir, datasets_str)
    if not os.path.exists(load_orthogonal_directions_dir):
        raise ValueError(
            f"Could not find orthogonal directions directory: {load_orthogonal_directions_dir}"
        )

    return load_orthogonal_directions_dir


prefixes = ["normal"]
models = [
    # "meta-llama/Llama-2-7b-chat-hf",
    "meta-llama/Llama-2-13b-chat-hf",
    # "meta-llama/Meta-Llama-3-8B-Instruct",
    "mistralai/Mistral-7B-Instruct-v0.2",
]
layers = [-1, -5, -9]
num_orth_dirs_list = [2, 4, 8, 16]

# Iterate over product of parameters.
experiment_configs = []
for model, layer, prefix, num_orth_dirs in itertools.product(
    models, layers, prefixes, num_orth_dirs_list
):
    for train_ds, test_ds in itertools.product(ALL_DATASETS, ALL_DATASETS):
        # Use the orthogonal LR directions from train_ds and train CCS in
        # their span on test_ds.
        load_orthogonal_directions_dir = get_orthogonal_lr_directions_dir(
            model, layer, prefix, train_ds
        )

        experiment_config = dict(
            DEFAULT_CCS_IN_LR_SPAN_ENV_VARS,
            MODEL=model,
            LAYER=layer,
            PREFIX=prefix,
            DATASETS=test_ds,
            LABELED_DATASETS=test_ds,  # Unused because the supervised weight is 0.
            EVAL_DATASETS=f'"{list(set([train_ds, test_ds]))}"',
            LOAD_ORTHOGONAL_DIRECTIONS_DIR=load_orthogonal_directions_dir,
            NUM_ORTHOGONAL_DIRECTIONS=num_orth_dirs,
        )
        env_vars = make_env_vars_for_experiment(experiment_config)
        print(make_train_command_for_experiment(env_vars, slurm=True), end="\n\n")

# LR in CCS span

## Train orthogonal probes (CCS span)

In [None]:
DEFAULT_CCS_SPAN_ENV_VARS = dict(
    METHOD_LIST="CCS+LR-in-span",
    MODE="concat",
    SUP_WEIGHT=0,
    UNSUP_WEIGHT=1,
    LR=1e-2,
    N_EPOCHS=1000,
    NUM_ORTHOGONAL_DIRECTIONS=100,
    SPAN_DIRS_COMBINATION="convex",
)

prefixes = ["normal"]
models = [
    # "meta-llama/Llama-2-7b-chat-hf",
    "meta-llama/Llama-2-13b-chat-hf",
    # "meta-llama/Meta-Llama-3-8B-Instruct",
    "mistralai/Mistral-7B-Instruct-v0.2",
]
layers = [-1, -3, -5, -7, -9]

# Iterate over product of parameters.
experiment_configs = []
for model, layer, prefix in itertools.product(models, layers, prefixes):
    experiment_configs.append(
        dict(
            DEFAULT_CCS_SPAN_ENV_VARS,
            MODEL=model,
            LAYER=layer,
            PREFIX=prefix,
            SAVE_ORTHOGONAL_DIRECTIONS=True,
            SAVE_FIT_PLOTS=False,
        )
    )

# Make a train command for each individual dataset for each experiment config.
env_vars_list = []
for experiment_config in experiment_configs:
    for ds in ALL_DATASETS:
        ds_experiment_config = experiment_config.copy()
        if "DATASETS" in ds_experiment_config:
            raise ValueError("DATASETS should not be set in ds_experiment_config.")
        if "LABELED_DATASETS" in ds_experiment_config:
            raise ValueError(
                "LABELED_DATASETS should not be set in ds_experiment_config."
            )
        if "EVAL_DATASETS" in ds_experiment_config:
            raise ValueError("EVAL_DATASETS should not be set in ds_experiment_config.")

        ds_experiment_config["DATASETS"] = ds
        # LABELED_DATASETS is unused by CCS, so this is arbitrary.
        ds_experiment_config["LABELED_DATASETS"] = ds
        # We don't care about the eval datasets because we're just generating
        # the LR span, but eval on all datasets just to have the results.
        ds_experiment_config["EVAL_DATASETS"] = "burns"
        env_vars_list.append(make_env_vars_for_experiment(ds_experiment_config))

slurm = True
for env_vars in env_vars_list:
    print(make_train_command_for_experiment(env_vars, slurm=slurm), end="\n\n")

## Train LR in CCS span

In [None]:
DEFAULT_LR_IN_CCS_SPAN_ENV_VARS = dict(
    METHOD_LIST="CCS+LR-in-span",
    MODE="concat",
    SUP_WEIGHT=1,
    UNSUP_WEIGHT=0,
    LR=1e-2,
    N_EPOCHS=5000,
    SPAN_DIRS_COMBINATION="convex",
)

model_to_orthogonal_ccs_directions_dir_template = {
    "meta-llama/Llama-2-13b-chat-hf": "/nas/ucb/ebronstein/Exhaustive-CCS/extraction_results/llama-2-13b-chat-hf/{prefix}/layer_{layer}/ccs_lr_in_span/mode_concat/sup_weight_0/unsup_weight_1/lr_0.01/n_epochs_1000/n_orth_dirs_100/span_dirs_combo_convex/meta-llama-Llama-2-13b-chat-hf",
    "mistralai/Mistral-7B-Instruct-v0.2": "/nas/ucb/ebronstein/Exhaustive-CCS/extraction_results/mistralai/Mistral-7B-Instruct-v0.2/{prefix}/layer_{layer}/ccs_lr_in_span/mode_concat/sup_weight_0/unsup_weight_1/lr_0.01/n_epochs_1000/n_orth_dirs_100/span_dirs_combo_convex/mistralai-Mistral-7B-Instruct-v0.2",
}


def get_orthogonal_ccs_directions_dir(
    model: str, layer: int, prefix: str, dataset
) -> str:
    base_dir_template = model_to_orthogonal_ccs_directions_dir_template[model]
    base_dir = base_dir_template.format(prefix=prefix, layer=layer)
    # Use the same dataset for the unlabeled and labeled datasets because the
    # CCS span generation only used the unlabeled dataset.
    datasets_str = load_utils.get_combined_datasets_str(
        dataset, labeled_datasets=dataset
    )
    load_orthogonal_directions_dir = os.path.join(base_dir, datasets_str)
    if not os.path.exists(load_orthogonal_directions_dir):
        raise ValueError(
            f"Could not find orthogonal directions directory: {load_orthogonal_directions_dir}"
        )

    return load_orthogonal_directions_dir


prefixes = ["normal"]
models = [
    # "meta-llama/Llama-2-7b-chat-hf",
    "meta-llama/Llama-2-13b-chat-hf",
    # "meta-llama/Meta-Llama-3-8B-Instruct",
    "mistralai/Mistral-7B-Instruct-v0.2",
]
layers = [-1, -5, -9]
num_orth_dirs_list = [2, 4, 8, 16]

# Iterate over product of parameters.
experiment_configs = []
for model, layer, prefix, num_orth_dirs in itertools.product(
    models, layers, prefixes, num_orth_dirs_list
):
    for train_ds, test_ds in itertools.product(ALL_DATASETS, ALL_DATASETS):
        # Use the orthogonal CCS directions from test_ds and train LR in
        # their span on train_ds.
        load_orthogonal_directions_dir = get_orthogonal_ccs_directions_dir(
            model, layer, prefix, test_ds
        )

        experiment_config = dict(
            DEFAULT_LR_IN_CCS_SPAN_ENV_VARS,
            MODEL=model,
            LAYER=layer,
            PREFIX=prefix,
            DATASETS=train_ds,  # Unused because the unsupervised weight is 0.
            LABELED_DATASETS=train_ds,
            EVAL_DATASETS=f'"{list(set([train_ds, test_ds]))}"',
            LOAD_ORTHOGONAL_DIRECTIONS_DIR=load_orthogonal_directions_dir,
            NUM_ORTHOGONAL_DIRECTIONS=num_orth_dirs,
        )
        env_vars = make_env_vars_for_experiment(experiment_config)
        print(make_train_command_for_experiment(env_vars, slurm=True), end="\n\n")