## Extension 3: Train `t5-simplification-2` (T5 + Encoder Adapter)

This notebook trains a *new* model directory `t5-simplification-2/` end-to-end on this project’s `data/train.csv` and validates on `data/dev.csv`.

### Architectural change (vs. prior T5 fine-tuning)
Instead of using plain `T5ForConditionalGeneration`, we add a **learned bottleneck adapter** applied to the encoder hidden states *before* decoding:

- Compute encoder hidden states `H` with T5 encoder.
- Apply a residual adapter: `H' = LN(H + gate * Up(GELU(Down(H))))`.

Intuition: simplification often needs a compact “rewrite transform” between input semantics and surface form. The adapter adds capacity targeted at this transformation without changing the tokenizer or decoding interface.

### Relation to real-time simplification
Since extension 2 was able to significantly reduce latency in real-time settings, this change is more meant to improve **quality-per-compute** in real-time settings (better outputs at the same decoding settings), not to dramatically reduce latency.

- **Why it can help**: the adapter adds small, targeted capacity for rewriting while keeping the same tokenizer + generation interface.
- **Latency impact**: it adds one extra MLP pass over encoder hidden states (usually minor). Most inference time is still dominated by autoregressive decoding (number of generated tokens × `num_beams`).
- **If you need strict lower latency**: the biggest knobs are typically lowering `num_beams`, lowering `max_length`, and/or using distillation/quantization; this adapter is complementary because it can help preserve quality when you dial beams down.

In [None]:
# Optional: install dependencies if missing (safe to re-run)
import importlib
import sys
import subprocess

def _ensure(pkg, pip_name=None):
    try:
        importlib.import_module(pkg)
    except ImportError:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', pip_name or pkg])


_ensure('numpy')
_ensure('pandas')
_ensure('tqdm')
_ensure('torch')
_ensure('transformers')
_ensure('sentencepiece', 'sentencepiece')


In [None]:
import os
import json
import random
from dataclasses import dataclass

import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader

from tqdm.auto import tqdm
from transformers import AutoTokenizer, T5ForConditionalGeneration
from transformers import get_linear_schedule_with_warmup
from transformers.modeling_outputs import BaseModelOutput

from score import sari_score

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device


In [None]:
# Config
model_name = 't5-small'
output_dir = './t5-simplification-2'
max_length = 128
batch_size = 16
num_beams = 6

# Training schedule
adapter_bottleneck = 256
adapter_dropout = 0.1

# Match the stronger baseline-style training budget (~20 epochs total):
# short warmup where only the adapter learns, then full fine-tuning.
phase1_epochs = 1      # train adapter only
phase2_epochs = 1     # fine-tune full model + adapter

lr_adapter_phase1 = 1e-3
lr_full_phase2 = 3e-5
grad_clip = 1.0

# Fast validation during training (full dev eval happens later)
quick_dev_eval_n = 200


In [None]:
# Load project CSVs
train_df = pd.read_csv('data/train.csv')
dev_df = pd.read_csv('data/dev.csv')
test_df = pd.read_csv('data/test.csv')

for df_name, df in [('train', train_df), ('dev', dev_df), ('test', test_df)]:
    if not {'Normal', 'Simple'}.issubset(df.columns):
        raise ValueError(f'{df_name}.csv must have Normal,Simple columns; got {df.columns.tolist()}')
    df['Normal'] = df['Normal'].astype(str)
    df['Simple'] = df['Simple'].astype(str)

print('train/dev/test sizes:', len(train_df), len(dev_df), len(test_df))
train_df.head()


In [None]:
class SimplificationDataset(Dataset):
    def __init__(self, sources, targets, tokenizer, max_length: int = 128):
        self.sources = list(sources)
        self.targets = list(targets)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sources)

    def __getitem__(self, idx):
        source = 'simplify: ' + self.sources[idx]
        target = self.targets[idx]

        source_enc = self.tokenizer(
            source,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )
        target_enc = self.tokenizer(
            target,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )

        labels = target_enc['input_ids'].squeeze(0)
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': source_enc['input_ids'].squeeze(0),
            'attention_mask': source_enc['attention_mask'].squeeze(0),
            'labels': labels,
        }


