In [4]:
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from typing import Dict, List

class ClozeScorer:
    def __init__(self, model_name="answerdotai/ModernBERT-base"):
        """
        Initialize cloze scorer with ModernBERT
        """
        print(f"Loading model: {model_name}")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForMaskedLM.from_pretrained(model_name)
        self.model.eval()

        # Get mask token
        self.mask_token = self.tokenizer.mask_token
        print(f"Loaded successfully. Mask token: {self.mask_token}")

    def get_masked_predictions(self, text_with_blank: str, top_k: int = 20) -> List[Dict[str, float]]:
        """
        Get probability distribution for multiple masked positions in text
        """
        # Tokenize the input
        inputs = self.tokenizer(text_with_blank, return_tensors="pt")

        # Get model predictions
        with torch.no_grad():
            outputs = self.model(**inputs)
            predictions = outputs.logits

        # Find mask token positions
        mask_token_indices = torch.where(inputs["input_ids"][0] == self.tokenizer.mask_token_id)[0]

        # Looping in case of multiple blanks
        all_predictions = []
        for mask_position in mask_token_indices:
            # Get predictions for the masked position
            masked_predictions = predictions[0, mask_position.item()]

            # Apply softmax to get probabilities
            probs = torch.nn.functional.softmax(masked_predictions, dim=-1)

            # Get top k predictions
            top_k_probs, top_k_indices = torch.topk(probs, top_k)

            # Convert to dictionary
            predictions_dict = {}
            for prob, idx in zip(top_k_probs.tolist(), top_k_indices.tolist()):
                token = self.tokenizer.decode([idx]).strip()
                # print(f"idx: {idx}, token: {token}")
                predictions_dict[token] = prob

            all_predictions.append(predictions_dict)

        return all_predictions


def run_example(scorer, sentence, top_n=10):
    print(f"\nSentence: {sentence}")
    all_predictions = scorer.get_masked_predictions(sentence)

    for i, predictions in enumerate(all_predictions):
        print(f"\nPredictions for Mask {i+1}:")
        for word, prob in list(predictions.items())[:top_n]:
            print(f" '{word}': {prob:.4f} ({prob*100:.2f}%)")

In [5]:
print("Initializing...")
scorer = ClozeScorer()
sample = """This chapter explains the structure and functions of the Legislative Branch of the U.S. government,\
also known as Congress, which is responsible for making federal laws. Congress consists of two [MASK],\
the House of Representatives and the Senate. The House has 435 [MASK] members, with representation based\
on state population; for instance, populous California has 52 [MASK],  while Wyoming, the least populous\
[MASK], has one. Members serve two [MASK] terms, must be at least 25, and must reside in the state they [MASK]."""

run_example(scorer,sample)

Initializing...
Loading model: answerdotai/ModernBERT-base
Loaded successfully. Mask token: [MASK]

Sentence: This chapter explains the structure and functions of the Legislative Branch of the U.S. government,also known as Congress, which is responsible for making federal laws. Congress consists of two [MASK],the House of Representatives and the Senate. The House has 435 [MASK] members, with representation basedon state population; for instance, populous California has 52 [MASK],  while Wyoming, the least populous[MASK], has one. Members serve two [MASK] terms, must be at least 25, and must reside in the state they [MASK].

Predictions for Mask 1:
 'branches': 0.3623 (36.23%)
 'chambers': 0.2787 (27.87%)
 'houses': 0.0985 (9.85%)
 'bodies': 0.0932 (9.32%)
 'parts': 0.0351 (3.51%)
 'agencies': 0.0167 (1.67%)
 'divisions': 0.0166 (1.66%)
 'sections': 0.0129 (1.29%)
 'departments': 0.0094 (0.94%)
 'parties': 0.0085 (0.85%)

Predictions for Mask 2:
 'elected': 0.7594 (75.94%)
 'total': 0.0

In [7]:
def show_tokens(word):
    print(f"{word:<10}: {scorer.tokenizer.tokenize(word)}")

words = ["branches", "chambers", "bodies", "houses", "parts", "agencies", "representatives"]
for word in words:
    show_tokens(word)

branches  : ['br', 'anches']
chambers  : ['ch', 'ambers']
bodies    : ['b', 'odies']
houses    : ['houses']
parts     : ['parts']
agencies  : ['ag', 'encies']
representatives: ['represent', 'atives']


In [8]:
scorer.tokenizer.tokenize(sample)

['This',
 'Ġchapter',
 'Ġexplains',
 'Ġthe',
 'Ġstructure',
 'Ġand',
 'Ġfunctions',
 'Ġof',
 'Ġthe',
 'ĠLegislative',
 'ĠBranch',
 'Ġof',
 'Ġthe',
 'ĠU',
 '.',
 'S',
 '.',
 'Ġgovernment',
 ',',
 'also',
 'Ġknown',
 'Ġas',
 'ĠCongress',
 ',',
 'Ġwhich',
 'Ġis',
 'Ġresponsible',
 'Ġfor',
 'Ġmaking',
 'Ġfederal',
 'Ġlaws',
 '.',
 'ĠCongress',
 'Ġconsists',
 'Ġof',
 'Ġtwo',
 ' [MASK]',
 ',',
 'the',
 'ĠHouse',
 'Ġof',
 'ĠRepresentatives',
 'Ġand',
 'Ġthe',
 'ĠSenate',
 '.',
 'ĠThe',
 'ĠHouse',
 'Ġhas',
 'Ġ435',
 ' [MASK]',
 'Ġmembers',
 ',',
 'Ġwith',
 'Ġrepresentation',
 'Ġbased',
 'on',
 'Ġstate',
 'Ġpopulation',
 ';',
 'Ġfor',
 'Ġinstance',
 ',',
 'Ġpop',
 'ulous',
 'ĠCalifornia',
 'Ġhas',
 'Ġ52',
 ' [MASK]',
 ',',
 '  ',
 'while',
 'ĠWyoming',
 ',',
 'Ġthe',
 'Ġleast',
 'Ġpop',
 'ulous',
 '[MASK]',
 ',',
 'Ġhas',
 'Ġone',
 '.',
 'ĠMembers',
 'Ġserve',
 'Ġtwo',
 ' [MASK]',
 'Ġterms',
 ',',
 'Ġmust',
 'Ġbe',
 'Ġat',
 'Ġleast',
 'Ġ25',
 ',',
 'Ġand',
 'Ġmust',
 'Ġreside',
 'Ġin',
 'Ġthe',