In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import torch
from torch import nn
from transformers.modeling_utils import (
    ModuleUtilsMixin, PushToHubMixin,
    logging, Union, Optional, Callable, unwrap_model, get_parameter_dtype,
    FLAX_WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME,
    is_offline_mode, is_remote_url, hf_bucket_url, cached_path
)

logger = logging.get_logger(__name__)


class PushToHubFriendlyModel(nn.Module, ModuleUtilsMixin, PushToHubMixin):
    def __init__(self):
        super().__init__()

    def save_pretrained(
            self,
            save_directory: Union[str, os.PathLike],
            save_config: bool = True,
            state_dict: Optional[dict] = None,
            save_function: Callable = torch.save,
            push_to_hub: bool = False,
            **kwargs,
    ):
        """
        Save a model and its configuration file to a directory, so that it can be re-loaded using the
        `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.

        Arguments:
            save_directory (:obj:`str` or :obj:`os.PathLike`):
                Directory to which to save. Will be created if it doesn't exist.
            save_config (:obj:`bool`, `optional`, defaults to :obj:`True`):
                Whether or not to save the config of the model. Useful when in distributed training like TPUs and need
                to call this function on all processes. In this case, set :obj:`save_config=True` only on the main
                process to avoid race conditions.
            state_dict (nested dictionary of :obj:`torch.Tensor`):
                The state dictionary of the model to save. Will default to :obj:`self.state_dict()`, but can be used to
                only save parts of the model or if special precautions need to be taken when recovering the state
                dictionary of a model (like when using model parallelism).
            save_function (:obj:`Callable`):
                The function to use to save the state dictionary. Useful on distributed training like TPUs when one
                need to replace :obj:`torch.save` by another method.
            push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to push your model to the Hugging Face model hub after saving it.

                .. warning::

                    Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with
                    :obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are
                    pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory
                    instead.

            kwargs:
                Additional key word arguments passed along to the
                :meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
        """
        if os.path.isfile(save_directory):
            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
            return

        if push_to_hub:
            commit_message = kwargs.pop("commit_message", None)
            repo = self._create_or_get_repo(save_directory, **kwargs)

        os.makedirs(save_directory, exist_ok=True)

        # Only save the model itself if we are using distributed training
        model_to_save = unwrap_model(self)

        # save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
        # we currently don't use this setting automatically, but may start to use with v5
        dtype = get_parameter_dtype(model_to_save)
        self.pretrain_model.config.torch_dtype = str(dtype).split(".")[1]

        # Attach architecture to the config
        self.pretrain_model.config.architectures = [model_to_save.__class__.__name__]

        # Save the config
        if save_config:
            self.pretrain_model.config.save_pretrained(save_directory)

        # Save the model
        if state_dict is None:
            state_dict = model_to_save.state_dict()

        # Handle the case where some state_dict keys shouldn't be saved
        # if self._keys_to_ignore_on_save is not None:
        #     state_dict = {k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save}

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
        save_function(state_dict, output_model_file)

        logger.info(f"Model weights saved in {output_model_file}")

        if push_to_hub:
            url = self._push_to_hub(repo, commit_message=commit_message)
            logger.info(f"Model pushed to the hub in this commit: {url}")

    def load(self, pretrained_model_name_or_path, *model_args, **kwargs):
        """
        Adopted and simplified from transformers.modeling_utils from_pretrained,
        but more similiar to load_state_dict(load the weight from anywhere into a create model).

        Just for downloading from huggingface platform.

        @param pretrained_model_name_or_path:
        @param model_args:
        @param kwargs:
        """
        config = kwargs.pop("config", None)
        state_dict = kwargs.pop("state_dict", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_tf = kwargs.pop("from_tf", False)
        from_flax = kwargs.pop("from_flax", False)
        ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        output_loading_info = kwargs.pop("output_loading_info", False)
        local_files_only = kwargs.pop("local_files_only", False)
        use_auth_token = kwargs.pop("use_auth_token", None)
        revision = kwargs.pop("revision", None)
        mirror = kwargs.pop("mirror", None)
        from_pipeline = kwargs.pop("_from_pipeline", None)
        from_auto_class = kwargs.pop("_from_auto", False)
        _fast_init = kwargs.pop("_fast_init", True)
        torch_dtype = kwargs.pop("torch_dtype", None)

        from_pt = not (from_tf | from_flax)

        user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
        if from_pipeline is not None:
            user_agent["using_pipeline"] = from_pipeline

        if is_offline_mode() and not local_files_only:
            logger.info("Offline mode: forcing local_files_only=True")
            local_files_only = True

        # Load model
        if pretrained_model_name_or_path is not None:
            pretrained_model_name_or_path = str(pretrained_model_name_or_path)
            if os.path.isdir(pretrained_model_name_or_path):
                if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
                    # Load from a TF 1.0 checkpoint in priority if from_tf
                    archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
                elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
                    # Load from a TF 2.0 checkpoint in priority if from_tf
                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
                elif from_flax and os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
                    # Load from a Flax checkpoint in priority if from_flax
                    archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
                else:
                    raise EnvironmentError(
                        f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + '.index', FLAX_WEIGHTS_NAME]} found in "
                        f"directory {pretrained_model_name_or_path} or `from_tf` and `from_flax` set to False."
                    )
            elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
                if not from_tf:
                    raise ValueError(
                        f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set "
                        "from_tf to True to load from this checkpoint."
                    )
                archive_file = pretrained_model_name_or_path + ".index"
            else:
                # set correct filename
                if from_tf:
                    filename = TF2_WEIGHTS_NAME
                elif from_flax:
                    filename = FLAX_WEIGHTS_NAME
                else:
                    filename = WEIGHTS_NAME

                archive_file = hf_bucket_url(
                    pretrained_model_name_or_path,
                    filename=filename,
                    revision=revision,
                    mirror=mirror,
                )

            try:
                # Load from URL or cache if already cached
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                    use_auth_token=use_auth_token,
                    user_agent=user_agent,
                )
            except EnvironmentError as err:
                logger.error(err)
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                    f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                    f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
                )
                raise EnvironmentError(msg)

            if resolved_archive_file == archive_file:
                logger.info(f"loading weights file {archive_file}")
            else:
                logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
        else:
            resolved_archive_file = None

        # load pt weights early so that we know which dtype to init the model under
        if from_pt:
            if state_dict is None:
                try:
                    state_dict = torch.load(resolved_archive_file, map_location="cpu")
                except Exception:
                    raise OSError(
                        f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
                        f"at '{resolved_archive_file}'"
                        "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
                    )
        self.load_state_dict(state_dict, strict=True)


