In [1]:
import re
import os
os.environ["WANDB_DISABLED"] = "true"
os.environ["TRANSFORMERS_VERBOSITY"] = "error"

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification,
)

from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

import pandas as pd
import numpy as np

## Load Data

In [2]:
in_df = pd.read_parquet("../data/CLOTH/cloth.parquet")
in_df = in_df[in_df.source != "high515"] # This item has no options.
in_df = in_df.reset_index(drop=True)
in_df

Unnamed: 0,article,options,answers,source,split,level
0,It is well known that Albert Einstein was one ...,"[[paper, food, water, air], [hold, take, carry...","[A, A, C, C, B, B, A, D]",high839,train,high
1,Douglas was my cousin. I first met him when he...,"[[day, week, month, year], [parent, cousin, un...","[D, B, A, D, C, B, B, A, C, D]",high2970,train,high
2,"Two weeks before Christmas, Mother told me we ...","[[wanted, lacked, refused, prepared], [also, s...","[B, A, C, B, A, A, B, D, C, C, B, D, B, A, B, D]",high849,train,high
3,"In 1930, a young African American, Vivien T. T...","[[always, often, occasionally, never], [chance...","[D, B, A, C, B, A, D, C, B, A, D, C, B, A, D, ...",high251,train,high
4,I was born in New York City . My first seven y...,"[[feeling, desire, taste, worry], [further, hi...","[B, C, D, A, C, A, B, D, C, A, B, D, C, B, C, ...",high2038,train,high
...,...,...,...,...,...,...
7125,"When I had something _ to do, I used to ask ...","[[difficult, glad, good, happy], [spoke, talke...","[A, C, B, B, A, C, C, A, B, A, C]",middle2789,test,middle
7126,"My name is Carla and I have got two sisters,...","[[me, I, my, I'm], [and, so, or, but], [also, ...","[B, D, D, B, C, C, D]",middle2788,test,middle
7127,A first-grade student named Vincent Butterfiel...,"[[speaking, saying, telling, talking], [hair, ...","[C, A, C, B, D, D, C]",middle2762,test,middle
7128,One night a man came to my house. He said to m...,"[[children, workers, farmers, cooks], [flowers...","[A, C, D, B, C, D, A, B]",middle2769,test,middle


## Get Correct Answers

In [3]:
exploded_df = in_df.explode(["options", "answers"])
exploded_df["answer_idx"] = exploded_df["answers"].astype("category").cat.codes
exploded_df["correct_answer"] = exploded_df.apply(lambda x: x["options"][x["answer_idx"]], axis=1)
exploded_df

Unnamed: 0,article,options,answers,source,split,level,answer_idx,correct_answer
0,It is well known that Albert Einstein was one ...,"[paper, food, water, air]",A,high839,train,high,0,paper
0,It is well known that Albert Einstein was one ...,"[hold, take, carry, bring]",A,high839,train,high,0,hold
0,It is well known that Albert Einstein was one ...,"[kindest, coolest, cleverest, coldest]",C,high839,train,high,2,cleverest
0,It is well known that Albert Einstein was one ...,"[think, talk, worry, set]",C,high839,train,high,2,worry
0,It is well known that Albert Einstein was one ...,"[suits, shoes, trousers, clothes]",B,high839,train,high,1,shoes
...,...,...,...,...,...,...,...,...
7129,I did very poorly in school. My headmaster tho...,"[terrible, excellent, wrong, right]",C,middle2708,test,middle,2,wrong
7129,I did very poorly in school. My headmaster tho...,"[stayed, laughed, lived, studied]",D,middle2708,test,middle,3,studied
7129,I did very poorly in school. My headmaster tho...,"[Themselves, myself, himself, herself]",B,middle2708,test,middle,1,myself
7129,I did very poorly in school. My headmaster tho...,"[proud, rich, poor, happy]",C,middle2708,test,middle,2,poor


In [4]:
df = (
    in_df.copy()
    .drop(columns=["options", "answers"])
    .rename(columns={
        "article": "text_with_gaps"
    })
)
df["original_words"] = (
    exploded_df
    .groupby(exploded_df.index)
    .agg({"correct_answer": lambda x: x.tolist()})
    .iloc[:, 0]
)

## Remove extra spaces around gaps

Extra spaces will make it hard to handle token labels later. We should also be careful that we are not introducing an extra space before a piece of punctuation when we put the words back into place.

In [5]:
df.sample().text_with_gaps.item()

