## 0. Package Installation

In [18]:
# Install required packages if missing
%pip install pytorch-crf seqeval



In [None]:
import json
import re
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from sklearn.model_selection import train_test_split
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from tqdm.notebook import tqdm
from collections import Counter
from typing import Dict, List, Tuple, Optional
import os

# Import CRF
try:
    from torchcrf import CRF
    CRF_AVAILABLE = True
    print("pytorch-crf available")
except ImportError:
    CRF_AVAILABLE = False
    print("pytorch-crf not installed. Install with: pip install pytorch-crf")
    print("Model will fall back to cross-entropy loss.")

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

✓ pytorch-crf available
PyTorch version: 2.9.0+cu126
CUDA available: True


## 1. Data Loading and Preprocessing

In [20]:
# Load the dataset
file_path = "Preprocessed/MTA_Data_silver_relations.csv"
df = pd.read_csv(file_path)

# Parse JSON columns containing spans
df['affected_spans'] = df['affected_spans'].apply(lambda x: json.loads(x) if pd.notna(x) and x != '[]' else [])
df['direction_spans'] = df['direction_spans'].apply(lambda x: json.loads(x) if pd.notna(x) and x != '[]' else [])

# Convert date column for temporal splitting
df['date'] = pd.to_datetime(df['date'])

# Display sample
print(f"Total rows: {len(df):,}")
print(f"\nSample data:")
df[['header', 'affected_spans', 'direction_spans']].head(3)

  df['date'] = pd.to_datetime(df['date'])


Total rows: 226,160

Sample data:


Unnamed: 0,header,affected_spans,direction_spans
0,A C trains are delayed while we conduct emerge...,"[{'start': 0, 'end': 1, 'type': 'ROUTE', 'valu...",[]
1,L trains are running with delays in both direc...,"[{'start': 0, 'end': 1, 'type': 'ROUTE', 'valu...","[{'start': 36, 'end': 51, 'type': 'DIRECTION',..."
2,Jamaica-bound J trains are delayed while we re...,"[{'start': 14, 'end': 15, 'type': 'ROUTE', 'va...","[{'start': 0, 'end': 13, 'type': 'DIRECTION', ..."


In [21]:
# Define Label Map (same as DeBERTa)
labels_to_ids = {
    'O': 0,
    'B-ROUTE': 1,
    'I-ROUTE': 2,
    'B-DIRECTION': 3,
    'I-DIRECTION': 4
}
ids_to_labels = {v: k for k, v in labels_to_ids.items()}

# Special tokens
PAD_TOKEN = "<PAD>"
UNK_TOKEN = "<UNK>"
PAD_IDX = 0
UNK_IDX = 1
CHAR_PAD_IDX = 0
CHAR_UNK_IDX = 1
IGNORE_INDEX = -100

print("Label Map:", labels_to_ids)

Label Map: {'O': 0, 'B-ROUTE': 1, 'I-ROUTE': 2, 'B-DIRECTION': 3, 'I-DIRECTION': 4}


## 2. Stratified Temporal Splits

Using the same split strategy as DeBERTa: sort by date and stratify by complexity.

In [22]:
# Sort by date for temporal split
df_sorted = df.sort_values('date').reset_index(drop=True)

# Compute complexity metric for analysis (optional - not used for splitting)
df_sorted['num_dirs'] = df_sorted['direction_spans'].apply(len)
df_sorted['num_routes'] = df_sorted['affected_spans'].apply(len)
df_sorted['complexity_bin'] = pd.cut(
    df_sorted['num_dirs'] + df_sorted['num_routes'],
    bins=[-1, 0, 1, 2, float('inf')],
    labels=['none', 'single', 'double', 'multi']
)

print("Complexity distribution:")
print(df_sorted['complexity_bin'].value_counts())

# TRUE Temporal Split: 70% Train, 15% Val, 15% Test (NO SHUFFLING)
n = len(df_sorted)
train_end = int(n * 0.70)
val_end = int(n * 0.85)

train_df = df_sorted.iloc[:train_end].reset_index(drop=True)
val_df = df_sorted.iloc[train_end:val_end].reset_index(drop=True)
test_df = df_sorted.iloc[val_end:].reset_index(drop=True)

print(f"\nSplit sizes: Train: {len(train_df):,}, Val: {len(val_df):,}, Test: {len(test_df):,}")