In [2]:
import torch
from torch import nn
from transformers import AutoTokenizer
# from PushToHubFriendlyModel
from modeling_auto import AutoModelForSeq2SeqLM
from modeling_bart import BartForConditionalGeneration
from modeling_t5 import T5ForConditionalGeneration

class Model(PushToHubFriendlyModel):
    def __init__(self, args):
        super().__init__()
        self.args = args

        """The prefix-tuning code"""

        self.preseqlen = args.prefix_tuning.prefix_sequence_length
        self.mid_dim = args.prefix_tuning.mid_dim

        print("prefix-tuning sequence length is {}.".format(self.preseqlen))

        # Load tokenizer and model.
        self.tokenizer = AutoTokenizer.from_pretrained(args.bert.location, use_fast=False)
        self.pretrain_model = AutoModelForSeq2SeqLM.from_pretrained(
            args.bert.location
        )
        self.config = self.pretrain_model.config

        if isinstance(self.pretrain_model, BartForConditionalGeneration):
            self.match_n_layer = self.config.decoder_layers
            self.match_n_head = self.config.decoder_attention_heads
        elif isinstance(self.pretrain_model, (T5ForConditionalGeneration)):
            self.match_n_layer = self.config.num_decoder_layers
            self.match_n_head = self.config.num_heads
        else:
            raise ValueError("Other models are not supported yet!")

        self.n_embd = self.config.d_model
        assert self.n_embd % self.match_n_head == 0
        self.match_n_embd = self.n_embd // self.match_n_head

        if args.special_tokens:
            self.tokenizer.add_tokens([v for k, v in args.special_tokens])
            self.pretrain_model.resize_token_embeddings(len(self.tokenizer))

        # Prefix related.
        self.register_buffer('input_tokens', torch.arange(self.preseqlen).long())

        self.wte = nn.Embedding(self.preseqlen, self.n_embd)
        self.control_trans = nn.Sequential(
            nn.Linear(self.n_embd, self.mid_dim),
            nn.Tanh(),
            nn.Linear(self.mid_dim, self.match_n_layer * 2 * self.n_embd),
        )
        if self.args.model.knowledge_usage == 'separate':
            self.knowledge_trans = nn.Sequential(
                nn.Linear(self.n_embd, self.mid_dim),
                nn.Tanh(),
                nn.Linear(self.mid_dim, self.match_n_layer * 2 * self.n_embd),
            )

        self.wte_enc = nn.Embedding(self.preseqlen, self.n_embd)
        self.control_trans_enc = nn.Sequential(
            nn.Linear(self.n_embd, self.mid_dim),
            nn.Tanh(),
            nn.Linear(self.mid_dim, self.match_n_layer * 2 * self.n_embd),
        )
        if self.args.model.knowledge_usage == 'separate':
            self.knowledge_trans_enc = nn.Sequential(
                nn.Linear(self.n_embd, self.mid_dim),
                nn.Tanh(),
                nn.Linear(self.mid_dim, self.match_n_layer * 2 * self.n_embd),
            )

        self.wte_dec = nn.Embedding(self.preseqlen, self.n_embd)
        self.control_trans_dec = nn.Sequential(
            nn.Linear(self.n_embd, self.mid_dim),
            nn.Tanh(),
            nn.Linear(self.mid_dim, self.match_n_layer * 2 * self.n_embd),
        )

        # Knowledge prompt.
        if self.args.model.knowledge_usage == 'separate':
            self.knowledge_trans_dec = nn.Sequential(
                nn.Linear(self.n_embd, self.mid_dim),
                nn.Tanh(),
                nn.Linear(self.mid_dim, self.match_n_layer * 2 * self.n_embd),
            )

        self.dropout = nn.Dropout(args.prefix_tuning.prefix_dropout)

        if self.args.model.freeze_plm:
            for param in self.pretrain_model.parameters():
                param.requires_grad = False
        if self.args.model.freeze_prefix:
            for param in self.wte.parameters():
                param.requires_grad = False
            for param in self.control_trans.parameters():
                param.requires_grad = False
            for param in self.wte_dec.parameters():
                param.requires_grad = False
            for param in self.control_trans_dec.parameters():
                param.requires_grad = False
            for param in self.wte_enc.parameters():
                param.requires_grad = False
            for param in self.control_trans_enc.parameters():
                param.requires_grad = False

    def get_prompt(self, bsz=None, sample_size=1, description=None, knowledge=None):
        old_bsz = bsz
        bsz = bsz * sample_size
        input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1)
        temp_control = self.wte(input_tokens)
        if description is not None:
            temp_control = temp_control + description.repeat_interleave(sample_size, dim=0).unsqueeze(1)
        past_key_values = self.control_trans(temp_control)  # bsz, seqlen, layer*emb
        if knowledge is not None:
            past_key_values = torch.cat([past_key_values, self.knowledge_trans(knowledge.repeat_interleave(sample_size, dim=0))], dim=1)

        bsz, seqlen, _ = past_key_values.shape
        past_key_values = past_key_values.view(
            bsz, seqlen, self.match_n_layer * 2, self.match_n_head, self.match_n_embd
        )
        past_key_values = self.dropout(past_key_values)
        past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)

        # Cross prefix
        temp_control_dec = self.wte_dec(input_tokens)
        if description is not None:
            temp_control_dec = temp_control_dec + description.repeat_interleave(sample_size, dim=0).unsqueeze(1)
        past_key_values_dec = self.control_trans_dec(
            temp_control_dec
        )  # bsz, seqlen, layer*emb
        if knowledge is not None:
            past_key_values_dec = torch.cat([past_key_values_dec, self.knowledge_trans_dec(knowledge.repeat_interleave(sample_size, dim=0))], dim=1)

        bsz, seqlen, _ = past_key_values_dec.shape
        past_key_values_dec = past_key_values_dec.view(
            bsz, seqlen, self.match_n_layer * 2, self.match_n_head, self.match_n_embd
        )
        past_key_values_dec = self.dropout(past_key_values_dec)
        past_key_values_dec = past_key_values_dec.permute([2, 0, 3, 1, 4]).split(2)

        # Encoder prefix
        input_tokens_enc = (
            self.input_tokens.unsqueeze(0).expand(old_bsz, -1)
        )
        temp_control_enc = self.wte_enc(input_tokens_enc)
        if description is not None:
            temp_control_enc = temp_control_enc + description.unsqueeze(1)
        past_key_values_enc = self.control_trans_enc(
            temp_control_enc
        )  # bsz, seqlen, layer*emb
        if knowledge is not None:
            past_key_values_enc = torch.cat([past_key_values_enc, self.knowledge_trans_enc(knowledge)], dim=1)

        bsz_enc, seqlen, _ = past_key_values_enc.shape
        past_key_values_enc = past_key_values_enc.view(
            bsz_enc,
            seqlen,
            self.match_n_layer * 2,
            self.match_n_head,
            self.match_n_embd,
        )
        past_key_values_enc = self.dropout(past_key_values_enc)
        past_key_values_enc = past_key_values_enc.permute([2, 0, 3, 1, 4]).split(2)

        result = []
        for i, key_val in enumerate(past_key_values):
            temp = dict()
            temp["decoder_prompt"] = {
                "prev_key": key_val[0].contiguous(),
                "prev_value": key_val[1].contiguous(),
                "prev_key_padding_mask": torch.zeros(bsz, seqlen)
                    .to(key_val.device)
                    .bool()
                # bsz, preseqlen
            }
            key_val_dec = past_key_values_dec[i]
            temp["cross_attention_prompt"] = {
                "prev_key": key_val_dec[0].contiguous(),
                "prev_value": key_val_dec[1].contiguous(),
                "prev_key_padding_mask": torch.zeros(bsz, seqlen)
                    .to(key_val_dec.device)
                    .bool(),
            }
            key_val_enc = past_key_values_enc[i]
            temp["encoder_prompt"] = {
                "prev_key": key_val_enc[0].contiguous(),
                "prev_value": key_val_enc[1].contiguous(),
                "prev_key_padding_mask": torch.zeros(bsz_enc, seqlen)
                    .to(key_val_enc.device)
                    .bool(),
            }
            result.append(temp)

        return result

    def get_description_representation(self, kwargs):
        if self.args.model.use_description and self.args.model.map_description:
            description_input_ids = kwargs.pop("description_input_ids")
            description_attention_mask = kwargs.pop("description_attention_mask")
            if self.args.bert.location in ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]:
                description_outputs = self.pretrain_model.encoder(
                    input_ids=description_input_ids,
                    attention_mask=description_attention_mask,
                )
                description = description_outputs.last_hidden_state[:, 0]  # TODO: the first token from the encoder.
            elif self.args.bert.location in ["facebook/bart-base", "facebook/bart-large"]:
                description_outputs = self.pretrain_model.model.encoder(
                    input_ids=description_input_ids,
                    attention_mask=description_attention_mask,
                )
                description = description_outputs.last_hidden_state[:, 0]  # TODO: the first token from the encoder.
            else:
                raise ValueError()
        else:
            description = None

        return description

    def get_knowledge_representation(self, kwargs):
        if self.args.model.knowledge_usage == 'separate':
            knowledge_input_ids = kwargs.pop("knowledge_input_ids", None)
            knowledge_attention_mask = kwargs.pop("knowledge_attention_mask", None)
            if self.args.bert.location in ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]:
                knowledge_outputs = self.pretrain_model.encoder(
                    input_ids=knowledge_input_ids,
                    attention_mask=knowledge_attention_mask,
                )
                knowledge = knowledge_outputs.last_hidden_state
            elif self.args.bert.location in ["facebook/bart-base", "facebook/bart-large"]:
                knowledge_outputs = self.pretrain_model.model.encoder(
                    input_ids=knowledge_input_ids,
                    attention_mask=knowledge_attention_mask,
                )
                knowledge = knowledge_outputs.last_hidden_state
            else:
                raise ValueError()
        elif self.args.model.knowledge_usage == 'concatenate':
            knowledge = None
        else:
            raise ValueError()

        return knowledge

    def forward(self,
                input_ids,
                attention_mask,
                labels,
                **kwargs,
                ):
        bsz = input_ids.shape[0]

        # Encode description.
        description_representation = self.get_description_representation(kwargs)

        # Encode knowledge.
        knowledge_representation = self.get_knowledge_representation(kwargs)

        past_prompt = self.get_prompt(
            bsz=bsz, description=description_representation, knowledge=knowledge_representation,
        )

        loss = self.pretrain_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            past_prompt=past_prompt,
        ).loss
        return {'loss': loss}

    def generate(self,
                 input_ids,
                 attention_mask,
                 **kwargs):

        bsz = input_ids.shape[0]

        # Encode description.
        description_representation = self.get_description_representation(kwargs)

        # Encode knowledge.
        knowledge_representation = self.get_knowledge_representation(kwargs)

        past_prompt = self.get_prompt(
            bsz=bsz, sample_size=kwargs['num_beams'], description=description_representation, knowledge=knowledge_representation,
        )
        generated_ids = self.pretrain_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_prompt=past_prompt,
            use_cache=True,
            **kwargs,
        )

        return generated_ids