'Pop singer Peng Tan has tasted the joys of being at the top of the world. He has also  _  life\'s lows too. This  has taught  him that having a  _  picture of oneself is the key to  _  . "I grew  _  at the peak of my career, and I began to lose faith when things turned  _  me," he said." "Then I realized that dreams will  come true  only if I put myself in the  _  place." Peng, 29, will  _  at the Beijing Pop Festival at Chaoyang Park in Beijing held on September 8 to 9. He has  _  his first album Teen Spirit after he went solo from the rock band Dada. As the name  _  the album is about his reflection on his youth. "The  _  years is a special restless period in life, with lots of confusion, sensations, with wise and ridiculous ideas colliding," said Peng. When younger, he first  _  of being a painter, until one day the  _  singing of Cui Jian lit up his passion, for rock music. In 1996, he became the  _  singer in the band Dada, which he set up with his  _  from junior school. Soon, t

In [6]:
def normalize_spacing(text: str) -> str:
    """
    Note that this is not perfect.
    We would need much more complex processing to handle words next to quotation marks.
    """
    # First standardize all gap sequences to single underscore
    text = re.sub(r'_+', '_', text)
    
    # Remove spaces before gaps
    text = re.sub(r'\s+_', ' _', text)
    
    # Remove spaces after gaps when followed by punctuation
    text = re.sub(r'_\s+([.,!?;:])', r'_', text)
    
    # Remove spaces after gaps in other cases
    text = re.sub(r'_\s+', '_ ', text)
    
    return text

In [7]:
df["text_with_gaps"] = df["text_with_gaps"].apply(normalize_spacing)

In [8]:
df.text_with_gaps.sample().item()

'Welcome to our school. Our school is very big _ beautiful. There are two playgrounds and two football fields in _ school. There are also some _ in it and we can do many things in them. They are the Reading Club, the Drawing Club, the Swimming Club and so on. Many of _ are in the clubs. I am in the _ Club. I go to the club on Monday and Wednesday. On Monday afternoon, I go to the club and draw _ like apples and pears. On _ morning, I go to the club too. The art room _ modern. I like drawing there. Our school has an Open Day. The _ is 15 October. On that day, tea chers, students and parents are very _ We all like that day. What about your school? Can you tell me something about it?'

## Find Gap Positions

In [9]:
def find_gap_positions(text: str) -> list[int]:
    positions = []
    for match in re.finditer(r'_', text):
        positions.append(match.start())
    return positions

df["gap_positions"] = df["text_with_gaps"].apply(find_gap_positions)
df

Unnamed: 0,text_with_gaps,source,split,level,original_words,gap_positions
0,It is well known that Albert Einstein was one ...,high839,train,high,"[paper, hold, cleverest, worry, shoes, why, re...","[316, 380, 439, 659, 879, 980, 1079, 1123]"
1,Douglas was my cousin. I first met him when he...,high2970,train,high,"[year, cousin, meet, clothes, wear, same, save...","[81, 316, 383, 584, 653, 809, 863, 961, 1065, ..."
2,"Two weeks before Christmas, Mother told me we ...",high849,train,high,"[lacked, also, remove, complete, dawned, But, ...","[211, 263, 366, 460, 495, 569, 634, 675, 705, ..."
3,"In 1930, a young African American, Vivien T. T...",high251,train,high,"[never, desire, secretly, developed, medicine,...","[171, 209, 315, 378, 429, 507, 609, 658, 722, ..."
4,I was born in New York City . My first seven y...,high2038,train,high,"[desire, basic, chemistry, deciding, highly, e...","[135, 199, 300, 313, 418, 561, 671, 753, 819, ..."
...,...,...,...,...,...,...
7125,"When I had something _ to do, I used to ask my...",middle2789,test,middle,"[difficult, said, glad, invite, mother, yourse...","[21, 79, 114, 212, 328, 378, 542, 557, 636, 68..."
7126,"My name is Carla and I have got two sisters,...",middle2788,test,middle,"[I, but, too, some, in, aren't, after]","[104, 125, 271, 293, 306, 400, 460]"
7127,A first-grade student named Vincent Butterfiel...,middle2762,test,middle,"[telling, hair, without, expensive, scarves, o...","[120, 334, 477, 597, 635, 813, 839]"
7128,One night a man came to my house. He said to m...,middle2769,test,middle,"[children, food, found, Where, answer, hungry,...","[79, 164, 221, 387, 426, 463, 556, 596]"


In [10]:
# Testing gap-replacement logic

for row in df.sample(2).itertuples():
    text = row.text_with_gaps
    offset = 0
    for word, pos in zip(row.original_words, row.gap_positions):
        pos_ = pos + offset
        text = text[:pos_] + word + text[pos_ + 1:]
        offset += len(word) - 1 # minus 1 because the underscore has length 1
    print("="*80)
    print(text)

