In [8]:
import torch
from transformers import BertTokenizer, BertForMaskedLM

In [9]:
# Load pre-trained BERT model and tokenizer
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [23]:
# Tokenize input text
text = "Three years later, the coffin was [MASK] full of Jello."
tokenized_text = tokenizer.tokenize(text)

In [24]:
# Convert tokenized text to input tensor
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
tokens_tensor = torch.tensor([indexed_tokens])

In [25]:
# Find the masked token
masked_index = tokenized_text.index('[MASK]')
print('Masked index:', masked_index)

Masked index: 7


In [31]:
# Predict the masked token using the pre-trained model
model.eval()
with torch.no_grad():
    outputs = model(tokens_tensor)
    predictions = outputs[0][0, masked_index].topk(10)

In [32]:
# Print top 5 predicted tokens and their probabilities
predicted_tokens = [tokenizer.convert_ids_to_tokens([i.item()])[0] for i in predictions.indices]
probabilities = predictions.values.tolist()
for i in range(len(predicted_tokens)):
    print(f'Token: {predicted_tokens[i]}, Probability: {probabilities[i]}')

Token: ", Probability: 7.537574291229248
Token: also, Probability: 6.186984062194824
Token: still, Probability: 5.680833339691162
Token: made, Probability: 5.503735065460205
Token: actually, Probability: 5.488828659057617
Token: not, Probability: 5.45798921585083
Token: even, Probability: 5.002530097961426
Token: so, Probability: 4.9967217445373535
Token: completely, Probability: 4.941343784332275
Token: now, Probability: 4.924426078796387


In [33]:
for i in range(len(predicted_tokens)):
    print(predicted_tokens[i])

"
also
still
made
actually
not
even
so
completely
now
