
# DeBERTa-v3 Relation Extraction for MTA Alerts

Train a DeBERTa-v3 classifier on entity-marked route–direction pairs from the silver dataset and evaluate on both silver and gold annotations.


In [1]:

# Install required packages if missing
#%pip install seqeval
"""The DeBERTa NER model is needed to be run first for the RE model to work,
as it generates the NER predictions used for RE training and evaluation. 
If you have not run the NER model yet,
please run the 5_deberta_ner.ipynb notebook first."""

'The DeBERTa NER model is needed to be run first for the RE model to work,\nas it generates the NER predictions used for RE training and evaluation. \nIf you have not run the NER model yet,\nplease run the 5_deberta_ner.ipynb notebook first.'

In [2]:
# Run preprocessing script and train/load NER checkpoint only if needed
#%run 6a_prepare_final_deberta_data.py

import os
if not os.path.exists("models/deberta_ner_best"):
    print("NER checkpoint not found. Running 6_deberta_ner.ipynb...")
    get_ipython().run_line_magic("run", "6_deberta_ner.ipynb")
else:
    print("Using existing NER checkpoint at models/deberta_ner_best")


Using existing NER checkpoint at models/deberta_ner_best


In [3]:

import json
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    get_linear_schedule_with_warmup,
)
from torch.optim import AdamW
from datasets import load_from_disk
from sklearn.metrics import classification_report, precision_recall_fscore_support
from tqdm.notebook import tqdm

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)



## 1. Load Prepared RE Data
Load the pre-tokenized relation extraction dataset (silver) and label mappings.


In [4]:

re_path = "final_data/re_dataset"
dataset = load_from_disk(re_path)
print("Loaded RE dataset from disk:")
print(dataset)

with open("final_data/label_mappings.json", "r") as f:
    mappings = json.load(f)

label2id = mappings["re_label2id"]
id2label = {int(k): v for k, v in mappings["re_id2label"].items()}
print("Relation Label Map:", label2id)


Loaded RE dataset from disk:
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'route_id', 'direction_id', 'header_id', 'route_text', 'direction_text'],
        num_rows: 281017
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'route_id', 'direction_id', 'header_id', 'route_text', 'direction_text'],
        num_rows: 60218
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'route_id', 'direction_id', 'header_id', 'route_text', 'direction_text'],
        num_rows: 60218
    })
})
Relation Label Map: {'NO_RELATION': 0, 'HAS_DIRECTION': 1}


In [5]:

# Ensure torch tensors from HF dataset
torch_cols = ['input_ids', 'attention_mask', 'labels']
for split in ['train', 'validation', 'test']:
    if split in dataset:
        dataset[split].set_format(type='torch', columns=torch_cols)



## 2. Tokenizer and Gold Dataset
Load tokenizer (with special markers) and the gold RE dataset for evaluation.


In [6]:
# Tokenizer
model_name = "microsoft/deberta-v3-base"
tokenizer_path = "final_data/tokenizer"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=False)
print(f"Tokenizer loaded with vocab size {len(tokenizer):,}")

# Gold dataset (all samples in test split)
gold_re_path = "final_data_gold/re_dataset"
if not os.path.exists(gold_re_path):
    print(f"Gold RE dataset not found at {gold_re_path}. Running data prep...")
    import subprocess
    subprocess.run(["python3", "6a_prepare_final_deberta_data.py"], check=True)

# Load only the test split directly (train/val splits are empty for gold data)
from datasets import Dataset
gold_test_path = os.path.join(gold_re_path, "test")
gold_test_dataset = Dataset.load_from_disk(gold_test_path)
print(f"Loaded gold test dataset: {len(gold_test_dataset):,} samples")

Tokenizer loaded with vocab size 128,005
Loaded gold test dataset: 20,781 samples


In [7]:
# Set tensor format for gold test dataset
gold_test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])


## 3. Data Splits & DataLoaders
Use the silver train/val/test splits and a dedicated gold test loader.


In [8]:

# Silver splits
train_dataset = dataset["train"]
val_dataset = dataset["validation"]
test_dataset = dataset["test"]

print(f"Split sizes (silver) -> Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
print(f"Gold test size: {len(gold_test_dataset)}")

train_batch_size = 128
eval_batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=eval_batch_size, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, num_workers=4, pin_memory=True)

