# DeBERTa-v3 NER for MTA Service Alerts

This notebook implements a Named Entity Recognition (NER) pipeline using `microsoft/deberta-v3-base` to extract `ROUTE` and `DIRECTION` entities from MTA transit alerts.
It uses existing span annotations from the silver dataset.

In [None]:
# Install required packages if missing
%pip install seqeval

In [None]:
import json
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForTokenClassification, get_linear_schedule_with_warmup
from torch.optim import AdamW
from datasets import load_from_disk
from seqeval.metrics import classification_report, f1_score
from tqdm.notebook import tqdm
import os

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Data Loading and Preprocessing

In [19]:
# Load the pre-processed dataset from disk
dataset_path = "final_data/ner_dataset"
dataset = load_from_disk(dataset_path)

print("Loaded dataset from disk:")
print(dataset)

  df['date'] = pd.to_datetime(df['date'])


Total rows: 226160


Unnamed: 0,header,affected_spans,direction_spans
0,A C trains are delayed while we conduct emerge...,"[{'start': 0, 'end': 1, 'type': 'ROUTE', 'valu...",[]
1,L trains are running with delays in both direc...,"[{'start': 0, 'end': 1, 'type': 'ROUTE', 'valu...","[{'start': 36, 'end': 51, 'type': 'DIRECTION',..."
2,Jamaica-bound J trains are delayed while we re...,"[{'start': 14, 'end': 15, 'type': 'ROUTE', 'va...","[{'start': 0, 'end': 13, 'type': 'DIRECTION', ..."


In [20]:
# Load Label Map
with open("final_data/label_mappings.json", "r") as f:
    mappings = json.load(f)

labels_to_ids = mappings["ner_label2id"]
# JSON keys are strings, convert back to int for id2label
ids_to_labels = {int(k): v for k, v in mappings["ner_id2label"].items()}

print("Label Map:", labels_to_ids)

Label Map: {'O': 0, 'B-ROUTE': 1, 'I-ROUTE': 2, 'B-DIRECTION': 3, 'I-DIRECTION': 4}


## 2. Dataset Class & Tokenization

In [21]:
# Load Tokenizer from local directory
tokenizer_path = "final_data/tokenizer"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

# Set format for PyTorch
# The dataset from disk already contains tokenized inputs
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])



## 3. Data Splits & Class Weights

In [None]:
# Access splits
train_dataset = dataset["train"]
val_dataset = dataset["validation"]
test_dataset = dataset["test"]

print(f"Split sizes: Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

batch_size = 64

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=4, pin_memory=True)

Train: 158312, Val: 33924, Test: 33924


In [None]:
# Compute Class Weights for Imbalance
# We count occurrences of each label in the training set
label_counts = {id: 0 for id in labels_to_ids.values()}

print("Calculating class weights...")
# Iterate over the dataset directly using DataLoader for efficiency
for batch in tqdm(DataLoader(train_dataset, batch_size=256, num_workers=4), desc="Counting labels"):
    labels = batch['labels'].flatten()
    # Filter out -100
    valid_labels = labels[labels != -100]
    
    # Update counts
    unique, counts = torch.unique(valid_labels, return_counts=True)
    for lbl, count in zip(unique, counts):
        label_counts[lbl.item()] += count.item()

total_counts = sum(label_counts.values())
num_classes = len(labels_to_ids)

# Enhanced inverse frequency weights with boosting for rare entity types
# Apply extra weight to I-tags and B-DIRECTION which are typically underrepresented
class_weights = []
boost_factors = {
    0: 1.0,   # O - no boost
    1: 1.2,   # B-ROUTE - slight boost
    2: 1.5,   # I-ROUTE - boost for continuity
    3: 1.5,   # B-DIRECTION - boost for rare starts
    4: 1.5    # I-DIRECTION - boost for continuity
}

for i in range(num_classes):
    count = label_counts[i]
    if count > 0:
        # Base inverse frequency weight
        weight = total_counts / (num_classes * count)
        # Apply boost factor
        weight *= boost_factors.get(i, 1.0)
    else:
        weight = 1.0
    class_weights.append(weight)