# Verify temporal ordering
print(f"\nDate ranges:")
print(f"  Train: {train_df['date'].min()} to {train_df['date'].max()}")
print(f"  Val:   {val_df['date'].min()} to {val_df['date'].max()}")
print(f"  Test:  {test_df['date'].min()} to {test_df['date'].max()}")

# Show complexity distribution per split (for analysis only)
print("\nComplexity distribution per split:")
for name, split_df in [("Train", train_df), ("Val", val_df), ("Test", test_df)]:
    dist = split_df['complexity_bin'].value_counts(normalize=True) * 100
    print(f"{name}: {dict(dist.round(1))}")

Complexity distribution:
complexity_bin
double    102355
multi      85468
single     38337
none           0
Name: count, dtype: int64

Split sizes: Train: 158,312, Val: 33,924, Test: 33,924

Date ranges:
  Train: 2020-04-28 13:12:00 to 2024-04-11 17:52:00
  Val:   2024-04-11 17:52:00 to 2025-01-02 01:39:00
  Test:  2025-01-02 01:54:00 to 2025-08-30 23:55:00

Complexity distribution per split:
Train: {'double': np.float64(45.1), 'multi': np.float64(37.7), 'single': np.float64(17.2), 'none': np.float64(0.0)}
Val: {'double': np.float64(43.5), 'multi': np.float64(38.1), 'single': np.float64(18.5), 'none': np.float64(0.0)}
Test: {'double': np.float64(47.8), 'multi': np.float64(38.0), 'single': np.float64(14.2), 'none': np.float64(0.0)}


## 3. Build Vocabularies

Create word and character vocabularies from the training set.

In [23]:
def word_tokenize(text: str) -> List[Tuple[str, int, int]]:
    """
    Tokenize text into words with character offsets.
    Returns list of (token, start_char, end_char) tuples.
    """
    tokens = []
    for match in re.finditer(r"\S+", text):
        word = match.group()
        start = match.start()
        # Split on punctuation but keep as separate tokens
        for sub in re.finditer(r"[A-Za-z0-9]+|[^\sA-Za-z0-9]", word):
            s = start + sub.start()
            e = s + len(sub.group())
            tokens.append((sub.group(), s, e))
    return tokens

# Test tokenizer
sample_text = "Jamaica-bound J trains are delayed."
print(f"Sample tokenization: '{sample_text}'")
print(word_tokenize(sample_text))

Sample tokenization: 'Jamaica-bound J trains are delayed.'
[('Jamaica', 0, 7), ('-', 7, 8), ('bound', 8, 13), ('J', 14, 15), ('trains', 16, 22), ('are', 23, 26), ('delayed', 27, 34), ('.', 34, 35)]


In [24]:
# Build word vocabulary from training set
print("Building word vocabulary from training set...")
word_counter = Counter()
char_set = set()

for _, row in tqdm(train_df.iterrows(), total=len(train_df), desc="Scanning text"):
    text = str(row['header']) if pd.notna(row['header']) else ""
    tokens = word_tokenize(text)
    for word, _, _ in tokens:
        word_lower = word.lower()
        word_counter[word_lower] += 1
        char_set.update(word)

# Create word vocabulary with frequency filtering
MIN_WORD_FREQ = 2
word2idx = {PAD_TOKEN: PAD_IDX, UNK_TOKEN: UNK_IDX}
for word, count in word_counter.most_common():
    if count >= MIN_WORD_FREQ:
        word2idx[word] = len(word2idx)

# Create character vocabulary
char2idx = {"<PAD>": CHAR_PAD_IDX, "<UNK>": CHAR_UNK_IDX}
for char in sorted(char_set):
    char2idx[char] = len(char2idx)

print(f"\nWord vocabulary size: {len(word2idx):,}")
print(f"Character vocabulary size: {len(char2idx):,}")
print(f"\nMost common words: {list(word_counter.most_common(20))}")

Building word vocabulary from training set...


Scanning text:   0%|          | 0/158312 [00:00<?, ?it/s]


Word vocabulary size: 4,156
Character vocabulary size: 97

Most common words: [('.', 189886), ('st', 141834), ('trains', 129927), ('are', 123956), ('a', 122700), ('-', 107456), ('at', 102195), ('we', 78062), ('to', 66080), ('running', 65666), ('with', 61721), ('and', 60923), ('train', 59919), ('after', 59227), ('from', 58764), ('the', 58460), (':', 55498), ('av', 52959), ('delays', 52655), ('in', 50199)]


