In [1]:
# import gpt3
import logging
import math
import os
from typing import List, Dict, Any, NewType

InputDataClass = NewType("InputDataClass", Any)
# os.environ["CUDA_VISIBLE_DEVICES"] = "7"
from transformers import (
    T5Config,
    T5ForConditionalGeneration,
    T5Tokenizer,
    HfArgumentParser,
    TrainingArguments,
    set_seed,
    EarlyStoppingCallback
)
from transformers.trainer_utils import EvaluationStrategy
from transformers.integrations import TensorBoardCallback
import transformers
from transformers import Trainer

from feature_conversion_methods import format_instance

from custom_args import (
    DataTrainingArguments,
    ModelArguments
)
from metrics import evaluate
import torch
import datasets
import git
import time
from datetime import datetime
import sys
from tqdm import trange
import random 
import pandas as pd 
import jsonlines
from copy import deepcopy 

logger = logging.getLogger(__name__)
transformers.logging.set_verbosity_info()
import re
def set_global_logging_level(level=logging.ERROR, prefices=[""]):
    """
    Override logging levels of different modules based on their name as a prefix.
    It needs to be invoked after the modules have been loaded so that their loggers have been initialized.

    Args:
        - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR
        - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional.
          Default is `[""]` to match all active loggers.
          The match is a case-sensitive `module_name.startswith(prefix)`
    """
    prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })')
    for name in logging.root.manager.loggerDict:
        if re.match(prefix_re, name):
            logging.getLogger(name).setLevel(level)
set_global_logging_level(logging.ERROR, ["datasets"])


CONFIG_MAPPING = {"t5": T5Config}
MODEL_MAPPING = {"t5": T5ForConditionalGeneration}
TOKENIZER_MAPPING = {"t5": T5Tokenizer}


def set_other_seeds(seed):
    torch.backends.cudnn.benchmark = False
    #torch.backends.cudnn.deterministic = True
    os.environ['PYTHONHASHSEED'] = str(seed)

# inspired by DefaultDataCollator from:
# https://github.com/huggingface/transformers/blob/master/src/transformers/data/data_collator.py
# modified to perform batch-level padding.
class SequenceCollator:
    def __init__(self, model, pad_token):
        self.model = model
        self.pad_token_mapping = {
            "labels": -100,
            "attention_mask": 0,
            "decoder_attention_mask": 0,
            "input_ids": pad_token,
        }

        self.columns = [
            "input_ids",
            "attention_mask",
            "labels",
            "decoder_attention_mask",
        ]

    def __call__(self, examples: List[Dict[str, InputDataClass]]) -> Dict[str, torch.Tensor]:
        # re-format inputs for training
        batch = {}
        for key in examples[0].keys():
            if key in self.columns:
                tmp_list = []
                for item in examples:
                    tmp_list.append(item[key])

                # pad lists to max length
                if isinstance(tmp_list[0], list):
                    max_length = max(map(len, tmp_list))
                    tmp_list = [
                        el + [self.pad_token_mapping[key]] * (max_length - len(el))
                        for el in tmp_list
                    ]

                batch[key] = torch.tensor(tmp_list, dtype=torch.long)
        return batch


In [2]:
og_start_time = time.time()

#parser = HfArgumentParser(
#    (ModelArguments, DataTrainingArguments, TrainingArguments)
#)
parser = HfArgumentParser(
    (ModelArguments, DataTrainingArguments, TrainingArguments)
)

model_args, data_args, training_args, unused_args = parser.parse_args_into_dataclasses(
    ["--model_type", "t5-base",
     "--tokenizer_name", "t5-base",
     "--task_name", "cos_e", 
     "--output_dir", "./cos_e_output", 
     "--n_shots", "10",
     "--do_train", "True"], return_remaining_strings=True)
if unused_args != []:
    raise ValueError(f"Received unused arguments: {unused_args}")
# make sure only one dataset split pick if manually specifying evaluation file

if model_args.use_gpt3:
    assert training_args.do_train
    assert not training_args.do_eval
    assert data_args.generations_filepath is None
    if data_args.gpt3_max_eval_size is not None:
        assert data_args.gpt3_max_eval_size <= data_args.fewshot_eval_size
        assert data_args.gpt3_max_eval_size % 2 == 0
        assert data_args.gpt3_max_eval_size % 3 == 0

if data_args.generations_filepath is not None:
    training_args.do_train = False
    training_args.do_eval = False
    if "train" in data_args.generations_filepath:
        data_args.train_predict = True
        data_args.test_predict = False
        data_args.dev_predict = False
    elif "test" in data_args.generations_filepath:
        data_args.train_predict = False
        data_args.test_predict = True
        data_args.dev_predict = False
    elif "validation" in data_args.generations_filepath:
        data_args.train_predict = False
        data_args.test_predict = False
        data_args.dev_predict = True

if not training_args.do_train and data_args.generations_filepath is None:
    if not model_args.pretrained_model_file:
        raise Exception(
            "if not training a model from scratch, must specify a trained model to load for evaluation"
        )

if training_args.do_train:
    # create a save directory and a logfile
    training_args.output_dir = os.path.join(
        training_args.output_dir, datetime.now().strftime("%m%d%y_%H%M%S")
    )
    training_args.logging_dir = training_args.output_dir
    assert not os.path.exists(training_args.output_dir)
    os.makedirs(training_args.output_dir)

    if (
            os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir)
            and training_args.do_train
            and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )
    handlers = [
        logging.FileHandler(os.path.join(training_args.output_dir, "logger.log")),
        logging.StreamHandler(),
    ]