gold_loader = DataLoader(gold_test_dataset, batch_size=eval_batch_size, num_workers=4, pin_memory=True)


Split sizes (silver) -> Train: 281017, Val: 60218, Test: 60218
Gold test size: 20781



## 4. Class Weights for Imbalance
Compute inverse-frequency class weights from the training set.


In [9]:

label_counts = {id_: 0 for id_ in label2id.values()}
print("Calculating class weights...")
for batch in tqdm(DataLoader(train_dataset, batch_size=512, num_workers=4), desc="Counting labels"):
    labels = batch['labels']
    for lbl in labels:
        label_counts[int(lbl)] += 1

total = sum(label_counts.values())
num_classes = len(label2id)
class_weights = []
for i in range(num_classes):
    count = label_counts.get(i, 1)
    weight = total / (num_classes * count)
    class_weights.append(weight)
class_weights = torch.tensor(class_weights, dtype=torch.float)

print("Label distribution (train):")
for name, idx in label2id.items():
    count = label_counts[idx]
    pct = (count / total * 100) if total else 0
    print(f"  {name}: {count:,} ({pct:.2f}%) weight={class_weights[idx]:.3f}")


Calculating class weights...


Counting labels:   0%|          | 0/549 [00:00<?, ?it/s]

Label distribution (train):
  NO_RELATION: 92,215 (32.81%) weight=1.524
  HAS_DIRECTION: 188,802 (67.19%) weight=0.744



## 5. Model & Optimizer
Load DeBERTa for sequence classification, resize embeddings for special tokens, and set optimizer/scheduler.


In [10]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
print(f"Using device: {device}")

model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=len(label2id),
)
# Resize embeddings to account for special tokens added in tokenizer
tokenizer_size = len(tokenizer)
model.resize_token_embeddings(tokenizer_size)
print(f"Resized model embeddings to {tokenizer_size} tokens")

# Ensure model is in float32 (not float16 from previous runs)
model = model.float()
model.to(device)
print(f"Model dtype: {next(model.parameters()).dtype}")

optimizer = AdamW(model.parameters(), lr=3e-5)

epochs = 3
warmup_ratio = 0.1
total_steps = max(len(train_loader) * epochs, 1)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(warmup_ratio * total_steps),
    num_training_steps=total_steps,
)

loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights.to(device))

Using device: cuda


Loading weights:   0%|          | 0/198 [00:00<?, ?it/s]