In [25]:
# Initialize word embeddings with Xavier uniform
WORD_EMBEDDING_DIM = 128

def initialize_embeddings(vocab_size: int, embedding_dim: int) -> np.ndarray:
    """Xavier uniform initialization for embeddings."""
    limit = np.sqrt(6.0 / (vocab_size + embedding_dim))
    embeddings = np.random.uniform(-limit, limit, (vocab_size, embedding_dim)).astype(np.float32)
    embeddings[PAD_IDX] = 0.0  # Keep padding as zeros
    return embeddings

pretrained_embeddings = initialize_embeddings(len(word2idx), WORD_EMBEDDING_DIM)
print(f"Initialized word embeddings: shape {pretrained_embeddings.shape}")

Initialized word embeddings: shape (4156, 128)


## 4. Dataset Class

PyTorch Dataset for BiLSTM-CRF NER with BIO label assignment.

In [26]:
class MTANERDataset(Dataset):
    """Dataset for NER with BiLSTM-CRF."""

    def __init__(
        self,
        dataframe: pd.DataFrame,
        word2idx: Dict[str, int],
        char2idx: Dict[str, int],
        max_seq_length: int = 128,
        max_word_length: int = 20,
    ):
        self.word2idx = word2idx
        self.char2idx = char2idx
        self.max_seq_length = max_seq_length
        self.max_word_length = max_word_length
        self.samples = self._build_samples(dataframe)

    def _build_samples(self, df: pd.DataFrame) -> List[Dict]:
        """Build samples with tokens and BIO labels."""
        samples = []
        for _, row in df.iterrows():
            text = str(row['header']) if pd.notna(row['header']) else ""
            tokens = word_tokenize(text)
            if not tokens:
                continue

            routes = row['affected_spans']
            directions = row['direction_spans']
            labels = self._assign_labels(tokens, routes, directions)

            samples.append({
                'tokens': tokens,
                'labels': labels,
                'text': text
            })
        return samples

    @staticmethod
    def _assign_labels(
        tokens: List[Tuple[str, int, int]],
        routes: List[Dict],
        directions: List[Dict]
    ) -> List[int]:
        """Assign BIO labels based on character-level span overlap."""
        labels = [labels_to_ids['O']] * len(tokens)

        def mark(spans: List[Dict], b_label: int, i_label: int):
            for span in spans:
                start, end = span['start'], span['end']
                inside = False
                for i, (_, token_start, token_end) in enumerate(tokens):
                    # Token overlaps with span
                    if token_start >= start and token_end <= end:
                        labels[i] = b_label if not inside else i_label
                        inside = True
                    elif token_end > end:
                        break

        mark(routes, labels_to_ids['B-ROUTE'], labels_to_ids['I-ROUTE'])
        mark(directions, labels_to_ids['B-DIRECTION'], labels_to_ids['I-DIRECTION'])
        return labels

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        sample = self.samples[idx]
        tokens = sample['tokens'][:self.max_seq_length]
        labels = sample['labels'][:self.max_seq_length]
        length = len(tokens)

        # Convert tokens to word IDs
        word_ids = []
        for word, _, _ in tokens:
            word_lower = word.lower()
            word_ids.append(self.word2idx.get(word_lower, UNK_IDX))

        # Convert tokens to character IDs
        char_ids = []
        for word, _, _ in tokens:
            word_chars = []
            for c in word[:self.max_word_length]:
                word_chars.append(self.char2idx.get(c, CHAR_UNK_IDX))
            # Pad word to max_word_length
            word_chars += [CHAR_PAD_IDX] * (self.max_word_length - len(word_chars))
            char_ids.append(word_chars)

        # Pad sequences
        pad_len = self.max_seq_length - length
        word_ids += [PAD_IDX] * pad_len
        labels += [IGNORE_INDEX] * pad_len
        char_ids += [[CHAR_PAD_IDX] * self.max_word_length] * pad_len

        return {
            'word_ids': torch.tensor(word_ids, dtype=torch.long),
            'char_ids': torch.tensor(char_ids, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long),
            'lengths': torch.tensor(length, dtype=torch.long)
        }

# Create datasets
print("Creating datasets...")
train_dataset = MTANERDataset(train_df, word2idx, char2idx)
val_dataset = MTANERDataset(val_df, word2idx, char2idx)
test_dataset = MTANERDataset(test_df, word2idx, char2idx)