In [3]:
import argparse
import configparser
import datetime
import os

DEFAULT_CONFIGURE_DIR = "configure"
DEFAULT_DATASET_DIR = "data"
DEFAULT_MODEL_DIR = "models"


class Args(object):
    def __init__(self, contain=None):
        self.__self__ = contain
        self.__default__ = None
        self.__default__ = set(dir(self))

    def __call__(self):
        return self.__self__

    def __getattribute__(self, name):
        if name[:2] == "__" and name[-2:] == "__":
            return super().__getattribute__(name)
        if name not in dir(self):
            return None
        return super().__getattribute__(name)

    def __setattr__(self, name, value):
        if not (value is None) or (name[:2] == "__" and name[-2:] == "__"):
            return super().__setattr__(name, value)

    def __delattr__(self, name):
        if name in dir(self) and name not in self.__default__:
            super().__delattr__(name)

    def __iter__(self):
        # give args elements dictionary order to ensure its replicate-ability
        return sorted(list((arg, getattr(self, arg)) for arg in set(dir(self)) - self.__default__)).__iter__()

    def __len__(self):
        return len(set(dir(self)) - self.__default__)


class String(object):
    @staticmethod
    def to_basic(string):
        """
        Convert the String to what it really means.
        For example, "true" --> True as a bool value
        :param string:
        :return:
        """
        try:
            return int(string)
        except ValueError:
            try:
                return float(string)
            except ValueError:
                pass
        if string in ["True", "true"]:
            return True
        elif string in ["False", "false"]:
            return False
        else:
            return string.strip("\"'")  # for those we want to add space before and after the string


