In [3]:
# ============================================================================
# T5 Question Answering on SQuAD Dataset - Complete Pipeline
# ============================================================================

# Install required packages (uncomment if needed)
!pip install -q transformers datasets accelerate evaluate sentencepiece

# ============================================================================
# PART 1: PREPROCESSING & TOKENIZATION
# ============================================================================

import torch
from datasets import load_dataset
from transformers import (
    T5ForConditionalGeneration,
    T5TokenizerFast,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
import evaluate
import numpy as np
import time

# Check device
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


# Load SQuAD dataset
squad = load_dataset("rajpurkar/squad")
train_ds = squad["train"]
valid_ds = squad["validation"]

# Create smaller subsets for faster training
small_train = train_ds.select(range(5000))
small_valid = valid_ds.select(range(1000))

print(f"Train size: {len(small_train)}, Validation size: {len(small_valid)}")

# Explore dataset
example = squad["train"][0]
print("\n--- Example from dataset ---")
print("ID      :", example["id"])
print("Title   :", example["title"])
print("Question:", example["question"])
print("Context :", example["context"][:400], "...")
print("Answers :", example["answers"])

# Initialize tokenizer
model_checkpoint = "t5-base"
tokenizer = T5TokenizerFast.from_pretrained(model_checkpoint)

max_input_length = 512
max_target_length = 32

# Preprocessing function
def preprocess_function(examples):
    inputs = []
    targets = []
    for question, context, answers in zip(
        examples["question"],
        examples["context"],
        examples["answers"],
    ):
        input_str = f"question: {question} context: {context}"
        answer_text = answers["text"][0] if len(answers["text"]) > 0 else ""
        inputs.append(input_str)
        targets.append(answer_text)

    model_inputs = tokenizer(
        inputs,
        max_length=max_input_length,
        padding="max_length",
        truncation=True,
    )

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            targets,
            max_length=max_target_length,
            padding="max_length",
            truncation=True,
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Tokenize datasets
tokenized_train = small_train.map(
    preprocess_function,
    batched=True,
    remove_columns=small_train.column_names,
)

tokenized_valid = small_valid.map(
    preprocess_function,
    batched=True,
    remove_columns=small_valid.column_names,
)

print("\n--- Tokenization complete ---")
print("Tokenized train:", tokenized_train)
print("Tokenized valid:", tokenized_valid)

# Optional: Save tokenized data to Google Drive
# from google.colab import drive
# drive.mount('/content/drive')
# save_path = "/content/drive/MyDrive/T5_QA_SQuAD_tokenized"
# tokenized_train.save_to_disk(save_path + "/train")
# tokenized_valid.save_to_disk(save_path + "/valid")

# ============================================================================
# PART 2: TRAINING & EVALUATION
# ============================================================================

# Load model
model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)
model.to(device)

# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
)

# Training arguments
batch_size = 4

training_args = Seq2SeqTrainingArguments(
    output_dir="./t5-base-squad-checkpoints",
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=3e-4,
    num_train_epochs=2,
    weight_decay=0.01,
    logging_steps=100,
    save_total_limit=2,
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),
    report_to=None, # Disable Weights & Biases
)

# Initialize trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_valid,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# Train model
print("\n--- Starting training ---")
train_result = trainer.train()

# Save model
trainer.save_state()
trainer.save_model("./t5-base-squad-finetuned")
tokenizer.save_pretrained("./t5-base-squad-finetuned")

print("\n--- Training complete ---")
print(train_result)

# Optional: Save to Google Drive
# save_dir = "/content/drive/MyDrive/T5_QA_SQuAD_model"
# trainer.save_model(save_dir)
# tokenizer.save_pretrained(save_dir)
# trainer.state.save_to_json(save_dir + "/trainer_state.json")

# ============================================================================
# INFERENCE FUNCTIONS
# ============================================================================

def answer_question(context, question, max_output_length=32, num_beams=4):
    model.eval()
    input_text = f"question: {question} context: {context}"
    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        truncation=True,
        max_length=512,
    ).to(device)

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_length=max_output_length,
            num_beams=num_beams,
        )

    answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return answer