print(f"Train samples: {len(train_dataset):,}")
print(f"Val samples: {len(val_dataset):,}")
print(f"Test samples: {len(test_dataset):,}")

# Show a sample
sample = train_dataset[0]
print(f"\nSample shapes:")
print(f"  word_ids: {sample['word_ids'].shape}")
print(f"  char_ids: {sample['char_ids'].shape}")
print(f"  labels: {sample['labels'].shape}")
print(f"  length: {sample['lengths'].item()}")

Creating datasets...
Train samples: 158,312
Val samples: 33,924
Test samples: 33,924

Sample shapes:
  word_ids: torch.Size([128])
  char_ids: torch.Size([128, 20])
  labels: torch.Size([128])
  length: 47


## 5. Model Architecture

BiLSTM-CRF with CharCNN for character-level features.

In [27]:
class CharCNN(nn.Module):
    """Character-level CNN for word representations."""

    def __init__(
        self,
        char_vocab_size: int,
        char_embedding_dim: int = 50,
        num_filters: int = 50,
        kernel_sizes: List[int] = [3, 4, 5],
        dropout: float = 0.3,
    ):
        super().__init__()
        self.char_embedding = nn.Embedding(
            char_vocab_size, char_embedding_dim, padding_idx=CHAR_PAD_IDX
        )

        self.convs = nn.ModuleList([
            nn.Conv1d(char_embedding_dim, num_filters, ks, padding=ks // 2)
            for ks in kernel_sizes
        ])

        self.dropout = nn.Dropout(dropout)
        self.output_dim = num_filters * len(kernel_sizes)

    def forward(self, char_ids: torch.Tensor) -> torch.Tensor:
        """
        Args:
            char_ids: [batch_size, seq_len, max_word_len]
        Returns:
            char_repr: [batch_size, seq_len, output_dim]
        """
        batch_size, seq_len, max_word_len = char_ids.shape

        # Flatten: [batch_size * seq_len, max_word_len]
        char_ids = char_ids.view(-1, max_word_len)

        # Embed: [batch * seq_len, max_word_len, char_emb_dim]
        char_emb = self.char_embedding(char_ids)

        # Transpose for Conv1d: [batch * seq_len, char_emb_dim, max_word_len]
        char_emb = char_emb.transpose(1, 2)

        # Apply convolutions and max-pool
        conv_outputs = []
        for conv in self.convs:
            conv_out = F.relu(conv(char_emb))
            pooled = F.max_pool1d(conv_out, conv_out.size(2)).squeeze(2)
            conv_outputs.append(pooled)

        # Concatenate: [batch * seq_len, output_dim]
        char_repr = torch.cat(conv_outputs, dim=1)
        char_repr = self.dropout(char_repr)

        # Reshape: [batch_size, seq_len, output_dim]
        char_repr = char_repr.view(batch_size, seq_len, -1)

        return char_repr

print("✓ CharCNN defined")

✓ CharCNN defined


In [28]:
class BiLSTMCRFNER(nn.Module):
    """BiLSTM-CRF model for Named Entity Recognition."""

    def __init__(
        self,
        vocab_size: int,
        char_vocab_size: int,
        num_labels: int = len(labels_to_ids),
        word_embedding_dim: int = 128,
        char_embedding_dim: int = 50,
        hidden_dim: int = 256,
        num_layers: int = 2,
        dropout: float = 0.3,
        pretrained_embeddings: Optional[np.ndarray] = None,
    ):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_labels = num_labels

        # Word embeddings
        self.word_embedding = nn.Embedding(vocab_size, word_embedding_dim, padding_idx=PAD_IDX)
        if pretrained_embeddings is not None:
            self.word_embedding.weight.data.copy_(torch.from_numpy(pretrained_embeddings))

        # Character CNN
        self.char_cnn = CharCNN(
            char_vocab_size=char_vocab_size,
            char_embedding_dim=char_embedding_dim,
            num_filters=50,
            kernel_sizes=[3, 4, 5],
            dropout=dropout,
        )

        # Combined embedding dimension
        combined_dim = word_embedding_dim + self.char_cnn.output_dim

        # BiLSTM encoder
        self.lstm = nn.LSTM(
            input_size=combined_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0,
        )

        self.dropout = nn.Dropout(dropout)

        # Linear projection to label space
        self.hidden2tag = nn.Linear(hidden_dim * 2, num_labels)

        # CRF layer
        if CRF_AVAILABLE:
            self.crf = CRF(num_labels, batch_first=True)
        else:
            self.crf = None

    def _get_lstm_features(
        self,
        word_ids: torch.Tensor,
        char_ids: torch.Tensor,
        lengths: torch.Tensor
    ) -> torch.Tensor:
        """Get BiLSTM output features (emissions)."""
        batch_size, seq_len = word_ids.shape

        # Word embeddings
        word_emb = self.word_embedding(word_ids)

        # Character embeddings
        char_emb = self.char_cnn(char_ids)

        # Concatenate
        combined = torch.cat([word_emb, char_emb], dim=-1)
        combined = self.dropout(combined)

        # BiLSTM without packing (to maintain sequence length)
        lstm_out, _ = self.lstm(combined)
        lstm_out = self.dropout(lstm_out)

        # Project to label space
        emissions = self.hidden2tag(lstm_out)

        return emissions

    def forward(
        self,
        word_ids: torch.Tensor,
        char_ids: torch.Tensor,
        labels: torch.Tensor,
        lengths: torch.Tensor
    ) -> torch.Tensor:
        """Compute loss."""
        emissions = self._get_lstm_features(word_ids, char_ids, lengths)

        if self.crf is not None:
            # CRF loss
            batch_size, seq_len = word_ids.shape
            mask = torch.arange(seq_len, device=word_ids.device).unsqueeze(0) < lengths.unsqueeze(1)

            # Replace IGNORE_INDEX with 0 for CRF (masked anyway)
            labels_for_crf = labels.clone()
            labels_for_crf[labels == IGNORE_INDEX] = 0

            # CRF returns negative log-likelihood
            loss = -self.crf(emissions, labels_for_crf, mask=mask, reduction='mean')
        else:
            # Fallback to cross-entropy
            loss = F.cross_entropy(
                emissions.view(-1, self.num_labels),
                labels.view(-1),
                ignore_index=IGNORE_INDEX,
            )

        return loss

    def decode(
        self,
        word_ids: torch.Tensor,
        char_ids: torch.Tensor,
        lengths: torch.Tensor
    ) -> List[List[int]]:
        """Decode best label sequence using Viterbi."""
        emissions = self._get_lstm_features(word_ids, char_ids, lengths)

        if self.crf is not None:
            batch_size, seq_len = word_ids.shape
            mask = torch.arange(seq_len, device=word_ids.device).unsqueeze(0) < lengths.unsqueeze(1)
            predictions = self.crf.decode(emissions, mask=mask)
        else:
            # Greedy decoding
            predictions = emissions.argmax(dim=-1).tolist()
            predictions = [pred[:length] for pred, length in zip(predictions, lengths.tolist())]

        return predictions

print("✓ BiLSTMCRFNER defined")

✓ BiLSTMCRFNER defined


## 6. Training Configuration & Class Weights

In [29]:
# Create DataLoaders
batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

print(f"Batch size: {batch_size}")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

Batch size: 64
Train batches: 2474
Val batches: 531
Test batches: 531


In [30]:
# Compute class weights (same strategy as DeBERTa)
print("Calculating class weights from training set...")
label_counts = {label_id: 0 for label_id in labels_to_ids.values()}

for sample in tqdm(train_dataset, desc="Counting labels"):
    labels = sample['labels'].numpy()
    valid_labels = labels[labels != IGNORE_INDEX]
    for label in valid_labels:
        label_counts[label] += 1

total_counts = sum(label_counts.values())
num_classes = len(labels_to_ids)

print("\nLabel distribution:")
for label_name, label_id in labels_to_ids.items():
    count = label_counts[label_id]
    pct = (count / total_counts) * 100 if total_counts > 0 else 0
    print(f"  {label_name}: {count:,} ({pct:.2f}%)")

# Inverse frequency weights with boosting (matching DeBERTa)
boost_factors = {
    0: 1.0,   # O
    1: 1.2,   # B-ROUTE
    2: 1.5,   # I-ROUTE
    3: 1.5,   # B-DIRECTION
    4: 1.5    # I-DIRECTION
}

class_weights = []
for i in range(num_classes):
    count = label_counts[i]
    if count > 0:
        weight = total_counts / (num_classes * count)
        weight *= boost_factors.get(i, 1.0)
    else:
        weight = 1.0
    class_weights.append(weight)

class_weights = torch.tensor(class_weights, dtype=torch.float)

print("\nFinal Class Weights:")
for label_name, label_id in labels_to_ids.items():
    print(f"  {label_name}: {class_weights[label_id]:.3f}")

Calculating class weights from training set...


Counting labels:   0%|          | 0/158312 [00:00<?, ?it/s]


Label distribution:
  O: 3,672,409 (87.88%)
  B-ROUTE: 259,088 (6.20%)
  I-ROUTE: 8,534 (0.20%)
  B-DIRECTION: 144,588 (3.46%)
  I-DIRECTION: 94,390 (2.26%)

Final Class Weights:
  O: 0.228
  B-ROUTE: 3.871
  I-ROUTE: 146.907
  B-DIRECTION: 8.671
  I-DIRECTION: 13.282


In [31]:
# Device setup
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
print(f"Using device: {device}")

# Initialize model
model = BiLSTMCRFNER(
    vocab_size=len(word2idx),
    char_vocab_size=len(char2idx),
    num_labels=len(labels_to_ids),
    word_embedding_dim=WORD_EMBEDDING_DIM,
    char_embedding_dim=50,
    hidden_dim=256,
    num_layers=2,
    dropout=0.3,
    pretrained_embeddings=pretrained_embeddings,
)
model.to(device)

# Optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=2
)