[1mDebertaV2ForSequenceClassification LOAD REPORT[0m from: microsoft/deberta-v3-base
Key                                     | Status     | 
----------------------------------------+------------+-
mask_predictions.LayerNorm.weight       | UNEXPECTED | 
lm_predictions.lm_head.dense.bias       | UNEXPECTED | 
mask_predictions.dense.weight           | UNEXPECTED | 
mask_predictions.dense.bias             | UNEXPECTED | 
lm_predictions.lm_head.dense.weight     | UNEXPECTED | 
lm_predictions.lm_head.bias             | UNEXPECTED | 
mask_predictions.classifier.weight      | UNEXPECTED | 
lm_predictions.lm_head.LayerNorm.bias   | UNEXPECTED | 
mask_predictions.classifier.bias        | UNEXPECTED | 
lm_predictions.lm_head.LayerNorm.weight | UNEXPECTED | 
mask_predictions.LayerNorm.bias         | UNEXPECTED | 
pooler.dense.weight                     | MISSING    | 
classifier.weight                       | MISSING    | 
pooler.dense.bias                       | MISSING    | 
classifier.bias  

Resized model embeddings to 128005 tokens
Model dtype: torch.float32


In [11]:
def train_epoch(model, data_loader, optimizer, scheduler, device, loss_fct):
    model.train()
    total_loss = 0.0
    
    pbar = tqdm(data_loader, desc="Training", leave=True)
    for batch_idx, batch in enumerate(pbar):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad(set_to_none=True)
        
        # Run in full precision (float32)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        loss = loss_fct(logits.view(-1, len(label2id)), labels.view(-1))
        
        total_loss += loss.item()
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        # Update progress bar with current loss
        avg_loss = total_loss / (batch_idx + 1)
        pbar.set_postfix({'loss': f'{avg_loss:.4f}', 'batch_loss': f'{loss.item():.4f}'})
    
    return total_loss / max(len(data_loader), 1)


def eval_epoch(model, data_loader, device, loss_fct):
    model.eval()
    total_loss = 0.0
    all_labels = []
    all_preds = []
    
    # Use full precision for evaluation
    with torch.no_grad():
        pbar = tqdm(data_loader, desc="Evaluating", leave=True)
        for batch_idx, batch in enumerate(pbar):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            loss = loss_fct(logits.view(-1, len(label2id)), labels.view(-1))
            
            total_loss += loss.item()
            
            preds = torch.argmax(logits, dim=1)
            all_labels.extend(labels.cpu().tolist())
            all_preds.extend(preds.cpu().tolist())
            
            # Update progress bar with current loss
            avg_loss = total_loss / (batch_idx + 1)
            pbar.set_postfix({'loss': f'{avg_loss:.4f}'})
    
    avg_loss = total_loss / max(len(data_loader), 1)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels,
        all_preds,
        average='binary',
        pos_label=label2id['HAS_DIRECTION'],
        zero_division=0,
    )
    report = classification_report(all_labels, all_preds, target_names=[id2label[i] for i in range(len(label2id))])
    return avg_loss, precision, recall, f1, report


## 6. Training with Early Stopping
Monitor validation F1 and keep the best checkpoint.


In [12]:
best_f1 = 0.0
patience = 3
patience_counter = 0
best_state = None

print(f"\n{'='*70}")
print(f"TRAINING STARTED - {epochs} epochs, batch size {train_batch_size}")
print(f"{'='*70}\n")

for epoch in range(epochs):
    print(f"\n{'='*70}")
    print(f"EPOCH {epoch + 1}/{epochs}")
    print(f"{'='*70}")
    
    # Training
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, device, loss_fct)
    print(f"\nTraining complete - Avg Loss: {train_loss:.4f}")
    
    # Validation
    val_loss, val_prec, val_rec, val_f1, val_report = eval_epoch(model, val_loader, device, loss_fct)
    
    print(f"\n{'-'*70}")
    print(f"VALIDATION RESULTS:")
    print(f"  Loss:      {val_loss:.4f}")
    print(f"  Precision: {val_prec:.4f}")
    print(f"  Recall:    {val_rec:.4f}")
    print(f"  F1:        {val_f1:.4f}")
    print(f"{'-'*70}")
    print(f"\nDetailed Classification Report:")
    print(val_report)
    
    # Model checkpointing
    if val_f1 > best_f1:
        best_f1 = val_f1
        patience_counter = 0
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        print(f"\nNew best model! F1: {best_f1:.4f} (saved)")
    else:
        patience_counter += 1
        print(f"\nNo improvement. Patience: {patience_counter}/{patience}")
        if patience_counter >= patience:
            print(f"\nEarly stopping triggered at epoch {epoch + 1}")
            break

print(f"\n{'='*70}")
print(f"TRAINING COMPLETE")
print(f"{'='*70}")

if best_state is not None:
    model.load_state_dict(best_state)
    print(f"Loaded best model with validation F1 = {best_f1:.4f}\n")


TRAINING STARTED - 3 epochs, batch size 128


EPOCH 1/3


Training:   0%|          | 0/2196 [00:00<?, ?it/s]

KeyboardInterrupt: 


## 7. Save Model
Persist the relation extractor and tokenizer.


In [None]:

save_dir = "models/deberta_re_best"
os.makedirs(save_dir, exist_ok=True)
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
print(f"Saved best model to {save_dir}")



## 8. Evaluation on Silver Test Split
Compute pair-level precision/recall/F1 on held-out silver data.


In [None]:

test_loss, test_prec, test_rec, test_f1, test_report = eval_epoch(model, test_loader, device, loss_fct)
print("SILVER TEST RESULTS")
print(f"Loss: {test_loss:.4f}")
print(f"Precision: {test_prec:.4f}")
print(f"Recall: {test_rec:.4f}")
print(f"F1: {test_f1:.4f}")
print(test_report)


In [None]:
from collections import defaultdict

