## Initialize

In [1]:
!pip3 install transformers
!pip3 install nltk

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

from google.colab import output
output.clear()

import nltk
nltk.download('brown')
from nltk.corpus import brown
brown.words()

[nltk_data] Downloading package brown to /root/nltk_data...
[nltk_data]   Unzipping corpora/brown.zip.


['The', 'Fulton', 'County', 'Grand', 'Jury', 'said', ...]

In [2]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

!gdown -O en_bert.tar.gz https://drive.google.com/uc?id=1-VJjnqLGKafSoiELTHmg-REVBz0a0QEC  # en
!tar xzf en_bert.tar.gz

our_tokenizer = AutoTokenizer.from_pretrained("en_bert")
our_model = AutoModelForMaskedLM.from_pretrained("en_bert").to(device)

Downloading...
From: https://drive.google.com/uc?id=1-VJjnqLGKafSoiELTHmg-REVBz0a0QEC
To: /content/en_bert.tar.gz
100% 408M/408M [00:03<00:00, 102MB/s]


Some weights of the model checkpoint at en_bert were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
original_tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
original_model = AutoModelForMaskedLM.from_pretrained("bert-base-cased").to(device)

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/208k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/426k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/416M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Check perplexity on Brown corpus

In [4]:
@torch.no_grad()
def test_perplexity(text, model, tokenizer, device):
    torch.manual_seed(42)

    encoded_input = tokenizer(text, return_tensors='pt', add_special_tokens=False)

    MASK_ID = tokenizer.convert_tokens_to_ids("[MASK]")
    CLS_ID = tokenizer.convert_tokens_to_ids("[CLS]")
    SEP_ID = tokenizer.convert_tokens_to_ids("[SEP]")

    length = encoded_input.input_ids.size(1)
    batch_size = 128
    mask_p = 0.15

    log_p_sum, denominator = 0.0, 0

    for batch in range(length // (batch_size*batch_size)):
        inputs = encoded_input.input_ids[:, batch_size*batch_size * batch : batch_size*batch_size * (batch + 1)].to(device)
        inputs = inputs.view(batch_size, batch_size)

        mask = torch.bernoulli(torch.full_like(inputs, mask_p, dtype=torch.float)).bool()
        masked_inputs = torch.where(mask, MASK_ID, inputs)
        masked_inputs = F.pad(masked_inputs, (1, 0), value=SEP_ID)
        masked_inputs = F.pad(masked_inputs, (0, 1), value=CLS_ID)

        output = model(input_ids=masked_inputs)["logits"][:, 1:-1, :]

        log_p = F.log_softmax(output, dim=-1)
        log_p = -log_p.gather(index=inputs.unsqueeze(-1), dim=-1).squeeze(-1)

        log_p_sum += (log_p * mask).sum().item()
        denominator += mask.sum().item()

        if (batch + 1) % 10 == 0:
            print(f"{batch} / {length // (batch_size*batch_size)}: {log_p_sum / denominator}")

    print()
    print("Cross entropy:", log_p_sum / denominator)
    print("Perplexity:", math.exp(log_p_sum / denominator))

In [5]:
text = ' '.join(brown.words())

In [8]:
test_perplexity(text, original_model, original_tokenizer, device)

9 / 82: 2.911851166866174
19 / 82: 2.941161776399491
29 / 82: 2.8926835863417772
39 / 82: 2.8869864985920164
49 / 82: 2.8260881025799858
59 / 82: 2.806196751910918
69 / 82: 2.8213818011892355
79 / 82: 2.838635135111492

Cross entropy: 2.8482865877238845
Perplexity: 17.258186105024055


In [9]:
test_perplexity(text, our_model, our_tokenizer, device)

9 / 83: 2.9177756486746267
19 / 83: 2.9529246580788806
29 / 83: 2.9447422322139154
39 / 83: 2.9429093938169895
49 / 83: 2.902524063451907
59 / 83: 2.8752949467120605
69 / 83: 2.900504649840149
79 / 83: 2.9309261142928444

Cross entropy: 2.9351692560192997
Perplexity: 18.8246890574932


## Check tokenizer

Show the most frequent tokens, which result in [UNK]

### Original BERT

Everything is tokenized in-vocabulary

In [22]:
from collections import Counter

encoded_input = original_tokenizer(text, add_special_tokens=False, return_offsets_mapping=True)
print("Number of subwords:", len(encoded_input["input_ids"]))
print("\nMost frequent UNKs:")
UNK_ID = original_tokenizer.convert_tokens_to_ids("[UNK]")
counter = Counter((text[start:end] for id, (start, end) in zip(encoded_input["input_ids"], encoded_input["offset_mapping"]) if id == UNK_ID))
print(counter.most_common(20))
print("\nNumber of UNKs:", len([_ for id in encoded_input["input_ids"] if id == UNK_ID]))

Number of subwords: 1344235

Most frequent UNKs:
[]

Number of UNKs: 0


### Our BERT

Lots of OOV tokens :(

In [23]:
from collections import Counter

encoded_input = our_tokenizer(text, add_special_tokens=False, return_offsets_mapping=True)
print("Number of subwords:", len(encoded_input["input_ids"]))
print("\nMost frequent UNKs:")
UNK_ID = our_tokenizer.convert_tokens_to_ids("[UNK]")
counter = Counter((text[start:end] for id, (start, end) in zip(encoded_input["input_ids"], encoded_input["offset_mapping"]) if id == UNK_ID))
print(counter.most_common(20))
print("\nNumber of UNKs:", len([_ for id in encoded_input["input_ids"] if id == UNK_ID]))

Number of subwords: 1366067

Most frequent UNKs:
[('`', 17674), ('?', 4694), ('!', 1597), ('$', 579), ('&', 166), ('Q', 54), ('X', 34), ('Quiney', 25), ('Queen', 22), ('{', 16), ('}', 16), ('Quint', 13), ('Quite', 10), ('Question', 7), ('Questions', 7), ('Quartet', 7), ('Quaker', 6), ('Xydis', 6), ('Queens', 5), ('EQU', 5)]

Number of UNKs: 25044


## Check probability on a random text

In [13]:
@torch.no_grad()
def show_perplexity(text, model, tokenizer, device):
    torch.manual_seed(42)

    encoded_input = tokenizer(text, return_tensors='pt', add_special_tokens=False)

    MASK_ID = tokenizer.convert_tokens_to_ids("[MASK]")
    CLS_ID = tokenizer.convert_tokens_to_ids("[CLS]")
    SEP_ID = tokenizer.convert_tokens_to_ids("[SEP]")

    N = encoded_input["input_ids"].size(-1)
    repeated_inputs = encoded_input["input_ids"].repeat(N, 1)
    masked_inputs = repeated_inputs.clone()
    for i in range(N):
        masked_inputs[i, i] = MASK_ID

    masked_inputs = F.pad(masked_inputs, (1, 0), value=SEP_ID)
    masked_inputs = F.pad(masked_inputs, (0, 1), value=CLS_ID)

    words = [tokenizer.decode([encoded_input["input_ids"][0, i].item()]) for i in range(N)]
    max_word_len = max(max(len(w) for w in words), 5) + 3
    print('ORIGINAL:   ', end='')
    for word in words:
        print((max_word_len - len(word)) * ' ' + word, end='')
    print()

    output = model(input_ids=masked_inputs.to(device))
    probs = F.log_softmax(output["logits"], dim=-1).cpu()[:, 1:-1]

    print("PROBS:      ", end='')
    for i in range(N):
        word_id = encoded_input["input_ids"][0, i].item()
        s = f"{probs[i, i, word_id].exp().item() * 100:02.2f}"
        print((max_word_len - len(s)) * ' ' + s, end='')
    
    print()
    print("ARGMAX:     ", end='')
    for i in range(N):
        max_word = tokenizer.decode(probs[i, i, :].argmax())
        print((max_word_len - len(max_word)) * ' ' + max_word, end='')

### Original BERT

In [14]:
sentence = "We tackle the generative learning trilemma with our novel method."
#sentence = "Hi, how are you?"
#sentence = "Привет! Как Ваши дела?"

show_perplexity(sentence, original_model, original_tokenizer, device)

ORIGINAL:            We     tackle        the     genera     ##tive   learning          t      ##ril       ##em       ##ma       with        our      novel     method          .
PROBS:            87.89       0.30      68.95      95.65      30.78       0.01      89.75      27.31      62.39      81.98      48.04       5.75       0.09       0.31       2.02
ARGMAX:              We introduced        the     genera     ##tive         of          t      ##rig       ##em       ##ma       with          a        own   approach         of

### Our BERT

Generally, our model is much less "certain" about the correct predictions -- most likely the result of a shorter training

In [15]:
sentence = "We tackle the generative learning trilemma with our novel method."
#sentence = "Hi, how are you?"
#sentence = "Привет! Как Ваши дела?"

show_perplexity(sentence, our_model, our_tokenizer, device)

ORIGINAL:            We     tackle        the     genera       ##ti       ##ve   learning        tri      ##lem       ##ma       with        our      novel     method          .
PROBS:            26.39       0.09      18.15      87.03      99.39      99.81       0.56       0.37      13.11       0.06      11.32       3.73       0.17       0.08      15.31
ARGMAX:              We    combine        the     genera       ##ti       ##ve      group         ka      ##lli        ##s        and          a   learning          .          .