# Loss function (for non-CRF fallback)
loss_fct = nn.CrossEntropyLoss(weight=class_weights.to(device), ignore_index=IGNORE_INDEX)

# Training config
epochs = 3
patience = 5
grad_clip = 5.0

print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f"\nTraining config:")
print(f"  Epochs: {epochs}")
print(f"  Learning rate: {optimizer.param_groups[0]['lr']}")
print(f"  Batch size: {batch_size}")
print(f"  Gradient clipping: {grad_clip}")
print(f"  Early stopping patience: {patience}")

Using device: cuda

Model parameters: 3,244,256
Trainable parameters: 3,244,256

Training config:
  Epochs: 3
  Learning rate: 0.001
  Batch size: 64
  Gradient clipping: 5.0
  Early stopping patience: 5


## 7. Training Loop

In [32]:
def train_epoch(model, data_loader, optimizer, device, grad_clip=5.0):
    """Train model for one epoch."""
    model.train()
    total_loss = 0.0

    pbar = tqdm(data_loader, desc="Training")
    for batch in pbar:
        word_ids = batch['word_ids'].to(device)
        char_ids = batch['char_ids'].to(device)
        labels = batch['labels'].to(device)
        lengths = batch['lengths'].to(device)

        optimizer.zero_grad()
        loss = model(word_ids, char_ids, labels, lengths)
        loss.backward()

        if grad_clip > 0:
            nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

        optimizer.step()
        total_loss += loss.item()

        pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    return total_loss / len(data_loader)

