In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
from dataclasses import dataclass
import io
import json
import os
from pathlib import Path
from pprint import pprint
import requests
import sys
from typing import Optional

if '..' not in sys.path: sys.path.append('..')

from datasets import load_dataset
import pandas as pd
import torch
from torch import nn
from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, BertTokenizer, AutoTokenizer
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions

from mllm.model.embgen_bert import EncEmbExpansionType, EncoderEmbDecoderModel
from mllm.data.qna import get_hotpotqa



# BERT Generator model
## Preload Encoder, Decoder and Tokenizer

In [3]:
model_name = 'google-bert/bert-base-uncased'

# leverage checkpoints for Bert2Bert model...
# use BERT's cls token as BOS token and sep token as EOS token
encoder: BertGenerationEncoder = BertGenerationEncoder.from_pretrained(model_name, bos_token_id=101, eos_token_id=102)
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
decoder: BertGenerationDecoder = BertGenerationDecoder.from_pretrained(
    model_name, add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102
)
tokenizer = BertTokenizer.from_pretrained(model_name)

You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
Some weights of BertGenerationDecoder were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer

## EncoderDecoderModel

In [None]:
bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder)

# create tokenizer...

input_ids = tokenizer(
    'This is a long article to summarize', add_special_tokens=False, return_tensors='pt'
).input_ids
labels = tokenizer('This is a short summary', return_tensors='pt').input_ids

# train...
loss = bert2bert(input_ids=input_ids, decoder_input_ids=labels, labels=labels).loss
loss.backward()
input_ids



tensor([[2023, 2003, 1037, 2146, 3720, 2000, 7680, 7849, 4697]])

In [9]:
bert2bert.eval()
pred: Seq2SeqLMOutput = bert2bert(input_ids=input_ids, decoder_input_ids=labels, labels=labels)
input_ids.shape, labels.shape, pred.logits.shape



(torch.Size([1, 9]), torch.Size([1, 7]), torch.Size([1, 7, 30522]))

In [15]:
gen_toks = bert2bert.generate(input_ids)
gen_toks



tensor([[  101, 19782, 24756, 24756, 24756, 24756, 24756, 24756, 24756, 24756,
         24756,   101, 26036,   101, 10431,   101, 10431,   101, 10431,   101],
        [  101, 10431, 10431, 10431, 10431, 10431, 10431, 10431, 10431, 10431,
         10431, 10431, 10431, 10431, 10431, 10431, 10431, 10431, 10431, 10431]])

In [8]:
tokenizer.decode(gen_toks.squeeze().detach().cpu().numpy())

'[CLS] hui hui hui hui hui hui hui hui hui hui hui hui hui hui hui hui hui hui [CLS]'


## Bert Encoder and Decoder

In [43]:
tokenizer