class Configure(object):
    @staticmethod
    def get_file_cfg(file):
        """
        get configurations in file.
        :param file:
        :return: configure args
        """
        cfgargs = Args()
        parser = configparser.ConfigParser()
        parser.read(file)
        for section in parser.sections():
            setattr(cfgargs, section, Args())
            for item in parser.items(section):
                setattr(getattr(cfgargs, section), item[0], String.to_basic(item[1]))
        return cfgargs

    @staticmethod
    def refresh_args_by_file_cfg(file, prev_args):
        args = Configure.get_file_cfg(file)
        if args.dir is not Args:
            args.dir = Args()
        args.dir.model = DEFAULT_MODEL_DIR
        args.dir.dataset = DEFAULT_DATASET_DIR
        args.dir.configure = DEFAULT_CONFIGURE_DIR
        for arg_name, arg in prev_args:
            if arg is None:
                continue
            if arg_name != "cfg":
                names = arg_name.split(".")
                cur = args
                for name in names[: -1]:
                    if getattr(cur, name) is None:
                        setattr(cur, name, Args())
                    cur = getattr(cur, name)
                if getattr(cur, names[-1]) is None:
                    setattr(cur, names[-1], arg)
        return args


    @staticmethod
    def Get(cfg):
        args = Configure.get_file_cfg(os.path.join(DEFAULT_CONFIGURE_DIR, cfg))

        if args.dir is not Args:
            args.dir = Args()
        args.dir.model = DEFAULT_MODEL_DIR
        args.dir.dataset = DEFAULT_DATASET_DIR
        args.dir.configure = DEFAULT_CONFIGURE_DIR
        return args


In [4]:
skt_args=Configure.get_file_cfg("./cos_e_prefix.cfg")
t5_model = Model(skt_args)

prefix-tuning sequence length is 10.


