## 1. Imports <a id='imports'></a>

Import the required libraries for data processing, regex matching, and JSON handling.

In [1]:
import pandas as pd
import re
import json
from typing import List, Dict, Set, Tuple

## 2. Load Transit Codes <a id='load-transit-codes'></a>

Functions to load valid bus and subway codes from CSV files.

In [2]:
def load_bus_codes(filepath: str = "../Data/bus_codes.csv") -> Set[str]:
    # Load valid bus codes from CSV
    df = pd.read_csv(filepath)
    return set(df['route_short_name'].astype(str).tolist())

In [3]:
def load_subway_codes(filepath: str = "../Data/subway_codes.csv") -> Tuple[Set[str], Set[str], Set[str]]:
    # Returns (letter_codes, digit_codes, special_codes)
    all_codes = set(pd.read_csv(filepath, sep=';')['route_short_name'].astype(str).tolist())
    letter_codes = {c for c in all_codes if len(c) == 1 and c.isalpha()}
    digit_codes = {c for c in all_codes if len(c) == 1 and c.isdigit()}
    special_codes = all_codes - letter_codes - digit_codes
    return letter_codes, digit_codes, special_codes

## 3. Parsing Utilities <a id='parsing-utilities'></a>

Helper function to parse the Affected column JSON into a list of codes.

In [4]:
def parse_affected_column(affected_str: str) -> List[str]:
    # Parse Affected column JSON to list of codes
    try:
        if pd.isna(affected_str):
            return []
        return json.loads(affected_str)
    except (json.JSONDecodeError, TypeError):
        return []

## 4. Bus Code Extraction <a id='bus-code-extraction'></a>

Extract bus code spans from text (e.g., Q65, B44-SBS, BxM10).

In [5]:
def extract_bus_code_spans(text: str, valid_bus_codes: Set[str], affected_codes: List[str]) -> List[Dict]:
    # Extract bus code spans (Q65, B44-SBS, BxM10, etc.)
    spans = []
    
    valid_bus_codes_upper = {code.upper() for code in valid_bus_codes}
    affected_codes_upper = {code.upper() for code in affected_codes}
    
    # Bus pattern: B, Bx, BxM, M, Q, QM, S, SIM, X, BM prefixes with optional -SBS suffix
    bus_pattern = r'\b(B(?:x(?:M)?)?|M|Q(?:M)?|S(?:IM)?|X|BM)\d+[A-Z]?(?:-SBS)?\b'
    
    for match in re.finditer(bus_pattern, text, re.IGNORECASE):
        code = match.group(0)
        code_upper = code.upper()
        base_code = code_upper.replace('-SBS', '')  # Normalize: Q44-SBS -> Q44
        if code_upper in valid_bus_codes_upper or base_code in valid_bus_codes_upper or base_code in affected_codes_upper:
            spans.append({
                "start": match.start(),
                "end": match.end(),
                "label": "TRANSIT_CODE",
                "text": code
            })
    
    return spans

## 5. Subway Code Extraction <a id='subway-code-extraction'></a>

Functions to extract different types of subway codes:
- **Letter codes**: Single-letter subway lines (A-G, J, L, M, N, Q, R, W, Z)
- **Special codes**: SIR, 5X, 6X, 7X, FX, FS, GS, S
- **Digit codes**: Single-digit subway lines (1-7)

### 5.1 Letter Subway Codes

In [6]:
def extract_letter_subway_spans(text: str, letter_codes: Set[str], affected_codes: List[str]) -> List[Dict]:
    # Extract single-letter subway codes, filtering false positives like 'E 149 St'
    spans = []
    
    affected_letters = set(code for code in affected_codes if code in letter_codes)
    
    # Valid subway letters: A-G, H, J, L, M, N, Q, R, W, Z
    letter_pattern = r'\b([A-GHJLMNQRWZ])\b'
    
    for match in re.finditer(letter_pattern, text):
        letter = match.group(1)
        if letter in letter_codes:
            after_text = text[match.end():match.end()+20]
            before_text = text[max(0, match.start()-20):match.start()]
            
            # Skip E/W/N/S if followed by street number (e.g. "E 149 St")
            if letter in {'E', 'W', 'N', 'S'} and re.match(r'\s+\d+\s*(St|Av|Ave|Street|Avenue)\b', after_text, re.IGNORECASE):
                if letter not in affected_letters:
                    continue
            
            if letter in affected_letters:
                spans.append({
                    "start": match.start(),
                    "end": match.end(),
                    "label": "TRANSIT_CODE",
                    "text": letter
                })
    
    return spans

### 5.2 Special Subway Codes