def eval_header_level(
    model,
    dataset,
    device,
    label2id,
    gold_dataset=None,
    batch_size=64,
    desc="Header-level eval",
):
    """
    Evaluate RE at the alert-header level using relation-set matching.

    Predictions are produced from `dataset`.
    Gold relation sets come from `gold_dataset` when provided; otherwise
    they are taken from `dataset`.

    This keeps Oracle and E2E comparable by evaluating both against the
    same gold header-level relation sets.

    Args:
        model: trained DeBERTa RE model
        dataset: HuggingFace Dataset used for inference (input pairs)
        device: torch device
        label2id: dict mapping label names to indices
        gold_dataset: optional HuggingFace Dataset with gold relation pairs
                      and columns: header_id, route_text, direction_text, labels
        batch_size: evaluation batch size
        desc: progress bar description

    Returns:
        dict with micro precision, recall, F1 and per-header breakdown
    """
    model.eval()
    has_dir_id = label2id["HAS_DIRECTION"]
    tensor_cols = ["input_ids", "attention_mask", "labels"]

    # --- 1) Read prediction-pair metadata before tensor formatting ---
    dataset.reset_format()
    pred_header_ids = list(dataset["header_id"])
    pred_route_texts = list(dataset["route_text"])
    pred_dir_texts = list(dataset["direction_text"])

    # --- 2) Run batched inference on the prediction dataset ---
    dataset.set_format(type="torch", columns=tensor_cols)
    loader = DataLoader(dataset, batch_size=batch_size, num_workers=4, pin_memory=True)

    all_preds = []
    with torch.no_grad():
        for batch in tqdm(loader, desc=desc):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            # Use full precision (no autocast)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs.logits, dim=1)
            all_preds.extend(preds.cpu().tolist())

    pred_relations = defaultdict(set)
    for i in range(len(pred_header_ids)):
        if all_preds[i] != has_dir_id:
            continue
        hid = pred_header_ids[i]
        pair = (pred_route_texts[i], pred_dir_texts[i])
        pred_relations[hid].add(pair)

    # --- 3) Build gold relation sets from the chosen gold source ---
    gold_source = gold_dataset if gold_dataset is not None else dataset
    gold_source.reset_format()

    gold_header_ids = list(gold_source["header_id"])
    gold_route_texts = list(gold_source["route_text"])
    gold_dir_texts = list(gold_source["direction_text"])
    gold_labels = list(gold_source["labels"])

    gold_relations = defaultdict(set)
    for i in range(len(gold_header_ids)):
        if gold_labels[i] != has_dir_id:
            continue
        hid = gold_header_ids[i]
        pair = (gold_route_texts[i], gold_dir_texts[i])
        gold_relations[hid].add(pair)

    # Keep the full gold header universe; include pred-only headers as well.
    eval_header_ids = sorted(set(gold_header_ids) | set(pred_header_ids))

    # --- 4) Per-header set-level TP/FP/FN, then micro-average ---
    total_tp = 0
    total_fp = 0
    total_fn = 0
    per_header = []

    for hid in eval_header_ids:
        gold_set = gold_relations.get(hid, set())
        pred_set = pred_relations.get(hid, set())

        tp = len(gold_set & pred_set)
        fp = len(pred_set - gold_set)
        fn = len(gold_set - pred_set)

        total_tp += tp
        total_fp += fp
        total_fn += fn

        h_p = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        h_r = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        h_f = (2 * h_p * h_r / (h_p + h_r)) if (h_p + h_r) > 0 else 0.0

        per_header.append({
            "header_id": hid,
            "gold": len(gold_set),
            "pred": len(pred_set),
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "precision": h_p,
            "recall": h_r,
            "f1": h_f,
        })

    micro_p = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
    micro_r = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
    micro_f = (2 * micro_p * micro_r / (micro_p + micro_r)) if (micro_p + micro_r) > 0 else 0.0

    # Macro F1 over headers with at least one gold relation
    f1s = [h["f1"] for h in per_header if h["gold"] > 0]
    macro_f = np.mean(f1s) if f1s else 0.0

    # Restore tensor format for downstream notebook cells.
    dataset.set_format(type="torch", columns=tensor_cols)
    if gold_dataset is not None and gold_dataset is not dataset:
        gold_dataset.set_format(type="torch", columns=tensor_cols)

    return {
        "micro_precision": micro_p,
        "micro_recall": micro_r,
        "micro_f1": micro_f,
        "macro_f1": macro_f,
        "total_tp": total_tp,
        "total_fp": total_fp,
        "total_fn": total_fn,
        "num_headers": len(eval_header_ids),
        "per_header": per_header,
    }