def evaluate(model, data_loader, device):
    """Evaluate model and compute span-level metrics."""
    model.eval()
    total_loss = 0.0
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        pbar = tqdm(data_loader, desc="Evaluating")
        for batch in pbar:
            word_ids = batch['word_ids'].to(device)
            char_ids = batch['char_ids'].to(device)
            labels = batch['labels'].to(device)
            lengths = batch['lengths'].to(device)

            loss = model(word_ids, char_ids, labels, lengths)
            total_loss += loss.item()

            # Decode predictions
            predictions = model.decode(word_ids, char_ids, lengths)
            all_predictions.extend(predictions)

            # Extract gold labels
            for i, length in enumerate(lengths.tolist()):
                gold = labels[i, :length].tolist()
                all_labels.append(gold)

    # Convert to label strings for seqeval
    pred_labels = [[ids_to_labels[idx] for idx in seq] for seq in all_predictions]
    true_labels = [[ids_to_labels[idx] for idx in seq] for seq in all_labels]

    # Compute metrics
    avg_loss = total_loss / len(data_loader)
    f1 = f1_score(true_labels, pred_labels)
    precision = precision_score(true_labels, pred_labels)
    recall = recall_score(true_labels, pred_labels)
    report = classification_report(true_labels, pred_labels)

    return avg_loss, f1, precision, recall, report

print("✓ Training functions defined")

✓ Training functions defined


In [33]:
# Training loop with early stopping
best_f1 = 0
patience_counter = 0
best_model_state = None