In [5]:
#coding=utf-8
# from transformers import T5ForConditionalGeneration
from transformers.file_utils import ModelOutput
import torch
from modeling_transformer import MyTransformer, SemanticMatch
from torch import nn

device = torch.device("cuda:0")
class GenMC(nn.Module):
    def __init__(self, model_path, num_hidden_layers, alpha, beta):
        super(GenMC, self).__init__()
        self.alpha = alpha
        self.beta = beta

        device = torch.device("cuda:0")
        #self.prefix_model=prefix_model
        self.t5_model = T5ForConditionalGeneration.from_pretrained(model_path)
        dim = self.t5_model.config.d_model
        self.option_linear = nn.Linear(dim, 1).to(device)
        self.option_linear.device = device
        self.criterion = nn.CrossEntropyLoss()
        self.semantic_matching = SemanticMatch(dim, num_hidden_layers).to(device)
        self.semantic_matching.device = device
        num_attention_heads = dim // 64
        self.transformer_laryer_de = MyTransformer(dim, num_attention_heads, num_hidden_layers).to(device)
        self.transformer_laryer_de.device = device
        self.relu = nn.ReLU()

        n_gpu = torch.cuda.device_count()
        layer_num = self.t5_model.config.num_layers
        layer_per_gpu = layer_num // n_gpu
        device_map = {}
        for n in range(n_gpu):
            device_map[n] = [i + n * layer_per_gpu for i in range(layer_per_gpu)]
        remain_layer = [i + n_gpu * layer_per_gpu for i in range(layer_num - layer_per_gpu * n_gpu)]
        device_map[n_gpu - 1] += remain_layer
        #self.t5_model.parallelize(device_map)


    def forward(self, q_ids, q_mask, qo_ids, qo_mask, choice_num, clue_ids=None, answers=None, rationale_ids=None):
        self.choice_num = choice_num
        if answers is not None and clue_ids is not None:
            opt_score, output_sequences = self.get_option_score(q_ids, q_mask, qo_ids, qo_mask)
            local_device = self.t5_model.device
            t5_output = self.t5_model(input_ids=q_ids.to(local_device), attention_mask=q_mask.to(local_device),
                                      labels=clue_ids.to(local_device))
            loss_ans = t5_output.loss
            loss = self.criterion(opt_score, answers)
            
#             loss_rationale = self.prefix_model(input_ids=q_ids.to(local_device), attention_mask=q_mask.to(local_device),
#                                       labels=rationale_ids.to(local_device))
            return self.alpha * loss + self.beta * loss_ans
        else:
            opt_score, output_sequences = self.get_option_score(q_ids, q_mask, qo_ids, qo_mask)
            return opt_score, output_sequences

    def get_option_score(self, q_ids, q_mask, qo_ids, qo_mask):
        local_device = self.t5_model.encoder.device
        t5_output = self.t5_model.encoder(input_ids=qo_ids.to(local_device), attention_mask=qo_mask.to(local_device))
        encoder_qo = t5_output[0]

        t5_output = self.t5_model.encoder(input_ids=q_ids.to(local_device), attention_mask=q_mask.to(local_device))
        encoder_q = t5_output[0]
        local_device = self.t5_model.device
        t5_output = self.t5_model.generate(
            encoder_outputs=ModelOutput(last_hidden_state=encoder_q.to(local_device)),
            attention_mask=q_mask.to(local_device),
            do_sample=False,
            output_hidden_states=True,
            return_dict_in_generate=True
        )
        output_sequences = t5_output.sequences
        output_sequences = output_sequences[:, 1:].contiguous()
        decoder_o = t5_output.decoder_hidden_states
        decoder_o = [item[-1] for item in decoder_o]
        decoder_o = torch.cat(decoder_o, dim=1)

        output_sequences_mask1 = output_sequences != 0
        output_sequences_mask2 = output_sequences != 1
        output_sequences_mask = output_sequences_mask1 * output_sequences_mask2
        output_sequences_mask = output_sequences_mask.long()
        decoder_qo = torch.cat([encoder_q, decoder_o], dim=1)
        output_sequences_mask = torch.cat([q_mask, output_sequences_mask], dim=1)
        local_device = self.transformer_laryer_de.device
        decoder_qo, _ = self.transformer_laryer_de(decoder_qo.to(local_device), output_sequences_mask.to(local_device))
        output_sequences_mask_ex = output_sequences_mask.unsqueeze(dim=1)
        output_sequences_mask_ex = output_sequences_mask_ex.expand(
            [output_sequences_mask_ex.size(0), self.choice_num, output_sequences_mask_ex.size(-1)]).contiguous()
        output_sequences_mask_ex = output_sequences_mask_ex.view(-1, output_sequences_mask.size(-1))
        decoder_qo = decoder_qo.unsqueeze(dim=1)
        decoder_qo = decoder_qo.expand(
            [decoder_qo.size(0), self.choice_num, decoder_qo.size(-2), decoder_qo.size(-1)]).contiguous()
        decoder_qo = decoder_qo.view(-1, decoder_qo.size(-2), decoder_qo.size(-1))
        local_device = self.semantic_matching.device
        semantic_vec, _, _ = self.semantic_matching(encoder_qo.to(local_device), decoder_qo.to(local_device),
                                                    qo_mask.to(local_device), output_sequences_mask_ex.to(local_device))
        local_device = self.option_linear.device
        opt_score = self.option_linear(semantic_vec.to(local_device)).view(-1, self.choice_num)

        return opt_score, output_sequences

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [6]:
import argparse
# if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model_path",
                    default='t5-base',
                    required=True,
                    type=str)
