First, install required libraries.

In [20]:
# Install required packages for text summarization and evaluation
%pip install --upgrade pip

# Core libs for summarization + evaluation
%pip install transformers datasets evaluate sentencepiece

# Metrics
%pip install rouge-score bert-score sacrebleu

# Sentence Transformers for embeddings
%pip install --quiet sentence-transformers transformers

# Notebook QoL (progress bars/widgets)
%pip install ipywidgets

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


Load and run the model. 

In [21]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch, re, time

MODEL = "facebook/bart-large-cnn"
MAX_IN = 1024  # BART limit

tok = AutoTokenizer.from_pretrained(MODEL, use_fast=True)
tok.model_max_length = MAX_IN

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL,
    torch_dtype=torch.float32
).to(device)
model.eval()

# --- sentence helpers so we can enforce 3–5 sentences ---
_SENT_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")

def keep_3_to_5_sentences(text, min_sents: int = 3, max_sents: int = 5) -> str:
    """Trim the model output so the final summary is ~3–5 sentences."""
    sents = [s.strip() for s in _SENT_SPLIT_RE.split(text.strip()) if s.strip()]
    if len(sents) <= max_sents:
        # If the model happened to produce fewer than 3 sentences, we just keep them.
        # Using higher min_new_tokens below makes that rare.
        return " ".join(sents)
    return " ".join(sents[:max_sents])

# --- core summarization helpers ---
def _summarize_enc(enc, max_new: int = 160, min_new: int = 80) -> str:
    with torch.inference_mode():
        out = model.generate(
            **enc,
            max_new_tokens=max_new,
            min_new_tokens=min_new,
            num_beams=4,
            length_penalty=1.0,
            early_stopping=True,
        )
    raw = tok.batch_decode(out, skip_special_tokens=True)[0].strip()
    return keep_3_to_5_sentences(raw)

def summarize_safe(text: str, max_new: int = 160, min_new: int = 80) -> str:
    enc = tok(
        text,
        truncation=True,
        max_length=MAX_IN,
        return_tensors="pt",
    ).to(device)
    return _summarize_enc(enc, max_new=max_new, min_new=min_new)

def chunk_text(text: str, target_chars: int = 4000):
    """Greedy sentence-based chunking to keep inputs under the model limit."""
    sentences = _SENT_SPLIT_RE.split(text)
    chunks, cur, cur_len = [], [], 0
    for s in sentences:
        s = s.strip()
        if not s:
            continue
        if cur_len + len(s) > target_chars and cur:
            chunks.append(" ".join(cur))
            cur, cur_len = [s], len(s)
        else:
            cur.append(s)
            cur_len += len(s)
    if cur:
        chunks.append(" ".join(cur))
    return chunks

def summarize_long(text: str) -> str:
    """Handle very long articles via chunking then a second-pass summary."""
    chunks = chunk_text(text, target_chars=4000)
    partial_summaries = [summarize_safe(c, max_new=120, min_new=60) for c in chunks]
    stitched = " ".join(partial_summaries)
    # Final pass produces a global 3–5 sentence summary
    return summarize_safe(stitched, max_new=160, min_new=80)

# === load eval slice ===
ds = load_dataset("cnn_dailymail", "3.0.0", split="validation[:200]")

# === run summaries ===
preds, refs = [], []
t0 = time.time()
for ex in ds:
    article = ex["article"]
    if len(article) < 12000:
        pred = summarize_safe(article)
    else:
        pred = summarize_long(article)  # rough char cutoff for "long"
    preds.append(pred)
    refs.append(ex["highlights"])
t1 = time.time()
print(f"Summarized {len(preds)} docs in {t1 - t0:.1f}s")


Summarized 200 docs in 272.9s


Evaluate Model Results.

In [22]:
import evaluate
import numpy as np, pandas as pd, re, torch
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset

# --- Sanity checks ---
assert 'preds' in globals() and 'refs' in globals(), "Run the summarization cell first."
N = len(preds)
assert len(refs) == N, "preds and refs must be the same length."

# --- 1) ROUGE + BERTScore on preds vs reference summaries ---
rouge = evaluate.load("rouge")
berts = evaluate.load("bertscore")

rouge_res = rouge.compute(predictions=preds, references=refs, use_stemmer=True)
bert_res  = berts.compute(predictions=preds, references=refs, lang="en")

