## Bert Averaging

We see in this notebook how can we calculate the average representation of each contextualized embedding

In [19]:
%load_ext autoreload
%autoreload 2
import os
from datetime import datetime
import fire
import torch
from torchtext import data
import torch.nn as nn
from transformers import (
    AdamW, BertTokenizer, BertModel
)


tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased')
bert = BertModel.from_pretrained('bert-base-multilingual-uncased')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Let's first tokenize the sentence

In [8]:
sentence = "This is a proof"

sent = tokenizer.tokenize(sentence)
print(sent)

['this', 'is', 'a', 'proof']


Then, convert it to ids. Also, we create a tensor with shape `(1, sent len)`

In [32]:
inp = torch.LongTensor(tokenizer.encode(sentence)).view(1, -1)

inp.shape


torch.Size([1, 6])

In [37]:
tokenizer.convert_ids_to_tokens(inp[0])

['[CLS]', 'this', 'is', 'a', 'proof', '[SEP]']

In [38]:
hidden, pooled = bert(inp)

## Using iterator and lengths

In [50]:
import torchtext.data as data

init_token = tokenizer.cls_token
eos_token  = tokenizer.sep_token
pad_token  = tokenizer.pad_token
unk_token  = tokenizer.unk_token

init_token_idx = tokenizer.cls_token_id
eos_token_idx  = tokenizer.sep_token_id
pad_token_idx  = tokenizer.pad_token_id
unk_token_idx  = tokenizer.unk_token_id

TEXT = data.Field(
    tokenize=tokenizer.tokenize,
    include_lengths = True,
    use_vocab=False,
    batch_first = True,
    preprocessing = tokenizer.convert_tokens_to_ids,
    init_token = init_token_idx,
    eos_token = eos_token_idx,
    pad_token = pad_token_idx,
    unk_token = unk_token_idx
)


In [56]:
ID = data.Field(sequential=False, use_vocab=False)
# All these arguments are because these are really floats
# See https://github.com/pytorch/text/issues/78#issuecomment-541203609
AVG = data.LabelField(dtype = torch.float, use_vocab=False, preprocessing=float)
STD = data.LabelField(dtype = torch.float, use_vocab=False, preprocessing=float)
SUBTASK_A = data.LabelField()

train_dataset = data.TabularDataset(
    "../../data/English/task_a_distant.xsmall.tsv",
    format="tsv", skip_header=True,
    fields=[("id", ID), ("text", TEXT), ("avg", AVG), ("std", STD)],
)

print(f"Train instances: {len(train_dataset)}")
BATCH_SIZE = 32


device = "cuda" if torch.cuda.is_available() else "cpu"

train_it = data.BucketIterator(
    train_dataset, batch_size=BATCH_SIZE, device=device,
    sort_key = lambda x: len(x.text), sort_within_batch = True,
)

Train instances: 908


In [60]:
batch = next(iter(train_it))

text, lens = batch.text

In [64]:
for i in range(32):
    print(tokenizer.convert_ids_to_tokens(text[i]))
    print(lens[i])

['[CLS]', '@', 'user', 'i', 'don', '[UNK]', 't', 'think', 'you', 'can', 'throw', '30', '[SEP]']
tensor(13, device='cuda:0')
['[CLS]', '@', 'user', 'when', 'he', 'hits', 'that', 'bong', 'and', 'almost', 'dies', '[UNK]', '[SEP]']
tensor(13, device='cuda:0')
['[CLS]', '@', 'user', 'wealth', 'is', 'measured', 'different', '##ly', 'among', 'people', 'bro', '[UNK]', '[SEP]']
tensor(13, device='cuda:0')
['[CLS]', 'this', 'head', '##ache', 'can', 'absolute', '##ly', 'fuck', '##in', 'po', '##ke', 'it', '[SEP]']
tensor(13, device='cuda:0')
['[CLS]', 'september', '3', 'first', 'day', 'of', 'school', '[UNK]', '#', 'pre', '##k', '[UNK]', '[SEP]']
tensor(13, device='cuda:0')
['[CLS]', '@', 'user', '@', 'user', 'same', 'energy', 'with', 'the', 'pair', 'phone', '.', '[SEP]']
tensor(13, device='cuda:0')
['[CLS]', 'rt', 'if', 'ur', 'dick', 'is', 'the', 'same', 'length', 'as', 'your', 'height', '[SEP]']
tensor(13, device='cuda:0')
['[CLS]', '@', 'user', 'you', 'always', 'come', 'through', '.', 'thanks', 

In [68]:
text = text.to(device)
bert = bert.to(device)
hidden, pooled = bert(text)

In [89]:

# Don't use CLS and the SEP token 
# Watch out if this could be also done to ignore padding...
hidden[:, 1:-1].sum(dim=1) / (hidden.shape[1] - 2)

tensor([[ 0.0406, -0.3458,  0.3251,  ...,  0.0376,  0.2888, -0.2381],
        [ 0.0726, -0.1260,  0.3359,  ..., -0.2967,  0.2420, -0.1426],
        [-0.1861, -0.0908,  0.1032,  ...,  0.1746,  0.3987, -0.5655],
        ...,
        [-0.1890, -0.1230,  0.3380,  ..., -0.1634,  0.6574, -0.3483],
        [-0.1171,  0.0681,  0.2051,  ...,  0.1395,  0.2031, -0.1296],
        [-0.0622, -0.1456,  0.4096,  ..., -0.1914,  0.2268, -0.1482]],
       device='cuda:0', grad_fn=<DivBackward0>)