In [1]:
from transformers import EncoderDecoderModel, BertTokenizer, PreTrainedModel
import pytorch_lightning as pl

In [2]:
import logging
from typing import Optional

from transformers.configuration_encoder_decoder import EncoderDecoderConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel


logger = logging.getLogger(__name__)

In [3]:
from src.models.modeling_ner import ModelForNERBase

In [4]:
class BaseWithPL(PreTrainedModel, pl.LightningModule):
    pass


In [5]:
class EncoderDecoderForNER(ModelForNERBase, BaseWithPL):

    def __init__(self, config=None, encoder=None, decoder=None, hparams=None):
        
        assert config is not None or (
            encoder is not None and decoder is not None
        ), "Either a configuration or an Encoder and a decoder has to be provided"
        if config is None:
            config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
        else:
            assert isinstance(config, self.config_class), "config: {} has to be of type {}".format(
                config, self.config_class
            )
        # initialize with config
        super(BaseWithPL, self).__init__(config)

        if encoder is None:
            from transformers import AutoModel

            encoder = AutoModel.from_config(config.encoder)

        if decoder is None:
            from transformers import AutoModelWithLMHead

            decoder = AutoModelWithLMHead.from_config(config.decoder)

        self.encoder = encoder
        self.decoder = decoder
        assert (
            self.encoder.get_output_embeddings() is None
        ), "The encoder {} should not have a LM Head. Please use a model without LM Head"


        self.hparams = hparams

        self.tokenizer = self.get_tokenizer()
        # creating the loss
