In [None]:
import os
import json
import shutil
import subprocess
from typing import List, Dict

import pandas as pd
import torch
from tqdm.auto import tqdm

from transformers import AutoTokenizer, AutoConfig, T5ForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
import torch.nn as nn
from dataclasses import dataclass

In [None]:
# Prefer local `data/test.csv` (no network needed). Fallback: load WikiLarge from HF.

FORCE_HF_TEST = False  # set True to ignore local data/ and use the full HF test split

DATA_TEST_PATH = os.path.join("data", "test.csv")

if (not FORCE_HF_TEST) and os.path.exists(DATA_TEST_PATH):
    test = pd.read_csv(DATA_TEST_PATH)
    missing = {"Normal", "Simple"} - set(test.columns)
    if missing:
        raise ValueError(f"{DATA_TEST_PATH} missing columns: {missing}")

    print(f"Loaded local test split: {DATA_TEST_PATH} ({len(test)} rows)")
    if len(test) <= 1000:
        print(
            "WARNING: This is a small subset test file. "
            "Set FORCE_HF_TEST=True (and ensure Colab has network) to evaluate on the full WikiLarge test split."
        )
else:
    splits = {
        "train": "wiki.full.aner.ori.train.95.tsv",
        "validation": "wiki.full.aner.ori.valid.95.tsv",
        "test": "wiki.full.aner.ori.test.95.tsv",
    }
    test = pd.read_csv(
        "hf://datasets/bogdancazan/wikilarge-text-simplification/" + splits["test"],
        sep="\t",
    )
    print(f"Loaded HF test split ({len(test)} rows)")

test_sources = test["Normal"].astype(str).tolist()
test_refs = test["Simple"].astype(str).tolist()
print("Example source:", test_sources[0][:200])
print("Example ref:", test_refs[0][:200])

In [None]:
# ---- Unzip / prepare trained model folders ----

def ensure_dir_from_zip(zip_path: str, out_dir: str) -> None:
    """Ensure `out_dir` exists by extracting `zip_path` if needed."""
    if os.path.isdir(out_dir):
        return
    if not os.path.exists(zip_path):
        raise FileNotFoundError(f"Missing {zip_path} and {out_dir} does not exist.")

    os.makedirs(out_dir, exist_ok=True)
    print(f"Extracting {zip_path} -> {out_dir} ...")
    shutil.unpack_archive(zip_path, out_dir)


def maybe_flatten_nested_dir(out_dir: str, nested_name: str, required_markers: List[str]) -> None:
    """Handle zips that extract to out_dir/nested_name/* instead of out_dir/*."""
    if all(os.path.exists(os.path.join(out_dir, m)) for m in required_markers):
        return

    nested = os.path.join(out_dir, nested_name)
    if not os.path.isdir(nested):
        return

    if not all(os.path.exists(os.path.join(nested, m)) for m in required_markers):
        return

    print(f"Flattening nested extraction: {nested} -> {out_dir}")
    for name in os.listdir(nested):
        shutil.move(os.path.join(nested, name), os.path.join(out_dir, name))
    shutil.rmtree(nested)


# Baseline checkpoint
BASELINE_DIR = "t5-simplification"
BASELINE_ZIP = "t5-simplification.zip"

# Extension 3 checkpoint (encoder-adapter architecture)
EXT3_DIR = "t5-simplification-2"
EXT3_ZIP = "t5-simplification-2.zip"

# Baseline zip may be archived with a 'content/t5-simplification/' prefix.
if not os.path.isdir(BASELINE_DIR):
    tmp_extract = "_tmp_extract_baseline"
    if os.path.isdir(tmp_extract):
        shutil.rmtree(tmp_extract)
    os.makedirs(tmp_extract, exist_ok=True)

    print(f"Extracting {BASELINE_ZIP} -> {tmp_extract} ...")
    shutil.unpack_archive(BASELINE_ZIP, tmp_extract)

    candidate = os.path.join(tmp_extract, "content", "t5-simplification")
    if os.path.isdir(candidate):
        shutil.move(candidate, BASELINE_DIR)
        shutil.rmtree(tmp_extract)
    else:
        # If the zip already contains the folder at root, extract directly.
        shutil.rmtree(tmp_extract)
        ensure_dir_from_zip(BASELINE_ZIP, BASELINE_DIR)

# Ext3 is expected to extract cleanly, but keep it robust.
ensure_dir_from_zip(EXT3_ZIP, EXT3_DIR)

maybe_flatten_nested_dir(BASELINE_DIR, BASELINE_DIR, required_markers=["config.json"])
maybe_flatten_nested_dir(EXT3_DIR, EXT3_DIR, required_markers=["adapter_config.json", "adapter.pt"])

