In [26]:
import sys
import torch

In [4]:
sys.path.insert(0, "/home/zaid/Source/ALBEF/models")

In [9]:
from models.vit import VisionTransformer
from transformers import BertModelForMaskedLM, AutoTokenizer, BertConfig
import torch

In [8]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') 
text = "this is a random image"

In [10]:
lm = BertModel(BertConfig())
vm = VisionTransformer()
tokens = torch.Tensor(tokenizer.encode(text)).long()

In [13]:
tokens

tensor([ 101, 2023, 2003, 1037, 6721, 3746,  102])

# How does ALBEF handle tokenization?
The captions in the input JSON fed into ALBEF are raw text. They are loaded into the [`pretrain_dataset`](https://sourcegraph.com/github.com/salesforce/ALBEF/-/blob/dataset/caption_dataset.py?L97) class, where the `pre_caption` function does some basic preprocessing. 
It's then wrapped by a `create_dataset` function that doesn't alter the text data.
So, in summary, the dataset that comes out of `create_dataset` has `str` typed text data, not integer input ids.
The dataset then gets passed into a `create_loader` function, which also does not modify the text data. 
The tokenization happens in the [`train`](https://sourcegraph.com/github.com/salesforce/ALBEF@9e9a5e952f72374c15cea02d3c34013554c86513/-/blob/Pretrain.py?L59) function.

```python
text_input = tokenizer(text, padding='longest', truncation=True, max_length=25, return_tensors="pt").to(device)
```


In [15]:
text = [
    "this",
    "this is",
    "this is a",
    "this is a random",
    "this is a random image"
]

In [16]:
text_input = tokenizer(text, padding='longest', truncation=True, max_length=25, return_tensors="pt")

In [17]:
text_input

{'input_ids': tensor([[ 101, 2023,  102,    0,    0,    0,    0],
        [ 101, 2023, 2003,  102,    0,    0,    0],
        [ 101, 2023, 2003, 1037,  102,    0,    0],
        [ 101, 2023, 2003, 1037, 6721,  102,    0],
        [ 101, 2023, 2003, 1037, 6721, 3746,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1]])}

# How does sentence-pair tokenization look?

In [23]:
text_input = tokenizer.batch_encode_plus(list(zip(text, text)), padding='longest', truncation=True, max_length=50, return_tensors="pt")

In [24]:
text_input

{'input_ids': tensor([[ 101, 2023,  102, 2023,  102,    0,    0,    0,    0,    0,    0,    0,
            0],
        [ 101, 2023, 2003,  102, 2023, 2003,  102,    0,    0,    0,    0,    0,
            0],
        [ 101, 2023, 2003, 1037,  102, 2023, 2003, 1037,  102,    0,    0,    0,
            0],
        [ 101, 2023, 2003, 1037, 6721,  102, 2023, 2003, 1037, 6721,  102,    0,
            0],
        [ 101, 2023, 2003, 1037, 6721, 3746,  102, 2023, 2003, 1037, 6721, 3746,
          102]]), 'token_type_ids': tensor([[0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1,

The input ALBEF gets is a dictionary containing `input_ids` and the `attention_mask`. The attention mask covers all the non-pad tokens. In sentence-pair mode, there is no padding in between the sentences. The sentences are just separated by the `[SEP]` symbol. 

# Constructing a sentence pair from ViT / word embeddings
The ViT output sequence is always the same length, and if stacked together, has no ragged edges. The sentences cannot be stacked together without padding.

In [42]:
image = torch.rand(3, 224, 224).unsqueeze(0)

In [43]:
vit_out = vm(image)
vit_out.shape

torch.Size([1, 197, 768])

In [None]:
text = [
    "this",
    "this is",
    "this is a",
    "this is a random",
    "this is a random image"
]

In [52]:
# Make a batch the same size as the amount of text.
img_batch = torch.vstack([vit_out] * len(text))
img_batch.shape

torch.Size([5, 197, 768])

In [125]:
# Tokenize the text.
text_batch = tokenizer(text, padding='longest', truncation=True, max_length=25, return_tensors="pt", add_special_tokens=False)

In [126]:
text_batch

{'input_ids': tensor([[2023,    0,    0,    0,    0],
        [2023, 2003,    0,    0,    0],
        [2023, 2003, 1037,    0,    0],
        [2023, 2003, 1037, 6721,    0],
        [2023, 2003, 1037, 6721, 3746]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 0, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1]])}

In [130]:
prefix = torch.ones(batch_size, 1, 1) * tokenizer.cls_token_id

In [131]:
separator = torch.ones(batch_size, 1, 1) * tokenizer.sep_token_id

In [None]:
eos = torch.ones(batch_size, 1, 1) * tokenizer.sep_token_id

The `input_embeds` keyword, which we will be passing data into, is filled from the `word_embeddings` layer if not provided (https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert/modeling_bert.py#L214).

In [137]:
lang_input_embeds = lm.embeddings.word_embeddings(text_batch['input_ids'])

In [151]:
prefix_embeds = lm.embeddings.word_embeddings(prefix.long()).squeeze(1)
sep_embeds = lm.embeddings.word_embeddings(separator.long()).squeeze(1)
eos_embeds = lm.embeddings.word_embeddings(eos.long()).squeeze(1)

In [156]:
model_input = torch.cat([prefix_embeds, img_batch, sep_embeds, lang_input_embeds, eos_embeds], dim=1)

In [159]:
lm(inputs_embeds=model_input)

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-4.7683e-01, -6.3524e-01, -4.1642e-01,  ..., -9.7788e-01,
           1.8886e-01,  1.7899e+00],
         [-4.3331e-01,  3.0411e-01,  6.5908e-01,  ...,  5.7775e-01,
           7.4744e-01,  2.0552e+00],
         [-7.9278e-01, -7.9437e-01,  4.4264e-01,  ...,  6.4663e-01,
           7.7473e-01,  4.7329e-01],
         ...,
         [-1.7355e+00, -4.6115e-01,  1.0934e+00,  ...,  4.5150e-01,
           9.1559e-01,  1.2327e+00],
         [-6.4930e-01, -8.7906e-01, -4.5566e-01,  ..., -2.5346e-04,
          -2.1459e-01,  2.0685e+00],
         [-6.2172e-01, -5.5700e-01,  7.0330e-01,  ..., -4.7587e-01,
           5.4234e-01,  1.2120e+00]],

        [[-4.2127e-01, -3.4930e-01,  4.4559e-01,  ..., -3.3558e-01,
          -3.7299e-01,  1.3806e+00],
         [-2.9046e-01,  1.5188e-01,  3.7710e-01,  ...,  5.8053e-01,
           1.9641e+00,  6.7893e-01],
         [ 1.1188e-01, -7.8049e-01,  2.9377e-01,  ...,  4.1177e-01,
          -3.

## Computing the attention and token_type_id masks
Let's ignore this for now and see what happens. It could be unimportant.