parser.add_argument("--choice_num",
                    default=5,
                    type=int)
parser.add_argument("--data_path_train",
                    default='./data/csqa/in_hourse/train.jsonl',
                    required=True,
                    type=str)
parser.add_argument("--data_path_dev",
                    default='./data/csqa/in_hourse/dev.jsonl',
                    required=True,
                    type=str)
parser.add_argument("--data_path_test",
                    default='./data/csqa/in_hourse/test.jsonl',
                    required=True,
                    type=str)
parser.add_argument("--results_save_path",
                    default='./results/',
                    type=str)
parser.add_argument("--train_batch_size",
                    default=64,
                    type=int,
                    help="Total batch size for training.")
parser.add_argument("--eval_batch_size",
                    default=8,
                    type=int,
                    help="Total batch size for eval.")
parser.add_argument('--gradient_accumulation_steps',
                    type=int,
                    default=4,
                    help="Number of updates steps to accumulate before performing a backward/update pass.")

parser.add_argument("--output_dir",
                    default='./outputs/',
                    type=str,
                    help="The output dreader2ctory whretriever the model checkpoints will be written.")
parser.add_argument("--init_checkpoint",
                    default=None,
                    type=str,
                    help="Initial checkpoint (usually from a pre-trained BERT model)")
parser.add_argument("--max_len",
                    default=64,
                    type=int,
                    help="The maximum total input sequence length after WordPiece tokenization. \n"
                         "Sequences longer than this will be truncated, and sequences shorter \n"
                         "than this will be padded.")
parser.add_argument("--max_len_gen",
                    default=32,
                    type=int,
                    help="The maximum total output sequence length for decoder")
parser.add_argument("--lr",
                    default=1e-5,
                    type=float,
                    help="The initial learning rate for Adam.")
parser.add_argument("--epoch_num",
                    default=30,
                    type=int,
                    help="Total number of training epochs to perform.")
parser.add_argument('--num_hidden_layers',
                    type=int,
                    default=1,
                    help="The number of hidden layer for co-matching and encoder-decoder interaction transformer")
parser.add_argument('--alpha',
                    type=float,
                    default=1)
parser.add_argument('--beta',
                    type=float,
                    default=1)
parser.add_argument('--seed',
                    type=int,
                    default=1,
                    help="random seed for initialization")
parser.add_argument("--name_save_prix",
                    default='GenMC_CSQA',
                    type=str)
parser.add_argument('--external_sent_num',
                    type=int,
                    default=None,
                    help="The number of retrieved sentences")

args = parser.parse_args(["--model_path", "t5-base", "--choice_num", "5", 
                          "--data_path_train", "../GenMC/data/csqa/in_hourse/train.jsonl",  
                          "--data_path_dev", "../GenMC/data/csqa/in_hourse/dev.jsonl",  
                          "--data_path_test", "../GenMC/data/csqa/in_hourse/test.jsonl"])



In [7]:
model = GenMC(t5_model, args.num_hidden_layers, args.alpha, args.beta)


In [8]:
# coding=utf-8
from transformers import T5Tokenizer
from tqdm import trange
import os
import random
import torch
from utils import compute_rouges, save_dataset, read_dataset, set_seed, save_model
# from model.modeling_genmc import GenMC
import json
import argparse

device = torch.device("cuda:0")


def get_input_feature(samples, max_source_length, max_len_gen, choice_num, external_sent_num=None):
    sep = ' \\n '
    output_clue = []
    output_rationale = []
    answers = []
    input_ids_q, attention_mask_q = [], []
    input_ids_qo, attention_mask_qo = [], []
    for sample in samples:
        if 'answerKey' in sample:
            answerKey = sample['answerKey']
        else:
            answerKey = "A"
        question = sample['question']['stem']
        while len(sample['question']['choices']) < choice_num:
            sample['question']['choices'].append({"text": "error", "para": "", "label":chr(ord('A')+len(sample)-1)})
        for o_i, (opt, opt_name) in enumerate(zip(sample['question']['choices'], 'ABCDEFGH'[:choice_num])):
            option = opt['text']
            content = ""
            if external_sent_num is not None and 'para' in opt:
                para = opt["para"]
                if isinstance(para, list):
                    if len(para) > external_sent_num:
                        para = para[:external_sent_num]
                    content = sep + " ".join(para)
                elif isinstance(para, str):
                    para = para.split(".")
                    if len(para) > external_sent_num:
                        para = para[:external_sent_num]
                    content = sep + " ".join(para)
                else:
                    print('lack retrieval')
                    # exit(0)
            input_ids_qo.append(question + sep + option + content)


        input_ids_q.append(question + sep)
        if answerKey in '123456':
            answer = ord(answerKey) - ord('1')
        else:
            answer = ord(answerKey) - ord('A')
        answers.append(answer)
        output_clue.append(sample['question']['choices'][answer]['text'])
        output_rationale.append(sample['cos-e'])

    def tokenizer_fun(input_ids, max_len):
        encoding = tokenizer(input_ids,
                             padding='longest',
                             max_length=max_len,
                             truncation=True,
                             return_tensors="pt")
        ids = encoding.input_ids.to(device)
        mask = encoding.attention_mask.to(device)
        return ids, mask

    q_ids, q_mask = tokenizer_fun(input_ids_q, max_source_length)
    qo_ids, qo_mask = tokenizer_fun(input_ids_qo, max_source_length)
    clue_ids, _ = tokenizer_fun(output_clue, max_len_gen)
    clue_ids = [
        [(label if label != tokenizer.pad_token_id else -100) for label in labels_example] for labels_example in
        clue_ids
    ]
    clue_ids = torch.tensor(clue_ids, dtype=torch.long).to(device)
    answers = torch.tensor(answers, dtype=torch.long).to(device)
    
    rationale_ids, _ = tokenizer_fun(output_rationale, max_len_gen)
