In [26]:
import sys
import torch

In [4]:
sys.path.insert(0, "/home/zaid/Source/ALBEF/models")

In [180]:
from models.vit import VisionTransformer
from transformers import BertForMaskedLM, AutoTokenizer, BertConfig
import torch

In [8]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') 
text = "this is a random image"

In [182]:
lm = BertForMaskedLM(BertConfig())
vm = VisionTransformer()

# How does ALBEF handle tokenization?
The captions in the input JSON fed into ALBEF are raw text. They are loaded into the [`pretrain_dataset`](https://sourcegraph.com/github.com/salesforce/ALBEF/-/blob/dataset/caption_dataset.py?L97) class, where the `pre_caption` function does some basic preprocessing. 
It's then wrapped by a `create_dataset` function that doesn't alter the text data.
So, in summary, the dataset that comes out of `create_dataset` has `str` typed text data, not integer input ids.
The dataset then gets passed into a `create_loader` function, which also does not modify the text data. 
The tokenization happens in the [`train`](https://sourcegraph.com/github.com/salesforce/ALBEF@9e9a5e952f72374c15cea02d3c34013554c86513/-/blob/Pretrain.py?L59) function.

```python
text_input = tokenizer(text, padding='longest', truncation=True, max_length=25, return_tensors="pt").to(device)
```


In [15]:
text = [
    "this",
    "this is",
    "this is a",
    "this is a random",
    "this is a random image"
]

In [16]:
text_input = tokenizer(text, padding='longest', truncation=True, max_length=25, return_tensors="pt")

In [17]:
text_input

{'input_ids': tensor([[ 101, 2023,  102,    0,    0,    0,    0],
        [ 101, 2023, 2003,  102,    0,    0,    0],
        [ 101, 2023, 2003, 1037,  102,    0,    0],
        [ 101, 2023, 2003, 1037, 6721,  102,    0],
        [ 101, 2023, 2003, 1037, 6721, 3746,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1]])}

# How does sentence-pair tokenization look?

In [23]:
text_input = tokenizer.batch_encode_plus(list(zip(text, text)), padding='longest', truncation=True, max_length=50, return_tensors="pt")

In [24]:
text_input

{'input_ids': tensor([[ 101, 2023,  102, 2023,  102,    0,    0,    0,    0,    0,    0,    0,
            0],
        [ 101, 2023, 2003,  102, 2023, 2003,  102,    0,    0,    0,    0,    0,
            0],
        [ 101, 2023, 2003, 1037,  102, 2023, 2003, 1037,  102,    0,    0,    0,
            0],
        [ 101, 2023, 2003, 1037, 6721,  102, 2023, 2003, 1037, 6721,  102,    0,
            0],
        [ 101, 2023, 2003, 1037, 6721, 3746,  102, 2023, 2003, 1037, 6721, 3746,
          102]]), 'token_type_ids': tensor([[0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1,

The input ALBEF gets is a dictionary containing `input_ids` and the `attention_mask`. The attention mask covers all the non-pad tokens. In sentence-pair mode, there is no padding in between the sentences. The sentences are just separated by the `[SEP]` symbol. 

# Constructing a sentence pair from ViT / word embeddings
The ViT output sequence is always the same length, and if stacked together, has no ragged edges. The sentences cannot be stacked together without padding.

In [42]:
image = torch.rand(3, 224, 224).unsqueeze(0)

In [43]:
vit_out = vm(image)
vit_out.shape

torch.Size([1, 197, 768])

In [None]:
text = [
    "this",
    "this is",
    "this is a",
    "this is a random",
    "this is a random image"
]

In [52]:
# Make a batch the same size as the amount of text.
img_batch = torch.vstack([vit_out] * len(text))
img_batch.shape

torch.Size([5, 197, 768])

In [125]:
# Tokenize the text.
text_batch = tokenizer(text, padding='longest', truncation=True, max_length=25, return_tensors="pt", add_special_tokens=False)

In [126]:
text_batch

{'input_ids': tensor([[2023,    0,    0,    0,    0],
        [2023, 2003,    0,    0,    0],
        [2023, 2003, 1037,    0,    0],
        [2023, 2003, 1037, 6721,    0],
        [2023, 2003, 1037, 6721, 3746]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 0, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1]])}

In [130]:
prefix = torch.ones(batch_size, 1, 1) * tokenizer.cls_token_id

In [131]:
separator = torch.ones(batch_size, 1, 1) * tokenizer.sep_token_id

In [None]:
eos = torch.ones(batch_size, 1, 1) * tokenizer.sep_token_id

The `input_embeds` keyword, which we will be passing data into, is filled from the `word_embeddings` layer if not provided (https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert/modeling_bert.py#L214).

In [137]:
lang_input_embeds = lm.embeddings.word_embeddings(text_batch['input_ids'])

In [151]:
prefix_embeds = lm.embeddings.word_embeddings(prefix.long()).squeeze(1)
sep_embeds = lm.embeddings.word_embeddings(separator.long()).squeeze(1)
eos_embeds = lm.embeddings.word_embeddings(eos.long()).squeeze(1)

In [156]:
model_input = torch.cat([prefix_embeds, img_batch, sep_embeds, lang_input_embeds, eos_embeds], dim=1)

In [184]:
lm.bert(inputs_embeds=model_input)

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[ 1.6899e+00, -5.4230e-01,  3.5317e-01,  ...,  6.5037e-01,
          -1.2622e+00, -3.3358e-01],
         [-3.1596e-02,  4.0642e-01,  5.5573e-01,  ..., -1.4848e-01,
          -8.9354e-01, -1.0798e-02],
         [ 1.3592e-01,  1.2805e-01, -6.3903e-01,  ..., -7.1887e-01,
          -1.1928e+00, -1.4440e+00],
         ...,
         [ 3.7314e-01,  6.8338e-01, -2.9855e-01,  ...,  5.7508e-01,
          -3.7531e-01,  6.8136e-01],
         [ 3.4729e-01,  9.2582e-01,  2.2197e-01,  ..., -7.3205e-01,
          -2.8642e-01, -4.6515e-01],
         [-4.6991e-02, -1.1992e+00,  9.2487e-01,  ...,  6.4147e-01,
          -3.8263e-01,  5.7462e-01]],

        [[ 1.5546e+00, -4.8769e-01,  1.3554e-01,  ..., -8.3493e-01,
          -2.3868e-01,  2.5389e-03],
         [ 4.6983e-01,  1.2182e-01,  7.3298e-02,  ...,  4.8344e-01,
          -3.0528e-01, -4.7187e-01],
         [ 5.5295e-01, -7.9896e-02, -8.7515e-01,  ...,  5.2899e-02,
          -5.

## Computing the  token_type_id mask

In [189]:
sentence_pair = tokenizer.encode_plus(
    text="this is some text",
    text_pair="a second sentence, longer than the first",
    padding='longest', 
    truncation=True, 
    max_length=25, 
    return_tensors="pt", 
    add_special_tokens=True
)

In [190]:
print(sentence_pair.input_ids)
print(sentence_pair.input_ids.shape)

tensor([[ 101, 2023, 2003, 2070, 3793,  102, 1037, 2117, 6251, 1010, 2936, 2084,
         1996, 2034,  102]])
torch.Size([1, 15])


In [191]:
print(sentence_pair.attention_mask)
print(sentence_pair.attention_mask.shape)

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
torch.Size([1, 15])


In [193]:
print(sentence_pair.token_type_ids)
print(sentence_pair.token_type_ids.shape)

tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
torch.Size([1, 15])
