In [38]:
from transformers import ElectraForPreTraining, ElectraTokenizerFast
import torch

discriminator = ElectraForPreTraining.from_pretrained("google/electra-large-discriminator")
tokenizer = ElectraTokenizerFast.from_pretrained("google/electra-large-discriminator")

config.json:   0%|          | 0.00/668 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

Some weights of the model checkpoint at google/electra-large-discriminator were not used when initializing ElectraForPreTraining: ['electra.embeddings_project.bias', 'electra.embeddings_project.weight']
- This IS expected if you are initializing ElectraForPreTraining from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForPreTraining from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [44]:
def print_fake_tokens(sentence):
    tokens = tokenizer.tokenize(sentence, add_special_tokens=True)
    embed = tokenizer.encode(sentence, return_tensors="pt")
    discriminator_outputs = discriminator(embed)
    # Shift logits to binary labels
    predictions = torch.round((torch.sign(discriminator_outputs.logits[0]) + 1) / 2)
    
    for token, prediction in zip(tokens, predictions.tolist()):
        print(f"{token:>10} : {int(prediction)}")

In [50]:
sentence = "The quick brown fox jumps over the lazy dog"
fake_sentence = "The quick brown fox jumps over fake lazy dog"

sentences = [
    "The scientist explained THEM the experiment worked",
    "Despite the weather, ENJOY enjoyed the picnic",
    "The student who MAKE yesterday passed the test",
    "She believes that is the BELIEF answer",
]

for sentence in sentences:
    print_fake_tokens(sentence)
    print("="*80)

     [CLS] : 0
       the : 0
 scientist : 0
 explained : 0
      them : 1
       the : 0
experiment : 0
    worked : 0
     [SEP] : 0
     [CLS] : 0
   despite : 0
       the : 0
   weather : 0
         , : 0
     enjoy : 1
   enjoyed : 0
       the : 0
    picnic : 0
     [SEP] : 0
     [CLS] : 0
       the : 0
   student : 0
       who : 0
      make : 1
 yesterday : 0
    passed : 0
       the : 0
      test : 0
     [SEP] : 0
     [CLS] : 0
       she : 0
  believes : 0
      that : 0
        is : 0
       the : 0
    belief : 1
    answer : 1
     [SEP] : 0


In [51]:
def get_token_scores(sentence):
    tokens = tokenizer.tokenize(sentence, add_special_tokens=True)
    embed = tokenizer.encode(sentence, return_tensors="pt")
    outputs = discriminator(embed)
    # Keep raw logits instead of binary predictions for more granular scoring
    scores = outputs.logits[0].tolist()
    return list(zip(tokens, scores))

def analyze_context_windows(text, window_size=10):
    # Get global scores
    global_scores = get_token_scores(text)
    
    # Get local scores by sliding window
    words = text.split()
    local_scores_by_token = {}
    
    for i in range(len(words) - window_size + 1):
        window = " ".join(words[i:i + window_size])
        local_scores = get_token_scores(window)
        
        # Store scores for each token appearance
        for token, score in local_scores[1:-1]:  # Skip special tokens
            if token not in local_scores_by_token:
                local_scores_by_token[token] = []
            local_scores_by_token[token].append(score)
    
    # Find interesting tokens (high global score, low local scores)
    interesting_tokens = []
    for token, global_score in global_scores[1:-1]:  # Skip special tokens
        if token in local_scores_by_token:
            avg_local_score = sum(local_scores_by_token[token]) / len(local_scores_by_token[token])
            score_diff = global_score - avg_local_score
            interesting_tokens.append((token, score_diff))
    
    # Sort by score difference
    return sorted(interesting_tokens, key=lambda x: x[1], reverse=True)

# Test with a sample passage
text = """The ancient castle stood atop the hill, its weathered stones telling stories of centuries past. 
Knights once roamed these halls, their armor gleaming in the torchlight. Now only whispers remain, 
echoing through the empty corridors."""

# Get potential cloze tokens
candidates = analyze_context_windows(text)
print("Top candidates for cloze gaps:")
for token, score_diff in candidates[:5]:
    print(f"{token:>15}: {score_diff:.3f}")

# Create cloze exercise
def create_cloze(text, num_gaps=3):
    candidates = analyze_context_windows(text)
    tokens_to_blank = [token for token, _ in candidates[:num_gaps]]
    
    # Create exercise by replacing selected tokens with gaps
    cloze_text = text
    for token in tokens_to_blank:
        cloze_text = cloze_text.replace(token, "_____")
    
    return cloze_text, tokens_to_blank

Top candidates for cloze gaps:
            the: 4.101
            the: 2.042
          stood: 2.003
         remain: 1.945
          these: 1.484


In [52]:
# Create and print a cloze exercise
text = "The cat chased the mouse through the garden while the dog slept peacefully under the tree."
exercise, answers = create_cloze(text, num_gaps=3)
print("\nExercise:")
print(exercise)
print("\nAnswers:", answers)


Exercise:
The _____ chased _____ mouse through _____ garden while _____ dog slept peacefully under _____ tree.

Answers: ['the', 'cat', 'the']