BertTokenizer(name_or_path='google-bert/bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [105]:
s1 = ['How many words are needed to describe an ocean? Let\'s try to keep it short!', 'abc def']
toks1 = tokenizer(s1, return_tensors='pt', padding=True)
toks1

{'input_ids': tensor([[  101,  2129,  2116,  2616,  2024,  2734,  2000,  6235,  2019,  4153,
          1029,  2292,  1005,  1055,  3046,  2000,  2562,  2009,  2460,   999,
           102],
        [  101,  5925, 13366,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}

In [106]:
s2 = ['Ocean deep and blue. It\'s endless and charming. Not much words, I\'d say.', 'Once upon a time!']
toks2 = tokenizer(s2, return_tensors='pt', padding=True)
toks2

{'input_ids': tensor([[  101,  4153,  2784,  1998,  2630,  1012,  2009,  1005,  1055, 10866,
          1998, 11951,  1012,  2025,  2172,  2616,  1010,  1045,  1005,  1040,
          2360,  1012,   102],
        [  101,  2320,  2588,  1037,  2051,   999,   102,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}

In [107]:
enc_out: BaseModelOutputWithPastAndCrossAttentions = encoder(toks1.input_ids)
enc_out

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[-0.0525,  0.2416, -0.2183,  ..., -0.2659,  0.5905,  0.8504],
         [ 0.5397,  0.1460, -0.4024,  ...,  0.3013,  0.5229,  0.2699],
         [ 0.6455, -0.1244,  0.4163,  ..., -0.1054,  0.5575,  0.1057],
         ...,
         [ 0.7216, -0.1501, -0.1798,  ..., -0.1326,  0.2695, -0.0244],
         [ 0.1543,  0.5561,  0.2382,  ...,  0.4147,  0.2552,  0.2053],
         [ 0.8313,  0.1728, -0.4358,  ...,  0.1262, -0.4243, -0.3342]],

        [[-0.4835, -0.0392,  0.1173,  ..., -0.0122,  0.6183, -0.0138],
         [ 0.3884, -0.1305,  0.2819,  ...,  0.0032,  0.9948,  0.0298],
         [-0.0101, -0.4335,  0.4002,  ...,  0.3235,  0.4988, -0.4418],
         ...,
         [ 0.2725, -0.3133,  1.0570,  ..., -0.2669,  0.2520, -0.2070],
         [ 0.3274, -0.3323,  0.9303,  ..., -0.2656,  0.2818, -0.1892],
         [ 0.4019, -0.1761,  0.9125,  ..., -0.3159,  0.2786, -0.3931]]],
       grad_fn=<NativeLayerNormBackward0>), past_key_val

In [108]:
toks1.input_ids.shape, enc_out.last_hidden_state.shape

(torch.Size([2, 21]), torch.Size([2, 21, 768]))

In [109]:

encoder.config

BertGenerationConfig {
  "_name_or_path": "google-bert/bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 101,
  "eos_token_id": 102,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert-generation",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.42.4",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [110]:
torch.allclose(enc_out[0], enc_out.last_hidden_state)

True

In [111]:
dec_out: CausalLMOutputWithCrossAttentions = decoder(input_ids=toks2.input_ids, encoder_hidden_states=enc_out.last_hidden_state)

In [112]:
toks2.input_ids.shape, dec_out.logits.shape

(torch.Size([2, 23]), torch.Size([2, 23, 30522]))

In [113]:
bert2bert.config.tie_encoder_decoder

False

In [114]:
bert2bert.config

EncoderDecoderConfig {
  "decoder": {
    "_name_or_path": "google-bert/bert-base-uncased",
    "add_cross_attention": true,
    "architectures": [
      "BertForMaskedLM"
    ],
    "attention_probs_dropout_prob": 0.1,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": 101,
    "chunk_size_feed_forward": 0,
    "cross_attention_hidden_size": null,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "early_stopping": false,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": 102,
    "exponential_decay_length_penalty": null,
    "finetuning_task": null,
    "forced_bos_token_id": null,
    "forced_eos_token_id": null,
    "gradient_checkpointing": false,
    "hidden_act": "gelu",
    "hidden_dropout_prob": 0.1,
    "hidden_size": 768,
    "id2label": {
      "0": "LABEL_0",
      "1": "LABEL_1"
    },
    "initializer_range": 0.02,
    "intermediate_size": 3072,
    "is_decoder": true,
    "is_encoder_decod

In [115]:
decoder.get_output_embeddings()

Linear(in_features=768, out_features=30522, bias=True)

In [116]:
decoder.lm_head.decoder

Linear(in_features=768, out_features=30522, bias=True)

In [117]:
encoder.embeddings

BertGenerationEmbeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [118]:
encoder.embeddings.word_embeddings

Embedding(30522, 768, padding_idx=0)

In [119]:
encoder.config.tie_word_embeddings

True

In [120]:
from transformers.models.bert_generation.modeling_bert_generation import BertGenerationEmbeddings

In [121]:
decoder.get_output_embeddings()

Linear(in_features=768, out_features=30522, bias=True)

In [122]:
decoder.get_input_embeddings()

Embedding(30522, 768, padding_idx=0)

In [123]:
torch.allclose(decoder.get_input_embeddings().weight, decoder.get_output_embeddings().weight)

True

In [124]:
torch.allclose(encoder.get_input_embeddings().weight, decoder.get_output_embeddings().weight)

True

In [125]:
id(decoder.get_input_embeddings().weight) == id(decoder.get_output_embeddings().weight)

True

In [126]:
id(encoder.get_input_embeddings().weight) == id(decoder.get_output_embeddings().weight)

False

In [127]:
id(bert2bert.get_input_embeddings().weight) == id(bert2bert.get_output_embeddings().weight)

False

## Custom override of EncoderDecoderModel

In [4]:
enc_model: BertGenerationEncoder = BertGenerationEncoder.from_pretrained(model_name, bos_token_id=101, eos_token_id=102)
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
dec_model: BertGenerationDecoder = BertGenerationDecoder.from_pretrained(
    model_name, add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102
)
tokenizer = BertTokenizer.from_pretrained(model_name)
eed_model = EncoderEmbDecoderModel(encoder=enc_model, decoder=dec_model)

You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
Some weights of BertGenerationDecoder were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer

In [135]:
s1 = ['How many words are needed to describe an ocean? Let\'s try to keep it short!', 'abc def']
toks1 = tokenizer(s1, return_tensors='pt', padding=True)
toks1

{'input_ids': tensor([[  101,  2129,  2116,  2616,  2024,  2734,  2000,  6235,  2019,  4153,
          1029,  2292,  1005,  1055,  3046,  2000,  2562,  2009,  2460,   999,
           102],
        [  101,  5925, 13366,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}

In [146]:
s2 = 'Ocean deep and blue'
toks2 = tokenizer(s2, return_tensors='pt', padding=True)
toks2

{'input_ids': tensor([[ 101, 4153, 2784, 1998, 2630,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [147]:
eed_out: Seq2SeqLMOutput = eed_model(input_ids=toks1.input_ids, decoder_input_ids=toks2.input_ids, labels=toks2.input_ids)
print(toks1.input_ids.shape, toks2.input_ids.shape, eed_out.logits.shape)

torch.Size([2, 21]) torch.Size([1, 6]) torch.Size([1, 6, 30522])


## Datasets
### HPQA

In [4]:
df_hpqa = get_hotpotqa()
df_hpqa

Load /home/misha/data/hotpotqa/hotpot_train_v1.1.json


Unnamed: 0,supporting_facts,level,question,context,answer,_id,type
0,"[[Arthur's Magazine, 0], [First for Women, 0]]",medium,Which magazine was started first Arthur's Maga...,"[[Radio City (Indian radio station), [Radio Ci...",Arthur's Magazine,5a7a06935542990198eaf050,comparison
1,"[[Oberoi family, 0], [The Oberoi Group, 0]]",medium,The Oberoi family is part of a hotel company t...,"[[Ritz-Carlton Jakarta, [The Ritz-Carlton Jaka...",Delhi,5a879ab05542996e4f30887e,bridge
2,"[[Allie Goertz, 0], [Allie Goertz, 1], [Allie ...",hard,Musician and satirist Allie Goertz wrote a son...,"[[Lisa Simpson, [Lisa Marie Simpson is a ficti...",President Richard Nixon,5a8d7341554299441c6b9fe5,bridge
3,"[[Peggy Seeger, 0], [Peggy Seeger, 1], [Ewan M...",medium,What nationality was James Henry Miller's wife?,"[[Moloch: or, This Gentile World, [Moloch: or,...",American,5a82171f5542990a1d231f4a,bridge
4,"[[Cadmium chloride, 1], [Ethanol, 0]]",medium,Cadmium Chloride is slightly soluble in this c...,"[[Cadmium chloride, [Cadmium chloride is a whi...",alcohol,5a84dd955542997b5ce3ff79,bridge
...,...,...,...,...,...,...,...
90442,"[[Kerry Remsen, 1], [Bert Remsen, 0]]",medium,Kerry Remsen is the daughter of an actor with ...,"[[Kerry Remsen, [Kerry Remsen is an American a...",American,5a8f8db25542997ba9cb32b9,bridge
90443,"[[Northshore Mall, 0], [Northshore Mall, 4], [...",easy,"Who manages both Northshore Mall in Peabody, M...","[[Green Tree Mall, [Green Tree Mall is a shopp...",Simon Property Group,5ae4f3615542993aec5ec0fd,bridge
90444,"[[Charlee Johnson, 4], [DreamWorks, 0]]",medium,Charlee Johnson was part of a band that signed...,"[[Simon M. Woods, [Simon M. Woods is a British...",Amblin Partners,5a903fc95542990a984935bd,bridge
90445,"[[Salt to the Sea, 1], [MV Wilhelm Gustloff, 0]]",medium,What is the ship that sank in the Baltic sea a...,[[The I.V. Stalin White Sea – Baltic Sea Canal...,"MV ""Wilhelm Gustloff",5ab56e71554299494045efc8,bridge


In [10]:
df_hpqa.question.str.len().describe()

count    90447.000000
mean       105.604277
std         59.117430
min         13.000000
25%         69.000000
50%         90.000000
75%        120.000000
max        654.000000
Name: question, dtype: float64

In [27]:
for ctx in df_hpqa.context:
    assert type(ctx) == list, str(ctx)
    for l in ctx:
        assert type(l) == list and len(l) == 2
        assert type(l[0]) == str and type(l[1]) == list
        for s in l[1]:
            assert type(s) == str

In [29]:
df_hpqa.context.apply(lambda ctx: sum(len(s.split()) for l in ctx for s in l[1])).describe()

count    90447.000000
mean       886.234380
std        251.339109
min         29.000000
25%        721.000000
50%        867.000000
75%       1028.000000
max       2792.000000
Name: context, dtype: float64

In [33]:
df_hpqa.answer.str.len().describe()

count    90447.000000
mean        13.719040
std         11.828903
min          1.000000
25%          7.000000
50%         12.000000
75%         17.000000
max        623.000000
Name: answer, dtype: float64

In [47]:
row = df_hpqa.iloc[4]
row

supporting_facts                [[Cadmium chloride, 1], [Ethanol, 0]]
level                                                          medium
question            Cadmium Chloride is slightly soluble in this c...
context             [[Cadmium chloride, [Cadmium chloride is a whi...
answer                                                        alcohol
_id                                          5a84dd955542997b5ce3ff79
type                                                           bridge
Name: 4, dtype: object

In [48]:
print(row.question)
print(row.supporting_facts)

Cadmium Chloride is slightly soluble in this chemical, it is also called what?
[['Cadmium chloride', 1], ['Ethanol', 0]]


In [50]:
row.context

[['Cadmium chloride',
  ['Cadmium chloride is a white crystalline compound of cadmium and chlorine, with the formula CdCl.',
   ' It is a hygroscopic solid that is highly soluble in water and slightly soluble in alcohol.',
   ' Although it is considered to be ionic, it has considerable covalent character to its bonding.',
   ' The crystal structure of cadmium chloride (described below), composed of two-dimensional layers of ions, is a reference for describing other crystal structures.',
   ' Also known are CdCl•HO and CdCl•5HO.']],
 ['Water blue',
  ['Water blue, also known as aniline blue, Acid blue 22, Soluble Blue 3M, Marine Blue V, or C.I. 42755, is a chemical compound used as a stain in histology.',
   ' Water blue stains collagen blue in tissue sections.',
   ' It is soluble in water and slightly soluble in ethanol.']],
 ['Diflucortolone valerate',
  ['Diflucortolone valerate (also "Nerisone" cream/oily cream/ointment, "Neriderm" ointment, Japanese ジフルコルトロン (Jifurucorutoron ) is 

### Squad v2

In [4]:
ds_sq = load_dataset('squad_v2')
ds_sq_t, ds_sq_v = ds_sq['train'], ds_sq['validation']
ds_sq

Reusing dataset squad_v2 (/home/misha/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 130319
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 11873
    })
})

In [5]:
row = ds_sq_t[4500]
row

{'id': '56cfc773234ae51400d9bf56',
 'title': 'Solar_energy',
 'context': 'Solar technologies are broadly characterized as either passive or active depending on the way they capture, convert and distribute sunlight and enable solar energy to be harnessed at different levels around the world, mostly depending on distance from the equator. Although solar energy refers primarily to the use of solar radiation for practical ends, all renewable energies, other than geothermal and tidal, derive their energy from the Sun in a direct or indirect way.',
 'question': 'How do renewable energies acquire energy from the sun?',
 'answers': {'text': ['direct or indirect'], 'answer_start': [449]}}

In [6]:
ans_txt, ans_start, ctx = row['answers']['text'][0], row['answers']['answer_start'][0], row['context']
ctx[ans_start:ans_start + len(ans_txt)]

'direct or indirect'

In [7]:
df_sq = ds_sq_t.to_pandas()

In [8]:
df_sq.question.str.len().describe()

count    130319.000000
mean         58.507739
std          73.757111
min           1.000000
25%          44.000000
50%          55.000000
75%          69.000000
max       25651.000000
Name: question, dtype: float64

In [9]:
df_sq.context.str.len().describe()

count    130319.000000
mean        754.566287
std         307.619239
min         151.000000
25%         561.000000
50%         692.000000
75%         891.000000
max        3706.000000
Name: context, dtype: float64

In [12]:
df_sq.iloc[0]

id                                   56be85543aeaaa14008c9063
title                                                 Beyoncé
context     Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...
question             When did Beyonce start becoming popular?
answers     {'text': ['in the late 1990s'], 'answer_start'...
Name: 0, dtype: object

In [17]:
ast = df_sq['answers'].apply(lambda anss: anss['answer_start'][0] if len(anss['answer_start']) > 0 else -1)
ast = ast[ast >= 0]
ast.describe()

count    86821.000000
mean       319.806475
std        266.390192
min          0.000000
25%        111.000000
50%        262.000000
75%        468.000000
max       3126.000000
Name: answers, dtype: float64

In [18]:
cl = df_sq['context'].apply(lambda c: len(c.split()))
cl.describe()

count    130319.000000
mean        119.614316
std          49.404411
min          20.000000
25%          89.000000
50%         110.000000
75%         141.000000
max         653.000000
Name: context, dtype: float64

### Eli5 Category

In [76]:
ds_eli5 = load_dataset("eli5_category", split="train[:5000]")

Downloading:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.11k [00:00<?, ?B/s]

No config specified, defaulting to: eli5_category/default


Downloading and preparing dataset eli5_category/default (download: 69.54 MiB, generated: 185.70 MiB, post-processed: Unknown size, total: 255.24 MiB) to /home/misha/.cache/huggingface/datasets/eli5_category/default/1.0.0/80106cc49322f1f5075e1387be4a5b74b95e0f56c40ff142b8999d0606aa1908...


  0%|          | 0/4 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/62.3M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/5.00M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.76M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.85M [00:00<?, ?B/s]

  0%|          | 0/4 [00:00<?, ?it/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset eli5_category downloaded and prepared to /home/misha/.cache/huggingface/datasets/eli5_category/default/1.0.0/80106cc49322f1f5075e1387be4a5b74b95e0f56c40ff142b8999d0606aa1908. Subsequent calls will reuse this data.


In [78]:
ds_eli5 = ds_eli5.train_test_split(test_size=0.1)
ds_eli5

DatasetDict({
    train: Dataset({
        features: ['q_id', 'title', 'selftext', 'category', 'subreddit', 'answers', 'title_urls', 'selftext_urls'],
        num_rows: 4500
    })
    test: Dataset({
        features: ['q_id', 'title', 'selftext', 'category', 'subreddit', 'answers', 'title_urls', 'selftext_urls'],
        num_rows: 500
    })
})

In [79]:
ds_eli5['train'][0]

{'q_id': '5lwp25',
 'title': 'Both the sun and moon appear flat, with no cues that one is further away than the other. What is the maximum distance that variations in depth can be perceived at, and why?',
 'selftext': '',
 'category': 'Biology',
 'subreddit': 'explainlikeimfive',
 'answers': {'a_id': ['dbz54c7'],
  'text': ["We don't have depth perception beyond about 10 meters. Anything further than that, your eyes are focused at infinity, and are pointing in the same direction. Beyond that, we use cues like its apparent size and brightness, or color shifts caused by the amount of air between you and the object (how a distant mountain looks more blue than a near one). We determine things are flat or round by their shadows."],
  'score': [4],
  'text_urls': [[]]},
 'title_urls': ['url'],
 'selftext_urls': ['url']}

## Model

In [160]:
enc_model: BertGenerationEncoder = BertGenerationEncoder.from_pretrained(model_name, bos_token_id=101, eos_token_id=102)
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
dec_model: BertGenerationDecoder = BertGenerationDecoder.from_pretrained(
    model_name, add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102
)
tokenizer = BertTokenizer.from_pretrained(model_name)
eed_model = EncoderEmbDecoderModel(encoder=enc_model, decoder=dec_model)

You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
Some weights of BertGenerationDecoder were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer

In [143]:
df_sq = pd.concat([ds_sq['train'].to_pandas(), ds_sq['validation'].to_pandas()], axis=0)
df_sq

Unnamed: 0,id,title,context,question,answers
0,56be85543aeaaa14008c9063,Beyoncé,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,When did Beyonce start becoming popular?,"{'text': ['in the late 1990s'], 'answer_start'..."
1,56be85543aeaaa14008c9065,Beyoncé,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,What areas did Beyonce compete in when she was...,"{'text': ['singing and dancing'], 'answer_star..."
2,56be85543aeaaa14008c9066,Beyoncé,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,When did Beyonce leave Destiny's Child and bec...,"{'text': ['2003'], 'answer_start': [526]}"
3,56bf6b0f3aeaaa14008c9601,Beyoncé,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,In what city and state did Beyonce grow up?,"{'text': ['Houston, Texas'], 'answer_start': [..."
4,56bf6b0f3aeaaa14008c9602,Beyoncé,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,In which decade did Beyonce become famous?,"{'text': ['late 1990s'], 'answer_start': [276]}"
...,...,...,...,...,...
11868,5737aafd1c456719005744ff,Force,"The pound-force has a metric counterpart, less...",What is the seldom used force unit equal to on...,"{'text': ['sthène', 'sthène', 'sthène', 'sthèn..."
11869,5ad28ad0d7d075001a4299cc,Force,"The pound-force has a metric counterpart, less...",What does not have a metric counterpart?,"{'text': [], 'answer_start': []}"
11870,5ad28ad0d7d075001a4299cd,Force,"The pound-force has a metric counterpart, less...",What is the force exerted by standard gravity ...,"{'text': [], 'answer_start': []}"
11871,5ad28ad0d7d075001a4299ce,Force,"The pound-force has a metric counterpart, less...",What force leads to a commonly used unit of mass?,"{'text': [], 'answer_start': []}"


In [None]:
import numpy as np
from mllm.train.embgen_bert import get_sq_batch, qna_loss


In [None]:
print(df_sq.columns)
inp_len = 128



Index(['id', 'title', 'context', 'question', 'answers'], dtype='object')


In [210]:
inds = [0, 100, 200]
inds = np.array(inds)
qnab = get_sq_batch(tkz=tokenizer, df_sq=df_sq, inds=inds, inp_len=inp_len)

In [211]:
print(qnab.ctx_toks.shape)
for qa_toks in qnab.qa_toks:
    print(qa_toks.shape, qa_toks == tokenizer.sep_token_id)
for qa_att_mask in qnab.qa_att_masks:
    print(qa_att_mask.shape)
for qa_tgt_mask in qnab.qa_tgt_masks:
    print(qa_tgt_mask.shape)

(6, 128)
(12,) [False False False False False False False  True False False False False]
(15,) [False False False False False False False False False False False False
 False  True False]
(18,) [False False False False False False False False False False False False
 False False False False  True False]
(4, 12)
(1, 15)
(1, 18)
(4, 12)
(1, 15)
(1, 18)


In [None]:
qas, qa_att_masks, qa_tgt_masks, ctxs = qnab.gen_tensors()
print(ctxs.shape)
for qa in qas:
    print(qa.shape, qa == tokenizer.sep_token_id)
for qa_att_mask in qa_att_masks:
    print(qa_att_mask.shape)
for qa_tgt_mask in qa_tgt_masks:
    print(qa_tgt_mask.shape)


torch.Size([6, 128])
torch.Size([12]) tensor([False, False, False, False, False, False, False,  True, False, False,
        False, False])
torch.Size([15]) tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False,  True, False])
torch.Size([18]) tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False,  True, False])
torch.Size([4, 12])
torch.Size([1, 15])
torch.Size([1, 18])
torch.Size([4, 12])
torch.Size([1, 15])
torch.Size([1, 18])


In [213]:
enc_out: BaseModelOutputWithPastAndCrossAttentions = enc_model(input_ids=ctxs)
enc_emb = enc_out.last_hidden_state[:, 0].unsqueeze(0)
print(enc_out.last_hidden_state.shape, enc_emb.shape)

torch.Size([6, 128, 768]) torch.Size([1, 6, 768])


In [214]:
ind = 0
qa, qa_att_mask, qa_tgt_mask = qas[ind].unsqueeze(0), qa_att_masks[ind], qa_tgt_masks[ind]
qa = qa.repeat(len(qa_att_mask), 1)
print(qa.shape, qa_att_mask.shape)
print(qa)
print(qa_att_mask)
print(qa_tgt_mask)


torch.Size([4, 12]) torch.Size([4, 12])
tensor([[ 2043,  2106, 20773,  2707,  3352,  2759,  1029,   102,  1999,  1996,
          2397,  4134],
        [ 2043,  2106, 20773,  2707,  3352,  2759,  1029,   102,  1999,  1996,
          2397,  4134],
        [ 2043,  2106, 20773,  2707,  3352,  2759,  1029,   102,  1999,  1996,
          2397,  4134],
        [ 2043,  2106, 20773,  2707,  3352,  2759,  1029,   102,  1999,  1996,
          2397,  4134]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]])
tensor([[False, False, False, False, False, False, False, False,  True, False,
         False, False],
        [False, False, False, False, False, False, False, False, False,  True,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
          True, False],
        [False, False, False, False, False, False, Fa

In [215]:
dec_out: CausalLMOutputWithCrossAttentions = dec_model(input_ids=qa, attention_mask=qa_att_mask, encoder_hidden_states=enc_emb)
print(dec_out.logits.shape)

torch.Size([4, 12, 30522])


In [216]:
dec_out.logits.min(), dec_out.logits.mean(), dec_out.logits.max(), dec_out.logits.std()

(tensor(-3.3552, grad_fn=<MinBackward1>),
 tensor(-0.1662, grad_fn=<MeanBackward0>),
 tensor(4.1287, grad_fn=<MaxBackward1>),
 tensor(0.5924, grad_fn=<StdBackward0>))

In [234]:
tgt_logits = dec_out.logits.masked_select(qa_tgt_mask.unsqueeze(-1))
tgt_logits = tgt_logits.reshape(dec_out.logits.shape[0], dec_out.logits.shape[2])
dec_out.logits.shape, tgt_logits.shape

(torch.Size([4, 12, 30522]), torch.Size([4, 30522]))

In [235]:
tgt_toks = qa.masked_select(qa_tgt_mask).unsqueeze(-1)
tgt_toks

tensor([[1999],
        [1996],
        [2397],
        [4134]])

In [240]:
probs_pred = torch.softmax(tgt_logits, dim=-1)
torch.gather(probs_pred, dim=-1, index=tgt_toks)

tensor([[6.3836e-05],
        [1.9808e-05],
        [2.8551e-05],
        [7.0084e-05]], grad_fn=<GatherBackward0>)

In [242]:
l = qna_loss(dec_out.logits, qa, qa_tgt_mask)
l

tensor(8.1433, grad_fn=<NegBackward0>)

In [244]:
df = pd.DataFrame({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})
print(df)
print(df.sample(len(df)))

   a  b
0  1  x
1  2  y
2  3  z
   a  b
2  3  z
0  1  x
1  2  y


In [249]:
l = [(1, 'a'), (2, 'b'), (3, 'c'), (4, 'd'), (5, 'e'), (6, 'f')]
np.random.shuffle(l)
l

[(2, 'b'), (6, 'f'), (4, 'd'), (5, 'e'), (3, 'c'), (1, 'a')]

## EncoderDecoderModel middle layer type

In [6]:
model_name = 'google-bert/bert-base-uncased'
enc_emb_exp_type = EncEmbExpansionType.Mat

enc_model: BertGenerationEncoder = BertGenerationEncoder.from_pretrained(model_name, bos_token_id=101, eos_token_id=102)
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
dec_model: BertGenerationDecoder = BertGenerationDecoder.from_pretrained(
    model_name, add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102
)
tokenizer = BertTokenizer.from_pretrained(model_name)
eed_model = EncoderEmbDecoderModel(encoder=enc_model, decoder=dec_model, enc_emb_exp_type=enc_emb_exp_type)
eed_model.config.enc_emb_exp_type

You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
You are using a model of type bert to instantiate a model of type bert-generation. This is not supported for all configurations of models and can yield errors.
Some weights of BertGenerationDecoder were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer

<EncEmbExpansionType.Mat: 'mat'>

In [7]:
cfg_enc = eed_model.config.encoder
cfg_enc

BertGenerationConfig {
  "_name_or_path": "google-bert/bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 101,
  "eos_token_id": 102,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert-generation",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.42.4",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [14]:
t = torch.randint(0, 9, (3, 5))
t

tensor([[5, 1, 6, 8, 8],
        [7, 0, 1, 8, 1],
        [1, 2, 6, 4, 6]])

In [21]:
torch.nn.functional.pad(t, (0, 0, 0, 1))

tensor([[5, 1, 6, 8, 8],
        [7, 0, 1, 8, 1],
        [1, 2, 6, 4, 6],
        [0, 0, 0, 0, 0]])