In [1]:
from transformers import (
    AutoTokenizer,
    AutoModelForQuestionAnswering,
    TrainingArguments,
    Trainer
)
from datasets import load_dataset, Dataset, DatasetDict

In [2]:
# 1) Prep your SQuAD‑style data (flatten & split)
raw = load_dataset("json", data_files="f1_gp_qa.json", field="data")
split_name = list(raw.keys())[0]        # e.g. "train" or "validation"
rows = []
for rec in raw[split_name]:
    for para in rec["paragraphs"]:
        ctx = para["context"]
        for qa in para["qas"]:
            text  = qa["answers"][0]["text"]
            start = qa["answers"][0].get("answer_start", ctx.find(text))
            rows.append({
                "context":      ctx,
                "question":     qa["question"],
                "answer_text":  text,
                "answer_start": start
            })
flat = Dataset.from_list(rows)
split = flat.train_test_split(test_size=0.2, seed=42)
train_ds, val_ds = split["train"], split["test"]


In [3]:
# 2) Load Longformer tokenizer & model
checkpoint = "allenai/longformer-base-4096"
tokenizer  = AutoTokenizer.from_pretrained(checkpoint)
model      = AutoModelForQuestionAnswering.from_pretrained(checkpoint)

Some weights of LongformerForQuestionAnswering were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
# 3) Tokenization + sliding windows (same as training)
def prepare_features(examples):
    tok = tokenizer(
        examples["question"], examples["context"],
        max_length=5000,
        truncation="only_second",
        stride=256,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length"
    )
    starts, ends = [], []
    for i, offsets in enumerate(tok["offset_mapping"]):
        idx = tok["overflow_to_sample_mapping"][i]
        start_char = examples["answer_start"][idx]
        end_char   = start_char + len(examples["answer_text"][idx])
        # locate token indices
        s = 0
        while s < len(offsets) and offsets[s][0] <= start_char:
            s += 1
        starts.append(s-1)
        e = len(offsets)-1
        while e >= 0 and offsets[e][1] >= end_char:
            e -= 1
        ends.append(e+1)
    tok["start_positions"] = starts
    tok["end_positions"]   = ends
    tok.pop("offset_mapping")
    return tok

train_tok = train_ds.map(prepare_features, batched=True,
                         remove_columns=train_ds.column_names)
val_tok   = val_ds.map(  prepare_features, batched=True,
                         remove_columns=val_ds.column_names)

datasets = DatasetDict({"train": train_tok, "validation": val_tok})


Map:   0%|          | 0/286 [00:00<?, ? examples/s]

Map:   0%|          | 0/72 [00:00<?, ? examples/s]

In [5]:
# 4) TrainingArguments: track eval_loss and save lowest‐loss model
training_args = TrainingArguments(
    output_dir="./longformer_qa",
    overwrite_output_dir=True,

    # Evaluate & save once per epoch
    eval_strategy="epoch",
    save_strategy="epoch",

    # Log every 100 steps
    logging_strategy="steps",
    logging_steps=100,

    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=3,
    learning_rate=3e-5,
    weight_decay=0.01,

    # Use eval_loss as the metric for best model
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False      # lower loss → better
)


In [6]:
# 5) Trainer & train
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=datasets["train"],
    eval_dataset=datasets["validation"],
    tokenizer=tokenizer
)

trainer.train()

  trainer = Trainer(
Input ids are automatically padded to be a multiple of `config.attention_window`: 512


Epoch,Training Loss,Validation Loss
1,0.0,
2,0.0,


KeyboardInterrupt: 

In [None]:
val_ds

In [None]:
val_tok

In [None]:
# find any examples whose windows never cover the span
bad = []
for i, ex in enumerate(val_tok):
    if ex["start_positions"] < 0 or ex["start_positions"] >= len(ex["input_ids"]):
        bad.append(i)
print("Bad examples:", bad)


In [None]:
lengths = [sum(attn) for attn in val_tok["attention_mask"]]
print(set(lengths))

In [None]:
for ex in val_tok.select(range(3)):
    print("starts:", ex["start_positions"], "ends:", ex["end_positions"])


In [7]:
import torch
torch.cuda.empty_cache()


In [11]:

torch.cuda.empty_cache()