class_weights = torch.tensor(class_weights, dtype=torch.float)
print("\nLabel distribution in training set:")
for label_name, label_id in labels_to_ids.items():
    count = label_counts[label_id]
    pct = (count / total_counts * 100) if total_counts > 0 else 0
    print(f"  {label_name}: {count:,} ({pct:.2f}%) - weight: {class_weights[label_id]:.3f}")
print(f"\nClass Weights: {class_weights}")

Calculating class weights...


  0%|          | 0/158312 [00:00<?, ?it/s]

Class Weights: tensor([ 0.2196, 23.9177,  9.2577, 10.0748,  5.0693])


## 4. Model Configuration & Training Setup

In [24]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
print(f"Using device: {device}")

# Load Model
model = AutoModelForTokenClassification.from_pretrained(
    'microsoft/deberta-v3-base',
    num_labels=len(labels_to_ids)
)
model.to(device)

# Optimizer with different learning rates
optimizer_grouped_parameters = [
    {'params': model.deberta.parameters(), 'lr': 3e-5},
    {'params': model.classifier.parameters(), 'lr': 1e-4}
]
optimizer = AdamW(optimizer_grouped_parameters)

# Scheduler
epochs = 3 # Start with small number for demo, user can increase
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

# Loss Function with Weights
loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights.to(device))

Using device: cuda


Some weights of DebertaV2ForTokenClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## 5. Training Loop

In [25]:
from torch.cuda.amp import GradScaler