print("Baseline dir ok:", os.path.isdir(BASELINE_DIR))
print("Ext3 dir ok:", os.path.isdir(EXT3_DIR))

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

In [None]:
# ---- Load baseline model (plain T5 checkpoint) ----

print("Loading baseline model...")
baseline_tokenizer = AutoTokenizer.from_pretrained(BASELINE_DIR, local_files_only=True)
baseline_model = T5ForConditionalGeneration.from_pretrained(BASELINE_DIR, local_files_only=True)
baseline_model.to(DEVICE).eval()
print("Loaded baseline.")

In [None]:
# ---- Load Extension 3 model (T5 + encoder adapter) ----

class EncoderBottleneckAdapter(nn.Module):
    def __init__(self, d_model: int, bottleneck: int = 256, dropout: float = 0.1):
        super().__init__()
        self.down = nn.Linear(d_model, bottleneck)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(dropout)
        self.up = nn.Linear(bottleneck, d_model)
        self.ln = nn.LayerNorm(d_model)
        # Trainable scalar gate initialized near 0 (start close to plain T5)
        self.gate = nn.Parameter(torch.tensor(0.0))

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        delta = self.up(self.dropout(self.act(self.down(h))))
        return self.ln(h + torch.tanh(self.gate) * delta)


@dataclass
class AdapterConfig:
    bottleneck: int
    dropout: float


class T5WithEncoderAdapter(nn.Module):
    def __init__(self, base_model_name: str, bottleneck: int = 256, dropout: float = 0.1):
        super().__init__()
        config = AutoConfig.from_pretrained(base_model_name, local_files_only=True)
        self.base = T5ForConditionalGeneration.from_pretrained(
            base_model_name,
            config=config,
            local_files_only=True,
        )
        d_model = self.base.config.d_model
        self.adapter = EncoderBottleneckAdapter(d_model=d_model, bottleneck=bottleneck, dropout=dropout)
        self.adapter_cfg = AdapterConfig(bottleneck=bottleneck, dropout=dropout)

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        enc = self.base.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )
        h = self.adapter(enc.last_hidden_state)
        encoder_outputs = BaseModelOutput(
            last_hidden_state=h,
            hidden_states=enc.hidden_states,
            attentions=enc.attentions,
        )
        return self.base(
            encoder_outputs=encoder_outputs,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True,
            **kwargs,
        )

    @torch.no_grad()
    def generate(self, input_ids=None, attention_mask=None, **gen_kwargs):
        enc = self.base.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )
        h = self.adapter(enc.last_hidden_state)
        encoder_outputs = BaseModelOutput(
            last_hidden_state=h,
            hidden_states=enc.hidden_states,
            attentions=enc.attentions,
        )
        return self.base.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            encoder_outputs=encoder_outputs,
            **gen_kwargs,
        )

    @classmethod
    def load(cls, output_dir: str):
        base_dir = os.path.join(output_dir, "base")
        with open(os.path.join(output_dir, "adapter_config.json"), "r") as f:
            cfg = json.load(f)
        model = cls(base_model_name=base_dir, bottleneck=cfg["bottleneck"], dropout=cfg["dropout"])
        sd = torch.load(os.path.join(output_dir, "adapter.pt"), map_location="cpu")
        model.adapter.load_state_dict(sd)
        return model


print("Loading Extension 3 model...")
ext3_tokenizer = AutoTokenizer.from_pretrained(EXT3_DIR, local_files_only=True)
ext3_model = T5WithEncoderAdapter.load(EXT3_DIR)
ext3_model.to(DEVICE).eval()
print("Loaded Extension 3.")

In [None]:
# ---- Completion-ratio helpers + batched generation ----

COMPLETION_RATIOS = [0.25, 0.5, 0.75, 1.0]


def get_sentence_prefix(source: str, completion_ratio: float) -> str:
    tokens = str(source).split()
    num_tokens = len(tokens)
    prefix_length = max(1, int(num_tokens * completion_ratio))
    return " ".join(tokens[:prefix_length])


def generate_predictions(
    model,
    tokenizer,
    sources: List[str],
    *,
    completion_ratio: float,
    batch_size: int = 16,
    max_length: int = 128,
    num_beams: int = 4,
) -> List[str]:
    """Generate predictions using prefix-based incremental inputs."""
    preds: List[str] = []

    model.eval()
    for i in tqdm(range(0, len(sources), batch_size), desc=f"Generating ({int(completion_ratio*100)}%)"):
        batch = sources[i : i + batch_size]
        inputs = ["simplify: " + get_sentence_prefix(s, completion_ratio) for s in batch]

        enc = tokenizer(
            inputs,
            max_length=max_length,
            truncation=True,
            padding=True,
            return_tensors="pt",
        ).to(DEVICE)

        with torch.no_grad():
            out_ids = model.generate(
                input_ids=enc["input_ids"],
                attention_mask=enc["attention_mask"],
                max_length=max_length,
                num_beams=num_beams,
                early_stopping=True,
            )

        texts = tokenizer.batch_decode(out_ids, skip_special_tokens=True)
        # Ensure one-line outputs for scoring scripts
        texts = [t.replace("\n", " ").strip() for t in texts]
        preds.extend(texts)

    return preds