One day a few years ago we had a guest of the uninvited variety. In fact, this uninvited guest was a bird--- a(n) sparrow to be more precise . "What's that?" I asked when I first heard the thump . "It sounds like Joe is outside playing basketball," my wife, Anita, said. She paused and listened more devotedly. "It's coming from the garage" she said. "Maybe it's one of the little kids ". We rushed out the door. Jonathan, our youngest, was easy to make trouble "If he's making holes in the wall again..." I said as I searched there. No children at all. But there was that sound again, coming from right up there. And that's when I spotted the sparrow. It was flying anxiously just inches below the ceiling. It was clearly trying to  get out  , but couldn't see that the way out wasn't up, but down and out through the open door So the bird continued beating its wings and hitting its head against the ceiling "Poor thing," Anita said. "It must be terrified" "Well, maybe it's because of me," I said 

## Define ClozeDataset

In [11]:
class ClozeDataset(Dataset):
    def __init__(self, df: pd.DataFrame, tokenizer, max_length=512):
        """        
        Expected DataFrame columns:
        - text_with_gaps: Text with underscores where words were removed
        - original_words: String or list of removed words
        - gap_positions: String or list of gap positions (integers)
        """
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.df = df
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        adjusted_gap_positions = []
        
        # Reconstruct original text by replacing underscores with original words
        text = row['text_with_gaps']
        offset = 0
        for word, pos in zip(row['original_words'], row['gap_positions']):
            pos_ = pos + offset
            adjusted_gap_positions.append(pos_) # We need this for constructing the labels
            text = text[:pos_] + word + text[pos_ + 1:]
            offset += len(word) - 1 # minus 1 because the underscore has length 1
        
        # Tokenize the original text
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_offsets_mapping=True,
            return_tensors='pt'
        )
        
        # Create labels: 1 for tokens that correspond to gap words, 0 for others
        labels = torch.zeros(len(encoding['input_ids'][0]))
        offset_mapping = encoding['offset_mapping'][0]
        
        # For each original word that should be a gap
        for word, start_pos in zip(row['original_words'], adjusted_gap_positions):
            end_pos = start_pos + len(word)
            # Find which token(s) correspond to this word
            for idx, (token_start, token_end) in enumerate(offset_mapping):
                if ( # a token starts inside the target word
                    (token_start.item() >= start_pos and token_start.item() < end_pos)
                    or # a token ends inside the target word
                    (token_end.item() > start_pos and token_end.item() <= end_pos) 
                ):
                    labels[idx] = 1
        
        return {
            'input_ids': encoding['input_ids'][0],
            'attention_mask': encoding['attention_mask'][0],
            'labels': labels
        }

## Define compute_metrics

In [12]:
def compute_metrics(eval_pred):
    
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=2)
    
    # Flatten the arrays and mask the padding (-100) tokens
    true_predictions = predictions.flatten()
    true_labels = labels.flatten()
    mask = true_labels != -100
    true_predictions = true_predictions[mask]
    true_labels = true_labels[mask]
    
    # Calculate metrics
    return {
        "precision": precision_score(true_labels, true_predictions, zero_division=0),
        "recall": recall_score(true_labels, true_predictions, zero_division=0),
        "f1": f1_score(true_labels, true_predictions, zero_division=0),
        "accuracy": accuracy_score(true_labels, true_predictions)
    }

## Train

In [14]:
def train_cloze_model(
    train_df: pd.DataFrame,
    eval_df: pd.DataFrame,
    model_name="microsoft/deberta-v3-base",
    output_dir="../bin",
):
    """    
    Args:
        train_df: Training DataFrame with columns [text_with_gaps, original_words, gap_positions]
        eval_df: Evaluation DataFrame with same columns
        model_name: Name of the pretrained model to use
        output_dir: Directory to save the model
    """
    # Initialize tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    model = AutoModelForTokenClassification.from_pretrained(
        model_name,
        num_labels=2  # Binary classification: gap or no gap
    )
    
    # Create datasets
    train_dataset = ClozeDataset(train_df, tokenizer)
    eval_dataset = ClozeDataset(eval_df, tokenizer)
    
    training_args = TrainingArguments(
        output_dir=output_dir,
        eval_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=3,
        weight_decay=0.01,
        save_strategy="epoch",
        load_best_model_at_end=True,
    )
    
    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        processing_class=tokenizer,
        data_collator=DataCollatorForTokenClassification(tokenizer),
        compute_metrics=compute_metrics,
    )
    
    # Train the model
    trainer.train()
    
    # Save the final model
    trainer.save_model("../bin/cloze-model")
    
    return model, tokenizer

model, tokenizer = train_cloze_model(
    df[df["split"] == "train"],
    df[df["split"] == "test"],
)

