In [3]:
import os
from typing import List, Tuple
import pandas as pd
import torch
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import T5ForConditionalGeneration, AutoTokenizer
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
from collections import Counter

In [11]:
def tokenize(text: str) -> List[str]:
    return text.lower().split()


def get_ngrams(tokens: List[str], n: int) -> Counter:
    return Counter(tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1))


def sari_score(source: str, prediction: str, references: List[str]) -> Tuple[float, dict]:
    src_tokens = tokenize(source)
    pred_tokens = tokenize(prediction)
    ref_tokens_list = [tokenize(ref) for ref in references]

    keep_scores = []
    del_scores = []
    add_scores = []

    for n in range(1, 5):
        src_ngrams = get_ngrams(src_tokens, n)
        pred_ngrams = get_ngrams(pred_tokens, n)
        ref_ngrams_list = [get_ngrams(ref_tokens, n) for ref_tokens in ref_tokens_list]

        src_in_refs = Counter()
        for ngram in src_ngrams:
            count = sum(1 for ref_ngrams in ref_ngrams_list if ngram in ref_ngrams)
            if count > 0:
                src_in_refs[ngram] = count

        kept_ngrams = Counter()
        for ngram in pred_ngrams:
            if ngram in src_ngrams:
                kept_ngrams[ngram] = min(pred_ngrams[ngram], src_ngrams[ngram])

        keep_prec_num = sum(min(kept_ngrams[ng], 1) for ng in kept_ngrams if ng in src_in_refs)
        keep_prec_den = sum(kept_ngrams.values())
        keep_prec = keep_prec_num / keep_prec_den if keep_prec_den > 0 else 0

        keep_rec_num = sum(min(kept_ngrams[ng], 1) for ng in src_in_refs if ng in kept_ngrams)
        keep_rec_den = len(src_in_refs)
        keep_rec = keep_rec_num / keep_rec_den if keep_rec_den > 0 else 0

        keep_f1 = 2 * keep_prec * keep_rec / (keep_prec + keep_rec) if (keep_prec + keep_rec) > 0 else 0
        keep_scores.append(keep_f1)

        src_not_in_refs = set(ng for ng in src_ngrams if ng not in src_in_refs)
        deleted_ngrams = set(ng for ng in src_ngrams if ng not in pred_ngrams)

        del_prec_num = len(deleted_ngrams & src_not_in_refs)
        del_prec_den = len(deleted_ngrams)
        del_prec = del_prec_num / del_prec_den if del_prec_den > 0 else 0

        del_rec_num = len(deleted_ngrams & src_not_in_refs)
        del_rec_den = len(src_not_in_refs)
        del_rec = del_rec_num / del_rec_den if del_rec_den > 0 else 0

        del_f1 = 2 * del_prec * del_rec / (del_prec + del_rec) if (del_prec + del_rec) > 0 else 0
        del_scores.append(del_f1)

        added_ngrams = set(ng for ng in pred_ngrams if ng not in src_ngrams)

        ref_not_in_src = set()
        for ref_ngrams in ref_ngrams_list:
            ref_not_in_src.update(ng for ng in ref_ngrams if ng not in src_ngrams)

        add_prec_num = len(added_ngrams & ref_not_in_src)
        add_prec_den = len(added_ngrams)
        add_prec = add_prec_num / add_prec_den if add_prec_den > 0 else 0

        add_scores.append(add_prec)

    keep_avg = sum(keep_scores) / 4
    del_avg = sum(del_scores) / 4
    add_avg = sum(add_scores) / 4

    sari = (keep_avg + del_avg + add_avg) / 3 * 100

    components = {
        'keep': keep_avg * 100,
        'delete': del_avg * 100,
        'add': add_avg * 100
    }

    return sari, components

