In [None]:
import os
from typing import List, Tuple
import pandas as pd
import numpy as np
import torch
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import T5ForConditionalGeneration, AutoTokenizer
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
from collections import Counter

In [None]:
def tokenize(text: str) -> List[str]:
    return text.lower().split()


def get_ngrams(tokens: List[str], n: int) -> Counter:
    return Counter(tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1))


def sari_score(source: str, prediction: str, references: List[str]) -> Tuple[float, dict]:
    src_tokens = tokenize(source)
    pred_tokens = tokenize(prediction)
    ref_tokens_list = [tokenize(ref) for ref in references]

    keep_scores = []
    del_scores = []
    add_scores = []

    for n in range(1, 5):
        src_ngrams = get_ngrams(src_tokens, n)
        pred_ngrams = get_ngrams(pred_tokens, n)
        ref_ngrams_list = [get_ngrams(ref_tokens, n) for ref_tokens in ref_tokens_list]

        src_in_refs = Counter()
        for ngram in src_ngrams:
            count = sum(1 for ref_ngrams in ref_ngrams_list if ngram in ref_ngrams)
            if count > 0:
                src_in_refs[ngram] = count

        kept_ngrams = Counter()
        for ngram in pred_ngrams:
            if ngram in src_ngrams:
                kept_ngrams[ngram] = min(pred_ngrams[ngram], src_ngrams[ngram])

        keep_prec_num = sum(min(kept_ngrams[ng], 1) for ng in kept_ngrams if ng in src_in_refs)
        keep_prec_den = sum(kept_ngrams.values())
        keep_prec = keep_prec_num / keep_prec_den if keep_prec_den > 0 else 0

        keep_rec_num = sum(min(kept_ngrams[ng], 1) for ng in src_in_refs if ng in kept_ngrams)
        keep_rec_den = len(src_in_refs)
        keep_rec = keep_rec_num / keep_rec_den if keep_rec_den > 0 else 0

        keep_f1 = 2 * keep_prec * keep_rec / (keep_prec + keep_rec) if (keep_prec + keep_rec) > 0 else 0
        keep_scores.append(keep_f1)

        src_not_in_refs = set(ng for ng in src_ngrams if ng not in src_in_refs)
        deleted_ngrams = set(ng for ng in src_ngrams if ng not in pred_ngrams)

        del_prec_num = len(deleted_ngrams & src_not_in_refs)
        del_prec_den = len(deleted_ngrams)
        del_prec = del_prec_num / del_prec_den if del_prec_den > 0 else 0

        del_rec_num = len(deleted_ngrams & src_not_in_refs)
        del_rec_den = len(src_not_in_refs)
        del_rec = del_rec_num / del_rec_den if del_rec_den > 0 else 0

        del_f1 = 2 * del_prec * del_rec / (del_prec + del_rec) if (del_prec + del_rec) > 0 else 0
        del_scores.append(del_f1)

        added_ngrams = set(ng for ng in pred_ngrams if ng not in src_ngrams)

        ref_not_in_src = set()
        for ref_ngrams in ref_ngrams_list:
            ref_not_in_src.update(ng for ng in ref_ngrams if ng not in src_ngrams)

        add_prec_num = len(added_ngrams & ref_not_in_src)
        add_prec_den = len(added_ngrams)
        add_prec = add_prec_num / add_prec_den if add_prec_den > 0 else 0

        add_scores.append(add_prec)

    keep_avg = sum(keep_scores) / 4
    del_avg = sum(del_scores) / 4
    add_avg = sum(add_scores) / 4

    sari = (keep_avg + del_avg + add_avg) / 3 * 100

    components = {
        'keep': keep_avg * 100,
        'delete': del_avg * 100,
        'add': add_avg * 100
    }

    return sari, components