else:
    # don't overwrite existing logfile or create new directory
    training_args.output_dir = model_args.pretrained_model_file
    handlers = [logging.StreamHandler()]

# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    handlers=handlers,
)
logger.warning(
    "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
    training_args.local_rank,
    training_args.device,
    training_args.n_gpu,
    bool(training_args.local_rank != -1),
    training_args.fp16,
)
logger.info("Save path: %s" % training_args.output_dir)

# get git hash and branch where deployed
repo = git.Repo(search_parent_directories=True)
git_hash = repo.head.object.hexsha
git_branch = repo.active_branch.name
logger.info("Git branch: %s" % git_branch)
logger.info("Git hash: %s" % git_hash)

model_class = "t5"
assert data_args.task_name in {"cos_e", "esnli", "sbic", "sensemaking", "ecqa"}

if training_args.do_train:
    # write command and args to file
    with open(
            os.path.join(training_args.output_dir, "commandline_args.txt"), "w"
    ) as f:
        f.write("Git branch: " + git_branch + "\n")
        f.write("Git hash: " + git_hash + "\n")
        f.write("Command:\n")
        f.write("\n".join(sys.argv[1:]))

# Set seed
set_seed(training_args.seed)
set_other_seeds(training_args.seed)

# Load pretrained model and tokenizer
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
08/25/2022 17:39:46 - INFO - __main__ -   Save path: ./cos_e_output/082522_173946
08/25/2022 17:39:46 - INFO - __main__ -   Git branch: dev
08/25/2022 17:39:46 - INFO - __main__ -   Git hash: b3e471e4130d052883f821ed4b4dd50e701fd4c6


# tokenizer and model

In [3]:
import logging
logger = logging.getLogger(__name__)
CONFIG_MAPPING = {"t5": T5Config}
MODEL_MAPPING = {"t5": T5ForConditionalGeneration}
TOKENIZER_MAPPING = {"t5": T5Tokenizer}
model_class = "t5"
tokenizer_name = TOKENIZER_MAPPING[model_class]
logger.info("Loading pretrained tokenizer...")
model_args.tokenizer_name='t5-base'
tokenizer = tokenizer_name.from_pretrained(model_args.tokenizer_name)#, cache_dir=model_args.cache_dir)

08/25/2022 17:39:46 - INFO - __main__ -   Loading pretrained tokenizer...
loading file https://huggingface.co/t5-base/resolve/main/spiece.model from cache at /home/huangyongfeng/.cache/huggingface/transformers/684a47ca6257e4ca71f0037771464c5b323e945fbc58697d2fad8a7dd1a2f8ba.3b69006860e7b5d0a63ffdddc01ddcd6b7c318a6f4fd793596552c741734c62d
loading file https://huggingface.co/t5-base/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/t5-base/resolve/main/special_tokens_map.json from cache at None
loading file https://huggingface.co/t5-base/resolve/main/tokenizer_config.json from cache at None
loading file https://huggingface.co/t5-base/resolve/main/tokenizer.json from cache at /home/huangyongfeng/.cache/huggingface/transformers/90de37880b5ff5ac7ab70ff0bd369f207e9b74133fa153c163d14c5bb0116207.8627f1bd5d270a9fd2e5a51c8bec3223896587cc3cfe13edeabb0992ab43c529
loading configuration file https://huggingface.co/t5-base/resolve/main/config.json from cache at /ho

In [4]:
import os

import torch
from torch import nn
from transformers.modeling_utils import (
    ModuleUtilsMixin, PushToHubMixin,
    logging, Union, Optional, Callable, unwrap_model, get_parameter_dtype,
    FLAX_WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME,
    is_offline_mode, is_remote_url, hf_bucket_url, cached_path
)

logger = logging.get_logger(__name__)


