# DeBERTa-v3 NER for MTA Service Alerts

This notebook implements a Named Entity Recognition (NER) pipeline using `microsoft/deberta-v3-base` to extract `ROUTE` and `DIRECTION` entities from MTA transit alerts.
It uses existing span annotations from the silver dataset.

In [None]:
# Install required packages if missing
# !pip install transformers torch pandas scikit-learn seqeval numpy

In [1]:
import json
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForTokenClassification, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from seqeval.metrics import classification_report, f1_score
from tqdm.notebook import tqdm
import os

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

ImportError: cannot import name 'AdamW' from 'transformers' (/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/transformers/__init__.py)

## 1. Data Loading and Preprocessing

In [None]:
# Load the dataset
file_path = "Preprocessed/MTA_Data_silver_relations.csv"
df = pd.read_csv(file_path)

# Parse JSON columns containing spans
# The dataset contains stringified JSON, so we convert them back to objects
df['affected_spans'] = df['affected_spans'].apply(json.loads)
df['direction_spans'] = df['direction_spans'].apply(json.loads)

# Convert date column for temporal splitting
df['date'] = pd.to_datetime(df['date'])

# Display sample
print(f"Total rows: {len(df)}")
df[['header', 'affected_spans', 'direction_spans']].head(3)

In [None]:
# Define Label Map
# 5 classes: O (Outside), B-ROUTE, I-ROUTE, B-DIRECTION, I-DIRECTION
labels_to_ids = {
    'O': 0,
    'B-ROUTE': 1,
    'I-ROUTE': 2,
    'B-DIRECTION': 3,
    'I-DIRECTION': 4
}
ids_to_labels = {v: k for k, v in labels_to_ids.items()}

print("Label Map:", labels_to_ids)

## 2. Dataset Class & Tokenization

In [None]:
class MTANERDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=128):
        self.data = df
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.label_map = labels_to_ids

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        row = self.data.iloc[index]
        text = row['header']
        
        # Combine spans
        # Structure: [{'start': 0, 'end': 1, 'type': 'ROUTE', ...}]
        spans = row['affected_spans'] + row['direction_spans']

        # Tokenize text
        # return_offsets_mapping=True gives us character start/end for each token
        encoding = self.tokenizer(
            text,
            return_offsets_mapping=True,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors='pt'
        )

        # Get offsets (remove batch dimension)
        offset_mapping = encoding['offset_mapping'][0].tolist()
        
        # Initialize labels with 'O' (0)
        # -100 is used by PyTorch to ignore loss for special tokens (CLS, SEP, PAD)
        labels = [0] * len(offset_mapping)

        # Create BIO tags
        # Iterate over tokens and check if they fall inside any span
        for idx, (start, end) in enumerate(offset_mapping):
            # Skip special tokens (start=0, end=0 usually implies special or empty)
            # But be careful: (0,0) could be valid for some tokenizers, 
            # DeBERTa uses (0,0) for CLS/SEP/PAD usually.
            if start == end:
                labels[idx] = -100
                continue

            # Check each span
            token_label = 'O'
            for span in spans:
                span_start = span['start']
                span_end = span['end']
                span_type = span['type']  # ROUTE or DIRECTION

                # Check overlap
                if start >= span_start and end <= span_end:
                    # If this token starts at the beginning of the span, it's B-
                    # Note: strict equality (start == span_start) might miss subwords 
                    # if tokenizer splits differently. 
                    # We use a simple logic: if it's the first token encountered for this span, B-, else I-
                    # Better logic: if start == span_start -> B, else I.
                    
                    if start == span_start:
                        token_label = f"B-{span_type}"
                    else:
                        token_label = f"I-{span_type}"
                    break # Found a label, stop checking spans
            
            if token_label != 'O':
                labels[idx] = self.label_map[token_label]

        # Convert to tensor
        item = {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(labels, dtype=torch.long)
        }
        return item

# Initialize Tokenizer
tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-base')

## 3. Temporal Data Splits

Implement true temporal split without shuffling to prevent data leakage.

In [None]:
# Sort by date for temporal split
df_sorted = df.sort_values('date').reset_index(drop=True)

# Compute complexity metric for analysis (optional - not used for splitting)
df_sorted['num_dirs'] = df_sorted['direction_spans'].apply(len)
df_sorted['num_routes'] = df_sorted['affected_spans'].apply(len)
df_sorted['complexity_bin'] = pd.cut(
    df_sorted['num_dirs'] + df_sorted['num_routes'], 
    bins=[-1, 0, 1, 2, float('inf')], 
    labels=['none', 'single', 'double', 'multi']
)

print("Complexity distribution:")
print(df_sorted['complexity_bin'].value_counts())

# TRUE Temporal Split: 70% Train, 15% Val, 15% Test (NO SHUFFLING)
n = len(df_sorted)
train_end = int(n * 0.70)
val_end = int(n * 0.85)

