https://colab.research.google.com/drive/11ScUROhQM6DznUe4qqJB3klfuXywjmTQ?usp=sharing

In [0]:
! pip install -q torch
! pip install -q pytorch-pretrained-bert

In [0]:
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

# Load pre-trained model tokenizer (vocabulary)
modelpath = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(modelpath)

In [0]:

text = "dummy. although he had already eaten a large meal, he was still very hungry."
target = "hungry"
tokenized_text = tokenizer.tokenize(text)
tokenized_text

['dummy',
 '.',
 'although',
 'he',
 'had',
 'already',
 'eaten',
 'a',
 'large',
 'meal',
 ',',
 'he',
 'was',
 'still',
 'very',
 'hungry',
 '.']

In [0]:

# Mask a token that we will try to predict back with `BertForMaskedLM`
masked_index = tokenized_text.index(target)
tokenized_text[masked_index] = '[MASK]'
tokenized_text

['dummy',
 '.',
 'although',
 'he',
 'had',
 'already',
 'eaten',
 'a',
 'large',
 'meal',
 ',',
 'he',
 'was',
 'still',
 'very',
 '[MASK]',
 '.']

In [0]:
# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
indexed_tokens

[24369,
 1012,
 2348,
 2002,
 2018,
 2525,
 8828,
 1037,
 2312,
 7954,
 1010,
 2002,
 2001,
 2145,
 2200,
 103,
 1012]

In [0]:
# Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
segments_ids = [1] * len(tokenized_text)
segments_ids

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

In [0]:
# this is for the dummy first sentence. 
segments_ids[0] = 0
segments_ids[1] = 0
segments_ids

[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

In [0]:

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
tokens_tensor, segments_tensors

(tensor([[24369,  1012,  2348,  2002,  2018,  2525,  8828,  1037,  2312,  7954,
           1010,  2002,  2001,  2145,  2200,   103,  1012]]),
 tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))

In [0]:
# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained(modelpath)
model.eval()


BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
   

In [0]:

# Predict all tokens
predictions = model(tokens_tensor, segments_tensors)
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])
predicted_index, predicted_token

(7501, ['hungry'])

In [0]:

print("Original:", text)
print("Masked:", " ".join(tokenized_text))

print("Predicted token:", predicted_token)
print("Other options:")


Original: dummy. although he had already eaten a large meal, he was still very hungry.
Masked: dummy . although he had already eaten a large meal , he was still very [MASK] .
Predicted token: ['hungry']
Other options:


In [0]:
# just curious about what the next few options look like.
for i in range(10):
    predictions[0,masked_index,predicted_index] = -11100000
    predicted_index = torch.argmax(predictions[0, masked_index]).item()
    predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])
    print(predicted_token)

['strong']
['tired']
['weak']
['angry']
['concerned']
['alert']
['nervous']
['pale']
['.']
['excited']