def print_re_results(results, title):
    """Pretty-print header-level RE evaluation results."""
    print(f"\n{'='*60}")
    print(f"  {title}")
    print(f"{'='*60}")
    print(f"  Headers evaluated: {results['num_headers']}")
    print(f"  Total TP: {results['total_tp']}, "
          f"FP: {results['total_fp']}, "
          f"FN: {results['total_fn']}")
    print(f"  Micro Precision: {results['micro_precision']:.4f}")
    print(f"  Micro Recall:    {results['micro_recall']:.4f}")
    print(f"  Micro F1:        {results['micro_f1']:.4f}")
    print(f"  Macro F1:        {results['macro_f1']:.4f}")
    print(f"{'='*60}\n")

## 9. Evaluation on Gold Dataset
Evaluate the silver-trained model on gold labels with header-level relation-set metrics.

In [None]:
# Clear GPU cache before eval if available
if torch.cuda.is_available():
    torch.cuda.empty_cache()

oracle_results = eval_header_level(
    model=model,
    dataset=gold_test_dataset,
    gold_dataset=gold_test_dataset,
    device=device,
    label2id=label2id,
    desc="Oracle (gold spans)",
)
print_re_results(oracle_results, "ORACLE RE RESULTS (Gold Entity Spans)")


## 9.5 Build Predicted-Span RE Inputs
<!-- E2E_DATASET_FROM_NER_PREDICTIONS -->
Generate `re_pred_splits` from DeBERTa NER predictions on the gold alerts. Section 10 uses this dataset for end-to-end evaluation.

In [None]:
# E2E_DATASET_FROM_NER_PREDICTIONS: helper utilities
from datasets import Dataset, DatasetDict
from transformers import AutoModelForTokenClassification


def _parse_json_list(value):
    if pd.isna(value) or value in ("", "[]"):
        return []
    try:
        parsed = json.loads(value)
        return parsed if isinstance(parsed, list) else []
    except (TypeError, json.JSONDecodeError):
        return []


def _normalize_entity_text(text):
    return " ".join(str(text).split()).strip().lower()


def _compute_token_offsets_slow(text, tok):
    # Mirror the span-to-token alignment logic used in data prep for slow tokenizers.
    tokens = tok.tokenize(text)
    offsets = []
    cursor = 0

    for token in tokens:
        clean = token.replace("\u2581", " ").strip()
        if not clean:
            offsets.append((cursor, cursor))
            continue

        remaining = text[cursor:]
        idx = remaining.find(clean)
        if idx == -1:
            idx = remaining.lower().find(clean.lower())

        if idx == -1:
            offsets.append((cursor, cursor))
        else:
            start = cursor + idx
            end = start + len(clean)
            offsets.append((start, end))
            cursor = end

    return offsets


def _decode_bio_spans(text, pred_ids, tok, id2label):
    token_offsets = _compute_token_offsets_slow(text, tok)
    offset_map = [(0, 0)] + token_offsets

    if len(offset_map) < len(pred_ids):
        offset_map.extend([(0, 0)] * (len(pred_ids) - len(offset_map)))
    else:
        offset_map = offset_map[:len(pred_ids)]

    spans = {"ROUTE": [], "DIRECTION": []}
    current = None

    def flush_current():
        nonlocal current
        if current is not None and current["end"] > current["start"]:
            current["text"] = text[current["start"]:current["end"]]
            spans[current["type"]].append(current)
        current = None

    for i, label_id in enumerate(pred_ids):
        label = id2label.get(int(label_id), "O")
        start, end = offset_map[i]

        if start == 0 and end == 0:
            flush_current()
            continue

        if label == "O":
            flush_current()
            continue

        tag, ent_type = label.split("-", 1)

        if tag == "B":
            flush_current()
            current = {"type": ent_type, "start": start, "end": end}
            continue

        if tag == "I":
            if current is not None and current["type"] == ent_type and start <= current["end"] + 1:
                current["end"] = max(current["end"], end)
            else:
                flush_current()
                current = {"type": ent_type, "start": start, "end": end}
            continue

        flush_current()

    flush_current()

    for ent_type in ("ROUTE", "DIRECTION"):
        for idx, span in enumerate(spans[ent_type]):
            span["id"] = idx

    return spans["ROUTE"], spans["DIRECTION"]