{'loss': 0.0831, 'grad_norm': 0.15343379974365234, 'learning_rate': 1.5162070633768748e-05, 'epoch': 0.7256894049346879}
{'eval_loss': 0.06652272492647171, 'eval_precision': 0.7896386045540004, 'eval_recall': 0.31154702052254185, 'eval_f1': 0.4468085106382979, 'eval_accuracy': 0.9775138376383764, 'eval_runtime': 48.2407, 'eval_samples_per_second': 16.853, 'eval_steps_per_second': 2.114, 'epoch': 1.0}
{'loss': 0.0734, 'grad_norm': 0.12325280904769897, 'learning_rate': 1.0324141267537495e-05, 'epoch': 1.4513788098693758}
{'eval_loss': 0.0652344822883606, 'eval_precision': 0.7893037336024218, 'eval_recall': 0.322344020440122, 'eval_f1': 0.457748127340824, 'eval_accuracy': 0.9777396602091021, 'eval_runtime': 48.4988, 'eval_samples_per_second': 16.763, 'eval_steps_per_second': 2.103, 'epoch': 2.0}
{'loss': 0.0694, 'grad_norm': 0.1444166600704193, 'learning_rate': 5.486211901306241e-06, 'epoch': 2.1770682148040637}
{'loss': 0.0664, 'grad_norm': 0.14548803865909576, 'learning_rate': 6.4828253

## Inference

In [52]:
def generate_cloze(text, model, tokenizer, threshold=0.5):
    """Generate a cloze exercise from input text."""
    encoding = tokenizer(
        text,
        return_tensors='pt',
        return_offsets_mapping=True,
        max_length=512,
        truncation=True,
        padding='max_length'
    )
    encoding.to('cuda')
    
    offset_mapping = encoding.pop('offset_mapping')

    # Get model predictions
    with torch.no_grad():
        outputs = model(**encoding)
        probabilities = torch.softmax(outputs.logits, dim=-1)
        predictions = (probabilities[0, :, 1] > threshold).cpu().numpy()

    print(outputs)
    
    # Convert predictions to gaps in the text
    offset_mapping = offset_mapping[0].cpu().numpy()
    text_list = list(text)

    # Replace predicted tokens with underscores
    for idx, pred in enumerate(predictions):
        if pred:
            start, end = offset_mapping[idx]
            if start < len(text_list):  # Check if within text bounds
                text_list[start:end] = ' ' + '_' * (end - start - 1)
    
    return ''.join(text_list)

In [55]:
example_text = '''Chevron's comprehensive training program ensures operators progress from new hires to Fully Qualified Operators (FQO) through a structured, multi-year process involving orientation, on-the-job training, and assessments. FQOs continue training for additional roles, with records meticulously maintained. Control Room Operators (CROs) and Head Operators (HOs) undergo specialized training, including console and simulator sessions, to maintain qualifications. Requalification processes address absences, ensuring operators remain current. Interns receive supervised exposure to operations, potentially transitioning to trainees. Unit School Instructors, selected for their expertise, facilitate training. The program emphasizes continuous learning and adaptation to procedural changes, documented via electronic systems to support operator competence and advancement.'''
example_text = '''The hugely popular video app was taken offline Saturday night in compliance with a law that effectively banned the service nationwide unless it splits off from ByteDance, its China-based owner. Last week, the Supreme Court upheld the law.
On Saturday, Google and Apple removed the app from their stores, a requirement of the ban, which also forbids web-hosting companies from providing back-end support to the app.
When Biden officials said they would leave enforcement of the law up to the Trump administration, web-hosting services were not confident they would not be prosecuted. The law outlines stiff penalties for violations that could cost the companies billions.'''

generate_cloze(example_text, model, tokenizer, threshold=0.50)

TokenClassifierOutput(loss=None, logits=tensor([[[ 1.5748, -2.4767],
         [ 2.9657, -3.0802],
         [ 2.4212, -2.4601],
         ...,
         [ 2.7631, -3.1936],
         [ 2.7631, -3.1936],
         [ 2.7631, -3.1936]]], device='cuda:0'), hidden_states=None, attentions=None)


'The hugely popular video app was taken offline Saturday night in compliance with a law that effectively banned the service nationwide unless it splits off from ByteDance, its China-based owner. Last week, the Supreme Court ______ the law.\nOn Saturday, Google and Apple removed the app from their stores, a ___________ of the ban, which also forbids web-hosting companies from providing back-end support to the app.\nWhen Biden officials said they would leave enforcement of the law up to the Trump administration, web-hosting services were not _________ they would not be prosecuted. The law outlines stiff penalties for violations that could cost the companies billions.'