In [None]:
class SimplificationDataset(Dataset):

    def __init__(self, sources: List[str], targets: List[str], tokenizer, max_length: int = 128):
        self.sources = sources
        self.targets = targets
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sources)

    def __getitem__(self, idx):
        source = "simplify: " + self.sources[idx]
        target = self.targets[idx]

        source_enc = self.tokenizer(
            source,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        target_enc = self.tokenizer(
            target,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        labels = target_enc['input_ids'].squeeze()
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': source_enc['input_ids'].squeeze(),
            'attention_mask': source_enc['attention_mask'].squeeze(),
            'labels': labels
        }

In [None]:
def train_model(
    model,
    tokenizer,
    train_sources: List[str],
    train_targets: List[str],
    val_sources: List[str],
    val_targets: List[str],
    output_dir: str,
    epochs: int = 3,
    batch_size: int = 16,
    learning_rate: float = 3e-5,
    device: str = 'cuda'
):
    train_dataset = SimplificationDataset(train_sources, train_targets, tokenizer)
    val_dataset = SimplificationDataset(val_sources, val_targets, tokenizer)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    optimizer = AdamW(model.parameters(), lr=learning_rate)
    total_steps = len(train_loader) * epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )

    model.to(device)
    best_val_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        train_loss = 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss
            train_loss += loss.item()

            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        avg_train_loss = train_loss / len(train_loader)

        model.eval()
        val_loss = 0

        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )

                val_loss += outputs.loss.item()

        avg_val_loss = val_loss / len(val_loader)

        print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            os.makedirs(output_dir, exist_ok=True)
            model.save_pretrained(output_dir)
            tokenizer.save_pretrained(output_dir)
            print(f"Saved best model to {output_dir}")

    return model

In [None]:
def strong_baseline(
    model,
    tokenizer,
    source: str,
    max_length: int = 128,
    num_beams: int = 4,
    device: str = 'cuda'
) -> str:
    model.eval()

    with torch.no_grad():
        input_text = "simplify: " + source
        input_enc = tokenizer(
            input_text,
            max_length=max_length,
            truncation=True,
            return_tensors='pt'
        ).to(device)

        output_ids = model.generate(
            input_ids=input_enc['input_ids'],
            attention_mask=input_enc['attention_mask'],
            max_length=max_length,
            num_beams=num_beams,
            early_stopping=True
        )

        output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    return output_text

In [None]:
model_name = 't5-small'
output_dir = './t5-simplification'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
epochs = 20
batch_size = 16
learning_rate = 3e-5

splits = {
    'train': 'wiki.full.aner.ori.train.95.tsv',
    'validation': 'wiki.full.aner.ori.valid.95.tsv',
    'test': 'wiki.full.aner.ori.test.95.tsv'
}

