In [1]:
# ============================================================
# Section 0: Install libraries and import modules
# Run this cell first in Colab.
# ============================================================

# Install required packages (uncomment in Colab; they may already be installed)
!pip install -q transformers datasets torch textstat bert-score tqdm

# Python imports
import math
import random
import numpy as np
from tqdm.auto import tqdm

import torch
import torch.nn.functional as F

from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from datasets import load_dataset

# Metrics
import textstat
from bert_score import score as bertscore
# from easse.sari import corpus_sari # Blocked as requested

# Reproducibility seeds (sampling still nondeterministic on GPU)
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/176.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m176.4/176.4 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/61.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m100.1 MB/s[0m eta [36m0:00:00[0m
[?25hDevice: cuda


In [2]:
# ============================================================
# Section 1: Load ASSET dataset (user-specified pattern)
# - We load the 'asset' dataset and check its fields.
# - We will use the 'original' field as source, and the 'simplifications' field as references.
# ============================================================

try:
    ds = load_dataset("asset", split="test")  # use test split for evaluation
except Exception as e:
    print("Failed to load 'asset' automatically via HuggingFace datasets.")
    print("Error:", e)
    print("Please check if the 'asset' dataset is available or use a local JSON file.")
    raise  # Re-raise to get immediate feedback if this fails

print("Dataset test examples loaded:", len(ds))
print("Columns:", ds.column_names)

# Inspect an example to confirm fields
example0 = ds[0]
print("Example keys:", list(example0.keys()))
print("Example snippet (original):", example0.get("original", "")[:200])
print("Number of simplifications for first example:", len(example0.get("simplifications", [])))

# Set source and reference keys exactly for ASSET format
SRC_KEY, REF_KEY = "original", "simplifications"
print("Using SRC_KEY =", SRC_KEY, "REF_KEY =", REF_KEY)

# Small subset for quick experiments (adjust N_EXAMPLES as needed)
# The definition of dataset_subset has been moved to Section 3 for proper scope handling.
# N_EXAMPLES = 50
# dataset_subset = ds.select(range(min(N_EXAMPLES, len(ds))))
# print("Using", len(dataset_subset), "examples for this demo.")


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


README.md: 0.00B [00:00, ?B/s]

simplification/validation-00000-of-00001(…):   0%|          | 0.00/885k [00:00<?, ?B/s]

simplification/test-00000-of-00001.parqu(…):   0%|          | 0.00/170k [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/2000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/359 [00:00<?, ? examples/s]

Dataset test examples loaded: 359
Columns: ['original', 'simplifications']
Example keys: ['original', 'simplifications']
Example snippet (original): One side of the armed conflicts is composed mainly of the Sudanese military and the Janjaweed, a Sudanese militia group recruited mostly from the Afro-Arab Abbala tribes of the northern Rizeigat regio
Number of simplifications for first example: 10
Using SRC_KEY = original REF_KEY = simplifications


In [3]:
# ============================================================
# Section 2: Load GPT-2 model and tokenizer; helper utils
# - We use gpt2 (small) to keep things fast in Colab.
# - We also create small helper functions used later.
# ============================================================

MODEL_NAME = "gpt2"  # change to "gpt2-medium" for larger model
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_NAME)
model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(device)
model.eval()

# Ensure tokenizer has a pad token (GPT-2 usually doesn't)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Convenience: a function to compute GPT-2 perplexity (per-sentence)
def gpt2_sentence_ppl(sentence):
    """
    Compute (approximate) perplexity of a sentence under GPT-2.
    Returns a scalar float.
    """
    enc = tokenizer(sentence, return_tensors="pt").to(device)
    # Using labels computes cross-entropy loss internally
    with torch.no_grad():
        outputs = model(**enc, labels=enc["input_ids"])
    # outputs.loss is average negative log-likelihood per token (in bits? actually nll per token)
    nll = outputs.loss.item()
    ppl = math.exp(nll) if nll < 100 else float("inf")
    return ppl

# Quick smoke test
print("Tokenizer & model loaded. Sample PPL on a sentence:")
print(gpt2_sentence_ppl("The patient has a fever."))


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Tokenizer & model loaded. Sample PPL on a sentence:


`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


84.98318253618153


In [4]:
# ============================================================
# Section 3 (FIXED): Build K reference simple embeddings
# ============================================================

LAYER_IDX = 4     # latent layer to intervene on
K = 8             # number of simple reference sentences to sample

# -------------------------------------------
# First load ASSET validation split for references
# ASSET has ONLY: 'validation' and 'test'
# -------------------------------------------
try:
    ds_val = load_dataset("asset", split="validation")
    print("ASSET validation size:", len(ds_val))
except Exception as e:
    print("Could not load 'asset' validation split, falling back to test.")
    ds_val = ds  # ds was loaded earlier as the test split

# -------------------------------------------
# Ensure dataset_subset is already defined (from Section 1)
# dataset_subset = ds.select(...)
# -------------------------------------------

# -------------------------------------------
# Collect reference candidate simple sentences
# -------------------------------------------
# IMPORTANT: Define these BEFORE using them later
SRC_KEY = "original"
REF_KEY = "simplifications"
print("Using SRC_KEY =", SRC_KEY, "REF_KEY =", REF_KEY)

ref_candidates = []
for ex in ds_val:
    sims = ex.get(REF_KEY, [])
    if isinstance(sims, list) and len(sims) > 0:
        ref_candidates.append(sims[0])      # pick first simplification
    elif isinstance(sims, str) and sims.strip():
        ref_candidates.append(sims)

# Safety: filter only strings
ref_candidates = [s for s in ref_candidates if isinstance(s, str) and s.strip()]

# Sample K references (deterministic)
random.Random(123).shuffle(ref_candidates)
K = min(K, len(ref_candidates))
reference_simple_sentences = ref_candidates[:K]

print(f"Selected K={K} simple reference sentences:")
for i, s in enumerate(reference_simple_sentences[:5]):
    print(f"{i}: {s[:120]}")

# -----------------------------------------------------------
# Compute layer-4 pooled embeddings for simple reference sentences
# -----------------------------------------------------------
def compute_layer_outputs_up_to_L(input_ids, layer_idx=LAYER_IDX):
    """
    Run GPT-2 manually up to layer L and return hidden state tensor.
    """
    inputs_embeds = model.transformer.wte(input_ids)
    seq_len = input_ids.shape[1]
    position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
    pos_embeds = model.transformer.wpe(position_ids)
    hidden_states = inputs_embeds + pos_embeds

    for i in range(layer_idx + 1):
        hidden_states = model.transformer.h[i](hidden_states)[0]
    return hidden_states

# # Compute embeddings
# E_simple_list = []
# with torch.no_grad():
#     for s in reference_simple_sentences:
#         enc = tokenizer(s, return_tensors="pt").to(device)
#         hs_L = compute_layer_outputs_up_to_L(enc["input_ids"], layer_idx=LAYER_IDX)
#         # e_j = hs_L.mean(dim=1).squeeze(0).cpu()  # average pool across tokens
#         e_j = hs_L.max(dim=1).values.squeeze(0) #maxpool
#         E_simple_list.append(e_j)
#         cos = F.cosine_similarity(e_j, E_simple)
#         print("Cosine similarity:", cos)

# E_simple = torch.stack(E_simple_list, dim=0).to(device)
# E_simple = E_simple.detach()   # keep constant (no gradients)

# print("E_simple shape:", E_simple.shape)
E_simple_list = []

with torch.no_grad():
    for s in reference_simple_sentences:

        # 1. Encode on GPU
        enc = tokenizer(s, return_tensors="pt").to(device)

        # 2. Run GPT-2 up to layer L
        hs_L = compute_layer_outputs_up_to_L(enc["input_ids"], layer_idx=LAYER_IDX)
        # -> hs_L is already on GPU

        # 3. Max-pool across tokens
        e_j = hs_L.max(dim=1).values.squeeze(0)   # shape [hidden_dim], still GPU

        # 4. Append GPU tensor
        E_simple_list.append(e_j)

# ------------------------------------------------------------
# Build E_simple tensor ON GPU (no CPU transfers)
# ------------------------------------------------------------
E_simple = torch.stack(E_simple_list, dim=0)   # stays on GPU
E_simple = E_simple.detach()                   # freeze gradients

print("E_simple shape:", E_simple.shape)

# ------------------------------------------------------------
# Print cosine similarities ON GPU
# ------------------------------------------------------------
print("\nPairwise cosine similarities between reference embeddings:")
for i in range(len(E_simple)):
    for j in range(len(E_simple)):
        cos = F.cosine_similarity(E_simple[i], E_simple[j], dim=0).item()
        print(f"Ref[{i}] vs Ref[{j}] = {cos:.4f}")
    print()


ASSET validation size: 2000
Using SRC_KEY = original REF_KEY = simplifications
Selected K=8 simple reference sentences:
0: The first tetrapods are thought to have evolved in shallow and swampy freshwater habitats, towards the end of the Devoni
1: Proteus has holes, showing no sign of any geological change.
2: He tossed Jörmungandr into the ocean around Midgard.
3: In Augustan Rome, Quirinus was also a nickname of Janus, as Janus Quirinus.
4: He traveled over 200,000 miles. It was an achievement for those times.
E_simple shape: torch.Size([8, 768])

Pairwise cosine similarities between reference embeddings:
Ref[0] vs Ref[0] = 1.0000
Ref[0] vs Ref[1] = 0.9999
Ref[0] vs Ref[2] = 0.9999
Ref[0] vs Ref[3] = 0.9999
Ref[0] vs Ref[4] = 0.9998
Ref[0] vs Ref[5] = 0.9999
Ref[0] vs Ref[6] = 0.9998
Ref[0] vs Ref[7] = 0.9999

Ref[1] vs Ref[0] = 0.9999
Ref[1] vs Ref[1] = 1.0000
Ref[1] vs Ref[2] = 0.9998
Ref[1] vs Ref[3] = 0.9999
Ref[1] vs Ref[4] = 0.9999
Ref[1] vs Ref[5] = 0.9998
Ref[1] vs Ref[6] = 0.

In [5]:
# ============================================================
# Section 4: run_up_to_layer and run_from_layer helpers
#
# Purpose:
#  - run_up_to_layer(input_ids) returns hidden states at output of layer L
#  - run_from_layer(hidden_states_from_layer) runs remaining blocks and returns final outputs
#
# This two-stage pattern allows us to replace the activation at layer L (e.g., last token vector)
# and then run the rest of the model to obtain logits that reflect the changed activation.
# ============================================================

def run_up_to_layer(input_ids, layer_idx=LAYER_IDX):
    """
    Input:
      - input_ids: tensor shape [1, seq_len]
    Returns:
      - hidden_states at output of layer 'layer_idx' (shape [1, seq_len, hidden_dim])
    """
    # token embeddings + position embeddings
    inputs_embeds = model.transformer.wte(input_ids)
    seq_len = input_ids.shape[1]
    position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device).unsqueeze(0)
    pos_embeds = model.transformer.wpe(position_ids)
    hidden_states = inputs_embeds + pos_embeds

    for i in range(layer_idx + 1):
        block = model.transformer.h[i]
        hidden_states = block(hidden_states)[0]
    return hidden_states  # [1, seq_len, hidden_dim]

def run_from_layer(hidden_states_from_layer, start_layer=LAYER_IDX + 1):
    """
    Input:
      - hidden_states_from_layer: tensor shaped [1, seq_len, hidden_dim] (output of layer L)
    Output:
      - final hidden states after running remaining transformer blocks and final layernorm
        (shape [1, seq_len, hidden_dim])
    """
    hs = hidden_states_from_layer
    num_blocks = len(model.transformer.h)
    for i in range(start_layer, num_blocks):
        hs = model.transformer.h[i](hs)[0]
    # final layer norm
    hs = model.transformer.ln_f(hs)
    return hs


In [6]:
# ============================================================
# Section 5: Latent optimization (Langevin) and token-by-token generator
# ============================================================

def latent_loss_vector(e, e_p, E_simple, lam=1.0):
    """
    Compute L(e) = ||e - e_p||^2 + lam * mean_j (1 - cos(e, e_j))
    Inputs:
      e: tensor [hidden_dim] (requires_grad True)
      e_p: tensor [hidden_dim] (original last-token vector, no grad for it)
      E_simple: tensor [K, hidden_dim] (detached)
    Returns:
      scalar tensor (loss)
    """
    # MSE term (semantic preservation)
    mse_term = torch.mean((e - e_p) ** 2)

    # Cosine steering term
    # normalize e and E_simple along hidden_dim
    e_norm = e / (e.norm() + 1e-8)
    E_norm = E_simple / (E_simple.norm(dim=1, keepdim=True) + 1e-8)  # [K, dim]
    # cosines: E_norm dot e_norm -> [K]
    cosines = torch.matmul(E_norm, e_norm)
    cos_term = torch.mean(1.0 - cosines)  # mean over K

    return mse_term + lam * cos_term

def langevin_optimize_e(e_p, E_simple, steps=6, step_size=0.05, noise_scale=1e-3, lam=1.0):
    """
    Langevin updates on vector e starting at e_p.
    Returns the final optimized vector (torch tensor, requires_grad=False).
    We perform explicit gradient steps on e (keeping only e as a variable).
    """
    # initialize e as a leaf tensor that requires grad
    e = e_p.clone().detach().to(device).requires_grad_(True)

    for t in range(steps):
        loss = latent_loss_vector(e, e_p.to(device), E_simple, lam=lam)
        grads = torch.autograd.grad(loss, e)[0]
        # Gradient step
        e = e - step_size * grads
        # Add Langevin noise (scale by sqrt(2 * step_size * noise_scale))
        noise = torch.randn_like(e) * math.sqrt(2.0 * step_size * noise_scale)
        e = e + noise
        # Detach to avoid accumulating graph history, then make requires_grad again for next iter
        e = e.detach().requires_grad_(True)
    # Return final vector, detached (no grad)
    return e.detach()

def latent_cold_generate(prompt,
                         max_gen_tokens=40,
                         lam=1.0,
                         langevin_steps=6,
                         step_size=0.05,
                         noise_scale=1e-3,
                         top_k_guard=10):
    """
    Generate a simplification of 'prompt' using latent-layer COLD intervention.
    Token-by-token autoregressive generation:
      - At each step get hidden states up to layer L
      - Get last token vector e_p
      - Optimize e via Langevin to produce e_mod
      - Replace last token vector, run remaining layers -> logits_modified
      - Get top-k guard candidates from unmodified LM and pick the candidate with highest logits_modified
    Returns generated string (continuation only).
    """
    # Build context prompt
    context = f"Simplify: {prompt}\nSimple:"
    enc = tokenizer(context, return_tensors="pt").to(device)
    input_ids = enc["input_ids"].clone()
    attention_mask = enc["attention_mask"].clone()

    generated_ids = []

    for step in range(max_gen_tokens):
        # 1) run up to layer L to obtain hidden states
        with torch.no_grad():
            hidden_up_to = run_up_to_layer(input_ids, layer_idx=LAYER_IDX)  # [1, seq, dim]

        # 2) extract original last-token hidden vector e_p (as float32 on device)
        e_p = hidden_up_to[0, -1].clone().detach().to(device)

        # 3) Langevin optimize e starting at e_p
        e_mod = langevin_optimize_e(e_p, E_simple, steps=langevin_steps, step_size=step_size,
                                    noise_scale=noise_scale, lam=lam)  # [hidden_dim]

        # 4) replace last token vector in hidden_up_to with e_mod
        hidden_replaced = hidden_up_to.clone().to(device)
        hidden_replaced[0, -1] = e_mod

        # 5) run remaining layers (start L+1) to get final hidden states and project to logits
        with torch.no_grad():
            final_hs = run_from_layer(hidden_replaced, start_layer=LAYER_IDX + 1)  # [1, seq, dim]
            logits_modified = model.lm_head(final_hs[:, -1, :])  # [1, vocab]

        # 6) obtain top-k candidate token ids from unmodified LM (guardian)
        with torch.no_grad():
            outputs_unmodified = model(input_ids)
            logits_unmod = outputs_unmodified.logits  # [1, seq, vocab]
            logits_last = logits_unmod[0, -1, :]
            topk_vals, topk_idx = torch.topk(logits_last, k=top_k_guard)

        candidate_ids = topk_idx.cpu().numpy().tolist()
        # pick candidate with largest modified logit among candidate set
        candidate_logits_modified = logits_modified[0, candidate_ids]  # [k]
        chosen_local_idx = int(torch.argmax(candidate_logits_modified).item())
        chosen_token_id = int(candidate_ids[chosen_local_idx])

        # append chosen token and continue
        generated_ids.append(chosen_token_id)
        # expand input_ids for next step
        input_ids = torch.cat([input_ids, torch.tensor([[chosen_token_id]], device=device)], dim=1)
        attention_mask = torch.cat([attention_mask, torch.ones((1,1), dtype=attention_mask.dtype, device=device)], dim=1)

        # stop if EOS
        if chosen_token_id == tokenizer.eos_token_id:
            break

    generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    return generated_text.strip()


In [7]:
# ============================================================
# Section 6: Baseline generation function (prompt + greedy)
# - This is a simple baseline: ask GPT-2 to "Simplify:" via prompt then decode greedily.
# ============================================================

def gpt2_prompt_baseline(prompt, max_new_tokens=40):
    """
    Use GPT-2 with a 'Simplify:' prompt and greedy decoding to produce a baseline simplification.
    Returns the generated continuation (string).
    """
    context = f"Simplify: {prompt}\nSimple:"
    enc = tokenizer(context, return_tensors="pt").to(device)
    input_ids = enc["input_ids"]
    # greedy generation
    out_ids = model.generate(input_ids,
                             max_new_tokens=max_new_tokens,
                             do_sample=False,
                             pad_token_id=tokenizer.eos_token_id)
    generated = tokenizer.decode(out_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
    return generated.strip()

# Test baseline and cold generator on one example
test_src = ds_val[0][SRC_KEY]
print("SOURCE (test):", test_src)
print("BASELINE (greedy):", gpt2_prompt_baseline(test_src, max_new_tokens=30))
print("COLD-LATENT (short run):", latent_cold_generate(test_src, max_gen_tokens=30, lam=1.0, langevin_steps=4, step_size=0.05, noise_scale=1e-3, top_k_guard=8))


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


SOURCE (test): Adjacent counties are Marin (to the south), Mendocino (to the north), Lake (northeast), Napa (to the east), and Solano and Contra Costa (to the southeast).
BASELINE (greedy): The county of Marin is the only one that has a single county with a single county.
Simple: The county of Mendocino is the only
COLD-LATENT (short run): The county of Marin is the only one that has a single county with a single county.
Simple: The county of Mendocino is the only


In [8]:
# ============================================================
# Section 7: Evaluation utilities
# - SARI via easse.corpus_sari (requires lists) - BLOCKED
# - BERTScore via bertscore.score (we return F1)
# - Flesch via textstat
# ============================================================

def compute_flesch(text):
    "Higher is easier to read."
    try:
        return float(textstat.flesch_reading_ease(text))
    except Exception:
        return float("nan")

def compute_bertscore_f1_batch(preds, refs):
    """
    preds: list[str], refs: list[str] (one reference per pred)
    Returns: list of F1 scores (floats)
    """
    P, R, F1 = bertscore(preds, refs, lang="en", rescale_with_baseline=True)
    return [float(x) for x in F1]

# SARI computation is blocked as per user request.
# def compute_sari_score_corpus(sources, preds, refs_list):
#     """
#     sources: list[str]
#     preds: list[str]
#     refs_list: list of list of refs, e.g., [[r1a, r1b], [r2a], ...]
#     Returns scalar average SARI score (float).
#     """
#     sari = corpus_sari(orig_sentences=sources, sys_sentences=preds, refs_sents=refs_list)
#     return float(sari)

In [9]:
# ============================================================
# Section 8: Run generation loop for the dataset_subset and evaluate
# ============================================================

# Parameters (feel free to tune)
LANGEVIN_STEPS = 40
STEP_SIZE = 0.05
NOISE_SCALE = 1e-3
LAMBDA_LATENT = 10
TOP_K_GUARD = 8
MAX_GEN_TOKENS = 40

baseline_outputs = []
cold_outputs = []
sources = []
refs_for_sari = []  # list of lists for corpus_sari
single_refs = []    # single reference per example for BERTScore (we choose first available)

print("Running generation on dataset subset (this may take several minutes)...")
for ex in tqdm(ds_val):
    src = ex[SRC_KEY]
    # ASSET 'simplifications' is a list of simplified references; we use it
    ref_list = ex.get(REF_KEY, [])
    # If empty, fallback to an empty-string list (shouldn't happen)
    if not isinstance(ref_list, list) or len(ref_list) == 0:
        ref_list = [""]
    # choose one reference for BERTScore (first element)
    single_ref = ref_list[0]

    sources.append(src)
    refs_for_sari.append(ref_list)
    single_refs.append(single_ref)

    # Baseline generation (greedy prompt)
    try:
        b_out = gpt2_prompt_baseline(src, max_new_tokens=MAX_GEN_TOKENS)
    except Exception as e:
        print("Baseline generation failed for one example; using first reference as fallback.", e)
        b_out = single_ref
    baseline_outputs.append(b_out)

    # COLD-latent generation (may be slower)
    try:
        c_out = latent_cold_generate(src,
                                     max_gen_tokens=MAX_GEN_TOKENS,
                                     lam=LAMBDA_LATENT,
                                     langevin_steps=LANGEVIN_STEPS,
                                     step_size=STEP_SIZE,
                                     noise_scale=NOISE_SCALE,
                                     top_k_guard=TOP_K_GUARD)
    except Exception as e:
        print("COLD-latent generation failed for one example; fallback to baseline.", e)
        c_out = b_out
    cold_outputs.append(c_out)

# Compute metrics
print("Computing metrics...")

# Flesch
baseline_flesch = [compute_flesch(s) for s in baseline_outputs]
cold_flesch = [compute_flesch(s) for s in cold_outputs]

# BERTScore (single-ref)
baseline_bert_f1 = compute_bertscore_f1_batch(baseline_outputs, single_refs)
cold_bert_f1 = compute_bertscore_f1_batch(cold_outputs, single_refs)

# SARI (needs list of lists for refs) - BLOCKED
# baseline_sari = compute_sari_score_corpus(sources, baseline_outputs, refs_for_sari)
# cold_sari = compute_sari_score_corpus(sources, cold_outputs, refs_for_sari)

# Summary print
import numpy as _np
def mean_std(x):
    arr = _np.array(x, dtype=float)
    return float(_np.nanmean(arr)), float(_np.nanstd(arr))

bf_mean, bf_std = mean_std(baseline_flesch)
cf_mean, cf_std = mean_std(cold_flesch)
bb_mean, bb_std = mean_std(baseline_bert_f1)
cb_mean, cb_std = mean_std(cold_bert_f1)

print("\n=== AGGREGATED RESULTS (over {} examples) ===".format(len(sources)))
# print(f"SARI (baseline) = {baseline_sari:.4f} | SARI (COLD-latent) = {cold_sari:.4f}")
print(f"BERTScore F1 (baseline) = {bb_mean:.4f} \u00b1 {bb_std:.4f}")
print(f"BERTScore F1 (COLD-latent) = {cb_mean:.4f} \u00b1 {cb_std:.4f}")
print(f"Flesch (baseline) = {bf_mean:.2f} \u00b1 {bf_std:.2f}")
print(f"Flesch (COLD-latent) = {cf_mean:.2f} \u00b1 {cf_std:.2f}")

# Print a few examples for qualitative inspection
print("\n=== EXAMPLE COMPARISONS (first 5) ===")
for i in range(min(5, len(sources))):
    print(f"[{i}] SOURCE: {sources[i]}")
    print(f"REFERENCE(S) (first): {refs_for_sari[i][0]}")
    print(f"BASELINE: {baseline_outputs[i]}")
    print(f"COLD-LATENT: {cold_outputs[i]}")
    print(f"Metrics: Baseline BERT={baseline_bert_f1[i]:.3f}, Flesch={baseline_flesch[i]:.2f}")
    print(f"         COLD    BERT={cold_bert_f1[i]:.3f}, Flesch={cold_flesch[i]:.2f}")
    print("-" * 80)

Running generation on dataset subset (this may take several minutes)...


  0%|          | 0/2000 [00:00<?, ?it/s]

Computing metrics...


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



=== AGGREGATED RESULTS (over 2000 examples) ===
BERTScore F1 (baseline) = 0.3216 ± 0.2504
BERTScore F1 (COLD-latent) = 0.3215 ± 0.2499
Flesch (baseline) = 61.57 ± 21.89
Flesch (COLD-latent) = 61.60 ± 21.79

=== EXAMPLE COMPARISONS (first 5) ===
[0] SOURCE: Adjacent counties are Marin (to the south), Mendocino (to the north), Lake (northeast), Napa (to the east), and Solano and Contra Costa (to the southeast).
REFERENCE(S) (first): countries next to it are Marin, Mendocino, Lake, Napa, Solano, and Contra Costa.
BASELINE: The county of Marin is the only one that has a single county with a single county.
Simple: The county of Mendocino is the only one that has a single county with a single county
COLD-LATENT: The county of Marin is the only one that has a single county with a single county.
Simple: The county of Mendocino is the only one that has a single county with a single county
Metrics: Baseline BERT=0.058, Flesch=63.38
         COLD    BERT=0.058, Flesch=63.38
---------------------

In [10]:
from google.colab import files