# Rational Deletion

The following approach is adapted from [Ondov et al. (2024)](https://aclanthology.org/2024.naacl-long.220/).

The basic idea is to measure word probabilities with a masked language model in two ways: the first probability estimate uses the entire context of the Cloze passage and the second probability estimate uses only the local (sentence) context. We want words that are predictable given the full context, but cannot be easily guessed using only the local context. The distance between these two probability estimates indicates whether the word is more predictable in the full context than the local context.

We include an alternative implementation that conditions global probability on the the entire page, rather than just the Cloze passage. iTELL Cloze exercises are source-dependent and intended to assess comprehension of the page. This means that Cloze gaps needn't be guessable from the Cloze passage itself, if they are guessable based on the larger page context.

The approach handles simultaneous masking of subword tokens and allows for the following configuration:
  - Target number of blanks to generate
  - Minimum distance between blanks
  - Blacklisting of part of speech tags to prevent masking of high entropy words like proper nouns and numbers

We also collect predictions for other possible words. This includes a greedy search algorithm to identify whole-word predictions in the event of sub-word tokenization.

In [13]:
import json
from collections import defaultdict
from pprint import pp

from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
import spacy
from spacy.tokens import Doc, Span, Token
import numpy as np
import pandas as pd

torch.set_float32_matmul_precision("high")

In [6]:
class RationalClozeGenerator:
    def __init__(self, model_name: str = "answerdotai/ModernBERT-large"):
        # Load SpaCy for sentence splitting and preprocessing
        self.nlp = spacy.load("en_core_web_sm")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Load model and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForMaskedLM.from_pretrained(model_name).to(self.device)
        self.model.eval()

        self.min_blank_distance = 7  # Minimum distance between blanks

        # Minimum predictability of alternatives
        # Use log probs to avoid underflow
        self.min_predictability = np.log(0.05)

        # Part-of-Speech Blacklist (do not delete these words)
        self.blacklist = [
            "PROPN",  # Proper nouns
            "NUM",  # Numbers
            "PUNCT",  # Punctuation
            "SYM",  # Symbols
            "X",  # Other
        ]

    def _get_leading_ws_tokens(self, doc: Doc) -> list[str]:
        """The ModernBERT Tokenizer will work fine if we give it tokens with leading spaces.
        SpaCy normally handles whitespace in terms of trailing space."""
        if not len(doc):
            return [""]

        tokens = [doc[0].text]
        # For tokens after the 0th, prepend trailing whitespace from the previous token.
        tokens += [doc[i - 1].whitespace_ + doc[i].text for i in range(1, len(doc))]
        return tokens

    def get_token_mappings(self, tokens: list[str]) -> dict[int, list[int]]:
        """Get mappings between word positions and token positions"""
        # Tokenize while keeping track of word IDs
        tokenized = self.tokenizer(
            tokens, return_tensors="pt", is_split_into_words=True
        )
        word_ids = tokenized.word_ids()

        # Create mapping from word position to token positions
        word_to_tokens = defaultdict(list)

        for token_idx, word_idx in enumerate(word_ids):
            if word_idx is not None:
                word_to_tokens[word_idx].append(token_idx)

        return word_to_tokens

    def get_masked_logits(
        self, tokens: list[str], mask_idx: int
    ) -> tuple[torch.Tensor, int]:
        """Get model logits for a masked position in text"""
        # Get the word tokens and their alignment info
        word_to_tokens = self.get_token_mappings(tokens)

        # Find all token positions for the word we want to mask
        token_positions = word_to_tokens[mask_idx]

        # Create masked version of the text
        input_ids = self.tokenizer(
            tokens, is_split_into_words=True, return_tensors="pt"
        ).input_ids[0]
        masked_ids = input_ids.clone()

        # ID of the first subword token that we masked
        first_token_id = input_ids[token_positions[0]]

        # Mask all tokens corresponding to our target word
        masked_ids[token_positions] = self.tokenizer.mask_token_id

        # Get model outputs
        outputs = self.model(input_ids.unsqueeze(0).to(self.device))

        # Get logits
        logits = outputs.logits[0, token_positions, :]

        return logits, first_token_id

    def get_contextuality_score(
        self, page_doc: Doc, summary_doc: Doc, sent: Span, tok: Token, method: str = "kl"
    ) -> float:
        """Calculate contextuality score for a word position using full page context

        Args:
            page_doc: The full page text as a spaCy Doc
            summary_doc: The summary text as a spaCy Doc
            sent: The sentence from the summary containing the token
            tok: The token from the summary to evaluate
            method: "kl" for kl-divergence or "contextuality" for contextuality score

        Returns:
            Contextuality score
        """

        # Get logits for both full text and sentence text
        # For the full text context, we use the page + summary
        full_toks = self._get_leading_ws_tokens(page_doc) + self._get_leading_ws_tokens(
            summary_doc
        )
        full_pos = len(page_doc) + tok.i  # Position of token in full document
        full_logits, word_id = self.get_masked_logits(full_toks, full_pos)

        # For the local context, we use just the sentence from the summary
        sent_pos = tok.i - sent.start  # Position of token in the sentence
        sent_logits, _ = self.get_masked_logits([tok.text for tok in sent], sent_pos)

        # Calculate probabilities using first sub-word token
        full_probs = torch.softmax(full_logits[0], dim=0)
        sent_probs = torch.softmax(sent_logits[0], dim=0)

        p = full_probs[word_id]
        q = sent_probs[word_id]

        if method == "kl":
            # KL-divergence is p*log(p/q)
            score = float(p*torch.log2(p/q))
        elif method == "contextuality":
            # Contextuality is distance between full-text and sentence probability
            score = float(p - q)
        else:
            raise ValueError("Unknown method.")

        return score

    def choose_blank_positions(
        self, page_doc: Doc, summary_doc: Doc, num_blanks: int
    ) -> list[int]:
        """Choose positions to blank in the summary based on contextuality scores with full page"""
        scores = []
        valid_positions = []

        # Calculate scores for each position in the summary
        for i, sent in enumerate(summary_doc.sents):
            if i == 0:
                continue  # Skip first sentence
            for tok in sent:
                if (
                    len(tok.text) < 3
                    or tok.pos_ in self.blacklist
                    or tok.is_stop
                    or not tok.text.isalpha()
                ):
                    scores.append(-float("inf"))
                else:
                    # Calculate contextuality using both the full page and summary
                    score = self.get_contextuality_score(
                        page_doc, summary_doc, sent, tok
                    )
                    scores.append(score)
                valid_positions.append(tok.i)

        # Convert to numpy for easier manipulation
        scores = np.array(scores)

        # Choose positions greedily while maintaining minimum distance
        positions = []
        for _ in range(num_blanks):
            if np.all(scores == -float("inf")):
                break

            # Choose highest scoring position
            idx = np.argmax(scores)
            pos = valid_positions[idx]
            positions.append(pos)

            # Zero out scores within minimum distance
            start = max(0, idx - self.min_blank_distance)
            end = min(len(scores), idx + self.min_blank_distance + 1)
            scores[start:end] = -float("inf")

        return sorted(positions)

    def get_alternates(self, tokens: list[str], topk=5) -> list[dict]:
        """Get top k predictions for the masked positions in tokens

        Returns:
            List of dictionaries, one per masked position, with candidate words and their probabilities
        """
        predictions = []

        # Find all mask positions
        mask_positions = [i for i, token in enumerate(tokens) if token == "[MASK]"]

        for mask_pos in mask_positions:
            word_candidates = {}

            # Try different mask lengths (1, 2, or 3 tokens)
            for mask_length in range(1, 4):
                # Replace the single mask with multiple if needed
                masked_tokens = (
                    tokens[:mask_pos]
                    + ["[MASK]"] * mask_length
                    + tokens[mask_pos + 1 :]
                )

                # Get initial predictions for first token
                current_candidates = []
                logits, _ = self.get_masked_logits(masked_tokens, mask_pos)
                probs = torch.softmax(logits[0], dim=0)
                top_values, top_indices = torch.topk(probs, topk)

                # Start with first token candidates
                for idx, prob in zip(top_indices.tolist(), top_values.tolist()):
                    current_candidates.append(([idx], prob))

                # Build up multi-token predictions if needed
                for token_idx in range(1, mask_length):
                    new_candidates = []
                    for token_ids, prob in current_candidates:
                        # Fill in what we've predicted so far
                        partial_filled = tokens.copy()
                        filled_text = self.tokenizer.decode(token_ids)
                        remaining_masks = mask_length - token_idx

                        partial_filled = (
                            tokens[:mask_pos]
                            + [filled_text]
                            + ["[MASK]"] * remaining_masks
                            + tokens[mask_pos + 1 :]
                        )

                        # Get prediction for next position
                        next_logits, _ = self.get_masked_logits(
                            partial_filled, mask_pos + 1
                        )
                        next_probs = torch.softmax(next_logits[0], dim=0)
                        next_values, next_indices = torch.topk(next_probs, 1)

                        # Add to candidates
                        new_token_ids = token_ids + [next_indices[0].item()]
                        new_prob = prob * next_values[0].item()
                        new_candidates.append((new_token_ids, new_prob))

                    current_candidates = new_candidates

                # Add final decoded words
                for token_ids, prob in current_candidates:
                    word = self.tokenizer.decode(token_ids).strip()
                    if " " in word:
                        # Word contains a space (is actually multiple words)
                        continue
                    if word not in word_candidates or prob > word_candidates[word]:
                        word_candidates[word] = prob

            # Sort candidates by probability
            sorted_candidates = sorted(
                word_candidates.items(), key=lambda x: x[1], reverse=True
            )
            predictions.append({word: prob for word, prob in sorted_candidates[:topk]})

        return predictions

    def generate_cloze(
        self, page_text: str, summary_text: str, num_blanks: int
    ) -> tuple[str, list[str], list[dict[str, float]]]:
        """Generate a cloze text from summary using page for context

        Args:
            page_text: The full page text
            summary_text: The summary text to create gaps in
            num_blanks: Number of blanks to create

        Returns:
            Tuple of (cloze_text, answers, alternates)
        """
        # Process both texts
        page_doc = self.nlp(page_text)
        summary_doc = self.nlp(summary_text)

        # Choose positions to blank in the summary
        masked_positions = self.choose_blank_positions(
            page_doc, summary_doc, num_blanks
        )

        # Get the answers (the original words that will be blanked)
        answers = [summary_doc[pos].text for pos in masked_positions]

        # Replace tokens with mask
        summary_tokens = np.array(self._get_leading_ws_tokens(summary_doc))
        summary_tokens[masked_positions] = "[MASK]"
        summary_tokens = summary_tokens.tolist()

        # Construct cloze token input for gap predictions
        cloze_tokens = self._get_leading_ws_tokens(page_doc) + summary_tokens

        # Get gap predictions based on the full page context
        alternates = self.get_alternates(cloze_tokens)

        # Replace words with blanks in the summary
        cloze_text = ""
        for tok in summary_doc:
            if tok.i in masked_positions:
                cloze_text += "_" * len(tok.text) + tok.whitespace_
            else:
                cloze_text += tok.text_with_ws

        return cloze_text, answers, alternates

In [7]:
generator = RationalClozeGenerator()

In [8]:
page_text = """A cloze test (also cloze deletion test or occlusion test) is an exercise, test, or assessment in which a portion of text is masked and the participant is asked to fill in the masked portion of text. Cloze tests require the ability to understand the context and vocabulary in order to identify the correct language or part of speech that belongs in the deleted passages. This exercise is commonly administered for the assessment of native and second language learning and instruction."""
# page_text = """Alphabet Soup is the best kind of soup."""

summary_text = """The Cloze procedure, first introduced by Taylor, is a widely used method for creating reading 
comprehension tests inspired by the Gestalt principle of closure. Though many variations have been 
introduced and studied, the core concept is to mask words in prose and task the subject with providing 
the missing words."""

cloze_text, answers, alternates = generator.generate_cloze(
    page_text, summary_text, num_blanks=6
)
print("Cloze text:")
print(cloze_text)
print("\nAnswers:")
answer_dict = {
    answer: {word: round(prob, 2) for word, prob in pred.items()}
    for answer, pred in zip(answers, alternates)
}
pp(answer_dict)

Cloze text:
The Cloze procedure, first introduced by Taylor, is a widely used method for creating reading 
comprehension tests inspired by the Gestalt principle of closure. Though many variations have been 
introduced and _______, the core concept is to mask words in _____ and task the subject with providing 
the missing _____.

Answers:
{'studied': {'used': 0.21,
             'tested': 0.13,
             'developed': 0.11,
             'refined': 0.09,
             'adapted': 0.06},
 'prose': {'text': 0.59,
           'context': 0.09,
           'sentences': 0.07,
           'texts': 0.03,
           'isolation': 0.02},
 'words': {'words': 0.5,
           'information': 0.17,
           'word': 0.17,
           'text': 0.02,
           'parts': 0.01}}


In [None]:
page_summaries = {}
df = pd.read_csv("../data/itell-pages.csv")
for page in df.itertuples():
    if page.summary:
        # print(page["PageSummary"])
        # page_summaries[page["Slug"]] = page["PageSummary"]
        print(page.page)
        print("=" * 80)
        cloze_text, answers, preds = generator.generate_cloze(
            page.summary, page.text, num_blanks=6
        )
        print("Cloze text:")
        print(cloze_text)
        print("\nAnswers:")
        answer_dict = {
            answer: {word: round(prob, 2) for word, prob in pred.items()}
            for answer, pred in zip(answers, preds)
        }
        pp(answer_dict)
        print("=" * 80)