# DeBERTa-v3 NER for MTA Service Alerts

This notebook implements a Named Entity Recognition (NER) pipeline using deberta-v3-base to extract ROUTE and
DIRECTION entities from MTA transit alerts.
It uses existing span annotations from the silver dataset.

In [2]:
# Install required packages if missing
%pip install seqeval

Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: seqeval
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
  Created wheel for seqeval: filename=seqeval-1.2.2-py3-none-any.whl size=16162 sha256=018e833092b3a9b5faa5fc615ef0aa2d850ea725f447ee1b157fb5cdfb21f0ae
  Stored in directory: /root/.cache/pip/wheels/5f/b8/73/0b2c1a76b701a677653dd79ece07cfabd7457989dbfbdcd8d7
Successfully built seqeval
Installing collected packages: seqeval
Successfully installed seqeval-1.2.2


In [None]:
# Run data preparation script (silver training data + gold evaluation data)
%run 6a_prepare_final_deberta_data.py

Archive:  Preprocessed/Preprocessed.zip
   creating: Preprocessed/Preprocessed/
  inflating: Preprocessed/Preprocessed/MTA_Data_Final_Gold.csv  
  inflating: Preprocessed/Preprocessed/MTA_Data_preprocessed.csv  
Archive:  Preprocessed/Prep_2.zip
  inflating: Preprocessed/MTA_Data_gold_dataset_unfilled.csv  
  inflating: Preprocessed/MTA_Data_silver_relations.csv  


In [3]:
import json
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForTokenClassification, get_linear_schedule_with_warmup
from torch.optim import AdamW
from datasets import load_from_disk
from seqeval.metrics import classification_report, f1_score
from tqdm.notebook import tqdm
import os

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Data Loading and Preprocessing

In [4]:
# Load the pre-processed dataset from disk
dataset_path = "final_data/ner_dataset"
dataset = load_from_disk(dataset_path)

print("Loaded dataset from disk:")
print(dataset)

Loaded dataset from disk:
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 158312
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 33924
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 33924
    })
})


In [5]:
# Load Label Map
with open("final_data/label_mappings.json", "r") as f:
    mappings = json.load(f)

labels_to_ids = mappings["ner_label2id"]
# JSON keys are strings, convert back to int for id2label
ids_to_labels = {int(k): v for k, v in mappings["ner_id2label"].items()}

print("Label Map:", labels_to_ids)

Label Map: {'O': 0, 'B-ROUTE': 1, 'I-ROUTE': 2, 'B-DIRECTION': 3, 'I-DIRECTION': 4}


## 2. Dataset Class & Tokenization

In [None]:
# Load Tokenizer from local directory
tokenizer_path = "final_data/tokenizer"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=False)

# Set format for PyTorch
# The dataset from disk already contains tokenized inputs
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

In [7]:
# Load gold (human-labeled) dataset for evaluation
import os
import subprocess

gold_dataset_path = "final_data_gold/ner_dataset"

if not os.path.exists(gold_dataset_path):
    print(f"Gold dataset not found at {gold_dataset_path}. Running data prep to generate it...")
    subprocess.run(["python3", "6a_prepare_final_deberta_data.py"], check=True)

if not os.path.exists(gold_dataset_path):
    raise FileNotFoundError(f"Gold dataset still missing at {gold_dataset_path} after data prep run.")

gold_dataset = load_from_disk(gold_dataset_path)
gold_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

gold_lengths = {split: len(gold_dataset[split]) for split in ["train", "validation", "test"]}
print(f"Gold dataset splits (train/val empty, all gold in test): {gold_lengths}")

gold_test_dataset = gold_dataset["test"]


Gold dataset splits (train/val empty, all gold in test): {'train': 0, 'validation': 0, 'test': 600}


## 3. Data Splits & Class Weights

In [8]:
# Access splits
train_dataset = dataset["train"]
val_dataset = dataset["validation"]
test_dataset = dataset["test"]

print(f"Split sizes (silver): Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
print(f"Gold total size (all gold in test): {len(gold_test_dataset)}")

train_batch_size = 128
eval_batch_size = 32  # smaller as i got cuda out of memory error during eval

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=eval_batch_size, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, num_workers=4, pin_memory=True)