In [None]:
class EncoderBottleneckAdapter(nn.Module):
    def __init__(self, d_model: int, bottleneck: int = 256, dropout: float = 0.1):
        super().__init__()
        self.down = nn.Linear(d_model, bottleneck)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(dropout)
        self.up = nn.Linear(bottleneck, d_model)
        self.ln = nn.LayerNorm(d_model)
        # Trainable scalar gate initialized near 0: start close to plain T5
        self.gate = nn.Parameter(torch.tensor(0.0))

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        delta = self.up(self.dropout(self.act(self.down(h))))
        return self.ln(h + torch.tanh(self.gate) * delta)


@dataclass
class AdapterConfig:
    bottleneck: int
    dropout: float


class T5WithEncoderAdapter(nn.Module):
    def __init__(self, base_model_name: str, bottleneck: int = 256, dropout: float = 0.1):
        super().__init__()
        self.base = T5ForConditionalGeneration.from_pretrained(base_model_name)
        d_model = self.base.config.d_model
        self.adapter = EncoderBottleneckAdapter(d_model=d_model, bottleneck=bottleneck, dropout=dropout)
        self.adapter_cfg = AdapterConfig(bottleneck=bottleneck, dropout=dropout)

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        # Run encoder once, then adapt encoder hidden states before decoding.
        enc = self.base.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )
        h = self.adapter(enc.last_hidden_state)
        encoder_outputs = BaseModelOutput(
            last_hidden_state=h,
            hidden_states=enc.hidden_states,
            attentions=enc.attentions,
        )

        return self.base(
            encoder_outputs=encoder_outputs,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True,
            **kwargs,
        )

    @torch.no_grad()
    def generate(self, input_ids=None, attention_mask=None, **gen_kwargs):
        enc = self.base.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )
        h = self.adapter(enc.last_hidden_state)
        encoder_outputs = BaseModelOutput(
            last_hidden_state=h,
            hidden_states=enc.hidden_states,
            attentions=enc.attentions,
        )
        return self.base.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            encoder_outputs=encoder_outputs,
            **gen_kwargs,
        )

    def save(self, output_dir: str, tokenizer):
        os.makedirs(output_dir, exist_ok=True)
        base_dir = os.path.join(output_dir, 'base')
        self.base.save_pretrained(base_dir)
        tokenizer.save_pretrained(output_dir)
        torch.save(self.adapter.state_dict(), os.path.join(output_dir, 'adapter.pt'))
        with open(os.path.join(output_dir, 'adapter_config.json'), 'w') as f:
            json.dump({
                'bottleneck': self.adapter_cfg.bottleneck,
                'dropout': self.adapter_cfg.dropout,
            }, f, indent=2)

    @classmethod
    def load(cls, output_dir: str):
        base_dir = os.path.join(output_dir, 'base')
        with open(os.path.join(output_dir, 'adapter_config.json'), 'r') as f:
            cfg = json.load(f)
        model = cls(base_model_name=base_dir, bottleneck=cfg['bottleneck'], dropout=cfg['dropout'])
        sd = torch.load(os.path.join(output_dir, 'adapter.pt'), map_location='cpu')
        model.adapter.load_state_dict(sd)
        return model


In [None]:
def set_base_trainable(model: T5WithEncoderAdapter, trainable: bool):
    for p in model.base.parameters():
        p.requires_grad = trainable
    for p in model.adapter.parameters():
        p.requires_grad = True


def train_one_epoch(model, train_loader, optimizer, scheduler=None):
    model.train()
    total = 0.0
    for batch in tqdm(train_loader, desc='train', leave=False):
        batch = {k: v.to(device) for k, v in batch.items()}
        out = model(**batch)
        loss = out.loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        optimizer.zero_grad(set_to_none=True)
        total += loss.item()
    return total / max(1, len(train_loader))


@torch.no_grad()
def eval_loss(model, val_loader):
    model.eval()
    total = 0.0
    for batch in tqdm(val_loader, desc='val', leave=False):
        batch = {k: v.to(device) for k, v in batch.items()}
        out = model(**batch)
        total += out.loss.item()
    return total / max(1, len(val_loader))