def _insert_entity_markers(text, route_span, direction_span):
    insertions = [
        (route_span["start"], "[ROUTE]", False),
        (route_span["end"], "[/ROUTE]", True),
        (direction_span["start"], "[DIR]", False),
        (direction_span["end"], "[/DIR]", True),
    ]

    # Insert from right to left so character indices stay valid.
    insertions.sort(key=lambda x: (-x[0], x[2]))

    marked = text
    for pos, marker, _ in insertions:
        marked = marked[:pos] + marker + " " + marked[pos:]

    return marked.strip()

In [None]:
# E2E_DATASET_FROM_NER_PREDICTIONS: build/load re_pred_splits
pred_re_path = "final_data_gold/re_pred_dataset"
ner_ckpt_dir = "models/deberta_ner_best"
gold_csv_path = "Preprocessed/MTA_Data_Final_Gold.csv"
required_cols = {
    "input_ids", "attention_mask", "labels",
    "route_id", "direction_id", "header_id", "route_text", "direction_text"
}

needs_rebuild = True
if os.path.exists(pred_re_path):
    re_pred_splits = load_from_disk(pred_re_path)
    existing_cols = set(re_pred_splits["test"].column_names)
    if required_cols.issubset(existing_cols):
        needs_rebuild = False
        print(f"Loaded predicted-span RE dataset from {pred_re_path}")
        print({k: len(v) for k, v in re_pred_splits.items()})
    else:
        print(f"Cached dataset at {pred_re_path} is missing required columns. Rebuilding...")