class PushToHubFriendlyModel(nn.Module, ModuleUtilsMixin, PushToHubMixin):
    def __init__(self):
        super().__init__()

    def save_pretrained(
            self,
            save_directory: Union[str, os.PathLike],
            save_config: bool = True,
            state_dict: Optional[dict] = None,
            save_function: Callable = torch.save,
            push_to_hub: bool = False,
            **kwargs,
    ):
        """
        Save a model and its configuration file to a directory, so that it can be re-loaded using the
        `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.

        Arguments:
            save_directory (:obj:`str` or :obj:`os.PathLike`):
                Directory to which to save. Will be created if it doesn't exist.
            save_config (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to save the config of the model. Useful when in distributed training like TPUs and need
                to call this function on all processes. In this case, set :obj:`save_config=True` only on the main
                process to avoid race conditions.
            state_dict (nested dictionary of :obj:`torch.Tensor`):
                The state dictionary of the model to save. Will default to :obj:`self.state_dict()`, but can be used to
                only save parts of the model or if special precautions need to be taken when recovering the state
                dictionary of a model (like when using model parallelism).
            save_function (:obj:`Callable`):
                The function to use to save the state dictionary. Useful on distributed training like TPUs when one
                need to replace :obj:`torch.save` by another method.
            push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to push your model to the Hugging Face model hub after saving it.

                .. warning::

                    Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with
                    :obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are
                    pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory
                    instead.

            kwargs:
                Additional key word arguments passed along to the
                :meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
        """
        if os.path.isfile(save_directory):
            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
            return

        if push_to_hub:
            commit_message = kwargs.pop("commit_message", None)
            repo = self._create_or_get_repo(save_directory, **kwargs)

        os.makedirs(save_directory, exist_ok=True)

        # Only save the model itself if we are using distributed training
        model_to_save = unwrap_model(self)

        # save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
        # we currently don't use this setting automatically, but may start to use with v5
        dtype = get_parameter_dtype(model_to_save)
        self.pretrain_model.config.torch_dtype = str(dtype).split(".")[1]

        # Attach architecture to the config
        self.pretrain_model.config.architectures = [model_to_save.__class__.__name__]

        # Save the config
        if save_config:
            self.pretrain_model.config.save_pretrained(save_directory)

        # Save the model
        if state_dict is None:
            state_dict = model_to_save.state_dict()

        # Handle the case where some state_dict keys shouldn't be saved
        # if self._keys_to_ignore_on_save is not None:
        #     state_dict = {k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save}

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
        save_function(state_dict, output_model_file)

        logger.info(f"Model weights saved in {output_model_file}")

        if push_to_hub:
            url = self._push_to_hub(repo, commit_message=commit_message)
            logger.info(f"Model pushed to the hub in this commit: {url}")

    def load(self, pretrained_model_name_or_path, *model_args, **kwargs):
        """
        Adopted and simplified from transformers.modeling_utils from_pretrained,
        but more similiar to load_state_dict(load the weight from anywhere into a create model).

        Just for downloading from huggingface platform.

        @param pretrained_model_name_or_path:
        @param model_args:
        @param kwargs:
        """
        config = kwargs.pop("config", None)
        state_dict = kwargs.pop("state_dict", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_tf = kwargs.pop("from_tf", False)
        from_flax = kwargs.pop("from_flax", False)
        ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
        local_files_only = kwargs.pop("local_files_only", False)
        use_auth_token = kwargs.pop("use_auth_token", None)
        revision = kwargs.pop("revision", None)
        mirror = kwargs.pop("mirror", None)
        from_pipeline = kwargs.pop("_from_pipeline", None)
        from_auto_class = kwargs.pop("_from_auto", False)
        _fast_init = kwargs.pop("_fast_init", True)
        torch_dtype = kwargs.pop("torch_dtype", None)

        from_pt = not (from_tf | from_flax)

        user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
        if from_pipeline is not None:
            user_agent["using_pipeline"] = from_pipeline

        if is_offline_mode() and not local_files_only:
            logger.info("Offline mode: forcing local_files_only=True")
            local_files_only = True

        # Load model
        if pretrained_model_name_or_path is not None:
            pretrained_model_name_or_path = str(pretrained_model_name_or_path)
            if os.path.isdir(pretrained_model_name_or_path):
                if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
                    # Load from a TF 1.0 checkpoint in priority if from_tf
                    archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
                elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint in priority if from_tf
                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
                elif from_flax and os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
                    # Load from a Flax checkpoint in priority if from_flax
                    archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
                else:
                    raise EnvironmentError(
                        f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + '.index', FLAX_WEIGHTS_NAME]} found in "
                        f"directory {pretrained_model_name_or_path} or `from_tf` and `from_flax` set to False."
                    )
            elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                if not from_tf:
                    raise ValueError(
                        f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set "
                        "from_tf to True to load from this checkpoint."
                    )
                archive_file = pretrained_model_name_or_path + ".index"
            else:
                # set correct filename
                if from_tf:
                    filename = TF2_WEIGHTS_NAME
                elif from_flax:
                    filename = FLAX_WEIGHTS_NAME
                else:
                    filename = WEIGHTS_NAME

                archive_file = hf_bucket_url(
                    pretrained_model_name_or_path,
                    filename=filename,
                    revision=revision,
                    mirror=mirror,
                )

            try:
                # Load from URL or cache if already cached
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                    use_auth_token=use_auth_token,
                    user_agent=user_agent,
                )
            except EnvironmentError as err:
                logger.error(err)
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                    f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                    f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
                )
                raise EnvironmentError(msg)

            if resolved_archive_file == archive_file:
                logger.info(f"loading weights file {archive_file}")
            else:
                logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
        else:
            resolved_archive_file = None

        # load pt weights early so that we know which dtype to init the model under
        if from_pt:
            if state_dict is None:
                try:
                    state_dict = torch.load(resolved_archive_file, map_location="cpu")
                except Exception:
                    raise OSError(
                        f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
                        f"at '{resolved_archive_file}'"
                        "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
                    )
        self.load_state_dict(state_dict, strict=True)


In [5]:
# if data_args.generations_filepath is None:
#     model_name = MODEL_MAPPING[model_class]
#     if model_args.pretrained_model_file:
#         model = T5ForConditionalGeneration.from_pretrained(model_args.pretrained_model_file)

#         if model_args.dropout_rate:
#             raise Exception("can't update/specify dropout currently when load pretrained model from directory")

#     elif model_args.pretrained:
#         # load pretrained model from HuggingFace
#         logger.info("Loading pretrained model")
#         if model_args.dropout_rate:
#             model = model_name.from_pretrained(model_args.model_type, dropout_rate=model_args.dropout_rate)
#         else:
#             model = model_name.from_pretrained(model_args.model_type)
#     else:
#         # load model from scratch with no pretrained weights
#         config_name = CONFIG_MAPPING[model_class]()
#         # TODO (Sarah): NOTE THIS ONLY DOES T5-BASE; PASS IN ARGS HERE^
#         logger.info(
#             "Training new model from scratch using default config (NOTE: SMALL MODELS ONLY FOR NOW)"
#         )
#         if model_args.dropout_rate:
#             raise Exception("sure you want to train a model from scratch?")
#         model = model_name.from_config(config_name)
#     model.resize_token_embeddings(len(tokenizer))
# else:
#     model = None
import torch
from torch import nn
from transformers import AutoTokenizer
# from PushToHubFriendlyModel
from modeling_auto import AutoModelForSeq2SeqLM
from modeling_bart import BartForConditionalGeneration
from modeling_t5 import T5ForConditionalGeneration