@torch.no_grad()
def generate_text(model, tokenizer, sources):
    model.eval()
    preds = []
    for src in tqdm(list(sources), desc='generate', leave=False):
        enc = tokenizer(
            'simplify: ' + str(src),
            truncation=True,
            max_length=max_length,
            return_tensors='pt',
        ).to(device)
        out_ids = model.generate(
            input_ids=enc['input_ids'],
            attention_mask=enc['attention_mask'],
            max_length=max_length,
            num_beams=num_beams,
            early_stopping=True,
            no_repeat_ngram_size=3,
        )
        preds.append(tokenizer.decode(out_ids[0], skip_special_tokens=True))
    return preds


@torch.no_grad()
def eval_sari(model, tokenizer, df, n: int = 200):
    if n is None or n >= len(df):
        small = df
    else:
        small = df.sample(n=n, random_state=5300).reset_index(drop=True)

    preds = generate_text(model, tokenizer, small['Normal'])

    sari_scores = []
    for pred, src, ref in zip(preds, small['Normal'], small['Simple']):
        s, _ = sari_score(str(src), str(pred), [str(ref)])
        sari_scores.append(s)
    return float(np.mean(sari_scores))


### (Optional / Experimental) Limited attention buffer for real-time decoding

If you want to simulate a *limited attention buffer* during generation, a practical option for T5 is to cap the **decoder self-attention history** to the last *N* generated tokens (a sliding KV-cache window).

- This can reduce **memory** and sometimes **latency** when outputs get long.
- For sentence simplification, outputs are typically short, so the biggest real-time levers are still `num_beams`, `max_length`/`max_new_tokens`, and batching.
- This implementation uses **greedy decoding** (beam search + cache-windowing is more complex).

In [None]:
@torch.no_grad()
def _truncate_t5_past_key_values(past_key_values, window: int):
    """Keep only the last `window` tokens of decoder self-attention KV.

    T5 past_key_values is a list/tuple over layers.
    Each layer typically has: (self_k, self_v, cross_k, cross_v).
    We truncate only self_k/self_v on the sequence-length dimension.
    """
    if past_key_values is None or window is None:
        return past_key_values

    new_past = []
    for layer_past in past_key_values:
        if layer_past is None:
            new_past.append(layer_past)
            continue

        # Expected: (self_k, self_v, cross_k, cross_v)
        if len(layer_past) < 2:
            new_past.append(layer_past)
            continue

        self_k, self_v = layer_past[0], layer_past[1]
        rest = tuple(layer_past[2:])

        # Shapes are usually (batch, heads, seq_len, head_dim)
        if self_k is not None and self_k.ndim >= 3:
            self_k = self_k[:, :, -window:, :]
        if self_v is not None and self_v.ndim >= 3:
            self_v = self_v[:, :, -window:, :]

        new_past.append((self_k, self_v) + rest)

    return tuple(new_past)


@torch.no_grad()
def generate_text_greedy_limited_buffer(model: T5WithEncoderAdapter, tokenizer, sources, *,
                                       max_new_tokens: int = 128,
                                       decoder_buffer_tokens: int | None = 64):
    """Greedy decoding with an optional sliding decoder attention buffer."""
    model.eval()
    preds = []

    for src in tqdm(list(sources), desc='generate(greedy+buffer)', leave=False):
        enc = tokenizer(
            'simplify: ' + str(src),
            truncation=True,
            max_length=max_length,
            return_tensors='pt',
        ).to(device)

        # Compute adapted encoder outputs once
        enc_out = model.base.encoder(
            input_ids=enc['input_ids'],
            attention_mask=enc['attention_mask'],
            return_dict=True,
        )
        adapted = model.adapter(enc_out.last_hidden_state)
        encoder_outputs = BaseModelOutput(last_hidden_state=adapted)

        # T5 uses pad_token_id as decoder start token
        decoder_input_ids = torch.tensor([[tokenizer.pad_token_id]], device=device)
        past = None

        for _ in range(max_new_tokens):
            # Keep decoder_input_ids short too (optional; main savings is truncating `past`)
            if decoder_buffer_tokens is not None:
                dec_ids = decoder_input_ids[:, -decoder_buffer_tokens:]
                past = _truncate_t5_past_key_values(past, decoder_buffer_tokens)
            else:
                dec_ids = decoder_input_ids

            out = model.base(
                encoder_outputs=encoder_outputs,
                attention_mask=enc['attention_mask'],
                decoder_input_ids=dec_ids,
                use_cache=True,
                past_key_values=past,
                return_dict=True,
            )

            next_token = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
            decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1)
            past = out.past_key_values

            if next_token.item() == tokenizer.eos_token_id:
                break

        preds.append(tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True))

    return preds


