In [1]:
from transformers import PreTrainedModel

In [3]:
from transformers import AutoModel, AutoTokenizer, MT5Tokenizer, GPT2LMHeadModel, AutoConfig

ENCODER_MODEL = "bert-base-multilingual-cased"
DECODER_MODEL = "THUMT/mGPT"

mBERT = AutoModel.from_pretrained(ENCODER_MODEL)
mBERT_tokenizer = AutoTokenizer.from_pretrained(ENCODER_MODEL)

decoder_config = AutoConfig.from_pretrained(DECODER_MODEL)
mGPT = AutoModel.from_pretrained(DECODER_MODEL, config=decoder_config)
mGPT_tokenizer = MT5Tokenizer.from_pretrained(DECODER_MODEL)

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at THUMT/mGPT were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are i

In [4]:
mGPT

GPT2Model(
  (wte): Embedding(250100, 1024)
  (wpe): Embedding(1024, 1024)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0): GPT2Block(
      (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (1): GPT2Block(
      (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (mlp): 

In [5]:
from transformers.models.bart.modeling_bart import BartEncoderLayer, BartDecoderLayer

In [None]:
class Grafomer(PreTrainedModel):
  def __init__(self, config):
    super().__init__()
    
    self.encoder = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
    self.decoder = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
    
    self.init_weights()
  def forward(
        self,
        encoder_hidden_states: torch.Tensor,
        encoder_attention_mask: torch.Tensor,
        decoder_hidden_states=None,
        decoder_attention_mask=None,
        output_attentions: bool = False,
        cross_attn_head_mask=None,
        use_cache=None,
        head_mask=None,
    ):

    for idx, encoder_layer in enumerate(self.encoder):
      encoder_layer_outputs = encoder_layer(
              encoder_hidden_states,
              encoder_attention_mask,
              layer_head_mask=(head_mask[idx] if head_mask is not None else None),
              output_attentions=output_attentions,
          )

      encoder_hidden_states = encoder_layer_outputs[0]    
    
    for idx, decoder_layer in enumerate(self.decoder):
      decoder_layer_outputs = decoder_layer(
          decoder_hidden_states,
          attention_mask=decoder_attention_mask,
          encoder_hidden_states=encoder_hidden_states,
          encoder_attention_mask=encoder_attention_mask,
          layer_head_mask=(head_mask[idx] if head_mask is not None else None),
          cross_attn_layer_head_mask=(
              cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
          ),
          past_key_value=None,
          output_attentions=output_attentions,
          use_cache=use_cache,
      )
      decoder_hidden_states = decoder_layer_outputs[0]

    return decoder_hidden_states

In [None]:
import torch
from torch import nn

class GrafomerModel(PreTrainedModel):
  def __init__(self, config):
    super().__init__()
    
    self.encoder = mBERT
    self.decoder = mGPT
    self.graformer = Grafomer(config)
    
    self.init_weights()
  def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
    
    encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
    encoder_hidden_state = encoder_outputs[0]
    
    decoder_outputs = self.decoder(
            input_ids=decoder_input_ids
        )
    decoder_hidden_state = decoder_outputs[0]

    graformer_hidden_state = self.graformer(
      encoder_hidden_states=encoder_hidden_state,
      encoder_attention_mask=attention_mask,
      decoder_hidden_states=decoder_hidden_state,
      output_attentions=output_attentions
    )
    
    output_hidden_states = decoder_hidden_state + graformer_hidden_state

    return output_hidden_states