gold_loader = DataLoader(gold_test_dataset, batch_size=eval_batch_size, num_workers=4, pin_memory=True)


Split sizes (silver): Train: 158312, Val: 33924, Test: 33924
Gold total size (all gold in test): 600


In [9]:
# Compute Class Weights for Imbalance
# Counting occurrences of each label in the training set
label_counts = {id: 0 for id in labels_to_ids.values()}

print("Calculating class weights...")
# Iterate over the dataset directly using DataLoader for efficiency
for batch in tqdm(DataLoader(train_dataset, batch_size=256, num_workers=4), desc="Counting labels"):
    labels = batch['labels'].flatten()
    # Filter out -100
    valid_labels = labels[labels != -100]

    # Update counts
    unique, counts = torch.unique(valid_labels, return_counts=True)
    for lbl, count in zip(unique, counts):
        label_counts[lbl.item()] += count.item()

total_counts = sum(label_counts.values())
num_classes = len(labels_to_ids)

# Enhanced inverse frequency weights with boosting for rare entity types
# Apply extra weight to I-tags and B-DIRECTION which are typically underrepresented
class_weights = []
boost_factors = {
    0: 1.0,   # O - no boost
    1: 1.2,   # B-ROUTE - slight boost
    2: 1.5,   # I-ROUTE - boost for continuity
    3: 1.5,   # B-DIRECTION - boost for rare starts
    4: 1.5    # I-DIRECTION - boost for continuity
}

for i in range(num_classes):
    count = label_counts[i]
    if count > 0:
        # Base inverse frequency weight
        weight = total_counts / (num_classes * count)
        # Apply boost factor
        weight *= boost_factors.get(i, 1.0)
    else:
        weight = 1.0
    class_weights.append(weight)

class_weights = torch.tensor(class_weights, dtype=torch.float)
print("\nLabel distribution in training set:")
for label_name, label_id in labels_to_ids.items():
    count = label_counts[label_id]
    pct = (count / total_counts * 100) if total_counts > 0 else 0
    print(f"  {label_name}: {count:,} ({pct:.2f}%) - weight: {class_weights[label_id]:.3f}")
print(f"\nClass Weights: {class_weights}")

Calculating class weights...


Counting labels:   0%|          | 0/619 [00:00<?, ?it/s]


Label distribution in training set:
  O: 3,866,910 (85.03%) - weight: 0.235
  B-ROUTE: 260,352 (5.72%) - weight: 4.192
  I-ROUTE: 97,280 (2.14%) - weight: 14.025
  B-DIRECTION: 145,074 (3.19%) - weight: 9.404
  I-DIRECTION: 178,064 (3.92%) - weight: 7.662

Class Weights: tensor([ 0.2352,  4.1922, 14.0245,  9.4042,  7.6619])


## 4. Model Configuration & Training Setup

In [10]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
print(f"Using device: {device}")

# Load Model
model = AutoModelForTokenClassification.from_pretrained(
    'microsoft/deberta-v3-base',
    num_labels=len(labels_to_ids)
)
model.to(device)

# Optimizer with different learning rates
optimizer_grouped_parameters = [
    {'params': model.deberta.parameters(), 'lr': 3e-5},
    {'params': model.classifier.parameters(), 'lr': 1e-4}
]
optimizer = AdamW(optimizer_grouped_parameters)

# Scheduler
epochs = 3
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

# Loss Function with Weights
loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights.to(device))

Using device: cuda


pytorch_model.bin:   0%|          | 0.00/371M [00:00<?, ?B/s]

Some weights of DebertaV2ForTokenClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## 5. Training Loop

In [11]:
from torch.cuda.amp import GradScaler

def train_epoch(model, data_loader, optimizer, scheduler, device, loss_fct):
    model.train()
    total_loss = 0.0

    amp_enabled = device.type == "cuda"
    scaler = GradScaler(enabled=amp_enabled and next(model.parameters()).dtype == torch.float32)

    for batch in tqdm(data_loader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=amp_enabled):
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            active_logits = logits.view(-1, len(labels_to_ids))
            active_labels = labels.view(-1)
            loss = loss_fct(active_logits, active_labels)

        total_loss += loss.item()

        if scaler.is_enabled():
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

        scheduler.step()

    return total_loss / max(len(data_loader), 1)