print("Starting training...\n")

for epoch in range(epochs):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch + 1}/{epochs}")
    print(f"{'='*60}")

    # Train
    train_loss = train_epoch(model, train_loader, optimizer, device, grad_clip)
    print(f"Train Loss: {train_loss:.4f}")

    # Evaluate
    val_loss, val_f1, val_precision, val_recall, val_report = evaluate(model, val_loader, device)

    print(f"\nValidation Results:")
    print(f"  Loss: {val_loss:.4f}")
    print(f"  Precision: {val_precision:.4f}")
    print(f"  Recall: {val_recall:.4f}")
    print(f"  F1: {val_f1:.4f}")
    print(f"\nClassification Report:")
    print(val_report)

    # Learning rate scheduling
    scheduler.step(val_f1)

    # Save best model
    if val_f1 > best_f1:
        best_f1 = val_f1
        patience_counter = 0
        best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        print(f"\n✓ New best model! F1: {best_f1:.4f}")
    else:
        patience_counter += 1
        print(f"\nNo improvement. Patience: {patience_counter}/{patience}")

    # Early stopping
    if patience_counter >= patience:
        print(f"\nEarly stopping triggered at epoch {epoch + 1}")
        break

# Load best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print(f"\n✓ Loaded best model with F1: {best_f1:.4f}")

Starting training...


Epoch 1/3


Training:   0%|          | 0/2474 [00:00<?, ?it/s]

Train Loss: 0.3882


Evaluating:   0%|          | 0/531 [00:00<?, ?it/s]


Validation Results:
  Loss: 0.0851
  Precision: 0.9955
  Recall: 0.9939
  F1: 0.9947

Classification Report:
              precision    recall  f1-score   support

   DIRECTION       1.00      1.00      1.00     29287
       ROUTE       0.99      0.99      0.99     53468

   micro avg       1.00      0.99      0.99     82755
   macro avg       1.00      1.00      1.00     82755
weighted avg       1.00      0.99      0.99     82755


✓ New best model! F1: 0.9947

Epoch 2/3


Training:   0%|          | 0/2474 [00:00<?, ?it/s]

Train Loss: 0.1213


Evaluating:   0%|          | 0/531 [00:00<?, ?it/s]


Validation Results:
  Loss: 0.0811
  Precision: 0.9961
  Recall: 0.9934
  F1: 0.9948

Classification Report:
              precision    recall  f1-score   support

   DIRECTION       1.00      1.00      1.00     29287
       ROUTE       0.99      0.99      0.99     53468

   micro avg       1.00      0.99      0.99     82755
   macro avg       1.00      0.99      1.00     82755
weighted avg       1.00      0.99      0.99     82755


✓ New best model! F1: 0.9948

Epoch 3/3


Training:   0%|          | 0/2474 [00:00<?, ?it/s]

Train Loss: 0.1087


Evaluating:   0%|          | 0/531 [00:00<?, ?it/s]


Validation Results:
  Loss: 0.0745
  Precision: 0.9949
  Recall: 0.9949
  F1: 0.9949

Classification Report:
              precision    recall  f1-score   support

   DIRECTION       1.00      1.00      1.00     29287
       ROUTE       0.99      0.99      0.99     53468

   micro avg       0.99      0.99      0.99     82755
   macro avg       1.00      1.00      1.00     82755
weighted avg       0.99      0.99      0.99     82755


✓ New best model! F1: 0.9949

✓ Loaded best model with F1: 0.9949


## 8. Save Model

In [34]:
# Save model and vocabularies
save_dir = "models/bilstm_ner_best"
os.makedirs(save_dir, exist_ok=True)

# Save model weights
torch.save({
    'model_state_dict': model.state_dict(),
    'word2idx': word2idx,
    'char2idx': char2idx,
    'labels_to_ids': labels_to_ids,
    'ids_to_labels': ids_to_labels,
    'config': {
        'vocab_size': len(word2idx),
        'char_vocab_size': len(char2idx),
        'num_labels': len(labels_to_ids),
        'word_embedding_dim': WORD_EMBEDDING_DIM,
        'char_embedding_dim': 50,
        'hidden_dim': 256,
        'num_layers': 2,
        'dropout': 0.3,
    }
}, os.path.join(save_dir, 'model.pt'))

print(f"✓ Model saved to {save_dir}")

✓ Model saved to models/bilstm_ner_best


