In [1]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from torch.amp import autocast
from datasets import load_dataset
from tqdm import tqdm
from transformers import logging
import json
from math import ceil
logging.set_verbosity_error()

In [2]:
# --- Config ---
MODEL_NAME = "./bigbird_pegasus_fine_tune/checkpoint-80"
BATCH_SIZE = 32
MAX_INPUT_LEN = 4096
MAX_OUTPUT_LEN = 600
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TEXT_FIELD = "text"  # adjust if your input field is different
ID_FIELD = "title"
OUTPUT_FILE = "billsum_test_bigbird_tuned_pred.jsonl"

In [3]:
# --- Load Model and Tokenizer ---
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
model.eval()

OptimizedModule(
  (_orig_mod): BigBirdPegasusForConditionalGeneration(
    (model): BigBirdPegasusModel(
      (shared): BigBirdPegasusScaledWordEmbedding(96103, 1024, padding_idx=0)
      (encoder): BigBirdPegasusEncoder(
        (embed_tokens): BigBirdPegasusScaledWordEmbedding(96103, 1024, padding_idx=0)
        (embed_positions): BigBirdPegasusLearnedPositionalEmbedding(4096, 1024)
        (layers): ModuleList(
          (0-15): 16 x BigBirdPegasusEncoderLayer(
            (self_attn): BigBirdPegasusEncoderAttention(
              (self): BigBirdPegasusBlockSparseAttention(
                (query): Linear(in_features=1024, out_features=1024, bias=False)
                (key): Linear(in_features=1024, out_features=1024, bias=False)
                (value): Linear(in_features=1024, out_features=1024, bias=False)
              )
              (output): Linear(in_features=1024, out_features=1024, bias=False)
            )
            (self_attn_layer_norm): LayerNorm((1024,), eps=1e-0

In [4]:
# --- Load Dataset ---
dataset = load_dataset("json", data_files="billsum_data/us_test_data_final_OFFICIAL.jsonl")["train"]

# Track already completed summaries
completed_ids = set()
if os.path.exists(OUTPUT_FILE):
    with open(OUTPUT_FILE, "r") as f:
        for line in f:
            try:
                record = json.loads(line)
                completed_ids.add(record[ID_FIELD])
            except:
                continue

# Filter out completed records
print(len(completed_ids), "records already summarized")
print("Dataset length before filter:", len(dataset))
dataset = dataset.filter(lambda x: str(x[ID_FIELD]) not in completed_ids)
print("Dataset length after filter:", len(dataset))

dataset

224 records already summarized
Dataset length before filter: 3269


Filter:   0%|          | 0/3269 [00:00<?, ? examples/s]

Dataset length after filter: 3045


Dataset({
    features: ['text', 'summary', 'bill_id', 'title', 'text_len', 'sum_len'],
    num_rows: 3045
})

In [5]:
# Tokenize all texts
def preprocess(example):
    tokens = tokenizer(
        example["text"],
        truncation=True,
        max_length=MAX_INPUT_LEN,
        padding=False
    )
    return {
        "input_ids": tokens["input_ids"],
        "input_len": len(tokens["input_ids"]),
        "attention_mask": tokens["attention_mask"]
    }

print("Tokenizing...")
dataset = dataset.map(preprocess, remove_columns=["text"])
dataset = dataset.sort("input_len", reverse=True)

dataset

Tokenizing...


Map:   0%|          | 0/3045 [00:00<?, ? examples/s]

Dataset({
    features: ['summary', 'bill_id', 'title', 'text_len', 'sum_len', 'input_ids', 'input_len', 'attention_mask'],
    num_rows: 3045
})

In [6]:
# --- Run Inference ---
num_batches = ceil(len(dataset) / BATCH_SIZE)
summaries = []
all_titles = []
with open(OUTPUT_FILE, "a") as outfile:
    for i in tqdm(range(num_batches), desc="Generating summaries"):
        # Extract batch from dataset
        start = i * BATCH_SIZE
        end = min((i + 1) * BATCH_SIZE, len(dataset))
        batch = dataset.select(range(start, end))
        
        # Pad to max length in batch
        padded = tokenizer.pad(
            {
                "input_ids": batch["input_ids"],
                "attention_mask": batch["attention_mask"]
            },
            return_tensors="pt"
        ).to(DEVICE)

        with torch.no_grad(), autocast("cuda"):
            outputs = model.generate(
                input_ids=padded["input_ids"],
                attention_mask=padded["attention_mask"],
                max_length=MAX_OUTPUT_LEN,
                num_beams=4,
                do_sample=False,
                early_stopping=True,
                no_repeat_ngram_size=3
            )

        decoded = tokenizer.batch_decode(outputs, skip_special_tokens = True)

        for example, summary in zip(batch, decoded):
            json.dump({ID_FIELD: example[ID_FIELD], "pred_summary": summary}, outfile)
            outfile.write("\n")
        outfile.flush()

Generating summaries: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 96/96 [3:16:50<00:00, 123.03s/it]
