
# DeBERTa-v3 Relation Extraction for MTA Alerts

Train a DeBERTa-v3 classifier on entity-marked route–direction pairs from the silver dataset and evaluate on both silver and gold annotations.


In [1]:

# Install required packages if missing
%pip install seqeval


Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: seqeval
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
  Created wheel for seqeval: filename=seqeval-1.2.2-py3-none-any.whl size=16162 sha256=6dadc77dc338450559908ca088f73316e23a4c864e1a32825cd347df9331f7f2
  Stored in directory: /root/.cache/pip/wheels/5f/b8/73/0b2c1a76b701a677653dd79ece07cfabd7457989dbfbdcd8d7
Successfully built seqeval
Installing collected packages: seqeval
Successfully installed seqeval-1.2.2


In [None]:

# Run data preparation script (generates NER + RE datasets and tokenizer if missing)
%run 6a_prepare_final_deberta_data.py


In [5]:

import json
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    get_linear_schedule_with_warmup,
)
from torch.optim import AdamW
from datasets import load_from_disk
from sklearn.metrics import classification_report, precision_recall_fscore_support
from tqdm.notebook import tqdm

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)



## 1. Load Prepared RE Data
Load the pre-tokenized relation extraction dataset (silver) and label mappings.


In [13]:

re_path = "final_data/re_dataset"
dataset = load_from_disk(re_path)
print("Loaded RE dataset from disk:")
print(dataset)

with open("final_data/label_mappings.json", "r") as f:
    mappings = json.load(f)

label2id = mappings["re_label2id"]
id2label = {int(k): v for k, v in mappings["re_id2label"].items()}
print("Relation Label Map:", label2id)


Loaded RE dataset from disk:
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'route_id', 'direction_id'],
        num_rows: 281017
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'route_id', 'direction_id'],
        num_rows: 60218
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'route_id', 'direction_id'],
        num_rows: 60218
    })
})
Relation Label Map: {'NO_RELATION': 0, 'HAS_DIRECTION': 1}


In [14]:

# Ensure torch tensors from HF dataset
torch_cols = ['input_ids', 'attention_mask', 'labels']
for split in ['train', 'validation', 'test']:
    if split in dataset:
        dataset[split].set_format(type='torch', columns=torch_cols)



## 2. Tokenizer and Gold Dataset
Load tokenizer (with special markers) and the gold RE dataset for evaluation.


In [15]:

# Tokenizer
model_name = "microsoft/deberta-v3-base"
tokenizer_path = "final_data/tokenizer"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=False, fix_mistral_regex=True)
print(f"Tokenizer loaded with vocab size {len(tokenizer):,}")

# Gold dataset (all samples in test split)
gold_re_path = "final_data_gold/re_dataset"
if not os.path.exists(gold_re_path):
    print(f"Gold RE dataset not found at {gold_re_path}. Running data prep...")
    import subprocess
    subprocess.run(["python3", "6a_prepare_final_deberta_data.py"], check=True)

gold_dataset = load_from_disk(gold_re_path)
print("Gold RE dataset splits:", {k: len(v) for k, v in gold_dataset.items()})
gold_test_dataset = gold_dataset["test"]


Tokenizer loaded with vocab size 128,005
Gold RE dataset splits: {'train': 0, 'validation': 0, 'test': 20781}


In [16]:

for split in ['train', 'validation', 'test']:
    if split in gold_dataset:
        gold_dataset[split].set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])



## 3. Data Splits & DataLoaders
Use the silver train/val/test splits and a dedicated gold test loader.


In [17]:

# Silver splits
train_dataset = dataset["train"]
val_dataset = dataset["validation"]
test_dataset = dataset["test"]

print(f"Split sizes (silver) -> Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
print(f"Gold test size: {len(gold_test_dataset)}")

train_batch_size = 128
eval_batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=eval_batch_size, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, num_workers=4, pin_memory=True)

gold_loader = DataLoader(gold_test_dataset, batch_size=eval_batch_size, num_workers=4, pin_memory=True)


Split sizes (silver) -> Train: 281017, Val: 60218, Test: 60218
Gold test size: 20781



## 4. Class Weights for Imbalance
Compute inverse-frequency class weights from the training set.


In [18]:

label_counts = {id_: 0 for id_ in label2id.values()}
print("Calculating class weights...")
for batch in tqdm(DataLoader(train_dataset, batch_size=512, num_workers=4), desc="Counting labels"):
    labels = batch['labels']
    for lbl in labels:
        label_counts[int(lbl)] += 1