#     rationale_ids = [
#         [(label if label != tokenizer.pad_token_id else -100) for label in labels_example] for labels_example in
#         rationale_ids
#     ]
#     rationale_ids = torch.tensor(rationale_ids, dtype=torch.long).to(device)
    
    return q_ids, q_mask, qo_ids, qo_mask, clue_ids, answers, rationale_ids, output_clue, output_rationale


@torch.no_grad()
def eval(model, test_examples, tokenizer, eval_batch_size, choice_num, max_len, max_len_gen, external_sent_num):
    count, count_right = 0, 0
    results = []
    model.eval()
    step_count = len(test_examples) // eval_batch_size
    if step_count * eval_batch_size < len(test_examples):
        step_count += 1
    step_trange = trange(step_count)
    sources, targets = [], []
    for step in step_trange:
        beg_index = step * eval_batch_size
        end_index = min((step + 1) * eval_batch_size, len(test_examples))
        batch_example = [example for example in test_examples[beg_index:end_index]]
        q_ids, q_mask, qo_ids, qo_mask, clue_ids, answers, output_clue = get_input_feature(batch_example,
                                                                                           max_len, max_len_gen,
                                                                                           args.choice_num,
                                                                                           external_sent_num)
        scores, output_sequences = model(q_ids, q_mask, qo_ids, qo_mask, choice_num)

        scores = scores.cpu().detach().tolist()
        answers = answers.cpu().detach().tolist()
        p_anss = []
        for p, a, example in zip(scores, answers, batch_example):
            p_ans = p.index(max(p))
            p_anss.append(example['question']['choices'][p_ans]['label'])
            if p_ans == a:
                count_right += 1
            count += 1
        for sample, p_ans in zip(batch_example, p_anss):
            qid = sample['id']
            results.append(qid + "," + p_ans)
        predicts = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
        sources += predicts
        targets += output_clue

    rouge_score = compute_rouges(sources, targets)['rouge-l']

    return count_right / count, rouge_score, results


os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"



In [9]:
file_name = f'lr_{args.lr}_seed_{args.seed}_bs_{args.train_batch_size}_ga_{args.gradient_accumulation_steps}_layer_num_{args.num_hidden_layers}_alpha_{args.alpha}_beta_{args.beta}'
output_model_path = './outputs/' + args.name_save_prix + '/' + file_name + "/"
path_save_result = './results/' + args.name_save_prix + '/' + file_name + "/"

os.makedirs(path_save_result, exist_ok=True)
set_seed(args.seed)
train_examples = read_dataset(args.data_path_train)
dev_examples = read_dataset(args.data_path_dev)
test_examples = read_dataset(args.data_path_test)

train_examples = train_examples + dev_examples
dev_examples = test_examples
test_examples = test_examples

In [21]:
# len(train_examples), train_examples[0].keys(), \\
train_examples[0]['positives'],train_examples[0]['negatives'],

(['Person who commits murder will have to undergo legal punishment.',
  'No person is mortal in this world.',
  'If  he or she is not caught, evenually he or she will have his own death.'],
 ['Ocean is a natures creation, person would not go to ocean eventually',
  'Fear is of getting caught, person will never feel fearful once he is not caught',
  'There is no point of imprisonment if person is not caught',
  'There is no point of incarceration if person is not caught'],
 "Person committing muder will have to undergo legal punishment. Although he is not caught, eventually he will have his own death at some point of time as no one is mortal. There is no question of imprisonment or incarceration if person is not caught for muder. Also person will not be fearful now as he is not caught. Ocean is a nature's creation and person would not got to ocean eventually.")

In [22]:
train_examples[0]['explanation'], train_examples[0]['cos-e']

("Person committing muder will have to undergo legal punishment. Although he is not caught, eventually he will have his own death at some point of time as no one is mortal. There is no question of imprisonment or incarceration if person is not caught for muder. Also person will not be fearful now as he is not caught. Ocean is a nature's creation and person would not got to ocean eventually.",
 'death is something that occurs to everyone and is the only certainty in life.')

In [10]:



print(json.dumps({"lr": args.lr, "model": args.model_path, "seed": args.seed,
                  "bs": args.train_batch_size,
                  'gradient_accumulation_steps': args.gradient_accumulation_steps,
                  "epoch": args.epoch_num,
                  "train_path": args.data_path_train,
                  "dev_path": args.data_path_dev,
                  "test_path": args.data_path_test,
                  "train_size": len(train_examples),
                  "dev_size": len(dev_examples),
                  "test_size": len(test_examples),
                  'num_hidden_layers': args.num_hidden_layers,
                  'external_sent_num': args.external_sent_num,
                  "alpha": args.alpha, "beta": args.beta}, indent=2))