def train_epoch(model, data_loader, optimizer, scheduler, device, loss_fct):
    model.train()
    total_loss = 0
    scaler = GradScaler()  # Initialize gradient scaler for FP16

    for batch in tqdm(data_loader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()

        # Runs the forward pass with autocasting (FP16)
        with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # Reshape logits and labels for loss calculation
            active_logits = logits.view(-1, len(labels_to_ids))
            active_labels = labels.view(-1)

            loss = loss_fct(active_logits, active_labels)

        total_loss += loss.item()

        # Scales loss. Calls backward() on scaled loss to create scaled gradients.
        scaler.scale(loss).backward()

        # Unscales the gradients of optimizer's assigned params in-place
        scaler.unscale_(optimizer)

        # Since the gradients of optimizer's assigned params are unscaled, clips as usual
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # Optimizer's step() is skipped if gradients contain NaNs/Infs
        scaler.step(optimizer)

        # Updates the scale for next iteration
        scaler.update()
        scheduler.step()

    return total_loss / len(data_loader)

def eval_epoch(model, data_loader, device, loss_fct):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Use autocast for evaluation too
            with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                loss = loss_fct(logits.view(-1, len(labels_to_ids)), labels.view(-1))

            total_loss += loss.item()

            # Get predictions
            preds = torch.argmax(logits, dim=2)

            # Convert to list and handle ignores (-100)
            for i in range(len(labels)):
                true_labels = labels[i]
                pred_labels = preds[i]

                true_list = []
                pred_list = []

                for j in range(len(true_labels)):
                    if true_labels[j] != -100:
                        true_list.append(ids_to_labels[true_labels[j].item()])
                        pred_list.append(ids_to_labels[pred_labels[j].item()])

                all_labels.append(true_list)
                all_preds.append(pred_list)

    return total_loss / len(data_loader), all_labels, all_preds

## 6. Execution

In [26]:
best_f1 = 0
patience = 3
patience_counter = 0

for epoch in range(epochs):
    print(f"\nEpoch {epoch + 1}/{epochs}")

    train_loss = train_epoch(model, train_loader, optimizer, scheduler, device, loss_fct)
    print(f"Train Loss: {train_loss:.4f}")

    val_loss, val_labels, val_preds = eval_epoch(model, val_loader, device, loss_fct)

    # Compute Metrics
    f1 = f1_score(val_labels, val_preds)
    report = classification_report(val_labels, val_preds)

    print(f"Val Loss: {val_loss:.4f}")
    print(f"Val F1: {f1:.4f}")
    print("Classification Report:")
    print(report)

    # Save Best Model & Early Stopping
    if f1 > best_f1:
        best_f1 = f1
        patience_counter = 0
        # Save model
        model.save_pretrained("models/deberta_ner_best")
        tokenizer.save_pretrained("models/deberta_ner_best")
        print("New best model saved!")
    else:
        patience_counter += 1
        print(f"No improvement. Patience: {patience_counter}/{patience}")

    if patience_counter >= patience:
        print("Early stopping triggered.")
        break


Epoch 1/3


  scaler = GradScaler()  # Initialize gradient scaler for FP16


Training:   0%|          | 0/2474 [00:00<?, ?it/s]

Train Loss: 0.0644


Evaluating:   0%|          | 0/531 [00:00<?, ?it/s]

Val Loss: 0.0013
Val F1: 0.9939
Classification Report:
              precision    recall  f1-score   support

   DIRECTION       1.00      1.00      1.00     28456
       ROUTE       0.98      1.00      0.99     22878

   micro avg       0.99      1.00      0.99     51334
   macro avg       0.99      1.00      0.99     51334
weighted avg       0.99      1.00      0.99     51334

New best model saved!

Epoch 2/3


Training:   0%|          | 0/2474 [00:00<?, ?it/s]

KeyboardInterrupt: 

## 7. Inference Test

In [None]:
def predict_ner(text, model, tokenizer, device):
    model.eval()
    inputs = tokenizer(text, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=2)[0]

    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    entities = []

    for token, label_id in zip(tokens, preds):
        label = ids_to_labels[label_id.item()]
        if label != "O" and token not in ["[CLS]", "[SEP]"]:
            entities.append((token, label))

    return entities

# Test
test_text = "Jamaica-bound J trains are delayed"
print(predict_ner(test_text, model, tokenizer, device))

## Evaluation

In [28]:
from seqeval.metrics import precision_score, recall_score

best_dir = "models/deberta_ner_best"
eval_tokenizer = AutoTokenizer.from_pretrained(best_dir)
eval_model = AutoModelForTokenClassification.from_pretrained(best_dir)
eval_model.to(device)

# Use the existing test_dataset (loaded from disk)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

def evaluate_ner(model, data_loader, device):
    loss_fct_eval = torch.nn.CrossEntropyLoss(ignore_index=-100)
    all_true = []
    all_pred = []
    total_loss = 0.0
    model.eval()

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating test"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            loss = loss_fct_eval(logits.view(-1, len(labels_to_ids)), labels.view(-1))
            total_loss += loss.item()

            preds = torch.argmax(logits, dim=2)

            for i in range(len(labels)):
                true_seq = []
                pred_seq = []
                for j in range(len(labels[i])):
                    if labels[i][j].item() == -100:
                        continue
                    true_seq.append(ids_to_labels[labels[i][j].item()])
                    pred_seq.append(ids_to_labels[preds[i][j].item()])
                all_true.append(true_seq)
                all_pred.append(pred_seq)

    num_batches = max(len(data_loader), 1)
    return {
        "loss": total_loss / num_batches,
        "precision": precision_score(all_true, all_pred),
        "recall": recall_score(all_true, all_pred),
        "f1": f1_score(all_true, all_pred),
        "report": classification_report(all_true, all_pred)
    }


test_metrics = evaluate_ner(eval_model, test_loader, device)
print(f"Test loss: {test_metrics['loss']:.4f}")
print(f"Test precision: {test_metrics['precision']:.4f}")
print(f"Test recall: {test_metrics['recall']:.4f}")
print(f"Test F1: {test_metrics['f1']:.4f}")
print(test_metrics["report"])


Evaluating test:   0%|          | 0/531 [00:00<?, ?it/s]

Test loss: 0.3153
Test precision: 0.5030
Test recall: 0.7290
Test F1: 0.5953
              precision    recall  f1-score   support

   DIRECTION       0.53      0.44      0.48     30733
       ROUTE       0.50      0.89      0.64     53092

   micro avg       0.50      0.73      0.60     83825
   macro avg       0.51      0.67      0.56     83825
weighted avg       0.51      0.73      0.58     83825