class Model(PushToHubFriendlyModel):
    def __init__(self, args):
        super().__init__()
        self.args = args

        """The prefix-tuning code"""

        self.preseqlen = args.prefix_tuning.prefix_sequence_length
        self.mid_dim = args.prefix_tuning.mid_dim

        print("prefix-tuning sequence length is {}.".format(self.preseqlen))

        # Load tokenizer and model.
        self.tokenizer = AutoTokenizer.from_pretrained(args.bert.location, use_fast=False)
        self.pretrain_model = AutoModelForSeq2SeqLM.from_pretrained(
            args.bert.location
        )
        self.config = self.pretrain_model.config

        if isinstance(self.pretrain_model, BartForConditionalGeneration):
            self.match_n_layer = self.config.decoder_layers
            self.match_n_head = self.config.decoder_attention_heads
        elif isinstance(self.pretrain_model, (T5ForConditionalGeneration)):
            self.match_n_layer = self.config.num_decoder_layers
            self.match_n_head = self.config.num_heads
        else:
            raise ValueError("Other models are not supported yet!")

        self.n_embd = self.config.d_model
        assert self.n_embd % self.match_n_head == 0
        self.match_n_embd = self.n_embd // self.match_n_head

        if args.special_tokens:
            self.tokenizer.add_tokens([v for k, v in args.special_tokens])
            self.pretrain_model.resize_token_embeddings(len(self.tokenizer))

        # Prefix related.
        self.register_buffer('input_tokens', torch.arange(self.preseqlen).long())

        self.wte = nn.Embedding(self.preseqlen, self.n_embd)
        self.control_trans = nn.Sequential(
            nn.Linear(self.n_embd, self.mid_dim),
            nn.Tanh(),
            nn.Linear(self.mid_dim, self.match_n_layer * 2 * self.n_embd),
        )
        if self.args.model.knowledge_usage == 'separate':
            self.knowledge_trans = nn.Sequential(
                nn.Linear(self.n_embd, self.mid_dim),
                nn.Tanh(),
                nn.Linear(self.mid_dim, self.match_n_layer * 2 * self.n_embd),
            )

        self.wte_enc = nn.Embedding(self.preseqlen, self.n_embd)
        self.control_trans_enc = nn.Sequential(
            nn.Linear(self.n_embd, self.mid_dim),
            nn.Tanh(),
            nn.Linear(self.mid_dim, self.match_n_layer * 2 * self.n_embd),
        )
        if self.args.model.knowledge_usage == 'separate':
            self.knowledge_trans_enc = nn.Sequential(
                nn.Linear(self.n_embd, self.mid_dim),
                nn.Tanh(),
                nn.Linear(self.mid_dim, self.match_n_layer * 2 * self.n_embd),
            )

        self.wte_dec = nn.Embedding(self.preseqlen, self.n_embd)
        self.control_trans_dec = nn.Sequential(
            nn.Linear(self.n_embd, self.mid_dim),
            nn.Tanh(),
            nn.Linear(self.mid_dim, self.match_n_layer * 2 * self.n_embd),
        )

        # Knowledge prompt.
        if self.args.model.knowledge_usage == 'separate':
            self.knowledge_trans_dec = nn.Sequential(
                nn.Linear(self.n_embd, self.mid_dim),
                nn.Tanh(),
                nn.Linear(self.mid_dim, self.match_n_layer * 2 * self.n_embd),
            )

        self.dropout = nn.Dropout(args.prefix_tuning.prefix_dropout)