def generate_answers_batch(examples, max_input_length=512, max_output_length=32, num_beams=4):
    inputs = [
        f"question: {q} context: {c}"
        for q, c in zip(examples["question"], examples["context"])
    ]
    model_inputs = tokenizer(
        inputs,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_input_length,
    ).to(device)

    model.eval()
    with torch.no_grad():
        output_ids = model.generate(
            **model_inputs,
            max_length=max_output_length,
            num_beams=num_beams,
        )

    preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    return preds

# ============================================================================
# EVALUATION ON VALIDATION SET
# ============================================================================

# Load SQuAD metric
squad_metric = evaluate.load("squad")

# Test on validation examples
print("\n--- Testing on validation examples ---")
from random import randrange

for _ in range(3):
    idx = randrange(len(valid_ds))
    ex = valid_ds[idx]
    pred = answer_question(ex["context"], ex["question"])
    print("="*80)
    print("Question:", ex["question"])
    print("Gold    :", ex["answers"]["text"][0])
    print("Pred    :", pred)

# Evaluate on subset
num_eval_samples = 500
val_subset = valid_ds.select(range(num_eval_samples))

all_predictions = []
all_references = []

print(f"\n--- Evaluating on {num_eval_samples} samples ---")
for start in range(0, len(val_subset), 16):
    end = start + 16
    batch = val_subset[start:end]
    preds = generate_answers_batch(batch)

    for i, pred in enumerate(preds):
        ex_id = batch["id"][i]
        ex_answers = batch["answers"][i]

        all_predictions.append({
            "id": ex_id,
            "prediction_text": pred,
        })
        all_references.append({
            "id": ex_id,
            "answers": ex_answers,
        })

results = squad_metric.compute(
    predictions=all_predictions,
    references=all_references,
)

print("\n--- Evaluation Results ---")
print(f"Exact Match: {results['exact_match']:.2f}")
print(f"F1 Score: {results['f1']:.2f}")

# ============================================================================
# ERROR ANALYSIS
# ============================================================================

def collect_examples_with_scores(num_samples=100):
    examples = []
    subset = valid_ds.select(range(num_samples))
    preds = generate_answers_batch(subset)

    for i, (ex, pred) in enumerate(zip(subset, preds)):
        gold_answers = ex["answers"]["text"]

        prediction = [{"id": ex["id"], "prediction_text": pred}]
        reference = [{"id": ex["id"], "answers": ex["answers"]}]
        scores = squad_metric.compute(
            predictions=prediction,
            references=reference,
        )

        examples.append({
            "id": ex["id"],
            "question": ex["question"],
            "context": ex["context"],
            "gold_answers": gold_answers,
            "pred": pred,
            "exact_match": scores["exact_match"],
            "f1": scores["f1"],
        })
    return examples

print("\n--- Error Analysis: Worst 5 predictions ---")
examples = collect_examples_with_scores(num_samples=200)
examples_sorted = sorted(examples, key=lambda x: x["f1"])

for ex in examples_sorted[:5]:
    print("="*80)
    print("ID       :", ex["id"])
    print("Question :", ex["question"])
    print("Gold     :", ex["gold_answers"])
    print("Pred     :", ex["pred"])
    print("EM / F1  :", ex["exact_match"], "/", ex["f1"])

# ============================================================================
# CUSTOM INFERENCE DEMO
# ============================================================================

print("\n--- Custom Inference Demo ---")
custom_context = """
Bandung is the capital city of West Java province in Indonesia.
It is known for its universities, cool climate, and surrounding volcanoes and tea plantations.
"""

custom_question = "What is Bandung the capital city of?"

print("Question:", custom_question)
print("Answer  :", answer_question(custom_context, custom_question))

# ============================================================================
# ABLATION STUDY: num_beams
# ============================================================================

print("\n--- Ablation Study: num_beams ---")
beam_values = [1, 2, 4, 8]
results_beams = []