In [7]:
def extract_special_subway_spans(text: str, special_codes: Set[str], affected_codes: List[str]) -> List[Dict]:
    # Extract special codes: SIR, 5X, 6X, 7X, FX, FS, GS, S
    spans = []
    
    # Pattern -> normalized code mappings
    patterns = [
        (r'\b([567]X|FX|FS)\b', None),  # None = use matched text
        (r'\b(SIR)\b', 'SI'),
        (r'\b(Franklin\s+Av(?:enue)?\s+Shuttle)\b', 'FS'),
        (r'\b(Rockaway\s+Park\s+Shuttle)\b', 'H'),
        (r'\b(42\s*St\s+Shuttle)\b', 'GS'),
    ]
    
    for pattern, normalized in patterns:
        for match in re.finditer(pattern, text, re.IGNORECASE):
            code = normalized if normalized else match.group(0)
            if normalized or code in special_codes:
                spans.append({"start": match.start(), "end": match.end(), "label": "TRANSIT_CODE", "text": code})
    
    # S train/shuttle (captures only 'S')
    for match in re.finditer(r'\b(S)\s+(?:train|trains|shuttle)\b', text, re.IGNORECASE):
        spans.append({"start": match.start(1), "end": match.end(1), "label": "TRANSIT_CODE", "text": "S"})
    
    return spans

### 5.3 Digit Subway Codes

In [8]:
def extract_digit_subway_spans(text: str, digit_codes: Set[str], affected_codes: List[str]) -> List[Dict]:
    # Extract single-digit subway codes with context-aware rules
    spans = []
    matched_digits = set()
    affected_digits = set(code for code in affected_codes if code in digit_codes)
    
    def add_span(start, end, digit):
        if not any(s['start'] == start and s['end'] == end for s in spans):
            spans.append({"start": start, "end": end, "label": "TRANSIT_CODE", "text": digit})
            matched_digits.add(digit)
    
    # Context patterns for digits
    patterns = [
        r'\b([1-7])\s+(?:train|trains|service|line)\b',  # "2 trains"
        r'(?:train|trains|service|line)\s+([1-7])\b',    # "train 2"
        r'(?:Northbound|Southbound|Eastbound|Westbound|Uptown|Downtown)\s+([1-7])\b',
        r'\b[\w-]+-bound\s+([1-7])\b',
        r'\b([1-7])\s+(?=[\w\s]*(?:Northbound|Southbound|Eastbound|Westbound|bound))',
        r'(?:^|[.!?]\s+)([1-7])\b',  # Sentence start
    ]
    
    for pattern in patterns:
        for match in re.finditer(pattern, text, re.IGNORECASE):
            digit = match.group(1)
            if digit in digit_codes:
                add_span(match.start(1), match.end(1), digit)
    
    # Fallback: unmatched affected digits (skip street numbers)
    for digit in affected_digits - matched_digits:
        for match in re.finditer(rf'\b{digit}\b', text):
            before, after = text[max(0, match.start()-3):match.start()], text[match.end():match.end()+5]
            if re.search(r'^\s*(St|Av|Ave|Street|Avenue)\b', after, re.IGNORECASE): continue
            if re.search(r'\d$', before) or re.search(r'^\d', after): continue
            add_span(match.start(), match.end(), digit)
    
    return spans

## 6. Combined Span Extraction <a id='combined-span-extraction'></a>

Main function that combines all extraction methods based on agency type.

In [9]:
def extract_all_spans(text: str, agency: str, affected_codes: List[str],
                      valid_bus_codes: Set[str], letter_codes: Set[str],
                      digit_codes: Set[str], special_codes: Set[str]) -> List[Dict]:
    # Extract all transit code spans based on agency type
    spans = []
    all_codes = valid_bus_codes | letter_codes | digit_codes | special_codes
    
    if agency == "NYCT Bus":
        spans.extend(extract_bus_code_spans(text, valid_bus_codes, affected_codes))
    elif agency == "NYCT Subway":
        spans.extend(extract_letter_subway_spans(text, letter_codes, affected_codes))
        spans.extend(extract_special_subway_spans(text, special_codes, affected_codes))
        spans.extend(extract_digit_subway_spans(text, digit_codes, affected_codes))
    
    # Slash-separated codes (e.g. "2/3 trains")
    for match in re.finditer(r'\b([A-Z1-7](?:/[A-Z1-7])+)\s+(?:train|trains|bus|buses|line|service)\b', text, re.IGNORECASE):
        pos = match.start(1)
        for code in match.group(1).split('/'):
            if code in all_codes:
                spans.append({"start": pos, "end": pos + len(code), "label": "TRANSIT_CODE", "text": code})
            pos += len(code) + 1
    
    # Remove duplicates and sort by position
    seen = set()
    return [s for s in sorted(spans, key=lambda x: x['start']) if (s['start'], s['end']) not in seen and not seen.add((s['start'], s['end']))]

## 7. Normalization Utilities <a id='normalization-utilities'></a>

Function to normalize transit codes (remove -SBS, uppercase, remove leading zeros).

In [10]:
def normalize_code(code: str) -> str:
    # Normalize: remove -SBS, uppercase, remove leading zeros (Q044 -> Q44)
    n = code.replace('-SBS', '').upper()
    m = re.match(r'^([A-Z]+)0*(\d+)([A-Z]?)$', n)
    return f"{m.group(1)}{int(m.group(2))}{m.group(3)}" if m else n

## 8. Dataset Processing <a id='dataset-processing'></a>

Main processing function that:
1. Loads the input dataset
2. Extracts spans for each row
3. Updates the Affected column
4. Removes rows with no extracted codes
5. Saves the annotated dataset