#         if self.args.model.freeze_plm:
#             for param in self.pretrain_model.parameters():
#                 param.requires_grad = False
#         if self.args.model.freeze_prefix:
#             for param in self.wte.parameters():
#                 param.requires_grad = False
#             for param in self.control_trans.parameters():
#                 param.requires_grad = False
#             for param in self.wte_dec.parameters():
#                 param.requires_grad = False
#             for param in self.control_trans_dec.parameters():
#                 param.requires_grad = False
#             for param in self.wte_enc.parameters():
#                 param.requires_grad = False
#             for param in self.control_trans_enc.parameters():
#                 param.requires_grad = False

    def get_prompt(self, bsz=None, sample_size=1, description=None, knowledge=None):
        old_bsz = bsz
        bsz = bsz * sample_size
        input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1)
        temp_control = self.wte(input_tokens)
        if description is not None:
            temp_control = temp_control + description.repeat_interleave(sample_size, dim=0).unsqueeze(1)
        past_key_values = self.control_trans(temp_control)  # bsz, seqlen, layer*emb
        if knowledge is not None:
            past_key_values = torch.cat([past_key_values, self.knowledge_trans(knowledge.repeat_interleave(sample_size, dim=0))], dim=1)

        bsz, seqlen, _ = past_key_values.shape
        past_key_values = past_key_values.view(
            bsz, seqlen, self.match_n_layer * 2, self.match_n_head, self.match_n_embd
        )
        past_key_values = self.dropout(past_key_values)
        past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)

        # Cross prefix
        temp_control_dec = self.wte_dec(input_tokens)
        if description is not None:
            temp_control_dec = temp_control_dec + description.repeat_interleave(sample_size, dim=0).unsqueeze(1)
        past_key_values_dec = self.control_trans_dec(
            temp_control_dec
        )  # bsz, seqlen, layer*emb
        if knowledge is not None:
            past_key_values_dec = torch.cat([past_key_values_dec, self.knowledge_trans_dec(knowledge.repeat_interleave(sample_size, dim=0))], dim=1)

        bsz, seqlen, _ = past_key_values_dec.shape
        past_key_values_dec = past_key_values_dec.view(
            bsz, seqlen, self.match_n_layer * 2, self.match_n_head, self.match_n_embd
        )
        past_key_values_dec = self.dropout(past_key_values_dec)
        past_key_values_dec = past_key_values_dec.permute([2, 0, 3, 1, 4]).split(2)

        # Encoder prefix
        input_tokens_enc = (
            self.input_tokens.unsqueeze(0).expand(old_bsz, -1)
        )
        temp_control_enc = self.wte_enc(input_tokens_enc)
        if description is not None:
            temp_control_enc = temp_control_enc + description.unsqueeze(1)
        past_key_values_enc = self.control_trans_enc(
            temp_control_enc
        )  # bsz, seqlen, layer*emb
        if knowledge is not None:
            past_key_values_enc = torch.cat([past_key_values_enc, self.knowledge_trans_enc(knowledge)], dim=1)

        bsz_enc, seqlen, _ = past_key_values_enc.shape
        past_key_values_enc = past_key_values_enc.view(
            bsz_enc,
            seqlen,
            self.match_n_layer * 2,
            self.match_n_head,
            self.match_n_embd,
        )
        past_key_values_enc = self.dropout(past_key_values_enc)
        past_key_values_enc = past_key_values_enc.permute([2, 0, 3, 1, 4]).split(2)

        result = []
        for i, key_val in enumerate(past_key_values):
            temp = dict()
            temp["decoder_prompt"] = {
                "prev_key": key_val[0].contiguous(),
                "prev_value": key_val[1].contiguous(),
                "prev_key_padding_mask": torch.zeros(bsz, seqlen)
                    .to(key_val.device)
                    .bool()
                # bsz, preseqlen
            }
            key_val_dec = past_key_values_dec[i]
            temp["cross_attention_prompt"] = {
                "prev_key": key_val_dec[0].contiguous(),
                "prev_value": key_val_dec[1].contiguous(),
                "prev_key_padding_mask": torch.zeros(bsz, seqlen)
                    .to(key_val_dec.device)
                    .bool(),
            }
            key_val_enc = past_key_values_enc[i]
            temp["encoder_prompt"] = {
                "prev_key": key_val_enc[0].contiguous(),
                "prev_value": key_val_enc[1].contiguous(),
                "prev_key_padding_mask": torch.zeros(bsz_enc, seqlen)
                    .to(key_val_enc.device)
                    .bool(),
            }
            result.append(temp)

        return result

    def get_description_representation(self, kwargs):
        if self.args.model.use_description and self.args.model.map_description:
            description_input_ids = kwargs.pop("description_input_ids")
            description_attention_mask = kwargs.pop("description_attention_mask")
            if self.args.bert.location in ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]:
                description_outputs = self.pretrain_model.encoder(
                    input_ids=description_input_ids,
                    attention_mask=description_attention_mask,
                )
                description = description_outputs.last_hidden_state[:, 0]  # TODO: the first token from the encoder.
            elif self.args.bert.location in ["facebook/bart-base", "facebook/bart-large"]:
                description_outputs = self.pretrain_model.model.encoder(
                    input_ids=description_input_ids,
                    attention_mask=description_attention_mask,
                )
                description = description_outputs.last_hidden_state[:, 0]  # TODO: the first token from the encoder.
            else:
                raise ValueError()
        else:
            description = None

        return description

    def get_knowledge_representation(self, kwargs):
        if self.args.model.knowledge_usage == 'separate':
            knowledge_input_ids = kwargs.pop("knowledge_input_ids", None)
            knowledge_attention_mask = kwargs.pop("knowledge_attention_mask", None)
            if self.args.bert.location in ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]:
                knowledge_outputs = self.pretrain_model.encoder(
                    input_ids=knowledge_input_ids,
                    attention_mask=knowledge_attention_mask,
                )
                knowledge = knowledge_outputs.last_hidden_state
            elif self.args.bert.location in ["facebook/bart-base", "facebook/bart-large"]:
                knowledge_outputs = self.pretrain_model.model.encoder(
                    input_ids=knowledge_input_ids,
                    attention_mask=knowledge_attention_mask,
                )
                knowledge = knowledge_outputs.last_hidden_state
            else:
                raise ValueError()
        elif self.args.model.knowledge_usage == 'concatenate':
            knowledge = None
        else:
            raise ValueError()

        return knowledge

    def forward(self,
                input_ids,
                attention_mask,
                labels,
                **kwargs,
                ):
        bsz = input_ids.shape[0]

        # Encode description.
        description_representation = self.get_description_representation(kwargs)

        # Encode knowledge.
        knowledge_representation = self.get_knowledge_representation(kwargs)

        past_prompt = self.get_prompt(
            bsz=bsz, description=description_representation, knowledge=knowledge_representation,
        )

        loss = self.pretrain_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            past_prompt=past_prompt,
        ).loss
        return {'loss': loss}

    def generate(self,
                 input_ids,
                 attention_mask,
                 **kwargs):

        bsz = input_ids.shape[0]

        # Encode description.
        description_representation = self.get_description_representation(kwargs)

        # Encode knowledge.
        knowledge_representation = self.get_knowledge_representation(kwargs)

        past_prompt = self.get_prompt(
            bsz=bsz, sample_size=kwargs['num_beams'], description=description_representation, knowledge=knowledge_representation,
        )
        generated_ids = self.pretrain_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_prompt=past_prompt,
            use_cache=True,
            **kwargs,
        )

        return generated_ids


