In [1]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from torch.amp import autocast
from datasets import load_dataset, Dataset
from tqdm import tqdm
from transformers import logging
import json
from math import ceil
import numpy as np
import evaluate
logging.set_verbosity_error()

In [2]:
# --- Config ---
MODEL_NAME = "google/bigbird-pegasus-large-arxiv"
BATCH_SIZE = 36
MAX_INPUT_LEN = 2000
MAX_OUTPUT_LEN = 600
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TEXT_FIELD = "text"  # adjust if your input field is different
ID_FIELD = "title"
OUTPUT_FILE = "billsum_test_mmr_bigbird_pred.jsonl"
MMR_FILE = "billsum_test_mmr.jsonl"

In [3]:
# --- Load Model and Tokenizer ---
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
model.eval()

OptimizedModule(
  (_orig_mod): BigBirdPegasusForConditionalGeneration(
    (model): BigBirdPegasusModel(
      (shared): BigBirdPegasusScaledWordEmbedding(96103, 1024, padding_idx=0)
      (encoder): BigBirdPegasusEncoder(
        (embed_tokens): BigBirdPegasusScaledWordEmbedding(96103, 1024, padding_idx=0)
        (embed_positions): BigBirdPegasusLearnedPositionalEmbedding(4096, 1024)
        (layers): ModuleList(
          (0-15): 16 x BigBirdPegasusEncoderLayer(
            (self_attn): BigBirdPegasusEncoderAttention(
              (self): BigBirdPegasusBlockSparseAttention(
                (query): Linear(in_features=1024, out_features=1024, bias=False)
                (key): Linear(in_features=1024, out_features=1024, bias=False)
                (value): Linear(in_features=1024, out_features=1024, bias=False)
              )
              (output): Linear(in_features=1024, out_features=1024, bias=False)
            )
            (self_attn_layer_norm): LayerNorm((1024,), eps=1e-0

In [5]:
# --- Load Dataset ---
billsum_test = load_dataset("json", data_files="billsum_data/us_test_data_final_OFFICIAL.jsonl")["train"]

# Track already completed summaries
completed_ids = set()
if os.path.exists(OUTPUT_FILE):
    with open(OUTPUT_FILE, "r") as f:
        for line in f:
            try:
                record = json.loads(line)
                completed_ids.add(record[ID_FIELD])
            except:
                continue
            
# Filter out completed records
print(len(completed_ids), "records already summarized")
print("Dataset length before filter:", len(billsum_test))
billsum_test = billsum_test.filter(lambda x: str(x[ID_FIELD]) not in completed_ids)
print("Dataset length after filter:", len(billsum_test))

# Exit early if everything is done
if len(billsum_test) == 0:
    print("All entries have already been summarized.")
    exit()

2664 records already summarized
Dataset length before filter: 3269


Filter:   0%|          | 0/3269 [00:00<?, ? examples/s]

Dataset length after filter: 605


In [6]:
# Load MMR summaries
mmr = []
if os.path.exists(MMR_FILE):
    with open(MMR_FILE, "r") as f:
        for line in f:
            try:
                record = json.loads(line)
                mmr.append(record)
            except:
                continue
                
len(mmr)

3269

In [7]:
data_dict = {
    ID_FIELD: [],
    TEXT_FIELD: [],
    "summary": []
}
for data in tqdm(billsum_test):
    for record in mmr:
        if record[ID_FIELD] == data[ID_FIELD]:
            data_dict[ID_FIELD].append(record[ID_FIELD])
            data_dict[TEXT_FIELD].append(record[TEXT_FIELD])
            data_dict["summary"].append(data["summary"])
            break
       

dataset = Dataset.from_dict(data_dict)
print("Records found:", len(dataset))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 605/605 [00:00<00:00, 4571.05it/s]

Records found: 605





In [8]:
# Tokenize all texts
def preprocess(example):
    tokens = tokenizer(
        example["text"],
        truncation=True,
        max_length=MAX_INPUT_LEN,
        padding=False
    )
    return {
        "input_ids": tokens["input_ids"],
        "input_len": len(tokens["input_ids"]),
        "attention_mask": tokens["attention_mask"]
    }

print("Tokenizing...")
dataset = dataset.map(preprocess, remove_columns=["text"])
dataset = dataset.sort("input_len", reverse=True)

Tokenizing...


Map:   0%|          | 0/605 [00:00<?, ? examples/s]

In [9]:
# --- Run Inference ---
num_batches = ceil(len(dataset) / BATCH_SIZE)
summaries = []
all_titles = []
with open(OUTPUT_FILE, "a") as outfile:
    for i in tqdm(range(num_batches), desc="Generating summaries"):
        # Extract batch from dataset
        start = i * BATCH_SIZE
        end = min((i + 1) * BATCH_SIZE, len(dataset))
        batch = dataset.select(range(start, end))
        
        # Pad to max length in batch
        padded = tokenizer.pad(
            {
                "input_ids": batch["input_ids"],
                "attention_mask": batch["attention_mask"]
            },
            return_tensors="pt"
        ).to(DEVICE)

        with torch.no_grad(), autocast("cuda"):
            outputs = model.generate(
                input_ids=padded["input_ids"],
                attention_mask=padded["attention_mask"],
                max_length=MAX_OUTPUT_LEN,
                num_beams=4,
                do_sample=False,
                early_stopping=True,
                no_repeat_ngram_size=3
            )

        decoded = tokenizer.batch_decode(outputs, skip_special_tokens = True)

        for example, summary in zip(batch, decoded):
            json.dump({ID_FIELD: example[ID_FIELD], "prediction": summary, "reference": example["summary"]}, outfile)
            outfile.write("\n")
        outfile.flush()

Generating summaries: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [21:16<00:00, 75.10s/it]


In [10]:
predictions = []
references = []
if os.path.exists(OUTPUT_FILE):
    with open(OUTPUT_FILE, "r") as f:
        for line in f:
            try:
                record = json.loads(line)
                predictions.append(record["prediction"])
                references.append(record["reference"])
            except:
                continue
                
print(len(predictions))

3269


In [11]:
rouge = evaluate.load('rouge')

results = rouge.compute(predictions=predictions, references=references)
print(results)

Downloading builder script: 0.00B [00:00, ?B/s]

{'rouge1': 0.17589358357783513, 'rouge2': 0.024158811870154152, 'rougeL': 0.11686338195226312, 'rougeLsum': 0.13006141959853468}


In [12]:
bertscore = evaluate.load("bertscore")

results = bertscore.compute(predictions=predictions, references=references, lang="en")
# print(results)

Downloading builder script: 0.00B [00:00, ?B/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

In [13]:
print(f'Average precision: {np.mean(results["precision"])}')
print(f'Average recall: {np.mean(results["recall"])}')
print(f'Average f1: {np.mean(results["f1"])}')

Average precision: 0.7980728778280444
Average recall: 0.7877096881301953
Average f1: 0.7923708279113151