In [11]:
def process_dataset(input_path: str, output_path: str):
    # Process dataset: extract spans, update Affected column, remove rows with no codes
    print(f"Loading dataset from {input_path}...")
    df = pd.read_csv(input_path)
    
    print("Loading transit code files...")
    valid_bus_codes = load_bus_codes()
    letter_codes, digit_codes, special_codes = load_subway_codes()
    
    span_annotations = []
    updated_affected = []
    rows_to_keep = []
    
    total_rows = len(df)
    print(f"Processing {total_rows} rows...")
    
    rows_removed_no_codes = 0
    rows_affected_updated = 0
    
    for idx, row in df.iterrows():
        header = str(row['header']) if pd.notna(row['header']) else ""
        agency = row['agency']
        affected_codes = parse_affected_column(row['affected'])
        
        spans = extract_all_spans(
            header, agency, affected_codes,
            valid_bus_codes, letter_codes, digit_codes, special_codes
        )
        
        # Get unique extracted codes in order of occurrence
        seen = set()
        extracted_codes = []
        for s in sorted(spans, key=lambda x: x['start']):
            if s['text'] not in seen:
                seen.add(s['text'])
                extracted_codes.append(s['text'])
        
        # Remove rows with no extracted codes
        if len(extracted_codes) == 0:
            rows_removed_no_codes += 1
            continue
        
        rows_to_keep.append(idx)
        
        # Check if Affected column was updated
        original_normalized = {normalize_code(c) for c in affected_codes}
        new_normalized = {normalize_code(c) for c in extracted_codes}
        
        if original_normalized != new_normalized:
            rows_affected_updated += 1
        
        output_spans = [{"start": s['start'], "end": s['end'], "type": "ROUTE", "value": s['text']} for s in spans]
        span_annotations.append(json.dumps(output_spans))
        updated_affected.append(json.dumps(extracted_codes))
    
    df = df.loc[rows_to_keep].copy()
    df = df.reset_index(drop=True)
    
    df['affected'] = updated_affected
    df['affected_spans'] = span_annotations
    
    # Reorder: place affected_spans after affected
    cols = [c for c in df.columns if c != 'affected_spans']
    cols.insert(cols.index('affected') + 1, 'affected_spans')
    df = df[cols]
    
    final_rows = len(df)
    
    print(f"Processing complete!")
    print(f"Total rows in input: {total_rows}")
    print(f"Rows removed (no extracted codes): {rows_removed_no_codes} ({100*rows_removed_no_codes/total_rows:.2f}%)")
    print(f"Rows with updated Affected column: {rows_affected_updated} ({100*rows_affected_updated/final_rows:.2f}%)")
    print(f"Final rows in output: {final_rows}")
    df.to_csv(output_path, index=False)
    
    return df

## 9. Run Pipeline <a id='run-pipeline'></a>

Execute the span labeling pipeline on the MTA dataset.

In [12]:
# Define input and output paths
input_path = "../Preprocessed/MTA_Data_preprocessed.csv"
output_path = "MTA_Data_preprocessed_routespans.csv"

# Process the dataset
df = process_dataset(input_path, output_path)

Loading dataset from ../Preprocessed/MTA_Data_preprocessed.csv...
Loading transit code files...
Processing 227210 rows...
Loading transit code files...
Processing 227210 rows...
Processing complete!
Total rows in input: 227210
Rows removed (no extracted codes): 1050 (0.46%)
Rows with updated Affected column: 8988 (3.97%)
Final rows in output: 226160
Processing complete!
Total rows in input: 227210
Rows removed (no extracted codes): 1050 (0.46%)
Rows with updated Affected column: 8988 (3.97%)
Final rows in output: 226160


In [13]:
# Display sample results
print("Sample of processed data:")
df.head()

Sample of processed data:


Unnamed: 0,alert_id,date,agency,status_label,affected,affected_spans,header
0,180128,11/05/2022 05:58:00 PM,NYCT Subway,delays,"[""A"", ""C""]","[{""start"": 0, ""end"": 1, ""type"": ""ROUTE"", ""valu...",A C trains are delayed while we conduct emerge...
1,189489,12/20/2022 07:09:00 PM,NYCT Subway,delays,"[""L""]","[{""start"": 0, ""end"": 1, ""type"": ""ROUTE"", ""valu...",L trains are running with delays in both direc...
2,189321,12/20/2022 12:31:00 AM,NYCT Subway,delays,"[""J""]","[{""start"": 14, ""end"": 15, ""type"": ""ROUTE"", ""va...",Jamaica-bound J trains are delayed while we re...
3,188948,12/18/2022 06:12:00 AM,NYCT Subway,delays,"[""Q""]","[{""start"": 11, ""end"": 12, ""type"": ""ROUTE"", ""va...",Southbound Q trains are running with delays af...
4,187749,12/12/2022 02:26:00 PM,NYCT Subway,delays,"[""B"", ""C""]","[{""start"": 11, ""end"": 12, ""type"": ""ROUTE"", ""va...",Southbound B C trains are running with delays ...