08/25/2022 17:39:53 - INFO - faiss.loader -   Loading faiss with AVX2 support.
08/25/2022 17:39:53 - INFO - faiss.loader -   Could not load library with AVX2 support due to:
ModuleNotFoundError("No module named 'faiss.swigfaiss_avx2'")
08/25/2022 17:39:53 - INFO - faiss.loader -   Loading faiss.
08/25/2022 17:39:53 - INFO - faiss.loader -   Successfully loaded faiss.


In [6]:
# [model]
# name = unified.prefixtuning
# use_description = True
# concatenate_description = True
# # Should be one of (separate, concatenate)
# knowledge_usage = concatenate
# freeze_plm = True
# freeze_prefix = False

# [dataset]
# data_store_path = ./data
# upsample_temp = 1

# [seq2seq]
# constructor = seq2seq_construction.meta_tuning
# patience = 200

# [arg_paths]
# fetaqa = META_TUNING/fetaqa.cfg

# [evaluate]
# tool = metrics.meta_tuning.evaluator

# [prefix_tuning]
# prefix_sequence_length = 10
# mid_dim = 512
# prefix_dropout = 0.0

# [special_tokens]
# less = ' <'
# less_or_equal = ' <='

# [bert]
# location = t5-base
# Configure

In [7]:
import argparse
import configparser
import datetime
import os

DEFAULT_CONFIGURE_DIR = "configure"
DEFAULT_DATASET_DIR = "data"
DEFAULT_MODEL_DIR = "models"


class Args(object):
    def __init__(self, contain=None):
        self.__self__ = contain
        self.__default__ = None
        self.__default__ = set(dir(self))

    def __call__(self):
        return self.__self__

    def __getattribute__(self, name):
        if name[:2] == "__" and name[-2:] == "__":
            return super().__getattribute__(name)
        if name not in dir(self):
            return None
        return super().__getattribute__(name)

    def __setattr__(self, name, value):
        if not (value is None) or (name[:2] == "__" and name[-2:] == "__"):
            return super().__setattr__(name, value)

    def __delattr__(self, name):
        if name in dir(self) and name not in self.__default__:
            super().__delattr__(name)

    def __iter__(self):
        # give args elements dictionary order to ensure its replicate-ability
        return sorted(list((arg, getattr(self, arg)) for arg in set(dir(self)) - self.__default__)).__iter__()

    def __len__(self):
        return len(set(dir(self)) - self.__default__)


class String(object):
    @staticmethod
    def to_basic(string):
        """
        Convert the String to what it really means.
        For example, "true" --> True as a bool value
        :param string:
        :return:
        """
        try:
            return int(string)
        except ValueError:
            try:
                return float(string)
            except ValueError:
                pass
        if string in ["True", "true"]:
            return True
        elif string in ["False", "false"]:
            return False
        else:
            return string.strip("\"'")  # for those we want to add space before and after the string


class Configure(object):
    @staticmethod
    def get_file_cfg(file):
        """
        get configurations in file.
        :param file:
        :return: configure args
        """
        cfgargs = Args()
        parser = configparser.ConfigParser()
        parser.read(file)
        for section in parser.sections():
            setattr(cfgargs, section, Args())
            for item in parser.items(section):
                setattr(getattr(cfgargs, section), item[0], String.to_basic(item[1]))
        return cfgargs

    @staticmethod
    def refresh_args_by_file_cfg(file, prev_args):
        args = Configure.get_file_cfg(file)
        if args.dir is not Args:
            args.dir = Args()
        args.dir.model = DEFAULT_MODEL_DIR
        args.dir.dataset = DEFAULT_DATASET_DIR
        args.dir.configure = DEFAULT_CONFIGURE_DIR
        for arg_name, arg in prev_args:
            if arg is None:
                continue
            if arg_name != "cfg":
                names = arg_name.split(".")
                cur = args
                for name in names[: -1]:
                    if getattr(cur, name) is None:
                        setattr(cur, name, Args())
                    cur = getattr(cur, name)
                if getattr(cur, names[-1]) is None:
                    setattr(cur, names[-1], arg)
        return args


    @staticmethod
    def Get(cfg):
        args = Configure.get_file_cfg(os.path.join(DEFAULT_CONFIGURE_DIR, cfg))

        if args.dir is not Args:
            args.dir = Args()
        args.dir.model = DEFAULT_MODEL_DIR
        args.dir.dataset = DEFAULT_DATASET_DIR
        args.dir.configure = DEFAULT_CONFIGURE_DIR
        return args


In [8]:
skt_args=Configure.get_file_cfg("./cos_e_prefix.cfg")

In [9]:
model = Model(skt_args)

prefix-tuning sequence length is 10.