# Example usage (optional):
# dev_preds_fast = generate_text_greedy_limited_buffer(best_model, best_tokenizer, dev_df['Normal'], max_new_tokens=64, decoder_buffer_tokens=64)
# print(dev_preds_fast[0])


In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = T5WithEncoderAdapter(model_name, bottleneck=adapter_bottleneck, dropout=adapter_dropout)
model.to(device)

train_ds = SimplificationDataset(train_df['Normal'], train_df['Simple'], tokenizer, max_length=max_length)
dev_ds = SimplificationDataset(dev_df['Normal'], dev_df['Simple'], tokenizer, max_length=max_length)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
dev_loader = DataLoader(dev_ds, batch_size=batch_size, shuffle=False)

best_quick_sari = -1.0

# Phase 1: train adapter only
set_base_trainable(model, trainable=False)
opt = AdamW(model.adapter.parameters(), lr=lr_adapter_phase1)
total_steps = len(train_loader) * phase1_epochs
sched = get_linear_schedule_with_warmup(opt, num_warmup_steps=max(1, int(0.1 * total_steps)), num_training_steps=total_steps)

for epoch in range(phase1_epochs):
    tr_loss = train_one_epoch(model, train_loader, opt, sched)
    va_loss = eval_loss(model, dev_loader)
    quick_sari = eval_sari(model, tokenizer, dev_df, n=quick_dev_eval_n)
    print(f'[phase1 epoch {epoch+1}/{phase1_epochs}] train_loss={tr_loss:.4f} dev_loss={va_loss:.4f} quick_dev_SARI={quick_sari:.2f}')
    if quick_sari > best_quick_sari:
        best_quick_sari = quick_sari
        model.save(output_dir, tokenizer)
        print(f'  saved best so far -> {output_dir} (quick_dev_SARI={best_quick_sari:.2f})')

# Phase 2: fine-tune full model + adapter
set_base_trainable(model, trainable=True)
opt = AdamW(model.parameters(), lr=lr_full_phase2)
total_steps = len(train_loader) * phase2_epochs
sched = get_linear_schedule_with_warmup(opt, num_warmup_steps=max(1, int(0.1 * total_steps)), num_training_steps=total_steps)

for epoch in range(phase2_epochs):
    tr_loss = train_one_epoch(model, train_loader, opt, sched)
    va_loss = eval_loss(model, dev_loader)
    quick_sari = eval_sari(model, tokenizer, dev_df, n=quick_dev_eval_n)
    print(f'[phase2 epoch {epoch+1}/{phase2_epochs}] train_loss={tr_loss:.4f} dev_loss={va_loss:.4f} quick_dev_SARI={quick_sari:.2f}')
    if quick_sari > best_quick_sari:
        best_quick_sari = quick_sari
        model.save(output_dir, tokenizer)
        print(f'  saved best so far -> {output_dir} (quick_dev_SARI={best_quick_sari:.2f})')

print('best quick dev SARI:', best_quick_sari)


In [None]:
# Reload best saved model + tokenizer, then run full dev evaluation
best_model = T5WithEncoderAdapter.load(output_dir).to(device)
best_tokenizer = AutoTokenizer.from_pretrained(output_dir)

dev_preds = generate_text(best_model, best_tokenizer, dev_df['Normal'])

sari_scores, keep_scores, del_scores, add_scores = [], [], [], []
for pred, src, ref in zip(dev_preds, dev_df['Normal'], dev_df['Simple']):
    sari, comp = sari_score(str(src), str(pred), [str(ref)])
    sari_scores.append(sari)
    keep_scores.append(comp['keep'])
    del_scores.append(comp['delete'])
    add_scores.append(comp['add'])

print('='*50)
print('Extension3 (T5+Adapter) DEV SARI Score Results')
print('='*50)
print(f'Number of samples: {len(sari_scores)}')
print()
print(f'  SARI:        {np.mean(sari_scores):.2f}')
print(f'    - Keep:    {np.mean(keep_scores):.2f}')
print(f'    - Delete:  {np.mean(del_scores):.2f}')
print(f'    - Add:     {np.mean(add_scores):.2f}')
print('='*50)


