# Rational Deletion

The following approach is adapted from [Ondov et al. (2024)](https://aclanthology.org/2024.naacl-long.220/).

The basic idea is to measure word probabilities with a masked language model in two ways: the first probability estimate uses the entire context and the second probability estimate uses only the local (sentence) context. We want words that are predictable given the full context, but cannot be easily guessed using only the local context. The distance between these two probability estimates indicates whether the word is more predictable in the full context than the local context.

The approach handles simultaneous masking of subword tokens and allows for the following configuration:
  - Target number of blanks to generate
  - Minimum distance between blanks
  - Blacklisting of part of speech tags to prevent masking of high entropy words like proper nouns and numbers

We also collect predictions for other possible words. This includes a greedy search algorithm to identify whole-word predictions in the event of sub-word tokenization.

In [74]:
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
import spacy
from spacy.tokens import Doc, Span, Token
import numpy as np

import json
from pprint import pp
from typing import List, Tuple
from collections import defaultdict

class RationalClozeGenerator:
    def __init__(self, model_name: str = "answerdotai/ModernBERT-large"):
        # Load SpaCy for sentence splitting and preprocessing
        self.nlp = spacy.load("en_core_web_sm")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Load model and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForMaskedLM.from_pretrained(model_name).to(self.device)
        self.model.eval()
        
        self.min_blank_distance = 7 # Minimum distance between blanks
        # Minimum predictability of alternatives
        # Use log probs to avoid underflow
        self.min_predictability = np.log(0.05)

        # Part-of-Speech Blacklist (do not delete these words)
        self.blacklist = [
            "PROPN", # Proper nouns
            "NUM", # Numbers
            "PUNCT", # Punctuation
            "SYM", # Symbols
            "X", # Other
        ]

    def get_token_mappings(self, text: str) -> Tuple[List[int], List[int]]:
        """Get mappings between word positions and token positions"""
        # Tokenize while keeping track of word IDs
        tokenized = self.tokenizer(text, return_tensors="pt", is_split_into_words=True)
        word_ids = tokenized.word_ids()
        
        # Create mapping from word position to token positions
        word_to_tokens = {}
        
        for token_idx, word_idx in enumerate(word_ids):
            if word_idx is not None:
                if word_idx not in word_to_tokens:
                    word_to_tokens[word_idx] = []
                word_to_tokens[word_idx].append(token_idx)
            
        return word_to_tokens

    def get_masked_logits(self, span: Doc|Span, mask_idx: int) -> torch.Tensor:
        """Get model logits for a masked position in text"""
        # Get the word tokens and their alignment info
        tokens = [tok.text for tok in span]
        word_to_tokens = self.get_token_mappings(tokens)
        
        # Find all token positions for the word we want to mask
        token_positions = word_to_tokens[mask_idx]
        
        # Create masked version of the text
        input_ids = self.tokenizer(tokens, is_split_into_words=True, return_tensors="pt").input_ids[0]
        masked_ids = input_ids.clone()

        # ID of the first subword token that we masked
        first_token_id = input_ids[token_positions[0]]

        # Mask all tokens corresponding to our target word
        masked_ids[token_positions] = self.tokenizer.mask_token_id

        # Get model outputs
        with torch.no_grad():
            outputs = self.model(input_ids.unsqueeze(0).to(self.device))
            
        # Get logits 
        logits = outputs.logits[0, token_positions, :]

        return logits, first_token_id

    def get_contextuality_score(self, doc: Doc, sent: Span, tok: Token) -> float:
        """Calculate contextuality score for a word position"""
        # Get logits for both full text and sentence contexts
        full_logits, word_id = self.get_masked_logits(doc, tok.i)
        sent_logits, _ = self.get_masked_logits(sent, tok.i - sent.start)
        
        # Calculate probabilities using first sub-word token
        full_probs = torch.softmax(full_logits[0], dim=0)
        sent_probs = torch.softmax(sent_logits[0], dim=0)

        # Contextuality is distance between full-text and sentence probability
        if float(full_probs[word_id].log()) > self.min_predictability:
            score = float(full_probs[word_id].log() - sent_probs[word_id].log())
        else:
            score = float("-inf")

        return score

    def choose_blank_positions(self, doc: Doc, num_blanks: int) -> list[int]:
        """Choose positions to blank based on contextuality scores"""
        scores = []
        valid_positions = []
        
        # Calculate scores and get predictions for each position
        for sent in doc.sents:
            for tok in sent:
                if (len(tok.text) < 3 or 
                    tok.pos_ in self.blacklist or 
                    tok.is_stop or 
                    not tok.text.isalpha()):
                    scores.append(-float('inf'))
                else:
                    score = self.get_contextuality_score(doc, sent, tok)
                    scores.append(score)
                valid_positions.append(tok.i)
                
        # Convert to numpy for easier manipulation
        scores = np.array(scores)
        
        # Choose positions greedily while maintaining minimum distance
        positions = []
        for _ in range(num_blanks):
            if np.all(scores == -float('inf')):
                break
                
            # Choose highest scoring position
            idx = np.argmax(scores)
            pos = valid_positions[idx]
            positions.append(pos)
            
            # Zero out scores within minimum distance
            start = max(0, idx - self.min_blank_distance)
            end = min(len(scores), idx + self.min_blank_distance + 1)
            scores[start:end] = -float('inf')

        return sorted(positions)

    def get_alternates(self, doc: Doc, positions: list[int], k: int = 5) -> List[Tuple[str, float]]:
        """Get top k predictions greater than self.min_predictability. Uses greedy search
        to find the top predictions for whole words when there are multiple subtokens.
        """
        
        tokens = [tok.text for tok in doc]
        words_to_tokens = self.get_token_mappings(tokens)
        input_ids = self.tokenizer(tokens, is_split_into_words=True, return_tensors="pt").input_ids[0]
        # print("Input:", self.tokenizer.decode(input_ids))

        predictions = []
        for pos in positions:
            # print("\nGapped Word:", doc[pos])
            logits, word_id = self.get_masked_logits(doc, pos)
            first_tok_probs = torch.softmax(logits[0], dim=0)
            top_probs, top_indices = torch.topk(first_tok_probs, k)
            # print("Num sub toks:", logits.shape[0])

            # Get topk predictions for multi-token words using greedy sampling
            word_predictions = defaultdict(float)
            for i in range(k):
                prob = top_probs[i]
                sub_toks = [top_indices[i].item()]
                if prob.log() < self.min_predictability:
                    break

                remaining_sub_toks = logits.shape[0] - 1
                while remaining_sub_toks > 0:
                    # Find which subtokens to mask
                    # And which subtokens to fill with top prediction
                    tok_idxs = words_to_tokens[pos]
                    to_fill = tok_idxs[:-remaining_sub_toks]
                    to_mask = tok_idxs[-remaining_sub_toks:]
                    next_sub_tok_idx = tok_idxs[-remaining_sub_toks]

                    # Mask and fill subtokens as needed
                    temp_ids = input_ids.detach().clone()
                    temp_ids[to_fill] = torch.tensor(sub_toks, dtype=torch.long)
                    temp_ids[to_mask] = self.tokenizer.mask_token_id

                    # print("Masked Token:", self.tokenizer.decode(input_ids[next_sub_tok_idx]))
                    # print("Masked Token ID:", input_ids[next_sub_tok_idx])
                    # print("Mask:", self.tokenizer.decode(temp_ids[next_sub_tok_idx]))
                    
                    # Collect predictions conditioned on previous subtokens
                    with torch.no_grad():
                        outputs = self.model(temp_ids.unsqueeze(0).to(self.device))
                    next_sub_tok_logits = outputs.logits[0, next_sub_tok_idx, :]
                    next_sub_tok_id = next_sub_tok_logits.softmax(dim=0).argmax(axis=-1).item()
                    kprobs, kidxs = torch.topk(next_sub_tok_logits.softmax(dim=0), 5)

                    # print("TopK:", kidxs)
                    # print("Pred_id:", next_sub_tok_id)
                    # print("Prediction:", self.tokenizer.decode(next_sub_tok_id))

                    sub_toks.append(next_sub_tok_id)

                    remaining_sub_toks -= 1

                # Assemble subtokens into word string
                # Normalize the word string
                # Different forms should be treated as the same answer:
                # e.g., "statistical", " Statistical", and "Statistical"
                token = self.tokenizer.decode(sub_toks).strip().lower()

                # Accumulate probabilities
                word_predictions[token] += float(prob)
            
            predictions.append(word_predictions)

        return predictions

    def generate_cloze(self, text: str, num_blanks: int) -> tuple[str, list[str]]:
        """Generate a cloze text with blanks and return answers"""
        doc = self.nlp(text)
        positions = self.choose_blank_positions(doc, num_blanks)
        answers = [doc[pos].text for pos in positions]

        alternates = self.get_alternates(doc, positions)

        # Replace words with blanks
        cloze_text = ""
        for tok in doc:
            if tok.i in positions:
                cloze_text += "_" * len(tok.text) + tok.whitespace_
            else:
                cloze_text += tok.text_with_ws
            
        return cloze_text, answers, alternates

In [75]:
generator = RationalClozeGenerator()

In [76]:
text = """The Cloze procedure, first introduced by Taylor, is a widely used method for creating reading 
comprehension tests inspired by the Gestalt principle of closure. Though many variations have been 
introduced and studied, the core concept is to mask words in prose and task the subject with providing 
the missing words."""

text = """Embarking on an international assignment, whether for work or study, entails navigating a complex 
landscape of emotional and cultural challenges. Initially marked by intrigue and excitement, expatriates often 
face culture shock and a period of adjustment before embracing their host culture. This journey necessitates 
meticulous preparation akin to other significant life changes, emphasizing the importance of adaptability, 
language proficiency, and cultural understanding. Successful expatriates are those who, rather than succumbing 
to frustration, leverage these experiences to enhance their personal and professional growth. The process of 
acculturation involves various emotional stages, including initial elation, culture shock, and eventual acceptance, 
followed by the challenges of reentry into one's native culture. Despite the potential for early termination of 
assignments due to family or personal issues, careful consideration and preparation can mitigate these risks, 
making international experience a valuable asset both personally and professionally."""

cloze_text, answers, preds = generator.generate_cloze(text, num_blanks=6)
print("Cloze text:")
print(cloze_text)
print("\nAnswers:")
answer_dict = {answer: {word: round(prob,2) for word, prob in pred.items()} for answer, pred in zip(answers, preds)}
pp(answer_dict)

Cloze text:
Embarking on an international assignment, whether for work or study, entails navigating a complex 
landscape of emotional and cultural challenges. _________ marked by intrigue and excitement, expatriates often 
face culture shock and a period of adjustment before embracing their host culture. This journey necessitates 
meticulous preparation akin to other significant life changes, emphasizing the importance of ____________, 
language proficiency, and cultural understanding. Successful ___________ are those who, rather than succumbing 
to frustration, leverage these experiences to enhance their personal and professional growth. The process of 
acculturation involves various emotional stages, including initial elation, culture shock, and eventual acceptance, 
followed by the challenges of reentry into one's native culture. _______ the potential for early termination of 
assignments due to family or ________ issues, careful consideration and preparation can mitigate these risk

In [79]:
page_summaries = {}
with open("../data/strapi-page-summaries.json") as f:
    for page in json.load(f):
        if page["PageSummary"]:
            # print(page["PageSummary"])
            # page_summaries[page["Slug"]] = page["PageSummary"]
            print(page["Slug"])
            print("="*80)
            cloze_text, answers, preds = generator.generate_cloze(page["PageSummary"], num_blanks=6)
            print("Cloze text:")
            print(cloze_text)
            print("\nAnswers:")
            answer_dict = {answer: {word: round(prob,2) for word, prob in pred.items()} for answer, pred in zip(answers, preds)}
            pp(answer_dict)
            print("="*80)

5-experimental-and-clinical-psychologists
Cloze text:
Experimental psychologists, _________ holding doctoral and master's degrees, conduct scientific research in various psychology subfields, often collaborating with students at universities. While some are trained clinicians, most focus on non-clinical areas such as cognitive or ______ psychology. Their research is crucial for understanding human behavior and developing _________ knowledge, which is vital for clinical practice. The interplay between research and practice is significant, as psychological disorders are ___________ testable. The effectiveness of treatments, like psychotherapy, relies on __________ validation. The clinical psychology community debates the emphasis on empirically _________ treatments, but there is consensus on the need for a scientific approach to ensure effective diagnosis and treatment.

Answers:
{'primarily': {'primarily': 1.0},
 'social': {'social': 0.95},
 'empirical': {'empirical': 1.0},
 'empiricall