In [8]:
class SimplificationDataset(Dataset):

    def __init__(self, sources: List[str], targets: List[str], tokenizer, max_length: int = 128):
        self.sources = sources
        self.targets = targets
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sources)

    def __getitem__(self, idx):
        source = "simplify: " + self.sources[idx]
        target = self.targets[idx]

        source_enc = self.tokenizer(
            source,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        target_enc = self.tokenizer(
            target,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        labels = target_enc['input_ids'].squeeze()
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': source_enc['input_ids'].squeeze(),
            'attention_mask': source_enc['attention_mask'].squeeze(),
            'labels': labels
        }

In [9]:
def train_model(
    model,
    tokenizer,
    train_sources: List[str],
    train_targets: List[str],
    val_sources: List[str],
    val_targets: List[str],
    output_dir: str,
    epochs: int = 3,
    batch_size: int = 16,
    learning_rate: float = 3e-5,
    device: str = 'cuda'
):
    train_dataset = SimplificationDataset(train_sources, train_targets, tokenizer)
    val_dataset = SimplificationDataset(val_sources, val_targets, tokenizer)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    optimizer = AdamW(model.parameters(), lr=learning_rate)
    total_steps = len(train_loader) * epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )

    model.to(device)
    best_val_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        train_loss = 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss
            train_loss += loss.item()

            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        avg_train_loss = train_loss / len(train_loader)

        model.eval()
        val_loss = 0

        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )

                val_loss += outputs.loss.item()

        avg_val_loss = val_loss / len(val_loader)

        print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            os.makedirs(output_dir, exist_ok=True)
            model.save_pretrained(output_dir)
            tokenizer.save_pretrained(output_dir)
            print(f"Saved best model to {output_dir}")

    return model

In [9]:
def strong_baseline(
    model,
    tokenizer,
    source: str,
    max_length: int = 128,
    num_beams: int = 4,
    device: str = 'cuda'
) -> str:
    model.eval()

    with torch.no_grad():
        input_text = "simplify: " + source
        input_enc = tokenizer(
            input_text,
            max_length=max_length,
            truncation=True,
            return_tensors='pt'
        ).to(device)

        output_ids = model.generate(
            input_ids=input_enc['input_ids'],
            attention_mask=input_enc['attention_mask'],
            max_length=max_length,
            num_beams=num_beams,
            early_stopping=True
        )

        output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    return output_text

In [12]:
model_name = 't5-small'
output_dir = './t5-simplification'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
epochs = 20
batch_size = 16
learning_rate = 3e-5

splits = {
    'train': 'wiki.full.aner.ori.train.95.tsv',
    'validation': 'wiki.full.aner.ori.valid.95.tsv',
    'test': 'wiki.full.aner.ori.test.95.tsv'
}

