In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import io
import json
import os
from pathlib import Path
from pprint import pprint
import requests
import sys
from typing import Optional

if '..' not in sys.path: sys.path.append('..')

import torch
from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions

from mllm.model.embgen_bert import EncoderEmbDecoderModel


# BERT Generator model
## Preload Encoder, Decoder and Tokenizer

In [3]:
model_name = 'google-bert/bert-base-uncased'

# leverage checkpoints for Bert2Bert model...
# use BERT's cls token as BOS token and sep token as EOS token
encoder: BertGenerationEncoder = BertGenerationEncoder.from_pretrained(model_name, bos_token_id=101, eos_token_id=102)
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
decoder: BertGenerationDecoder = BertGenerationDecoder.from_pretrained(
    model_name, add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102
)
tokenizer = BertTokenizer.from_pretrained(model_name)

You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
Some weights of BertGenerationDecoder were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer

## EncoderDecoderModel

In [None]:
bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder)

# create tokenizer...

input_ids = tokenizer(
    'This is a long article to summarize', add_special_tokens=False, return_tensors='pt'
).input_ids
labels = tokenizer('This is a short summary', return_tensors='pt').input_ids

# train...
loss = bert2bert(input_ids=input_ids, decoder_input_ids=labels, labels=labels).loss
loss.backward()
input_ids



tensor([[2023, 2003, 1037, 2146, 3720, 2000, 7680, 7849, 4697]])

In [9]:
bert2bert.eval()
pred: Seq2SeqLMOutput = bert2bert(input_ids=input_ids, decoder_input_ids=labels, labels=labels)
input_ids.shape, labels.shape, pred.logits.shape



(torch.Size([1, 9]), torch.Size([1, 7]), torch.Size([1, 7, 30522]))

In [15]:
gen_toks = bert2bert.generate(input_ids)
gen_toks



tensor([[  101, 19782, 24756, 24756, 24756, 24756, 24756, 24756, 24756, 24756,
         24756,   101, 26036,   101, 10431,   101, 10431,   101, 10431,   101],
        [  101, 10431, 10431, 10431, 10431, 10431, 10431, 10431, 10431, 10431,
         10431, 10431, 10431, 10431, 10431, 10431, 10431, 10431, 10431, 10431]])

In [8]:
tokenizer.decode(gen_toks.squeeze().detach().cpu().numpy())

'[CLS] hui hui hui hui hui hui hui hui hui hui hui hui hui hui hui hui hui hui [CLS]'


## Bert Encoder and Decoder

In [43]:
tokenizer