print("Loading WikiLarge dataset...")
train = pd.read_csv(
    "hf://datasets/bogdancazan/wikilarge-text-simplification/" + splits["train"],
    sep="\t"
)
val = pd.read_csv(
    "hf://datasets/bogdancazan/wikilarge-text-simplification/" + splits["validation"],
    sep="\t"
)
test = pd.read_csv(
    "hf://datasets/bogdancazan/wikilarge-text-simplification/" + splits["test"],
    sep="\t"
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

model = train_model(
    model,
    tokenizer,
    train_sources=train['Normal'].tolist(),
    train_targets=train['Simple'].tolist(),
    val_sources=val['Normal'].tolist(),
    val_targets=val['Simple'].tolist(),
    output_dir=output_dir,
    epochs=epochs,
    batch_size=batch_size,
    learning_rate=learning_rate,
    device=device
)

model.to(device)
outputs = []
for source in tqdm(test['Normal'], desc="Generating simplifications"):
    output = strong_baseline(model, tokenizer, source, device=device)
    outputs.append(output)


sari_scores = []
keep_scores = []
del_scores = []
add_scores = []

for output, source, reference in zip(outputs, test['Normal'], test['Simple']):
    sari, components = sari_score(source, output, [reference])
    sari_scores.append(sari)
    keep_scores.append(components['keep'])
    del_scores.append(components['delete'])
    add_scores.append(components['add'])

print("\n" + "=" * 50)
print("Strong Baseline SARI Score Results")
print("=" * 50)
print(f"Number of samples: {len(sari_scores)}")
print()
print(f"  SARI:        {sum(sari_scores) / len(sari_scores):.2f}")
print(f"    - Keep:    {sum(keep_scores) / len(keep_scores):.2f}")
print(f"    - Delete:  {sum(del_scores) / len(del_scores):.2f}")
print(f"    - Add:     {sum(add_scores) / len(add_scores):.2f}")
print("=" * 50)

In [None]:
!zip -r t5-simplification.zip /content/t5-simplification

In [None]:
from google.colab import files
files.download('t5-simplification.zip')

In [None]:
import zipfile

with zipfile.ZipFile('t5-simplification.zip', 'r') as zip_ref:
    zip_ref.extractall('./')


In [None]:
# output_dir = 'content/t5-simplification'
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# model = T5ForConditionalGeneration.from_pretrained(output_dir)
# tokenizer = AutoTokenizer.from_pretrained(output_dir)
# model.to(device)

# test = pd.read_csv(
#     "hf://datasets/bogdancazan/wikilarge-text-simplification/wiki.full.aner.ori.test.95.tsv",
#     sep="\t"
# )

# print("Generating predictions and calculating scores...")
# results = []
# for idx, row in tqdm(test.iterrows(), total=len(test)):
#     source = row['Normal']
#     reference = row['Simple']

#     output = strong_baseline(model, tokenizer, source, device=device)

#     sari, components = sari_score(source, output, [reference])

#     results.append({
#         'idx': idx,
#         'source': source,
#         'reference': reference,
#         'prediction': output,
#         'sari': sari,
#         'keep': components['keep'],
#         'delete': components['delete'],
#         'add': components['add']
#     })

# results_df = pd.DataFrame(results)

# results_df = results_df.sort_values('sari', ascending=False).reset_index(drop=True)

# print("\n" + "="*80)
# print("TOP 5 EXAMPLES (Highest SARI Scores)")
# print("="*80)
# for i, row in results_df.head(5).iterrows():
#     print(f"\nExample {i+1} - SARI: {row['sari']:.2f} (Keep: {row['keep']:.2f}, Del: {row['delete']:.2f}, Add: {row['add']:.2f})")
#     print(f"Source:     {row['source']}")
#     print(f"Reference:  {row['reference']}")
#     print(f"Prediction: {row['prediction']}")
#     print("-"*80)

# print("\n" + "="*80)
# print("BOTTOM 5 EXAMPLES (Lowest SARI Scores)")
# print("="*80)
# for i, row in results_df.tail(5).iterrows():
#     print(f"\nExample {i+1} - SARI: {row['sari']:.2f} (Keep: {row['keep']:.2f}, Del: {row['delete']:.2f}, Add: {row['add']:.2f})")
#     print(f"Source:     {row['source']}")
#     print(f"Reference:  {row['reference']}")
#     print(f"Prediction: {row['prediction']}")
#     print("-"*80)

# print("\n" + "="*80)
# print("SUMMARY STATISTICS")
# print("="*80)
# print(f"Total examples: {len(results_df)}")
# print(f"\nSARI Score Distribution:")
# print(f"  Mean:   {results_df['sari'].mean():.2f}")
# print(f"  Median: {results_df['sari'].median():.2f}")
# print(f"  Std:    {results_df['sari'].std():.2f}")
# print(f"  Min:    {results_df['sari'].min():.2f}")
# print(f"  Max:    {results_df['sari'].max():.2f}")
# print("="*80)

# results_df.to_csv('model_analysis_results.csv', index=False)
# print(f"\nResults saved to 'model_analysis_results.csv'")

In [None]:
def get_sentence_prefix(source: str, completion_ratio: float) -> str:
    tokens = source.split()
    num_tokens = len(tokens)
    prefix_length = max(1, int(num_tokens * completion_ratio))

    prefix_tokens = tokens[:prefix_length]
    return ' '.join(prefix_tokens)


def incremental_simplify(
    model,
    tokenizer,
    source: str,
    completion_ratio: float,
    max_length: int = 128,
    num_beams: int = 4,
    length_penalty: float = 0.6,
    device: str = 'cuda'
) -> str:
    prefix = get_sentence_prefix(source, completion_ratio)

    model.eval()
    with torch.no_grad():
        input_text = "simplify: " + prefix
        input_enc = tokenizer(
            input_text,
            max_length=max_length,
            truncation=True,
            return_tensors='pt'
        ).to(device)

        output_ids = model.generate(
            input_ids=input_enc['input_ids'],
            attention_mask=input_enc['attention_mask'],
            max_length=max_length,
            num_beams=num_beams,
            length_penalty=length_penalty,
            early_stopping=True,
            no_repeat_ngram_size=2
        )

        output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    return output_text


def adaptive_incremental_simplify(
    model,
    tokenizer,
    source: str,
    completion_ratio: float,
    max_length: int = 128,
    num_beams: int = 4,
    device: str = 'cuda'
) -> str:
    if completion_ratio <= 0.25:
        length_penalty = 0.4
    elif completion_ratio <= 0.5:
        length_penalty = 0.5
    elif completion_ratio <= 0.75:
        length_penalty = 0.6
    else:
        length_penalty = 0.7

    return incremental_simplify(
        model, tokenizer, source, completion_ratio,
        max_length=max_length,
        num_beams=num_beams,
        length_penalty=length_penalty,
        device=device
    )



In [None]:
model_dir = 'content/t5-simplification'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
max_length = 128

model_dir = os.path.abspath(model_dir)

splits = {
    'test': 'wiki.full.aner.ori.test.95.tsv'
}

print("Loading fine-tuned T5 model...")
if not os.path.exists(model_dir):
    raise FileNotFoundError(
        f"Model directory {model_dir} not found. "
        "Please run strong_baseline.py first to train the model."
    )

tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True)
model = T5ForConditionalGeneration.from_pretrained(model_dir, local_files_only=True)
model.to(device)
model.eval()
print("Model loaded successfully!")

