In [25]:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch.nn as nn

## Quick start

In [26]:
tokenizer = AutoTokenizer.from_pretrained(
    "google-bert/bert-base-uncased",
)
model = AutoModelForMaskedLM.from_pretrained(
    "google-bert/bert-base-uncased",
    torch_dtype=torch.float16,
    device_map="auto",
    attn_implementation="sdpa"
)
inputs = tokenizer("Plants create [MASK] through a process known as photosynthesis.", return_tensors="pt").to("cuda")

with torch.no_grad():
    outputs = model(**inputs)
    predictions = outputs.logits

# masked_index = torch.where(inputs['input_ids'] == tokenizer.mask_token_id)[1]
# predicted_token_id = predictions[0, masked_index].argmax(dim=-1)
# predicted_token = tokenizer.decode(predicted_token_id)

# print(f"The predicted token is: {predicted_token}")

Some weights of the model checkpoint at google-bert/bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [27]:
predictions.shape

torch.Size([1, 14, 30522])

In [28]:
tokenizer.decode(inputs["input_ids"][0])

'[CLS] plants create [MASK] through a process known as photosynthesis. [SEP]'

## Bert

In [29]:
encoder_layer = nn.TransformerEncoderLayer(
    d_model=128, nhead=4, dim_feedforward=128, dropout=0.1, batch_first=True
)

In [30]:
N = 16
S = 11
E = 128

src = torch.rand(N, S, E)
out = encoder_layer(src)

out.shape


torch.Size([16, 11, 128])

## Padding

In [31]:
out[:, -1, :].shape

torch.Size([16, 128])