In [1]:
import torch
from pprint import pprint
from attentif.masked_lm import MaskedLM, MaskedLMConfig
from transformers import AutoTokenizer

In [3]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

config = MaskedLMConfig(
    hidden_size=768,
    num_attention_heads=12,
    num_hidden_layers=12,
    vocab_size=tokenizer.vocab_size,
    pad_token_id=tokenizer.pad_token_id,
)

model = MaskedLM(config)
model.load_state_dict(torch.load("../outputs/masked_lm/model.pt"))
model.eval()

MaskedLM(
  (transformer_encoder): TransformerEncoder(
    (blocks): ModuleList(
      (0-11): 12 x TransformerEncoderBlock(
        (ln1): LayerNorm()
        (ln2): LayerNorm()
        (mha): MultiHeadAttention()
        (ffn): FeedForwardNetwork()
      )
    )
    (ln): LayerNorm()
  )
  (token_embedding): TokenEmbedding()
  (lm_head): LMHead()
)

In [4]:
def get_mask_index(input_ids: list[int]) -> tuple[int]:
    indices, = torch.nonzero(input_ids == tokenizer.mask_token_id)
    batch_idx, token_idx = indices
    return (batch_idx, token_idx)

def fill_mask(input: str, top_k: int = 5) -> list[dict]:
    tokens = tokenizer(input, return_tensors="pt")
    batch_idx, token_idx = get_mask_index(tokens["input_ids"])
    
    with torch.no_grad():
        logits = model(tokens["input_ids"], tokens["attention_mask"])
        logits = logits[batch_idx, token_idx, :] 

    result = logits.topk(top_k)
    
    return [
        {
            "answer": tokenizer.decode(int(index)),
            "score": float(value),
        }
        for (value, index) in zip(result.values, result.indices)
    ]


pprint(fill_mask("All human [MASK] are born free and equal in dignity and rights.", 5))

[{'answer': 'rights', 'score': 10.926982879638672},
 {'answer': 'services', 'score': 6.993497371673584},
 {'answer': 'beings', 'score': 6.798624038696289},
 {'answer': 'issues', 'score': 6.764981269836426},
 {'answer': 'groups', 'score': 6.7575764656066895}]