def eval_epoch(model, data_loader, device, loss_fct):
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_labels = []

    amp_enabled = device.type == "cuda"

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            with torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=amp_enabled):
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                loss = loss_fct(logits.view(-1, len(labels_to_ids)), labels.view(-1))

            total_loss += loss.item()
            preds = torch.argmax(logits, dim=2)

            for i in range(len(labels)):
                true_labels = labels[i]
                pred_labels = preds[i]

                true_list = []
                pred_list = []

                for j in range(len(true_labels)):
                    if true_labels[j] != -100:
                        true_list.append(ids_to_labels[true_labels[j].item()])
                        pred_list.append(ids_to_labels[pred_labels[j].item()])

                all_labels.append(true_list)
                all_preds.append(pred_list)

    return total_loss / max(len(data_loader), 1), all_labels, all_preds


## 6. Execution

In [12]:
best_f1 = -1.0
patience = 3
patience_counter = 0

for epoch in range(epochs):
    print(f"\nEpoch {epoch + 1}/{epochs}")

    train_loss = train_epoch(model, train_loader, optimizer, scheduler, device, loss_fct)
    print(f"Train Loss: {train_loss:.4f}")

    val_loss, val_labels, val_preds = eval_epoch(model, val_loader, device, loss_fct)

    # Compute Metrics
    f1 = f1_score(val_labels, val_preds)
    report = classification_report(val_labels, val_preds)

    print(f"Val Loss: {val_loss:.4f}")
    print(f"Val F1: {f1:.4f}")
    print("Classification Report:")
    print(report)

    # Save Best Model & Early Stopping
    if f1 > best_f1:
        best_f1 = f1
        patience_counter = 0
        # Save model
        model.save_pretrained("models/deberta_ner_best")
        tokenizer.save_pretrained("models/deberta_ner_best")
        print("New best model saved!")
    else:
        patience_counter += 1
        print(f"No improvement. Patience: {patience_counter}/{patience}")

    if patience_counter >= patience:
        print("Early stopping triggered.")
        break


Epoch 1/3


Training:   0%|          | 0/1237 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/371M [00:00<?, ?B/s]

Train Loss: 0.0876


Evaluating:   0%|          | 0/1061 [00:00<?, ?it/s]

Val Loss: 0.0017
Val F1: 0.9904
Classification Report:
              precision    recall  f1-score   support

   DIRECTION       0.99      1.00      0.99     29288
       ROUTE       0.98      1.00      0.99     53468

   micro avg       0.98      1.00      0.99     82756
   macro avg       0.98      1.00      0.99     82756
weighted avg       0.98      1.00      0.99     82756

New best model saved!

Epoch 2/3


Training:   0%|          | 0/1237 [00:00<?, ?it/s]

Train Loss: 0.0026


Evaluating:   0%|          | 0/1061 [00:00<?, ?it/s]

Val Loss: 0.0019
Val F1: 0.9930
Classification Report:
              precision    recall  f1-score   support

   DIRECTION       1.00      1.00      1.00     29288
       ROUTE       0.98      1.00      0.99     53468

   micro avg       0.99      1.00      0.99     82756
   macro avg       0.99      1.00      0.99     82756
weighted avg       0.99      1.00      0.99     82756

New best model saved!

Epoch 3/3


Training:   0%|          | 0/1237 [00:00<?, ?it/s]

Train Loss: 0.0017


Evaluating:   0%|          | 0/1061 [00:00<?, ?it/s]

Val Loss: 0.0017
Val F1: 0.9933
Classification Report:
              precision    recall  f1-score   support

   DIRECTION       1.00      1.00      1.00     29288
       ROUTE       0.98      1.00      0.99     53468

   micro avg       0.99      1.00      0.99     82756
   macro avg       0.99      1.00      0.99     82756
weighted avg       0.99      1.00      0.99     82756

New best model saved!


## 7. Small Inference Test

In [13]:
def predict_ner(text, model, tokenizer, device):
    model.eval()
    inputs = tokenizer(text, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=2)[0]

    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    entities = []

    for token, label_id in zip(tokens, preds):
        label = ids_to_labels[label_id.item()]
        if label != "O" and token not in ["[CLS]", "[SEP]"]:
            entities.append((token, label))

    return entities