BertTokenizer(name_or_path='google-bert/bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [105]:
s1 = ['How many words are needed to describe an ocean? Let\'s try to keep it short!', 'abc def']
toks1 = tokenizer(s1, return_tensors='pt', padding=True)
toks1

{'input_ids': tensor([[  101,  2129,  2116,  2616,  2024,  2734,  2000,  6235,  2019,  4153,
          1029,  2292,  1005,  1055,  3046,  2000,  2562,  2009,  2460,   999,
           102],
        [  101,  5925, 13366,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}

In [106]:
s2 = ['Ocean deep and blue. It\'s endless and charming. Not much words, I\'d say.', 'Once upon a time!']
toks2 = tokenizer(s2, return_tensors='pt', padding=True)
toks2

{'input_ids': tensor([[  101,  4153,  2784,  1998,  2630,  1012,  2009,  1005,  1055, 10866,
          1998, 11951,  1012,  2025,  2172,  2616,  1010,  1045,  1005,  1040,
          2360,  1012,   102],
        [  101,  2320,  2588,  1037,  2051,   999,   102,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}

In [107]:
enc_out: BaseModelOutputWithPastAndCrossAttentions = encoder(toks1.input_ids)
enc_out

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[-0.0525,  0.2416, -0.2183,  ..., -0.2659,  0.5905,  0.8504],
         [ 0.5397,  0.1460, -0.4024,  ...,  0.3013,  0.5229,  0.2699],
         [ 0.6455, -0.1244,  0.4163,  ..., -0.1054,  0.5575,  0.1057],
         ...,
         [ 0.7216, -0.1501, -0.1798,  ..., -0.1326,  0.2695, -0.0244],
         [ 0.1543,  0.5561,  0.2382,  ...,  0.4147,  0.2552,  0.2053],
         [ 0.8313,  0.1728, -0.4358,  ...,  0.1262, -0.4243, -0.3342]],

        [[-0.4835, -0.0392,  0.1173,  ..., -0.0122,  0.6183, -0.0138],
         [ 0.3884, -0.1305,  0.2819,  ...,  0.0032,  0.9948,  0.0298],
         [-0.0101, -0.4335,  0.4002,  ...,  0.3235,  0.4988, -0.4418],
         ...,
         [ 0.2725, -0.3133,  1.0570,  ..., -0.2669,  0.2520, -0.2070],
         [ 0.3274, -0.3323,  0.9303,  ..., -0.2656,  0.2818, -0.1892],
         [ 0.4019, -0.1761,  0.9125,  ..., -0.3159,  0.2786, -0.3931]]],
       grad_fn=<NativeLayerNormBackward0>), past_key_val

In [108]:
toks1.input_ids.shape, enc_out.last_hidden_state.shape

(torch.Size([2, 21]), torch.Size([2, 21, 768]))

In [109]:

encoder.config

BertGenerationConfig {
  "_name_or_path": "google-bert/bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 101,
  "eos_token_id": 102,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert-generation",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.42.4",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [110]:
torch.allclose(enc_out[0], enc_out.last_hidden_state)

True

In [111]:
dec_out: CausalLMOutputWithCrossAttentions = decoder(input_ids=toks2.input_ids, encoder_hidden_states=enc_out.last_hidden_state)

In [112]:
toks2.input_ids.shape, dec_out.logits.shape

(torch.Size([2, 23]), torch.Size([2, 23, 30522]))

In [113]:
bert2bert.config.tie_encoder_decoder

False

In [114]:
bert2bert.config

EncoderDecoderConfig {
  "decoder": {
    "_name_or_path": "google-bert/bert-base-uncased",
    "add_cross_attention": true,
    "architectures": [
      "BertForMaskedLM"
    ],
    "attention_probs_dropout_prob": 0.1,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": 101,
    "chunk_size_feed_forward": 0,
    "cross_attention_hidden_size": null,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "early_stopping": false,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": 102,
    "exponential_decay_length_penalty": null,
    "finetuning_task": null,
    "forced_bos_token_id": null,
    "forced_eos_token_id": null,
    "gradient_checkpointing": false,
    "hidden_act": "gelu",
    "hidden_dropout_prob": 0.1,
    "hidden_size": 768,
    "id2label": {
      "0": "LABEL_0",
      "1": "LABEL_1"
    },
    "initializer_range": 0.02,
    "intermediate_size": 3072,
    "is_decoder": true,
    "is_encoder_decod

In [115]:
decoder.get_output_embeddings()

Linear(in_features=768, out_features=30522, bias=True)

In [116]:
decoder.lm_head.decoder

Linear(in_features=768, out_features=30522, bias=True)

In [117]:
encoder.embeddings

BertGenerationEmbeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [118]:
encoder.embeddings.word_embeddings

Embedding(30522, 768, padding_idx=0)

In [119]:
encoder.config.tie_word_embeddings

True

In [120]:
from transformers.models.bert_generation.modeling_bert_generation import BertGenerationEmbeddings

In [121]:
decoder.get_output_embeddings()

Linear(in_features=768, out_features=30522, bias=True)

In [122]:
decoder.get_input_embeddings()

Embedding(30522, 768, padding_idx=0)

In [123]:
torch.allclose(decoder.get_input_embeddings().weight, decoder.get_output_embeddings().weight)

True

In [124]:
torch.allclose(encoder.get_input_embeddings().weight, decoder.get_output_embeddings().weight)

True

In [125]:
id(decoder.get_input_embeddings().weight) == id(decoder.get_output_embeddings().weight)

True

In [126]:
id(encoder.get_input_embeddings().weight) == id(decoder.get_output_embeddings().weight)

False

In [127]:
id(bert2bert.get_input_embeddings().weight) == id(bert2bert.get_output_embeddings().weight)

False

## Custom override of EncoderDecoderModel

In [130]:
enc_out.last_hidden_state.shape, enc_out.last_hidden_state[:, 0].unsqueeze(0).shape

(torch.Size([2, 21, 768]), torch.Size([1, 2, 768]))

In [134]:
enc_model: BertGenerationEncoder = BertGenerationEncoder.from_pretrained(model_name, bos_token_id=101, eos_token_id=102)
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
dec_model: BertGenerationDecoder = BertGenerationDecoder.from_pretrained(
    model_name, add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102
)
tokenizer = BertTokenizer.from_pretrained(model_name)
eed_model = EncoderEmbDecoderModel(encoder=enc_model, decoder=dec_model)

loading configuration file config.json from cache at /home/misha/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594/config.json
You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
Model config BertGenerationConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 101,
  "eos_token_id": 102,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert-generation",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.42.4",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 3

In [135]:
s1 = ['How many words are needed to describe an ocean? Let\'s try to keep it short!', 'abc def']
toks1 = tokenizer(s1, return_tensors='pt', padding=True)
toks1

{'input_ids': tensor([[  101,  2129,  2116,  2616,  2024,  2734,  2000,  6235,  2019,  4153,
          1029,  2292,  1005,  1055,  3046,  2000,  2562,  2009,  2460,   999,
           102],
        [  101,  5925, 13366,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}

In [146]:
s2 = 'Ocean deep and blue'
toks2 = tokenizer(s2, return_tensors='pt', padding=True)
toks2

{'input_ids': tensor([[ 101, 4153, 2784, 1998, 2630,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [147]:
eed_out: Seq2SeqLMOutput = eed_model(input_ids=toks1.input_ids, decoder_input_ids=toks2.input_ids, labels=toks2.input_ids)
print(toks1.input_ids.shape, toks2.input_ids.shape, eed_out.logits.shape)

torch.Size([2, 21]) torch.Size([1, 6]) torch.Size([1, 6, 30522])