## 9. Test Set Evaluation

In [35]:
# Evaluate on test set
print("Evaluating on test set...\n")
test_loss, test_f1, test_precision, test_recall, test_report = evaluate(model, test_loader, device)

print("="*60)
print("TEST SET RESULTS")
print("="*60)
print(f"Loss: {test_loss:.4f}")
print(f"Precision: {test_precision:.4f}")
print(f"Recall: {test_recall:.4f}")
print(f"F1: {test_f1:.4f}")
print(f"\nDetailed Classification Report:")
print(test_report)

Evaluating on test set...



Evaluating:   0%|          | 0/531 [00:00<?, ?it/s]

TEST SET RESULTS
Loss: 0.0732
Precision: 0.9973
Recall: 0.9916
F1: 0.9944

Detailed Classification Report:
              precision    recall  f1-score   support

   DIRECTION       1.00      1.00      1.00     30733
       ROUTE       1.00      0.99      0.99     53084

   micro avg       1.00      0.99      0.99     83817
   macro avg       1.00      0.99      1.00     83817
weighted avg       1.00      0.99      0.99     83817



## 10. Inference Demo

In [36]:
def predict_ner(text: str, model, word2idx, char2idx, device, max_seq_length=128, max_word_length=20):
    """Predict NER tags for a given text."""
    model.eval()

    # Tokenize
    tokens = word_tokenize(text)
    if not tokens:
        return []

    tokens = tokens[:max_seq_length]
    length = len(tokens)

    # Convert to IDs
    word_ids = [word2idx.get(word.lower(), UNK_IDX) for word, _, _ in tokens]

    char_ids = []
    for word, _, _ in tokens:
        word_chars = [char2idx.get(c, CHAR_UNK_IDX) for c in word[:max_word_length]]
        word_chars += [CHAR_PAD_IDX] * (max_word_length - len(word_chars))
        char_ids.append(word_chars)

    # Pad
    pad_len = max_seq_length - length
    word_ids += [PAD_IDX] * pad_len
    char_ids += [[CHAR_PAD_IDX] * max_word_length] * pad_len

    # Convert to tensors
    word_ids = torch.tensor([word_ids], dtype=torch.long).to(device)
    char_ids = torch.tensor([char_ids], dtype=torch.long).to(device)
    lengths = torch.tensor([length], dtype=torch.long).to(device)

    # Predict
    with torch.no_grad():
        predictions = model.decode(word_ids, char_ids, lengths)[0]

    # Format output
    entities = []
    for i, (word, start, end) in enumerate(tokens):
        label = ids_to_labels[predictions[i]]
        if label != 'O':
            entities.append((word, label, start, end))

    return entities

# Test on sample texts
test_texts = [
    "Jamaica-bound J trains are delayed",
    "Southbound Q65 and Q66 buses are running with delays",
    "Manhattan-bound E F trains are running express",
    "Downtown 2 trains are delayed",
    "G trains are running with delays in both directions"
]

print("Inference Examples:")
print("="*60)

for text in test_texts:
    entities = predict_ner(text, model, word2idx, char2idx, device)
    print(f"\nText: {text}")
    if entities:
        print("Entities:")
        for word, label, start, end in entities:
            print(f"  [{start}:{end}] {word} → {label}")
    else:
        print("  No entities found")

Inference Examples:

Text: Jamaica-bound J trains are delayed
Entities:
  [0:7] Jamaica → B-DIRECTION
  [7:8] - → I-DIRECTION
  [8:13] bound → I-DIRECTION
  [14:15] J → B-ROUTE

Text: Southbound Q65 and Q66 buses are running with delays
Entities:
  [0:10] Southbound → B-DIRECTION
  [11:14] Q65 → B-ROUTE
  [19:22] Q66 → B-ROUTE

Text: Manhattan-bound E F trains are running express
Entities:
  [0:9] Manhattan → B-DIRECTION
  [9:10] - → I-DIRECTION
  [10:15] bound → I-DIRECTION
  [16:17] E → B-ROUTE
  [18:19] F → B-ROUTE

Text: Downtown 2 trains are delayed
Entities:
  [0:8] Downtown → B-DIRECTION
  [9:10] 2 → B-ROUTE

Text: G trains are running with delays in both directions
Entities:
  [0:1] G → B-ROUTE
  [36:40] both → B-DIRECTION
  [41:51] directions → I-DIRECTION