In [32]:
import torch
from typing import List
from dataclasses import dataclass
import re

@dataclass
class GapCandidate:
    word: str
    start_idx: int
    end_idx: int
    local_rtd_score: float
    global_rtd_score: float
    score_ratio: float

class ClozeGenerator:
    def __init__(self, model_name: str = "microsoft/deberta-v3-base"):
        """
        Initialize the Cloze generator with a DeBERTa-v3 model.
        DeBERTa-v3 was trained with RTD (Replaced Token Detection) objective.
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        # Load model for token classification to get RTD scores
        self.model = AutoModel.from_pretrained(
            model_name,
            num_labels=2  # Binary classification: original vs replaced
        )
        self.model.eval()

    def get_rtd_scores(self, text: str, word_positions: List[tuple], context: str = None) -> List[float]:
        """
        Get RTD (Replaced Token Detection) scores for words at specified positions.
        A higher score indicates the model believes the token is original (not replaced).
        """
        # Prepare text with context if provided
        full_text = f"{context} {text}" if context else text
        
        # Tokenize the full text
        inputs = self.tokenizer(
            full_text,
            return_tensors="pt",
            return_offsets_mapping=True  # Get character positions for each token
        )

        # Get offset mapping to align character positions with tokens
        offset_map = inputs.pop("offset_mapping")[0].tolist()
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            # Get logits for original token prediction
            logits = outputs.logits[0]  # Shape: [sequence_length, 2]
            probs = torch.softmax(logits, dim=-1)
            original_probs = probs[:, 1]  # Probability of being original
                
        # Calculate scores for each word position
        scores = []
        for word_start, word_end in word_positions:
            # Adjust positions if context was added
            if context:
                word_start += len(context) + 1
                word_end += len(context) + 1
            
            # Find tokens that overlap with the word
            token_scores = []
            for token_idx, (token_start, token_end) in enumerate(offset_map):
                if token_end <= word_start:
                    continue
                if token_start >= word_end:
                    break
                token_scores.append(original_probs[token_idx].item())
            
            # Average the scores for all tokens in the word
            word_score = sum(token_scores) / len(token_scores) if token_scores else 0
            scores.append(word_score)
        
        return scores

    def find_gap_candidates(self, text: str, min_word_length: int = 4) -> List[GapCandidate]:
        """
        Find suitable gap candidates in the text by comparing local and global RTD scores.
        """
        # Find all words and their positions in the full text
        words_and_positions = [
            (m.group(), m.start(), m.end())
            for m in re.finditer(r'\b\w+\b', text)
        ]
        
        # Filter out short words and numbers
        valid_words = [
            (word, start, end) 
            for word, start, end in words_and_positions
            if len(word) >= min_word_length and not word.isdigit()
        ]
        
        if not valid_words:
            return []
        
        # Split positions for batch processing
        words, starts, ends = zip(*valid_words)
        positions = list(zip(starts, ends))
        
        # Get RTD scores for each word in local and global context
        local_scores = []
        for start, end in positions:
            # Get local context (sentence containing the word)
            sentence_bounds = self._get_sentence_bounds(text, start)
            local_text = text[sentence_bounds[0]:sentence_bounds[1]]
            local_pos = [(start - sentence_bounds[0], end - sentence_bounds[0])]
            score = self.get_rtd_scores(local_text, local_pos)[0]
            local_scores.append(score)
        
        # Get global scores using full text context
        global_scores = self.get_rtd_scores(text, positions)
        
        # Create candidates
        candidates = []
        for i, (word, start, end) in enumerate(valid_words):
            local_score = local_scores[i]
            global_score = global_scores[i]
            
            # Calculate ratio (lower means word is more replaceable in local context)
            score_ratio = local_score / global_score if global_score > 0 else float('inf')
            
            candidates.append(GapCandidate(
                word=word,
                start_idx=start,
                end_idx=end,
                local_rtd_score=local_score,
                global_rtd_score=global_score,
                score_ratio=score_ratio
            ))
        
        return sorted(candidates, key=lambda x: x.score_ratio)

    def _get_sentence_bounds(self, text: str, position: int) -> tuple:
        """
        Find the start and end positions of the sentence containing the given position.
        """
        sentence_end = text.find('.', position)
        if sentence_end == -1:
            sentence_end = len(text)
        
        sentence_start = text.rfind('.', 0, position)
        if sentence_start == -1:
            sentence_start = 0
        else:
            sentence_start += 1  # Move past the period
            
        return (sentence_start, sentence_end)

    def create_cloze_exercise(self, text: str, num_gaps: int = 5) -> str:
        """
        Create an open cloze exercise by selecting the best gap candidates.
        """
        candidates = self.find_gap_candidates(text)
        selected_candidates = candidates[:num_gaps]
        
        # Sort by position to process from end to beginning
        selected_candidates.sort(key=lambda x: x.start_idx, reverse=True)
        
        cloze_text = text
        for candidate in selected_candidates:
            gap_marker = "_____"
            cloze_text = (cloze_text[:candidate.start_idx] + 
                         gap_marker + 
                         cloze_text[candidate.end_idx:])
        
        return cloze_text