print("Loading WikiLarge test dataset...")
test = pd.read_csv(
    "hf://datasets/bogdancazan/wikilarge-text-simplification/" + splits["test"],
    sep="\t"
)
print(f"Loaded {len(test)} test examples")

completion_ratios = [0.25, 0.5, 0.75, 1.0]

print("\nEvaluating incremental simplification with fixed length penalty...")
fixed_results = {ratio: {'sari': [], 'keep': [], 'delete': [], 'add': []} for ratio in completion_ratios}

for idx, row in tqdm(test.iterrows(), total=len(test), desc="Processing"):
    source = row['Normal']
    reference = row['Simple']

    for ratio in completion_ratios:
        output = incremental_simplify(
            model, tokenizer, source, ratio,
            max_length=max_length,
            num_beams=4,
            length_penalty=0.6,
            device=device
        )

        sari, components = sari_score(source, output, [reference])
        fixed_results[ratio]['sari'].append(sari)
        fixed_results[ratio]['keep'].append(components['keep'])
        fixed_results[ratio]['delete'].append(components['delete'])
        fixed_results[ratio]['add'].append(components['add'])

print("\nEvaluating incremental simplification with adaptive length penalty...")
adaptive_results = {ratio: {'sari': [], 'keep': [], 'delete': [], 'add': []} for ratio in completion_ratios}

for idx, row in tqdm(test.iterrows(), total=len(test), desc="Processing"):
    source = row['Normal']
    reference = row['Simple']

    for ratio in completion_ratios:
        output = adaptive_incremental_simplify(
            model, tokenizer, source, ratio,
            max_length=max_length,
            num_beams=4,
            device=device
        )

        sari, components = sari_score(source, output, [reference])
        adaptive_results[ratio]['sari'].append(sari)
        adaptive_results[ratio]['keep'].append(components['keep'])
        adaptive_results[ratio]['delete'].append(components['delete'])
        adaptive_results[ratio]['add'].append(components['add'])