train_df = df_sorted.iloc[:train_end].reset_index(drop=True)
val_df = df_sorted.iloc[train_end:val_end].reset_index(drop=True)
test_df = df_sorted.iloc[val_end:].reset_index(drop=True)

print(f"\nSplit sizes: Train: {len(train_df):,}, Val: {len(val_df):,}, Test: {len(test_df):,}")

# Verify temporal ordering
print(f"\nDate ranges:")
print(f"  Train: {train_df['date'].min()} to {train_df['date'].max()}")
print(f"  Val:   {val_df['date'].min()} to {val_df['date'].max()}")
print(f"  Test:  {test_df['date'].min()} to {test_df['date'].max()}")

# Show complexity distribution per split (for analysis only)
print("\nComplexity distribution per split:")
for name, split_df in [("Train", train_df), ("Val", val_df), ("Test", test_df)]:
    dist = split_df['complexity_bin'].value_counts(normalize=True) * 100
    print(f"{name}: {dict(dist.round(1))}")

# Create Datasets & DataLoaders
train_dataset = MTANERDataset(train_df, tokenizer)
val_dataset = MTANERDataset(val_df, tokenizer)
test_dataset = MTANERDataset(test_df, tokenizer)

batch_size = 16

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

## 3.5 Rule-Based Baseline

Simple regex-based extractor to establish baseline performance and verify data isn't trivially solvable.

In [None]:
# Compute Class Weights for Imbalance
# We count occurrences of each label in the training set
label_counts = {id: 0 for id in labels_to_ids.values()}

# Sample a subset to estimate weights to avoid iterating full dataset if huge
# Or iterate full dataset (safe for 200k rows)
print("Calculating class weights...")
for i in tqdm(range(len(train_dataset))):
    labels = train_dataset[i]['labels'].numpy()
    # Filter out -100
    valid_labels = labels[labels != -100]
    for label in valid_labels:
        label_counts[label] += 1

total_counts = sum(label_counts.values())
num_classes = len(labels_to_ids)

print("\nLabel distribution:")
for label_name, label_id in labels_to_ids.items():
    count = label_counts[label_id]
    pct = (count / total_counts) * 100 if total_counts > 0 else 0
    print(f"  {label_name}: {count:,} ({pct:.2f}%)")

# Inverse frequency weights with smoothing and manual boosting for rare classes
# Base weights: total / (num_classes * count)
class_weights = []
for i in range(num_classes):
    count = label_counts[i]
    if count > 0:
        # Smoothed inverse frequency
        weight = total_counts / (num_classes * count)
        # Cap weights to prevent extreme values
        weight = min(weight, 10.0)
    else:
        weight = 1.0
    class_weights.append(weight)

# Manual adjustments based on analysis:
# - Boost B-DIRECTION (idx 3) for better recall on direction starts
# - Boost I-ROUTE (idx 2) and I-DIRECTION (idx 4) for multi-token entities
class_weights = torch.tensor(class_weights, dtype=torch.float)

# Apply additional boost for rare continuation tags
boost_factors = torch.tensor([1.0, 1.0, 1.5, 1.5, 1.5], dtype=torch.float)  # O, B-ROUTE, I-ROUTE, B-DIR, I-DIR
class_weights = class_weights * boost_factors

print("\nFinal Class Weights:")
for label_name, label_id in labels_to_ids.items():
    print(f"  {label_name}: {class_weights[label_id]:.3f}")

## 4. Model Configuration & Training Setup

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
print(f"Using device: {device}")

# Load Model
model = AutoModelForTokenClassification.from_pretrained(
    'microsoft/deberta-v3-base',
    num_labels=len(labels_to_ids)
)
model.to(device)

# Optimizer with different learning rates
optimizer_grouped_parameters = [
    {'params': model.deberta.parameters(), 'lr': 3e-5},
    {'params': model.classifier.parameters(), 'lr': 1e-4}
]
optimizer = AdamW(optimizer_grouped_parameters)

# Scheduler
epochs = 3 # Start with small number for demo, user can increase
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=int(0.1 * total_steps), 
    num_training_steps=total_steps
)

# Loss Function with Weights
loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights.to(device))

## 5. Training Loop