Could not locate the tokenizer configuration file, will try to use the model config instead.
loading configuration file https://huggingface.co/t5-base/resolve/main/config.json from cache at /home/huangyongfeng/.cache/huggingface/transformers/91e9fe874e06c44883b535d6c950b8b89d6eaa3298d8e7fb3b2c78039e9f8b7b.66b9637a52aa11e9285cdd6e668cc0df14b3bcf0b6674cf3ba5353c542649637
Model config T5Config {
  "architectures": [
    "T5WithLMHeadModel"
  ],
  "d_ff": 3072,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "gradient_checkpointing": false,
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,

# datasets

In [19]:
import datasets
class SequenceCollator:
    def __init__(self, pad_token):
        # self.pad_token_mapping = {
        #     "lm_labels": -100,
        #     "attention_mask": 0,
        #     "decoder_attention_mask": 0,
        #     "input_ids": pad_token,
        # }
        # self.columns = [
        #     "input_ids",
        #     "attention_mask",
        #     "lm_labels",
        #     "decoder_attention_mask",
        # ]
        self.pad_token_mapping = {
            "labels": -100,
            "attention_mask": 0,
            "decoder_attention_mask": 0,
            "input_ids": pad_token,
        }
        self.columns = [
            "input_ids",
            "attention_mask",
            "labels",
            "decoder_attention_mask",
        ]

    def collate_batch(self, examples):

        # batch inputs for training
        batch = {}
        for key in examples[0].keys():
            if key in self.columns:
                tmp_list = []
                for item in examples:
                    tmp_list.append(item[key])

                # pad lists to max length
                if isinstance(tmp_list[0], list):
                    max_length = max(map(len, tmp_list))
                    tmp_list = [
                        el + [self.pad_token_mapping[key]] * (max_length - len(el))
                        for el in tmp_list
                    ]

                batch[key] = torch.tensor(tmp_list, dtype=torch.long)
        return batch
    
    def __call__(self, examples: List[Dict[str, InputDataClass]]) -> Dict[str, torch.Tensor]:
        # re-format inputs for training
        batch = {}
        for key in examples[0].keys():
            if key in self.columns:
                tmp_list = []
                for item in examples:
                    tmp_list.append(item[key])

                # pad lists to max length
                if isinstance(tmp_list[0], list):
                    max_length = max(map(len, tmp_list))
                    tmp_list = [
                        el + [self.pad_token_mapping[key]] * (max_length - len(el))
                        for el in tmp_list
                    ]

                batch[key] = torch.tensor(tmp_list, dtype=torch.long)
        return batch
dataset = datasets.load_dataset(data_args.task_name, version_arg)

  0%|          | 0/2 [00:00<?, ?it/s]

In [27]:
seq_collector = SequenceCollator(0)
train_ds = seq_collector.__call__(dataset['train'])
train_ds
dataset['train'][0].keys()

dict_keys(['id', 'choices', 'question', 'abstractive_explanation', 'answer', 'extractive_explanation'])

In [10]:
data_splits = {'train': None, 'validation': None, 'test': None}
original_data_splits = {'train': None, 'validation': None, 'test': None}  
data_args.io_format="t5_fewshot_infilling_with_choices"
# Data loading from huggingface's datasets
if data_args.task_name in {"cos_e", "esnli"}:
    version_arg = None
    if data_args.task_name == "cos_e":
        assert data_args.version_name in {"v1.11", "v1.0"}
        version_arg = data_args.version_name

    load_train = True
    if (not training_args.do_train
        and not training_args.do_eval
        and not data_args.train_predict
    ):
        # don't load training dataset
        dataset = {}
        dataset["train"] = None
        dataset["validation"] = datasets.load_dataset(
            data_args.task_name, version_arg, split="validation"
        )
        data_splits['validation'] = dataset["validation"]

        if data_args.task_name == "esnli":
            dataset["test"] = datasets.load_dataset(data_args.task_name, split="test")
            data_splits['test'] = dataset["test"]
        load_train = False
    else:
        dataset = datasets.load_dataset(data_args.task_name, version_arg)

        if data_args.n_shots > 0: # Shots = number of training examples **per label** 
            if data_args.task_name == 'esnli': # Construct a *balanced* random sample of the size `data_args.n_shots*len(labels)` (for train) or `data_args.fewshot_eval_size` (for eval)
                for split in ["train", "validation", "test"]:
                    split_data = dataset[split]
                    label_subsets = []
                    labels = split_data.features['label'].names
                    sample_size = data_args.n_shots if split == "train" else int(data_args.fewshot_eval_size/len(labels))
                    if data_args.gpt3_max_eval_size is not None and split != 'train':
                        assert len(labels) == 3
                        sample_size = data_args.gpt3_max_eval_size // len(labels)
                    for label in labels:
                        # The following is a hack to only run on `neutral` labels of `esnli` to get data for human eval
                        # if data_args.gpt3_max_eval_size is not None and split != 'train' and label != 'neutral':
                        #     continue
                        label_int = split_data.features['label'].str2int(label)
                        label_set = split_data.filter(lambda example: example['label'] == label_int).shuffle() # all instances of labeled as `label`
                        label_subset = label_set.select(range(sample_size)) #select `sample_size` random instances labeled as `label`
                        label_subsets.append(label_subset)
                    dataset[split] = datasets.concatenate_datasets(label_subsets) #merge all label-specific instances
            elif data_args.task_name == 'cos_e': 
                for split in ["train", "validation"]: 
                    split_data = dataset[split]
                    sample_size = data_args.n_shots if split == "train" else int(data_args.fewshot_eval_size) #Shots for QA are not label-specific, i.e., `n_shots` is the training data size
                    if data_args.gpt3_max_eval_size is not None and split != 'train':
                        sample_size = data_args.gpt3_max_eval_size
                    dataset[split] = split_data#.shuffle().select(range(sample_size)) # select `sample_size` random instances
            else: 
                raise ValueError('Only cos_e and esnli are supported by Huggingface datasets.')
    # Apply method, and format dataset to torch.Tensor outputs
    for split in dataset.keys():
        if dataset[split] is not None:
            dataset[split] = dataset[split].map(
                lambda x: format_instance(
                    x,
                    tokenizer,
                    data_args.explanation_sep,
                    datasource=data_args.task_name,
                    io_format=data_args.io_format
                ),
                batched=False,
                load_from_cache_file=False,
            )
    data_splits["train"] = deepcopy(dataset["train"])
    data_splits["validation"] = deepcopy(dataset["validation"])
    if data_args.task_name == "esnli":
        data_splits["test"] = deepcopy(dataset["test"])

    original_data_splits["train"] = deepcopy(dataset["train"])
    original_data_splits["validation"] = deepcopy(dataset["validation"])
    if data_args.task_name == "esnli":
        original_data_splits["test"] = deepcopy(dataset["test"])

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/9741 [00:00<?, ?ex/s]

  0%|          | 0/1221 [00:00<?, ?ex/s]

In [11]:
class SequenceCollator:
    def __init__(self, model, pad_token):
        self.model = model
        self.pad_token_mapping = {
            "labels": -100,
            "attention_mask": 0,
            "decoder_attention_mask": 0,
            "input_ids": pad_token,
        }

        self.columns = [
            "input_ids",
            "attention_mask",
            "labels",
            "decoder_attention_mask",
        ]

    def __call__(self, examples: List[Dict[str, InputDataClass]]) -> Dict[str, torch.Tensor]:
        # re-format inputs for training
        batch = {}
        for key in examples[0].keys():
            if key in self.columns:
                tmp_list = []
                for item in examples:
                    tmp_list.append(item[key])

                # pad lists to max length
                if isinstance(tmp_list[0], list):
                    max_length = max(map(len, tmp_list))
                    tmp_list = [
                        el + [self.pad_token_mapping[key]] * (max_length - len(el))
                        for el in tmp_list
                    ]

                batch[key] = torch.tensor(tmp_list, dtype=torch.long)
        return batch

In [12]:
# os.environ["WANDB_DISABLED"] = "True"
if data_args.generations_filepath is None:
    callbacks = [TensorBoardCallback()]
    if data_args.early_stopping_patience > 0:
        callbacks.append(EarlyStoppingCallback(early_stopping_patience=data_args.early_stopping_patience))
        training_args.load_best_model_at_end = True
    else:
        training_args.load_best_model_at_end = False  # use the last model state
    training_args.metric_for_best_model = 'eval_loss'
    training_args.greater_is_better = False
    if training_args.eval_steps is None:
        training_args.evaluation_strategy = EvaluationStrategy.EPOCH
    else:
        training_args.evaluation_strategy = EvaluationStrategy.STEPS

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=data_splits['train'],
        eval_dataset=data_splits['validation'],
        data_collator=SequenceCollator(
            model=model_class, pad_token=tokenizer.pad_token_id
        ),
        callbacks=callbacks,
    )