In [None]:
# --- Results add-on: limited decoder self-attention context (sliding window) ---
# This compares the existing “normal” generation (dev_preds, via beam search) to greedy decoding
# with/without a limited decoder self-attention buffer.

import time
import importlib
import sys
import subprocess


def _ensure_local(pkg, pip_name=None):
    try:
        importlib.import_module(pkg)
    except ImportError:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', pip_name or pkg])


_ensure_local('matplotlib', 'matplotlib')

import matplotlib.pyplot as plt


@torch.no_grad()
def _truncate_t5_self_kv_past(past_key_values, window):
    """Truncate only decoder self-attn KV history to last `window` tokens."""
    if past_key_values is None or window is None:
        return past_key_values

    new_past = []
    for layer_past in past_key_values:
        # Expected T5 format per-layer: (self_k, self_v, cross_k, cross_v)
        if layer_past is None:
            new_past.append(layer_past)
            continue

        if len(layer_past) < 2:
            new_past.append(layer_past)
            continue

        self_k, self_v = layer_past[0], layer_past[1]
        rest = tuple(layer_past[2:])

        # Typical shape: (batch, heads, seq_len, head_dim)
        if self_k is not None and getattr(self_k, 'ndim', 0) >= 3:
            self_k = self_k[:, :, -window:, :]
        if self_v is not None and getattr(self_v, 'ndim', 0) >= 3:
            self_v = self_v[:, :, -window:, :]

        new_past.append((self_k, self_v) + rest)

    return tuple(new_past)


@torch.no_grad()
def generate_greedy_with_optional_window(model, tokenizer, sources, *, max_new_tokens=128, window=None):
    """Greedy decoding; if `window` is set, limit decoder self-attn KV history."""
    model.eval()

    preds = []
    t0 = time.perf_counter()

    for src in tqdm(list(sources), desc=f'generate(greedy, window={window})', leave=False):
        enc = tokenizer(
            'simplify: ' + str(src),
            truncation=True,
            max_length=max_length,
            return_tensors='pt',
        ).to(device)

        # Compute adapted encoder outputs once
        enc_out = model.base.encoder(
            input_ids=enc['input_ids'],
            attention_mask=enc['attention_mask'],
            return_dict=True,
        )
        adapted = model.adapter(enc_out.last_hidden_state)
        encoder_outputs = BaseModelOutput(last_hidden_state=adapted)

        # T5 uses pad_token_id as decoder start token
        generated = torch.tensor([[tokenizer.pad_token_id]], device=device)
        past = None

        for _ in range(max_new_tokens):
            if past is None:
                decoder_input_ids = generated
            else:
                decoder_input_ids = generated[:, -1:]

            out = model.base(
                encoder_outputs=encoder_outputs,
                attention_mask=enc['attention_mask'],
                decoder_input_ids=decoder_input_ids,
                use_cache=True,
                past_key_values=past,
                return_dict=True,
            )

            next_token = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
            generated = torch.cat([generated, next_token], dim=1)
            past = out.past_key_values
            past = _truncate_t5_self_kv_past(past, window)

            if next_token.item() == tokenizer.eos_token_id:
                break

        preds.append(tokenizer.decode(generated[0], skip_special_tokens=True))

    t1 = time.perf_counter()
    ms_per_example = (t1 - t0) * 1000.0 / max(1, len(preds))
    return preds, ms_per_example


def sari_components_for_preds(preds, df):
    sari_list, keep_list, del_list, add_list = [], [], [], []
    for pred, src, ref in zip(preds, df['Normal'], df['Simple']):
        s, comp = sari_score(str(src), str(pred), [str(ref)])
        sari_list.append(s)
        keep_list.append(comp['keep'])
        del_list.append(comp['delete'])
        add_list.append(comp['add'])
    return (np.array(sari_list), np.array(keep_list), np.array(del_list), np.array(add_list))


# 1) Existing “normal” dev_preds (beam search)
beam_sari, beam_keep, beam_del, beam_add = sari_components_for_preds(dev_preds, dev_df)