#         self._token_weights = self._create_token_weights()

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None
        
    def tie_weights(self):
        # for now no weights tying in encoder-decoder
        pass

    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

    def get_input_embeddings(self):
        return self.encoder.get_input_embeddings()


    def get_output_embeddings(self):
        return self.decoder.get_output_embeddings()


    @classmethod
    def from_encoder_decoder_pretrained(
        cls,
        encoder_pretrained_model_name_or_path: str = None,
        decoder_pretrained_model_name_or_path: str = None,
        hparams=None,
        *model_args,
        **kwargs
    ) -> PreTrainedModel:

        kwargs_encoder = {
            argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
        }

        kwargs_decoder = {
            argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
        }

        # Load and initialize the encoder and decoder
        # The distinction between encoder and decoder at the model level is made
        # by the value of the flag `is_decoder` that we need to set correctly.
        encoder = kwargs_encoder.pop("model", None)
        if encoder is None:
            assert (
                encoder_pretrained_model_name_or_path is not None
            ), "If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined"
            from transformers.modeling_auto import AutoModel

            encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
        encoder.config.is_decoder = False

        decoder = kwargs_decoder.pop("model", None)
        if decoder is None:
            assert (
                decoder_pretrained_model_name_or_path is not None
            ), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
            from transformers.modeling_auto import AutoModelWithLMHead

            if "config" not in kwargs_decoder:
                from transformers import AutoConfig

                decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
                if decoder_config.is_decoder is False:
                    logger.info(
                        f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
                    )
                    decoder_config.is_decoder = True

                kwargs_decoder["config"] = decoder_config

            if kwargs_decoder["config"].is_decoder is False:
                logger.warning(
                    f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attribute `is_decoder` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` is set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
                )

            decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)

        return cls(encoder=encoder, decoder=decoder, hparams=hparams)


    def forward(
        self,
        input_ids=None,
        inputs_embeds=None,
        attention_mask=None,
        head_mask=None,
        encoder_outputs=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        decoder_head_mask=None,
        decoder_inputs_embeds=None,
        labels=None,
        lm_labels=None,
        **kwargs,
    ):

        kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}

        kwargs_decoder = {
            argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
        }

        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                **kwargs_encoder,
            )

        hidden_states = encoder_outputs[0]

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            inputs_embeds=decoder_inputs_embeds,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            lm_labels=lm_labels,
#             labels=labels,
            **kwargs_decoder,
        )

        return decoder_outputs + encoder_outputs


    def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):
        assert past is not None, "past has to be defined for encoder_outputs"

        # first step
        if type(past) is tuple:
            encoder_outputs = past
        else:
            encoder_outputs = (past,)

        decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)

        return {
            "attention_mask": attention_mask,
            "decoder_attention_mask": decoder_inputs["attention_mask"],
            "decoder_input_ids": decoder_inputs["input_ids"],
            "encoder_outputs": encoder_outputs,
        }

    def _reorder_cache(self, past, beam_idx):
        # as a default encoder-decoder models do not re-order the past.
        # TODO(PVP): might have to be updated, e.g. if GPT2 is to be used as a decoder
        return past
        
    def _handle_batch(self, batch):
        batch = self.trim_batch(batch)
        input_ids, attention_mask, lm_labels = batch
        outputs = self(input_ids=input_ids,
                       decoder_input_ids=input_ids,
                       attention_mask=attention_mask,
                       lm_labels=lm_labels)
        return outputs
    
    @staticmethod
    def trim_matrix(mat, value):
        eq_val = (mat == value).float()
        eq_val = eq_val.cumsum(-1)
        index = torch.nonzero(eq_val == 1.)
        if len(index) and len(index) == len(mat):
            index = index[:, 1].max().item()
        else:
            index = mat.shape[-1]
        return index
    
    def trim_batch(self, batch):
        input_ids, attention_mask, lm_labels = batch
        input_ids_index = self.trim_matrix(input_ids, self.config.encoder.pad_token_id)
        lm_labels_index = self.trim_matrix(lm_labels, -100)
        
        index = max(input_ids_index, lm_labels_index)
        
        attention_mask = attention_mask[:, :index]
        input_ids = input_ids[:, :index]
        lm_labels = lm_labels[:, :index]
        return input_ids, attention_mask, lm_labels
        
    def get_tokenizer(self,):
        pretrained_model = self.get_value_or_default_hparam(
            'pretrained_model', 'bert-base-cased')
        return BertTokenizer.from_pretrained(pretrained_model)

# Conll2003

In [17]:
from argparse import Namespace
import torch

In [7]:
from src.models.modeling_conll2003 import Conll2003Base

In [8]:
class EncoderDecoderForConll2003(Conll2003Base, EncoderDecoderForNER):
    
    def get_tokenizer(self,):
        tokenizer = super().get_tokenizer()
#         tokenizer.add_tokens(self.entities_tokens)
        return tokenizer
#     pass

In [9]:
hparams = {
    "end_token": 'sep',
    'pretrained_model': 'bert-base-uncased',
    'labels_mode': 'words',
    'merge_O': True
}
hparams = Namespace(**hparams)

In [10]:
model = EncoderDecoderForConll2003.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased', hparams=hparams) # initialize Bert2Bert

In [11]:
model.prepare_data()

In [12]:
examples = model.get_examples()

In [13]:
examples['train'][0]

Source: EU rejects German call to boycott British lamb .
Target: EU [Organization] rejects [Other] German [Miscellaneous] call to boycott [Other] British [Miscellaneous] lamb . [Other]

In [14]:
batch = next(iter(model.train_dataloader()))

In [15]:
input_ids, attention_mask, lm_labels = batch

In [18]:
model._handle_batch(batch)

	nonzero(Tensor input, *, Tensor out)
Consider using one of the following signatures instead:
	nonzero(Tensor input, *, bool as_tuple)


(tensor(11.1760, grad_fn=<NllLossBackward>),
 tensor([[[-4.2656, -4.5210, -4.5161,  ..., -4.8522, -5.3251, -5.5572],
          [-2.4787, -2.8884, -2.7178,  ..., -3.6082, -3.3271, -6.9371],
          [-7.8276, -7.6111, -7.9250,  ..., -6.8847, -8.1065, -2.6276],
          ...,
          [-4.9362, -5.9066, -5.5162,  ..., -7.5794, -5.6245, -5.0047],
          [-4.2324, -5.2047, -4.8554,  ..., -6.9375, -5.1227, -3.6019],
          [-2.9852, -3.9339, -3.6173,  ..., -6.0910, -3.8285, -2.0919]],
 
         [[-6.5920, -6.4046, -6.6076,  ..., -6.9532, -6.5586, -1.2814],
          [-5.3631, -5.4862, -5.3050,  ..., -6.1829, -5.0344, -3.2574],
          [-8.7467, -8.8679, -8.9009,  ..., -8.8365, -8.4121, -2.8041],
          ...,
          [-3.8576, -4.6478, -4.2220,  ..., -6.9258, -5.2491, -0.6430],
          [-3.3311, -4.0553, -3.6056,  ..., -5.8988, -4.6884, -0.2115],
          [-2.2159, -2.9656, -2.5046,  ..., -4.7028, -3.5162,  0.3270]]],
        grad_fn=<AddBackward0>),
 tensor([[[-0.3623, -0.

In [None]:
outs = model(input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=input_ids, lm_labels=lm_labels)

In [None]:
outs[0]

In [None]:
lm_labels

In [None]:
model.tokenizer.convert_ids_to_tokens(1031)