In [None]:
# ---- Run both models on the test set (25/50/75/100% completion) ----

baseline_preds_by_ratio: Dict[float, List[str]] = {}
ext3_preds_by_ratio: Dict[float, List[str]] = {}

for r in COMPLETION_RATIOS:
    baseline_preds_by_ratio[r] = generate_predictions(
        baseline_model,
        baseline_tokenizer,
        test_sources,
        completion_ratio=r,
        batch_size=16,
        max_length=128,
        num_beams=4,
    )
    ext3_preds_by_ratio[r] = generate_predictions(
        ext3_model,
        ext3_tokenizer,
        test_sources,
        completion_ratio=r,
        batch_size=16,
        max_length=128,
        num_beams=4,
    )

    assert len(baseline_preds_by_ratio[r]) == len(test_sources)
    assert len(ext3_preds_by_ratio[r]) == len(test_sources)

    print(f"Sample baseline pred ({int(r*100)}%):", baseline_preds_by_ratio[r][0][:200])
    print(f"Sample ext3 pred ({int(r*100)}%):", ext3_preds_by_ratio[r][0][:200])

In [None]:
# ---- Write `output/` bundle (predictions + gold labels + README) ----

OUTPUT_DIR = "output"
BASELINE_OUT_DIR = os.path.join(OUTPUT_DIR, BASELINE_DIR)
EXT3_OUT_DIR = os.path.join(OUTPUT_DIR, EXT3_DIR)

os.makedirs(BASELINE_OUT_DIR, exist_ok=True)
os.makedirs(EXT3_OUT_DIR, exist_ok=True)

# Note: for SARI (and to match the repoâ€™s existing evaluation), `sources.txt` is ALWAYS the full source,
# even when the model input is a prefix. Only the model predictions vary by completion ratio.
sources_path = os.path.join(OUTPUT_DIR, "sources.txt")
refs_path = os.path.join(OUTPUT_DIR, "references.txt")

with open(sources_path, "w", encoding="utf-8") as f:
    f.write("\n".join(test_sources) + "\n")

with open(refs_path, "w", encoding="utf-8") as f:
    f.write("\n".join(test_refs) + "\n")

print("Wrote:")
print("-", sources_path)
print("-", refs_path)

# Write per-ratio predictions + aligned TSVs
baseline_preds_paths: Dict[float, str] = {}
ext3_preds_paths: Dict[float, str] = {}

for r in COMPLETION_RATIOS:
    pct = int(r * 100)
    baseline_ratio_dir = os.path.join(BASELINE_OUT_DIR, f"{pct}")
    ext3_ratio_dir = os.path.join(EXT3_OUT_DIR, f"{pct}")
    os.makedirs(baseline_ratio_dir, exist_ok=True)
    os.makedirs(ext3_ratio_dir, exist_ok=True)

    baseline_preds_path = os.path.join(baseline_ratio_dir, "predictions.txt")
    ext3_preds_path = os.path.join(ext3_ratio_dir, "predictions.txt")

    with open(baseline_preds_path, "w", encoding="utf-8") as f:
        f.write("\n".join(baseline_preds_by_ratio[r]) + "\n")

    with open(ext3_preds_path, "w", encoding="utf-8") as f:
        f.write("\n".join(ext3_preds_by_ratio[r]) + "\n")

    baseline_preds_paths[r] = baseline_preds_path
    ext3_preds_paths[r] = ext3_preds_path

    aligned_tsv = os.path.join(OUTPUT_DIR, f"aligned_{pct}.tsv")
    pd.DataFrame(
        {
            "source": test_sources,
            "reference": test_refs,
            "pred_t5_simplification": baseline_preds_by_ratio[r],
            "pred_t5_simplification_2": ext3_preds_by_ratio[r],
        }
    ).to_csv(aligned_tsv, sep="\t", index=False)

    print(f"- {baseline_preds_path}")
    print(f"- {ext3_preds_path}")
    print(f"- {aligned_tsv}")

In [None]:
# ---- Zip outputs for easy download ----

zip_path = shutil.make_archive("output", "zip", root_dir=OUTPUT_DIR)
print("Created:", zip_path)