total = sum(label_counts.values())
num_classes = len(label2id)
class_weights = []
for i in range(num_classes):
    count = label_counts.get(i, 1)
    weight = total / (num_classes * count)
    class_weights.append(weight)
class_weights = torch.tensor(class_weights, dtype=torch.float)

print("Label distribution (train):")
for name, idx in label2id.items():
    count = label_counts[idx]
    pct = (count / total * 100) if total else 0
    print(f"  {name}: {count:,} ({pct:.2f}%) weight={class_weights[idx]:.3f}")


Calculating class weights...


Counting labels:   0%|          | 0/549 [00:00<?, ?it/s]

Label distribution (train):
  NO_RELATION: 92,215 (32.81%) weight=1.524
  HAS_DIRECTION: 188,802 (67.19%) weight=0.744



## 5. Model & Optimizer
Load DeBERTa for sequence classification, resize embeddings for special tokens, and set optimizer/scheduler.


In [19]:

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
print(f"Using device: {device}")

model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=len(label2id),
)
# Resize embeddings to account for special tokens added in tokenizer
tokenizer_size = len(tokenizer)
model.resize_token_embeddings(tokenizer_size)
model.to(device)

optimizer = AdamW(model.parameters(), lr=3e-5)

epochs = 3
warmup_ratio = 0.1
total_steps = max(len(train_loader) * epochs, 1)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(warmup_ratio * total_steps),
    num_training_steps=total_steps,
)

loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights.to(device))


Using device: cuda


Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [20]:

from torch.cuda.amp import GradScaler

def train_epoch(model, data_loader, optimizer, scheduler, device, loss_fct):
    model.train()
    total_loss = 0.0
    scaler = GradScaler()
    for batch in tqdm(data_loader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        with torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=torch.cuda.is_available()):
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            loss = loss_fct(logits.view(-1, len(label2id)), labels.view(-1))

        total_loss += loss.item()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
    return total_loss / max(len(data_loader), 1)


def eval_epoch(model, data_loader, device, loss_fct):
    model.eval()
    total_loss = 0.0
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            with torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=torch.cuda.is_available()):
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                loss = loss_fct(logits.view(-1, len(label2id)), labels.view(-1))
            total_loss += loss.item()

            preds = torch.argmax(logits, dim=1)
            all_labels.extend(labels.cpu().tolist())
            all_preds.extend(preds.cpu().tolist())

    avg_loss = total_loss / max(len(data_loader), 1)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary', pos_label=label2id['HAS_DIRECTION'], zero_division=0)
    report = classification_report(all_labels, all_preds, target_names=[id2label[i] for i in range(len(label2id))])
    return avg_loss, precision, recall, f1, report



## 6. Training with Early Stopping
Monitor validation F1 and keep the best checkpoint.


In [21]:

best_f1 = 0.0
patience = 3
patience_counter = 0
best_state = None

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, device, loss_fct)
    val_loss, val_prec, val_rec, val_f1, val_report = eval_epoch(model, val_loader, device, loss_fct)

    print(f"Train loss: {train_loss:.4f}")
    print(f"Val -> loss: {val_loss:.4f}, precision: {val_prec:.4f}, recall: {val_rec:.4f}, f1: {val_f1:.4f}")
    print(val_report)

    if val_f1 > best_f1:
        best_f1 = val_f1
        patience_counter = 0
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        print(f"New best model (F1={best_f1:.4f})")
    else:
        patience_counter += 1
        print(f"No improvement. Patience {patience_counter}/{patience}")
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

if best_state is not None:
    model.load_state_dict(best_state)
    print(f"Loaded best model with val F1={best_f1:.4f}")


Epoch 1/3


Training:   0%|          | 0/2196 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/941 [00:00<?, ?it/s]

Train loss: 0.0524
Val -> loss: 0.0182, precision: 0.9986, recall: 0.9973, f1: 0.9979
               precision    recall  f1-score   support

  NO_RELATION       0.98      0.99      0.99      8118
HAS_DIRECTION       1.00      1.00      1.00     52100

     accuracy                           1.00     60218
    macro avg       0.99      0.99      0.99     60218
 weighted avg       1.00      1.00      1.00     60218

New best model (F1=0.9979)
Epoch 2/3


Training:   0%|          | 0/2196 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/941 [00:00<?, ?it/s]

Train loss: 0.0081
Val -> loss: 0.0144, precision: 0.9987, recall: 0.9973, f1: 0.9980
               precision    recall  f1-score   support

  NO_RELATION       0.98      0.99      0.99      8118
HAS_DIRECTION       1.00      1.00      1.00     52100

     accuracy                           1.00     60218
    macro avg       0.99      0.99      0.99     60218
 weighted avg       1.00      1.00      1.00     60218