# Training. Don't train if it is use_gpt3
if training_args.do_train and not model_args.use_gpt3:
    start_time = time.time()
    trainer.train()
    train_time = time.time() - start_time
    model = trainer.model
else:
    start_time = time.time()
    train_time = time.time() - start_time

You are adding a <class 'transformers.integrations.TensorBoardCallback'> to the callbacks of this Trainer, but there is already one. The currentlist of callbacks is
:DefaultFlowCallback
TensorBoardCallback
WandbCallback
The following columns in the training set  don't have a corresponding argument in `Model.forward` and have been ignored: id, choices, question, question_encoding, abstractive_explanation, answer, extractive_explanation, decoder_attention_mask.
***** Running training *****
  Num examples = 9741
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 3654
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: Currently logged in as: [33mcuhk_lavilab[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.13.2 is available!  To upgrade, please ru

Step,Training Loss,Validation Loss


Saving model checkpoint to ./cos_e_output/082522_173946/checkpoint-500
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Saving model checkpoint to ./cos_e_output/082522_173946/checkpoint-1000
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Saving model checkpoint to ./cos_e_output/082522_173946/checkpoint-1500
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Saving model checkpoint to ./cos_e_output/082522_173946/checkpoint-2000
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Saving model checkpoint to ./cos_e_output/082522_173946/checkpoint-2500
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Saving model checkpoint to ./cos_e_output/082522_173946/checkpoint-3000
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Saving model checkpoint to ./cos_e_output/082522_173946/checkpoint-3500
Trainer.model is not a `PreTrainedModel`, only saving its state dict.


Train

In [13]:
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mcuhk_lavilab[0m (use `wandb login --relogin` to force relogin)


In [14]:
import wandb

In [None]:
#OK 算是run起来了 剩下的问题就是要去测试出rationale的效果了的 这个明天上午
#融合一下两边的模型，开始评测出rationale的结果  明天把这个事情搞定 然后可以产出好结果的话 那这篇文章就基本有0-1的结果的了
下步就是把llm的知识拿过来了的 再刷刷效果的