# Test
test_text = "Jamaica-bound J trains are delayed"
print(predict_ner(test_text, model, tokenizer, device))

[('▁Jamaica', 'B-DIRECTION'), ('-', 'I-DIRECTION'), ('bound', 'I-DIRECTION'), ('▁J', 'B-ROUTE')]


## 8. Silver and Gold Evaluation

In [None]:
from seqeval.metrics import precision_score, recall_score

best_dir = "models/deberta_ner_best"
if os.path.exists(best_dir):
    eval_tokenizer = AutoTokenizer.from_pretrained(best_dir, use_fast=False, fix_mistral_regex=True)
    eval_model = AutoModelForTokenClassification.from_pretrained(best_dir)
    print(f"Loaded best checkpoint from {best_dir}")
else:
    # Fallback keeps notebook runnable if training was interrupted before checkpoint save.
    print(f"Warning: {best_dir} not found. Using current in-memory model for evaluation.")
    eval_tokenizer = tokenizer
    eval_model = model

eval_model.to(device)

# DataLoaders already created above with eval_batch_size

def evaluate_ner(model, data_loader, device):
    loss_fct_eval = torch.nn.CrossEntropyLoss(ignore_index=-100)
    all_true = []
    all_pred = []
    total_loss = 0.0
    model.eval()

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            loss = loss_fct_eval(logits.view(-1, len(labels_to_ids)), labels.view(-1))
            total_loss += loss.item()

            preds = torch.argmax(logits, dim=2)

            for i in range(len(labels)):
                true_seq = []
                pred_seq = []
                for j in range(len(labels[i])):
                    if labels[i][j].item() == -100:
                        continue
                    true_seq.append(ids_to_labels[labels[i][j].item()])
                    pred_seq.append(ids_to_labels[preds[i][j].item()])
                all_true.append(true_seq)
                all_pred.append(pred_seq)

    num_batches = max(len(data_loader), 1)
    return {
        "loss": total_loss / num_batches,
        "precision": precision_score(all_true, all_pred),
        "recall": recall_score(all_true, all_pred),
        "f1": f1_score(all_true, all_pred),
        "report": classification_report(all_true, all_pred)
    }


# Clear cached GPU memory before eval if available
if torch.cuda.is_available():
    torch.cuda.empty_cache()

silver_metrics = evaluate_ner(eval_model, test_loader, device)
print("Silver (silver labels) test performance):")
print(f"  loss: {silver_metrics['loss']:.4f}")
print(f"  precision: {silver_metrics['precision']:.4f}")
print(f"  recall: {silver_metrics['recall']:.4f}")
print(f"  f1: {silver_metrics['f1']:.4f}")
print(silver_metrics["report"])

print("Evaluating on all gold labels (all gold in test split)...")
gold_metrics = evaluate_ner(eval_model, gold_loader, device)
print(f"  loss: {gold_metrics['loss']:.4f}")
print(f"  precision: {gold_metrics['precision']:.4f}")
print(f"  recall: {gold_metrics['recall']:.4f}")
print(f"  f1: {gold_metrics['f1']:.4f}")
print(gold_metrics["report"])


Evaluating:   0%|          | 0/1061 [00:00<?, ?it/s]

Silver (silver labels) test performance):
  loss: 0.0044
  precision: 0.9896
  recall: 0.9997
  f1: 0.9946
              precision    recall  f1-score   support

   DIRECTION       1.00      1.00      1.00     30733
       ROUTE       0.99      1.00      0.99     53084

   micro avg       0.99      1.00      0.99     83817
   macro avg       0.99      1.00      1.00     83817
weighted avg       0.99      1.00      0.99     83817

Evaluating on all gold labels (all gold in test split)...


Evaluating:   0%|          | 0/19 [00:00<?, ?it/s]

  loss: 0.4030
  precision: 0.8142
  recall: 0.8809
  f1: 0.8462
              precision    recall  f1-score   support

   DIRECTION       0.86      0.89      0.87      2428
       ROUTE       0.79      0.87      0.83      3644

   micro avg       0.81      0.88      0.85      6072
   macro avg       0.82      0.88      0.85      6072
weighted avg       0.82      0.88      0.85      6072