if needs_rebuild:
    if not os.path.exists(ner_ckpt_dir):
        raise FileNotFoundError(
            f"Missing NER checkpoint at {ner_ckpt_dir}. Run 6_deberta_ner.ipynb first."
        )

    with open("final_data/label_mappings.json", "r") as f:
        mappings = json.load(f)
    ner_id2label = {int(k): v for k, v in mappings["ner_id2label"].items()}

    ner_tokenizer = AutoTokenizer.from_pretrained(ner_ckpt_dir, use_fast=False)
    ner_model = AutoModelForTokenClassification.from_pretrained(ner_ckpt_dir).to(device)
    ner_model.eval()

    gold_df = pd.read_csv(gold_csv_path)
    gold_df["date"] = pd.to_datetime(gold_df["date"])
    gold_df = gold_df.sort_values("date").reset_index(drop=True)

    max_length = 256

    all_input_ids = []
    all_attention_masks = []
    all_labels = []
    all_route_ids = []
    all_direction_ids = []
    all_header_ids = []
    all_route_texts = []
    all_direction_texts = []

    for header_id, row in tqdm(gold_df.iterrows(), total=len(gold_df), desc="Generating E2E RE pairs"):
        text = str(row["header"]) if pd.notna(row["header"]) else ""
        if not text:
            continue

        ner_inputs = ner_tokenizer(
            text,
            max_length=max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        with torch.no_grad():
            outputs = ner_model(
                input_ids=ner_inputs["input_ids"].to(device),
                attention_mask=ner_inputs["attention_mask"].to(device),
            )
            pred_ids = torch.argmax(outputs.logits, dim=2)[0].cpu().tolist()

        pred_routes, pred_dirs = _decode_bio_spans(text, pred_ids, ner_tokenizer, ner_id2label)
        if not pred_routes or not pred_dirs:
            continue

        gold_route_spans = _parse_json_list(row.get("affected_spans_gold", "[]"))
        gold_direction_spans = _parse_json_list(row.get("direction_spans_gold", "[]"))
        gold_relations = _parse_json_list(row.get("relations_gold", "[]"))

        route_text_by_id = {}
        for span in gold_route_spans:
            if isinstance(span, dict) and all(k in span for k in ("id", "start", "end")):
                route_text_by_id[span["id"]] = text[span["start"]:span["end"]]

        direction_text_by_id = {}
        for span in gold_direction_spans:
            if isinstance(span, dict) and all(k in span for k in ("id", "start", "end")):
                direction_text_by_id[span["id"]] = text[span["start"]:span["end"]]

        gold_pair_texts = set()
        for rel in gold_relations:
            if not isinstance(rel, dict):
                continue
            route_text = route_text_by_id.get(rel.get("route_span_id"))
            direction_text = direction_text_by_id.get(rel.get("direction_span_id"))
            if route_text is None or direction_text is None:
                continue
            gold_pair_texts.add((_normalize_entity_text(route_text), _normalize_entity_text(direction_text)))

        for route_span in pred_routes:
            route_text = text[route_span["start"]:route_span["end"]]
            for direction_span in pred_dirs:
                direction_text = text[direction_span["start"]:direction_span["end"]]

                marked_text = _insert_entity_markers(text, route_span, direction_span)
                re_inputs = tokenizer(
                    marked_text,
                    max_length=max_length,
                    padding="max_length",
                    truncation=True,
                    return_tensors=None,
                )

                pair_key = (_normalize_entity_text(route_text), _normalize_entity_text(direction_text))
                label = 1 if pair_key in gold_pair_texts else 0

                all_input_ids.append(re_inputs["input_ids"])
                all_attention_masks.append(re_inputs["attention_mask"])
                all_labels.append(label)
                all_route_ids.append(route_span["id"])
                all_direction_ids.append(direction_span["id"])
                all_header_ids.append(int(header_id))
                all_route_texts.append(route_text)
                all_direction_texts.append(direction_text)

    if not all_labels:
        raise RuntimeError(
            "No predicted RE candidates were created. Check the NER checkpoint and gold data paths."
        )

    re_pred_test = Dataset.from_dict({
        "input_ids": all_input_ids,
        "attention_mask": all_attention_masks,
        "labels": all_labels,
        "route_id": all_route_ids,
        "direction_id": all_direction_ids,
        "header_id": all_header_ids,
        "route_text": all_route_texts,
        "direction_text": all_direction_texts,
    })

    empty = re_pred_test.select([])
    re_pred_splits = DatasetDict({
        "train": empty,
        "validation": empty,
        "test": re_pred_test,
    })

    os.makedirs("final_data_gold", exist_ok=True)
    re_pred_splits.save_to_disk(pred_re_path)

    print(f"Saved predicted-span RE dataset to {pred_re_path}")
    print({k: len(v) for k, v in re_pred_splits.items()})


## 10. End-to-End Evaluation (DeBERTa NER spans to RE)
Use predicted spans from the DeBERTa NER model instead of gold spans to approximate real end-to-end performance.

**Note:** This section scores predicted-span RE outputs against gold relation sets from `gold_test_dataset` (same `header_id` universe as Oracle).

In [None]:
e2e_results = eval_header_level(
    model=model,
    dataset=re_pred_splits["test"],
    gold_dataset=gold_test_dataset,
    device=device,
    label2id=label2id,
    desc="End-to-End (predicted spans, gold-header scoring)",
)
print_re_results(e2e_results, "END-TO-END RE RESULTS (Predicted NER Spans → RE)")


## 11. Oracle vs End-to-End Comparison
Compare the header-level metrics between Oracle (gold entity spans) and End-to-End (predicted NER spans) evaluation.

In [None]:
print(f"\n{'Metric':<20} {'Oracle':>10} {'End-to-End':>12}")
print(f"{'-'*44}")
print(f"{'Micro Precision':<20} {oracle_results['micro_precision']:>10.4f} "
      f"{e2e_results['micro_precision']:>12.4f}")
print(f"{'Micro Recall':<20} {oracle_results['micro_recall']:>10.4f} "
      f"{e2e_results['micro_recall']:>12.4f}")
print(f"{'Micro F1':<20} {oracle_results['micro_f1']:>10.4f} "
      f"{e2e_results['micro_f1']:>12.4f}")
print(f"{'Macro F1':<20} {oracle_results['macro_f1']:>10.4f} "
      f"{e2e_results['macro_f1']:>12.4f}")

print("\n" + "="*60)
print("IMPORTANT: Oracle F1 should be >= E2E F1")
print("If E2E > Oracle, this indicates a measurement error.")
print("="*60)