print("\n" + "=" * 70)
print("Fixed Length Penalty Results")
print("=" * 70)
print(f"\n{'Completion Ratio':<20} {'SARI':<10} {'Keep':<10} {'Delete':<10} {'Add':<10}")
print("-" * 70)

for ratio in completion_ratios:
    sari_avg = np.mean(fixed_results[ratio]['sari'])
    keep_avg = np.mean(fixed_results[ratio]['keep'])
    del_avg = np.mean(fixed_results[ratio]['delete'])
    add_avg = np.mean(fixed_results[ratio]['add'])

    print(f"{ratio*100:>5.0f}%{'':<14} {sari_avg:>6.2f}    {keep_avg:>6.2f}    {del_avg:>6.2f}    {add_avg:>6.2f}")

print("=" * 70)

print("\n" + "=" * 70)
print("Adaptive Length Penalty Results")
print("=" * 70)
print(f"\n{'Completion Ratio':<20} {'SARI':<10} {'Keep':<10} {'Delete':<10} {'Add':<10}")
print("-" * 70)

for ratio in completion_ratios:
    sari_avg = np.mean(adaptive_results[ratio]['sari'])
    keep_avg = np.mean(adaptive_results[ratio]['keep'])
    del_avg = np.mean(adaptive_results[ratio]['delete'])
    add_avg = np.mean(adaptive_results[ratio]['add'])

    print(f"{ratio*100:>5.0f}%{'':<14} {sari_avg:>6.2f}    {keep_avg:>6.2f}    {del_avg:>6.2f}    {add_avg:>6.2f}")

print("=" * 70)

print("\n" + "=" * 70)
print("Comparison: Fixed vs Adaptive Length Penalty")
print("=" * 70)
print(f"\n{'Completion Ratio':<20} {'Fixed SARI':<15} {'Adaptive SARI':<15} {'Improvement':<15}")
print("-" * 70)

for ratio in completion_ratios:
    fixed_sari = np.mean(fixed_results[ratio]['sari'])
    adaptive_sari = np.mean(adaptive_results[ratio]['sari'])
    improvement = adaptive_sari - fixed_sari

    print(f"{ratio*100:>5.0f}%{'':<14} {fixed_sari:>8.2f}      {adaptive_sari:>8.2f}      {improvement:>+8.2f}")

print("=" * 70)

full_context_sari = np.mean(adaptive_results[1.0]['sari'])
print("\n" + "=" * 70)
print("Quality Degradation Analysis (Adaptive Strategy)")
print("=" * 70)
print(f"\nFull context (100%) SARI: {full_context_sari:.2f}")
print(f"\n{'Completion Ratio':<20} {'SARI':<15} {'% of Full Quality':<20}")
print("-" * 70)

for ratio in completion_ratios:
    sari = np.mean(adaptive_results[ratio]['sari'])
    pct_quality = (sari / full_context_sari) * 100 if full_context_sari > 0 else 0
    print(f"{ratio*100:>5.0f}%{'':<14} {sari:>8.2f}      {pct_quality:>15.1f}%")

print("=" * 70)

print("\n" + "=" * 70)
print("Example: Incremental Simplification")
print("=" * 70)

example_idx = 0
source = test.iloc[example_idx]['Normal']
reference = test.iloc[example_idx]['Simple']

print(f"\nSource: {source}")
print(f"Reference: {reference}")
print("\n" + "-" * 70)

for ratio in completion_ratios:
    prefix = get_sentence_prefix(source, ratio)
    output = adaptive_incremental_simplify(
        model, tokenizer, source, ratio,
        max_length=max_length,
        num_beams=4,
        device=device
    )
    sari, _ = sari_score(source, output, [reference])

    print(f"\n[{ratio*100:.0f}% context]")
    print(f"  Prefix: {prefix}")
    print(f"  Output: {output}")
    print(f"  SARI: {sari:.2f}")