# 2) Greedy full context (window=None)
greedy_full_preds, greedy_full_ms = generate_greedy_with_optional_window(
    best_model, best_tokenizer, dev_df['Normal'], max_new_tokens=max_length, window=None
)
greedy_full_sari, greedy_full_keep, greedy_full_del, greedy_full_add = sari_components_for_preds(greedy_full_preds, dev_df)

# 3) Greedy limited decoder self-attn context (e.g., last 64 tokens)
window_tokens = 64
limited_preds, limited_ms = generate_greedy_with_optional_window(
    best_model, best_tokenizer, dev_df['Normal'], max_new_tokens=max_length, window=window_tokens
)
limited_sari, limited_keep, limited_del, limited_add = sari_components_for_preds(limited_preds, dev_df)


# Summary table
summary = pd.DataFrame([
    {
        'method': f'beam (normal, num_beams={num_beams})',
        'SARI': beam_sari.mean(),
        'KEEP': beam_keep.mean(),
        'DELETE': beam_del.mean(),
        'ADD': beam_add.mean(),
        'ms/example (measured)': np.nan,
    },
    {
        'method': 'greedy (full context)',
        'SARI': greedy_full_sari.mean(),
        'KEEP': greedy_full_keep.mean(),
        'DELETE': greedy_full_del.mean(),
        'ADD': greedy_full_add.mean(),
        'ms/example (measured)': greedy_full_ms,
    },
    {
        'method': f'greedy (decoder window={window_tokens})',
        'SARI': limited_sari.mean(),
        'KEEP': limited_keep.mean(),
        'DELETE': limited_del.mean(),
        'ADD': limited_add.mean(),
        'ms/example (measured)': limited_ms,
    },
]).set_index('method')

summary


# Plots
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Mean SARI bar
summary['SARI'].plot(kind='bar', ax=axes[0], title='DEV mean SARI')
axes[0].set_ylabel('SARI')
axes[0].grid(axis='y', alpha=0.3)

# Distribution comparison: greedy full vs limited
axes[1].hist(greedy_full_sari, bins=30, alpha=0.6, label='greedy full')
axes[1].hist(limited_sari, bins=30, alpha=0.6, label=f'greedy window={window_tokens}')
axes[1].set_title('Per-example SARI distribution (greedy)')
axes[1].set_xlabel('SARI')
axes[1].set_ylabel('count')
axes[1].legend()

# Delta plot (limited - full) on greedy
delta = limited_sari - greedy_full_sari
axes[2].hist(delta, bins=30, alpha=0.8)
axes[2].axvline(0, color='k', linewidth=1)
axes[2].set_title('ΔSARI = windowed - full (greedy)')
axes[2].set_xlabel('ΔSARI')
axes[2].set_ylabel('count')

plt.tight_layout()
plt.show()


# A small per-example table (top changes)
per_example = pd.DataFrame({
    'Normal': dev_df['Normal'],
    'Reference': dev_df['Simple'],
    'beam_pred': dev_preds,
    'greedy_full_pred': greedy_full_preds,
    'greedy_window_pred': limited_preds,
    'greedy_full_SARI': greedy_full_sari,
    'greedy_window_SARI': limited_sari,
    'delta_SARI(window-full)': delta,
})

per_example.sort_values('delta_SARI(window-full)', ascending=True).head(10)


In [None]:
# Generate outputs on test.csv and write to a file (one simplification per line)
test_preds = generate_text(best_model, best_tokenizer, test_df['Normal'])
out_path = 'outputs_extension3.txt'
with open(out_path, 'w', encoding='utf-8') as f:
    for p in test_preds:
        f.write(p.strip() + '\n')
print('wrote:', out_path, 'num_lines:', len(test_preds))

# Optional: evaluate on this repo's test.csv (it includes Simple references here)
sari_scores = []
for pred, src, ref in zip(test_preds, test_df['Normal'], test_df['Simple']):
    s, _ = sari_score(str(src), str(pred), [str(ref)])
    sari_scores.append(s)
print('TEST SARI (local, if references exist):', float(np.mean(sari_scores)))


In [None]:
# Zip the trained model directory for easy submission/sharing
import shutil
zip_base = 't5-simplification-2'
zip_path = shutil.make_archive(zip_base, 'zip', root_dir=output_dir)
zip_path