In [None]:
def train_epoch(model, data_loader, optimizer, scheduler, device, loss_fct):
    model.train()
    total_loss = 0
    
    for batch in tqdm(data_loader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        
        # Calculate loss manually to apply class weights
        # Flatten logits and labels
        active_loss = attention_mask.view(-1) == 1
        active_logits = logits.view(-1, len(labels_to_ids))
        active_labels = labels.view(-1)
        
        # Only compute loss for active labels (not -100 and not padded)
        # PyTorch CrossEntropy handles -100 ignoring automatically if we pass it,
        # but we also want to apply our class weights.
        loss = loss_fct(active_logits, active_labels)

        total_loss += loss.item()
        loss.backward()
        
        # Gradient Clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        scheduler.step()
        
    return total_loss / len(data_loader)

def eval_epoch(model, data_loader, device, loss_fct):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            
            loss = loss_fct(logits.view(-1, len(labels_to_ids)), labels.view(-1))
            total_loss += loss.item()
            
            # Get predictions
            preds = torch.argmax(logits, dim=2)
            
            # Convert to list and handle ignores (-100)
            for i in range(len(labels)):
                true_labels = labels[i]
                pred_labels = preds[i]
                
                true_list = []
                pred_list = []
                
                for j in range(len(true_labels)):
                    if true_labels[j] != -100:
                        true_list.append(ids_to_labels[true_labels[j].item()])
                        pred_list.append(ids_to_labels[pred_labels[j].item()])
                
                all_labels.append(true_list)
                all_preds.append(pred_list)
                
    return total_loss / len(data_loader), all_labels, all_preds

## 6. Execution

In [None]:
best_f1 = 0
patience = 3
patience_counter = 0

for epoch in range(epochs):
    print(f"\nEpoch {epoch + 1}/{epochs}")
    
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, device, loss_fct)
    print(f"Train Loss: {train_loss:.4f}")
    
    val_loss, val_labels, val_preds = eval_epoch(model, val_loader, device, loss_fct)
    
    # Compute Metrics
    f1 = f1_score(val_labels, val_preds)
    report = classification_report(val_labels, val_preds)
    
    print(f"Val Loss: {val_loss:.4f}")
    print(f"Val F1: {f1:.4f}")
    print("Classification Report:")
    print(report)
    
    # Save Best Model & Early Stopping
    if f1 > best_f1:
        best_f1 = f1
        patience_counter = 0
        # Save model
        model.save_pretrained("models/deberta_ner_best")
        tokenizer.save_pretrained("models/deberta_ner_best")
        print("New best model saved!")
    else:
        patience_counter += 1
        print(f"No improvement. Patience: {patience_counter}/{patience}")
        
    if patience_counter >= patience:
        print("Early stopping triggered.")
        break

## 7. Evaluation on Test Set

Evaluate the saved best NER model on the held-out test split using span-aware metrics.


In [None]:
from seqeval.metrics import precision_score, recall_score

best_dir = "models/deberta_ner_best"
eval_tokenizer = AutoTokenizer.from_pretrained(best_dir, fix_mistral_regex=True)
eval_model = AutoModelForTokenClassification.from_pretrained(best_dir)
eval_model.to(device)

test_dataset = MTANERDataset(test_df, eval_tokenizer)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

def evaluate_ner(model, data_loader, device):
    loss_fct_eval = torch.nn.CrossEntropyLoss(ignore_index=-100)
    all_true = []
    all_pred = []
    total_loss = 0.0
    model.eval()

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating test"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            loss = loss_fct_eval(logits.view(-1, len(labels_to_ids)), labels.view(-1))
            total_loss += loss.item()

            preds = torch.argmax(logits, dim=2)

            for i in range(len(labels)):
                true_seq = []
                pred_seq = []
                for j in range(len(labels[i])):
                    if labels[i][j].item() == -100:
                        continue
                    true_seq.append(ids_to_labels[labels[i][j].item()])
                    pred_seq.append(ids_to_labels[preds[i][j].item()])
                all_true.append(true_seq)
                all_pred.append(pred_seq)

    num_batches = max(len(data_loader), 1)
    return {
        "loss": total_loss / num_batches,
        "precision": precision_score(all_true, all_pred),
        "recall": recall_score(all_true, all_pred),
        "f1": f1_score(all_true, all_pred),
        "report": classification_report(all_true, all_pred)
    }


test_metrics = evaluate_ner(eval_model, test_loader, device)
print(f"Test loss: {test_metrics['loss']:.4f}")
print(f"Test precision: {test_metrics['precision']:.4f}")
print(f"Test recall: {test_metrics['recall']:.4f}")
print(f"Test F1: {test_metrics['f1']:.4f}")
print(test_metrics["report"])


## 8. Inference Test

In [None]:
def predict_ner(text, model, tokenizer, device):
    model.eval()
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=2)[0]
    
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    entities = []
    
    for token, label_id in zip(tokens, preds):
        label = ids_to_labels[label_id.item()]
        if label != "O" and token not in ["[CLS]", "[SEP]"]:
            entities.append((token, label))
            
    return entities

active_model = eval_model if "eval_model" in globals() else model
active_tokenizer = eval_tokenizer if "eval_tokenizer" in globals() else tokenizer

# Test
test_text = "Jamaica-bound J trains are delayed"
print(predict_ner(test_text, active_model, active_tokenizer, device))