for nb in beam_values:
    print(f"\nEvaluating with num_beams = {nb}")
    start_time = time.time()

    all_predictions = []
    all_references = []

    for start in range(0, len(val_subset), 16):
        end = start + 16
        batch = val_subset[start:end]
        preds = generate_answers_batch(batch, num_beams=nb)

        for i, pred in enumerate(preds):
            ex_id = batch["id"][i]
            ex_answers = batch["answers"][i]

            all_predictions.append({
                "id": ex_id,
                "prediction_text": pred,
            })
            all_references.append({
                "id": ex_id,
                "answers": ex_answers,
            })

    metrics = squad_metric.compute(
        predictions=all_predictions,
        references=all_references,
    )
    elapsed = time.time() - start_time

    results_beams.append({
        "num_beams": nb,
        "exact_match": metrics["exact_match"],
        "f1": metrics["f1"],
        "time_sec": elapsed,
    })

    print(f"EM: {metrics['exact_match']:.2f}, F1: {metrics['f1']:.2f}, Time: {elapsed:.1f}s")

print("\n--- Summary: num_beams ablation ---")
for r in results_beams:
    print(
        f"num_beams={r['num_beams']}: "
        f"EM={r['exact_match']:.2f}, "
        f"F1={r['f1']:.2f}, "
        f"time={r['time_sec']:.1f}s"
    )

# Clean up GPU memory
torch.cuda.empty_cache()

print("\n=== Pipeline Complete ===")

PyTorch version: 2.9.0+cu126
CUDA available: True
Device: cuda
Train size: 5000, Validation size: 1000

--- Example from dataset ---
ID      : 5733be284776f41900661182
Title   : University_of_Notre_Dame
Question: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
Context : Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of p ...
Answers : {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}


Map:   0%|          | 0/5000 [00:00<?, ? examples/s]



Map:   0%|          | 0/1000 [00:00<?, ? examples/s]


--- Tokenization complete ---
Tokenized train: Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 5000
})
Tokenized valid: Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 1000
})


  trainer = Seq2SeqTrainer(
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:


--- Starting training ---


[34m[1mwandb[0m: You chose "Don't visualize my results"


Step,Training Loss
100,1.4477
200,0.0785
300,0.0743
400,0.0664
500,0.0712
600,0.0679
700,0.0707
800,0.0696
900,0.0605
1000,0.0639


Step,Training Loss
100,1.4477
200,0.0785
300,0.0743
400,0.0664
500,0.0712
600,0.0679
700,0.0707
800,0.0696
900,0.0605
1000,0.0639



--- Training complete ---
TrainOutput(global_step=2500, training_loss=0.10195608563423157, metrics={'train_runtime': 1415.2876, 'train_samples_per_second': 7.066, 'train_steps_per_second': 1.766, 'total_flos': 6089578905600000.0, 'train_loss': 0.10195608563423157, 'epoch': 2.0})


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading extra modules: 0.00B [00:00, ?B/s]


--- Testing on validation examples ---
Question: How might gravity effects be observed differently according to Newton?
Gold    : at larger distances.
Pred    : larger distances
Question: What is the prize offered for finding a solution to P=NP?
Gold    : $1,000,000
Pred    : US$1,000,000
Question: What color were the Bronco's uniforms in Super Bowl 50?
Gold    : white
Pred    : white

--- Evaluating on 500 samples ---

--- Evaluation Results ---
Exact Match: 85.40
F1 Score: 89.02

--- Error Analysis: Worst 5 predictions ---
ID       : 56d2045de7d4791d009025f5
Question : How many times have the Panthers been in the Super Bowl?
Gold     : ['2', 'second', 'second']
Pred     : eight
EM / F1  : 0.0 / 0.0
ID       : 56d2045de7d4791d009025f6
Question : Who did Denver beat in the AFC championship?
Gold     : ['New England Patriots', 'the New England Patriots', 'New England Patriots']
Pred     : the Arizona Cardinals
EM / F1  : 0.0 / 0.0
ID       : 56d6017d1c85041400946ec1
Question : Who did 