print("Loading WikiLarge dataset...")
train = pd.read_csv(
    "hf://datasets/bogdancazan/wikilarge-text-simplification/" + splits["train"],
    sep="\t"
)
val = pd.read_csv(
    "hf://datasets/bogdancazan/wikilarge-text-simplification/" + splits["validation"],
    sep="\t"
)
test = pd.read_csv(
    "hf://datasets/bogdancazan/wikilarge-text-simplification/" + splits["test"],
    sep="\t"
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

model = train_model(
    model,
    tokenizer,
    train_sources=train['Normal'].tolist(),
    train_targets=train['Simple'].tolist(),
    val_sources=val['Normal'].tolist(),
    val_targets=val['Simple'].tolist(),
    output_dir=output_dir,
    epochs=epochs,
    batch_size=batch_size,
    learning_rate=learning_rate,
    device=device
)

model.to(device)
outputs = []
for source in tqdm(test['Normal'], desc="Generating simplifications"):
    output = strong_baseline(model, tokenizer, source, device=device)
    outputs.append(output)


sari_scores = []
keep_scores = []
del_scores = []
add_scores = []

for output, source, reference in zip(outputs, test['Normal'], test['Simple']):
    sari, components = sari_score(source, output, [reference])
    sari_scores.append(sari)
    keep_scores.append(components['keep'])
    del_scores.append(components['delete'])
    add_scores.append(components['add'])

print("\n" + "=" * 50)
print("Strong Baseline SARI Score Results")
print("=" * 50)
print(f"Number of samples: {len(sari_scores)}")
print()
print(f"  SARI:        {sum(sari_scores) / len(sari_scores):.2f}")
print(f"    - Keep:    {sum(keep_scores) / len(keep_scores):.2f}")
print(f"    - Delete:  {sum(del_scores) / len(del_scores):.2f}")
print(f"    - Add:     {sum(add_scores) / len(add_scores):.2f}")
print("=" * 50)

Loading WikiLarge dataset...


Epoch 1/20: 100%|██████████| 9303/9303 [12:01<00:00, 12.90it/s]


Epoch 1: Train Loss = 1.7627, Val Loss = 1.4882
Saved best model to ./t5-simplification


Epoch 2/20: 100%|██████████| 9303/9303 [11:57<00:00, 12.96it/s]


Epoch 2: Train Loss = 1.5156, Val Loss = 1.4143
Saved best model to ./t5-simplification


Epoch 3/20: 100%|██████████| 9303/9303 [11:59<00:00, 12.93it/s]


Epoch 3: Train Loss = 1.4465, Val Loss = 1.3770
Saved best model to ./t5-simplification


Epoch 4/20: 100%|██████████| 9303/9303 [12:05<00:00, 12.82it/s]


Epoch 4: Train Loss = 1.4079, Val Loss = 1.3585
Saved best model to ./t5-simplification


Epoch 5/20: 100%|██████████| 9303/9303 [11:44<00:00, 13.20it/s]


Epoch 5: Train Loss = 1.3829, Val Loss = 1.3476
Saved best model to ./t5-simplification


Epoch 6/20: 100%|██████████| 9303/9303 [11:57<00:00, 12.96it/s]


Epoch 6: Train Loss = 1.3632, Val Loss = 1.3366
Saved best model to ./t5-simplification


Epoch 7/20: 100%|██████████| 9303/9303 [11:57<00:00, 12.96it/s]


Epoch 7: Train Loss = 1.3484, Val Loss = 1.3273
Saved best model to ./t5-simplification


Epoch 8/20: 100%|██████████| 9303/9303 [11:57<00:00, 12.96it/s]


Epoch 8: Train Loss = 1.3346, Val Loss = 1.3245
Saved best model to ./t5-simplification


Epoch 9/20: 100%|██████████| 9303/9303 [11:57<00:00, 12.97it/s]


Epoch 9: Train Loss = 1.3238, Val Loss = 1.3166
Saved best model to ./t5-simplification


Epoch 10/20: 100%|██████████| 9303/9303 [12:02<00:00, 12.87it/s]


Epoch 10: Train Loss = 1.3141, Val Loss = 1.3140
Saved best model to ./t5-simplification


Epoch 11/20: 100%|██████████| 9303/9303 [11:44<00:00, 13.20it/s]


Epoch 11: Train Loss = 1.3058, Val Loss = 1.3094
Saved best model to ./t5-simplification


Epoch 12/20: 100%|██████████| 9303/9303 [11:57<00:00, 12.96it/s]


Epoch 12: Train Loss = 1.2981, Val Loss = 1.3060
Saved best model to ./t5-simplification


Epoch 13/20: 100%|██████████| 9303/9303 [11:58<00:00, 12.95it/s]


Epoch 13: Train Loss = 1.2920, Val Loss = 1.3048
Saved best model to ./t5-simplification


Epoch 14/20: 100%|██████████| 9303/9303 [11:58<00:00, 12.94it/s]


Epoch 14: Train Loss = 1.2872, Val Loss = 1.3035
Saved best model to ./t5-simplification


Epoch 15/20: 100%|██████████| 9303/9303 [11:57<00:00, 12.96it/s]


Epoch 15: Train Loss = 1.2822, Val Loss = 1.3029
Saved best model to ./t5-simplification


Epoch 16/20: 100%|██████████| 9303/9303 [11:59<00:00, 12.92it/s]


Epoch 16: Train Loss = 1.2774, Val Loss = 1.3003
Saved best model to ./t5-simplification


Epoch 17/20: 100%|██████████| 9303/9303 [11:45<00:00, 13.18it/s]


Epoch 17: Train Loss = 1.2746, Val Loss = 1.2992
Saved best model to ./t5-simplification


Epoch 18/20: 100%|██████████| 9303/9303 [11:59<00:00, 12.94it/s]


Epoch 18: Train Loss = 1.2715, Val Loss = 1.2991
Saved best model to ./t5-simplification


Epoch 19/20: 100%|██████████| 9303/9303 [11:59<00:00, 12.94it/s]


Epoch 19: Train Loss = 1.2699, Val Loss = 1.2992


Epoch 20/20: 100%|██████████| 9303/9303 [11:58<00:00, 12.95it/s]


Epoch 20: Train Loss = 1.2688, Val Loss = 1.2985
Saved best model to ./t5-simplification


Generating simplifications: 100%|██████████| 191/191 [01:03<00:00,  3.01it/s]


Strong Baseline SARI Score Results
Number of samples: 191

  SARI:        33.43
    - Keep:    59.90
    - Delete:  31.98
    - Add:     8.41





In [15]:
!zip -r t5-simplification.zip /content/t5-simplification

  adding: content/t5-simplification/ (stored 0%)
  adding: content/t5-simplification/model.safetensors (deflated 8%)
  adding: content/t5-simplification/spiece.model (deflated 48%)
  adding: content/t5-simplification/generation_config.json (deflated 27%)
  adding: content/t5-simplification/special_tokens_map.json (deflated 85%)
  adding: content/t5-simplification/tokenizer.json (deflated 74%)
  adding: content/t5-simplification/tokenizer_config.json (deflated 95%)
  adding: content/t5-simplification/config.json (deflated 63%)


In [16]:
from google.colab import files
files.download('t5-simplification.zip')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [5]:
import zipfile

with zipfile.ZipFile('t5-simplification.zip', 'r') as zip_ref:
    zip_ref.extractall('./')


In [12]:
# Load the best model
output_dir = 'content/t5-simplification'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = T5ForConditionalGeneration.from_pretrained(output_dir)
tokenizer = AutoTokenizer.from_pretrained(output_dir)
model.to(device)

# Load test data
test = pd.read_csv(
    "hf://datasets/bogdancazan/wikilarge-text-simplification/wiki.full.aner.ori.test.95.tsv",
    sep="\t"
)

# Generate predictions and calculate SARI scores for each example
print("Generating predictions and calculating scores...")
results = []
for idx, row in tqdm(test.iterrows(), total=len(test)):
    source = row['Normal']
    reference = row['Simple']

    # Generate simplification
    output = strong_baseline(model, tokenizer, source, device=device)

    # Calculate SARI score
    sari, components = sari_score(source, output, [reference])

    results.append({
        'idx': idx,
        'source': source,
        'reference': reference,
        'prediction': output,
        'sari': sari,
        'keep': components['keep'],
        'delete': components['delete'],
        'add': components['add']
    })

# Create DataFrame
results_df = pd.DataFrame(results)

# Sort by SARI score
results_df = results_df.sort_values('sari', ascending=False).reset_index(drop=True)

# Display top 5 examples (highest SARI scores)
print("\n" + "="*80)
print("TOP 5 EXAMPLES (Highest SARI Scores)")
print("="*80)
for i, row in results_df.head(5).iterrows():
    print(f"\nExample {i+1} - SARI: {row['sari']:.2f} (Keep: {row['keep']:.2f}, Del: {row['delete']:.2f}, Add: {row['add']:.2f})")
    print(f"Source:     {row['source']}")
    print(f"Reference:  {row['reference']}")
    print(f"Prediction: {row['prediction']}")
    print("-"*80)

# Display bottom 5 examples (lowest SARI scores)
print("\n" + "="*80)
print("BOTTOM 5 EXAMPLES (Lowest SARI Scores)")
print("="*80)
for i, row in results_df.tail(5).iterrows():
    print(f"\nExample {i+1} - SARI: {row['sari']:.2f} (Keep: {row['keep']:.2f}, Del: {row['delete']:.2f}, Add: {row['add']:.2f})")
    print(f"Source:     {row['source']}")
    print(f"Reference:  {row['reference']}")
    print(f"Prediction: {row['prediction']}")
    print("-"*80)

# Summary statistics
print("\n" + "="*80)
print("SUMMARY STATISTICS")
print("="*80)
print(f"Total examples: {len(results_df)}")
print(f"\nSARI Score Distribution:")
print(f"  Mean:   {results_df['sari'].mean():.2f}")
print(f"  Median: {results_df['sari'].median():.2f}")
print(f"  Std:    {results_df['sari'].std():.2f}")
print(f"  Min:    {results_df['sari'].min():.2f}")
print(f"  Max:    {results_df['sari'].max():.2f}")
print("="*80)

# Save results to CSV for further analysis
results_df.to_csv('model_analysis_results.csv', index=False)
print(f"\nResults saved to 'model_analysis_results.csv'")

Generating predictions and calculating scores...


  7%|▋         | 14/191 [00:04<01:00,  2.91it/s]


KeyboardInterrupt: 