New best model (F1=0.9980)
Epoch 3/3


Training:   0%|          | 0/2196 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/941 [00:00<?, ?it/s]

Train loss: 0.0052
Val -> loss: 0.0136, precision: 0.9988, recall: 0.9975, f1: 0.9982
               precision    recall  f1-score   support

  NO_RELATION       0.98      0.99      0.99      8118
HAS_DIRECTION       1.00      1.00      1.00     52100

     accuracy                           1.00     60218
    macro avg       0.99      0.99      0.99     60218
 weighted avg       1.00      1.00      1.00     60218

New best model (F1=0.9982)
Loaded best model with val F1=0.9982



## 7. Save Model
Persist the relation extractor and tokenizer.


In [22]:

save_dir = "models/deberta_re_best"
os.makedirs(save_dir, exist_ok=True)
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
print(f"Saved best model to {save_dir}")


Saved best model to models/deberta_re_best



## 8. Evaluation on Silver Test Split
Compute pair-level precision/recall/F1 on held-out silver data.


In [23]:

test_loss, test_prec, test_rec, test_f1, test_report = eval_epoch(model, test_loader, device, loss_fct)
print("SILVER TEST RESULTS")
print(f"Loss: {test_loss:.4f}")
print(f"Precision: {test_prec:.4f}")
print(f"Recall: {test_rec:.4f}")
print(f"F1: {test_f1:.4f}")
print(test_report)


Evaluating:   0%|          | 0/941 [00:00<?, ?it/s]

SILVER TEST RESULTS
Loss: 0.0055
Precision: 0.9997
Recall: 0.9990
F1: 0.9993
               precision    recall  f1-score   support

  NO_RELATION       0.99      1.00      1.00      7646
HAS_DIRECTION       1.00      1.00      1.00     52572

     accuracy                           1.00     60218
    macro avg       1.00      1.00      1.00     60218
 weighted avg       1.00      1.00      1.00     60218




## 9. Evaluation on Gold Dataset
Evaluate the silver-trained model on gold labels (pair-level metrics).


In [24]:

# Clear GPU cache before eval if available
if torch.cuda.is_available():
    torch.cuda.empty_cache()

gold_loss, gold_prec, gold_rec, gold_f1, gold_report = eval_epoch(model, gold_loader, device, loss_fct)
print("GOLD TEST RESULTS")
print(f"Loss: {gold_loss:.4f}")
print(f"Precision: {gold_prec:.4f}")
print(f"Recall: {gold_rec:.4f}")
print(f"F1: {gold_f1:.4f}")
print(gold_report)


Evaluating:   0%|          | 0/325 [00:00<?, ?it/s]

GOLD TEST RESULTS
Loss: 0.2463
Precision: 0.9268
Recall: 0.7921
F1: 0.8542
               precision    recall  f1-score   support

  NO_RELATION       0.96      0.99      0.97     17486
HAS_DIRECTION       0.93      0.79      0.85      3295

     accuracy                           0.96     20781
    macro avg       0.94      0.89      0.91     20781
 weighted avg       0.96      0.96      0.96     20781




## 11. End-to-End Evaluation (DeBERTa NER spans to RE)
Use predicted spans from the DeBERTa NER model instead of gold spans to approximate real end-to-end performance.


In [30]:

re_pred_loader = DataLoader(re_pred_splits['test'], batch_size=eval_batch_size, num_workers=4, pin_memory=True)

# Evaluate end-to-end on predicted spans (pair-level metrics)
e2e_loss, e2e_prec, e2e_rec, e2e_f1, e2e_report = eval_epoch(model, re_pred_loader, device, loss_fct)
print("END-TO-END RE RESULTS (DeBERTa NER spans → RE)")
print(f"Loss: {e2e_loss:.4f}")
print(f"Precision: {e2e_prec:.4f}")
print(f"Recall: {e2e_rec:.4f}")
print(f"F1: {e2e_f1:.4f}")
print(e2e_report)


Evaluating:   0%|          | 0/326 [00:00<?, ?it/s]

END-TO-END RE RESULTS (DeBERTa NER spans → RE)
Loss: 0.3634
Precision: 0.7886
Recall: 0.9785
F1: 0.8733
               precision    recall  f1-score   support

  NO_RELATION       1.00      0.96      0.98     18395
HAS_DIRECTION       0.79      0.98      0.87      2466

     accuracy                           0.97     20861
    macro avg       0.89      0.97      0.93     20861
 weighted avg       0.97      0.97      0.97     20861