train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
tokenizer = T5Tokenizer.from_pretrained(args.model_path)
# model = GenMC(args.model_path, args.num_hidden_layers, args.alpha, args.beta)

if args.init_checkpoint is not None:
    checkpoint = torch.load(args.init_checkpoint, map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'])
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)

step_count, step_all, early_stop = 0, 0, 0
best_dev_rouge_score, best_test_rouge_score = 0, 0
tr_loss, nb_tr_steps = 0, 0

best_dev_acc, _, _ = eval(model, dev_examples, tokenizer, args.eval_batch_size, args.choice_num, args.max_len,
                          args.max_len_gen, args.external_sent_num)
print('best_dev_acc:',best_dev_acc)
best_test_acc = 0
for epoch in range(args.epoch_num):
    early_stop += 1
    order = list(range(len(train_examples)))
    random.seed(args.seed + epoch)
    random.shuffle(order)
    model.train()
    step_count = len(train_examples) // train_batch_size
    if step_count * train_batch_size < len(train_examples):
        step_count += 1
    step_trange = trange(step_count)
    for step in step_trange:
        step_all += 1
        beg_index = step * train_batch_size
        end_index = min((step + 1) * train_batch_size, len(train_examples))
        order_index = order[beg_index:end_index]
        batch_example = [train_examples[index] for index in order_index]
        q_ids, q_mask, qo_ids, qo_mask, clue_ids, answers, output_clue, rationale_ids, output_rationale = get_input_feature(
            batch_example,
            max_source_length=args.max_len,
            max_len_gen=args.max_len_gen,
            choice_num=args.choice_num,
            external_sent_num=args.external_sent_num)
        loss = model(q_ids, q_mask, qo_ids, qo_mask, args.choice_num, clue_ids, answers,rationale_ids)

        loss = loss.mean()
        tr_loss += loss.item()
        nb_tr_steps += 1
        loss = loss / args.gradient_accumulation_steps
        loss.backward()
        if (step + 1) % args.gradient_accumulation_steps == 0:
            optimizer.step()
            # scheduler.step()
            optimizer.zero_grad()

        loss_show = ' Epoch:' + str(epoch) + " loss:" + str(round(tr_loss / nb_tr_steps, 4))
        step_trange.set_postfix_str(loss_show)

    dev_acc, dev_rouge_score, results_dev = eval(model, dev_examples, tokenizer, args.eval_batch_size,
                                                 args.choice_num, args.max_len, args.max_len_gen,
                                                 args.external_sent_num)
    print('dev_acc:', dev_acc)
    if dev_acc > best_dev_acc:
        save_dataset(path_save_result + '/dev.csv', results_dev)
        early_stop = 0
        test_acc, test_rouge_score, results_test = eval(model, test_examples, tokenizer, args.eval_batch_size,
                                                        args.choice_num, args.max_len, args.max_len_gen,
                                                        args.external_sent_num)
        save_dataset(path_save_result + '/test.csv', results_test)
        best_dev_acc, best_test_acc, best_dev_rouge_score, best_test_rouge_score = dev_acc, test_acc, dev_rouge_score, test_rouge_score

        # save_model(output_model_path, model, optimizer)
        print('new best dev acc:', dev_acc, 'test_acc:', test_acc, 'rouge:', dev_rouge_score)

    if early_stop >= 5:
        break

print('best dev acc:', best_dev_acc, 'best_test_acc:', best_test_acc,
      'best_dev_rouge_score:', best_dev_rouge_score, 'best_test_rouge_score:', best_test_rouge_score)


{
  "lr": 1e-05,
  "model": "t5-base",
  "seed": 1,
  "bs": 64,
  "gradient_accumulation_steps": 4,
  "epoch": 30,
  "train_path": "../GenMC/data/csqa/in_hourse/train.jsonl",
  "dev_path": "../GenMC/data/csqa/in_hourse/dev.jsonl",
  "test_path": "../GenMC/data/csqa/in_hourse/test.jsonl",
  "train_size": 9741,
  "dev_size": 1221,
  "test_size": 1221,
  "num_hidden_layers": 1,
  "external_sent_num": null,
  "alpha": 1,
  "beta": 1
}


  0%|                                                                                                                                                                               | 0/153 [00:00<?, ?it/s]


AttributeError: 'Model' object has no attribute 'encoder'

In [24]:
t5_model.pretrain_model.encoder

T5Stack(
  (embed_tokens): Embedding(32102, 768)
  (block): ModuleList(
    (0): T5Block(
      (layer): ModuleList(
        (0): T5LayerSelfAttention(
          (SelfAttention): T5Attention(
            (q): Linear(in_features=768, out_features=768, bias=False)
            (k): Linear(in_features=768, out_features=768, bias=False)
            (v): Linear(in_features=768, out_features=768, bias=False)
            (o): Linear(in_features=768, out_features=768, bias=False)
            (relative_attention_bias): Embedding(32, 12)
          )
          (layer_norm): T5LayerNorm()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (1): T5LayerFF(
          (DenseReluDense): T5DenseReluDense(
            (wi): Linear(in_features=768, out_features=3072, bias=False)
            (wo): Linear(in_features=3072, out_features=768, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (layer_norm): T5LayerNorm()
          (dropout): Dropout(p=0.1, i