In [1]:
from transformers import BertTokenizer, BertForMaskedLM
import torch

# Load the tokenizer and model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")

# Input sentence with a masked token
sentence = "I love to [MASK] machine learning."

# Tokenize the input sentence
input_ids = tokenizer.encode(sentence, return_tensors="pt")

# Find the position of the [MASK] token
mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1]

# Forward pass through the model
with torch.no_grad():
    output = model(input_ids)

# Get the logits (predictions) for the [MASK] token position
mask_token_logits = output.logits[0, mask_token_index, :]

# Find the top 5 predictions for the masked token
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

# Print the predicted words
print("Predicted words for the masked token:")
for token in top_5_tokens:
    print(f"{tokenizer.decode([token])}")


BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another archite

Predicted words for the masked token:
do
study
practice
play
learn