print("=== Overlap metrics ===")
print("ROUGE-Lsum F1:", round(rouge_res["rougeLsum"], 4))
print("BERTScore F1:", round(sum(bert_res["f1"]) / N, 4))

# --- 2) Prepare source articles for faithfulness metrics ---
# Recreate the same slice of CNN/DailyMail used for summarization
ds = load_dataset("cnn_dailymail", "3.0.0", split=f"validation[:{N}]")
refs_src = [ex["article"] for ex in ds]
assert len(refs_src) == N

# --- 3) Faithfulness metrics: sentence-level entailment/contradiction ---
embed = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
nli_tok = AutoTokenizer.from_pretrained("roberta-large-mnli")
nli = AutoModelForSequenceClassification.from_pretrained(
    "roberta-large-mnli"
).eval().to("cuda:0" if torch.cuda.is_available() else "cpu")

def split_sents(t: str):
    return [s for s in re.split(r'(?<=[.!?])\s+', t.strip()) if s.strip()]

@torch.inference_mode()
def nli_probs(premise: str, hyp: str):
    enc = nli_tok(premise, hyp, return_tensors="pt",
                  truncation=True, max_length=512).to(nli.device)
    # logits order: [contradiction, neutral, entailment]
    p = nli(**enc).logits.softmax(-1)[0].tolist()
    return p[2], p[0]  # entailment, contradiction

def faithfulness_metrics(summary_text, source_text,
                         top_k=3, ent_th=0.5, con_th=0.3):
    sum_sents = split_sents(summary_text)
    src_sents = split_sents(source_text)
    if not sum_sents or not src_sents:
        return dict(ent_mean=0.0, con_mean=0.0, sent_count=0,
                    ent_pass=0.0, con_pass=0.0, both_pass=0.0)

    # Embed sentences
    sum_emb = embed.encode(sum_sents, convert_to_tensor=True)
    src_emb = embed.encode(src_sents, convert_to_tensor=True)

    # For each summary sentence, find top-k closest source sentences
    hits = []
    for i, s_vec in enumerate(sum_emb):
        cos_scores = util.cos_sim(s_vec, src_emb)[0]
        top_idx = torch.topk(cos_scores, k=min(top_k, len(src_sents))).indices.tolist()
        best_src = " ".join(src_sents[j] for j in top_idx)

        ent_p, con_p = nli_probs(best_src, sum_sents[i])
        hits.append((ent_p, con_p))

    ents = np.array([h[0] for h in hits])
    cons = np.array([h[1] for h in hits])

    ent_mean = float(ents.mean()) if len(ents) else 0.0
    con_mean = float(cons.mean()) if len(cons) else 0.0
    ent_pass = float((ents >= ent_th).mean()) if len(ents) else 0.0
    con_pass = float((cons <= con_th).mean()) if len(cons) else 0.0
    both_pass = float(((ents >= ent_th) & (cons <= con_th)).mean()) if len(ents) else 0.0

    return dict(
        ent_mean=ent_mean,
        con_mean=con_mean,
        sent_count=len(sum_sents),
        ent_pass=ent_pass,
        con_pass=con_pass,
        both_pass=both_pass,
    )

rows = []
for i, (pred, src) in enumerate(zip(preds, refs_src)):
    m = faithfulness_metrics(pred, src, top_k=3, ent_th=0.5, con_th=0.3)
    rows.append(dict(
        idx=i,
        **m,
        pred_len=len(pred.split()),
        src_len=len(src.split())
    ))

df = pd.DataFrame(rows)

print("\n=== Faithfulness metrics (summary vs source article) ===")
print("Entailment mean", df.ent_mean.mean().round(4))
print("Contradiction mean", df.con_mean.mean().round(4))
print("Sentences passing entailment", df.ent_pass.mean().round(4))
print("Sentences passing contradiction", df.con_pass.mean().round(4))
print("Sentences passing both", df.both_pass.mean().round(4))


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


=== Overlap metrics ===
ROUGE-Lsum F1: 0.2735
BERTScore F1: 0.8686


Some weights of the model checkpoint at roberta-large-mnli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).



=== Faithfulness metrics (summary vs source article) ===
Entailment mean 0.7895
Contradiction mean 0.0309
Sentences passing entailment 0.8045
Sentences passing contradiction 0.9694
Sentences passing both 0.8
