# Baseline Relation Extraction for MTA Transit Alerts

In this notebook i implement a rule based relation extraction system that pairs DIRECTION and ROUTE entities using segment based logic.

## Schema of the extraction

The system works with the following data structure:
- **affected_spans**: `[{"id": 0, "start": X, "end": Y, "type": "ROUTE", "value": "Q"}, ...]`
- **direction_spans**: `[{"id": 0, "start": X, "end": Y, "type": "DIRECTION", "value": "SOUTHBOUND"}, ...]`
- **relations**: `[{"route_span_id": 0, "direction_span_id": 0, "type": "HAS_DIRECTION"}, ...]`

In [1]:
# Import libraries
import pandas as pd
import json
import re
from typing import List, Dict, Tuple, Optional
from collections import defaultdict

## 1. Data Loading and Preprocessing

Functions to load the silver dataset and parse JSON span annotations.

In [2]:
# Load the silver dataset with direction and route spans
def load_silver_data(filepath: str) -> pd.DataFrame:
    print(f"Loading data from {filepath}...")
    df = pd.read_csv(filepath)
    print(f"Loaded {len(df):,} records")
    return df

In [3]:
# Parse JSON spans string into list of span dictionaries
def parse_spans(spans_json: str) -> List[Dict]:
    if pd.isna(spans_json) or spans_json == '[]':
        return []
    try:
        return json.loads(spans_json)
    except (json.JSONDecodeError, TypeError):
        return []

## 2. Span Processing Utilities

Helper functions to manage span IDs and identify segment breaks in text.

In [4]:
# Add sequential IDs to spans if not already present
def add_span_ids(spans: List[Dict], start_id: int = 0) -> List[Dict]:
    result = []
    for i, span in enumerate(spans):
        span_with_id = span.copy()
        if 'id' not in span_with_id:
            span_with_id['id'] = start_id + i
        result.append(span_with_id)
    return result

In [None]:
# Check if text between two entities contains a major segment break
# Major breaks are newlines, parentheses, colons (not followed by time pattern)
def is_major_break(text_between: str) -> bool:
    # Check for newline
    if '\n' in text_between:
        return True
    
    # Check for parentheses
    if '(' in text_between or ')' in text_between:
        return True
    
    # Check for colon not followed by time pattern
    # Time pattern: colon followed by 2 digits ("8:45", "10:30")
    colon_matches = list(re.finditer(r':', text_between))
    for match in colon_matches:
        after_colon = text_between[match.end():]
        # If colon is NOT followed by time pattern (digits), it's a major break
        if not re.match(r'\s*\d{2}', after_colon):
            return True
    
    return False

## 3. Relation Extraction Logic

Core function that implements the rule-based relation extraction algorithm.

### Algorithm Overview:
1. **First pass**: Process entities left-to-right, tracking active direction
   - DIRECTION entities set the active direction
   - ROUTE entities pair with the current active direction
   - Major breaks reset the active direction
2. **Second pass**: Handle unpaired routes by looking forward for directions in the same segment

In [None]:
def extract_relations(
    header: str, 
    direction_spans: List[Dict], 
    route_spans: List[Dict]
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
    # Add IDs to spans
    route_spans_with_ids = add_span_ids(route_spans, start_id=0)
    direction_spans_with_ids = add_span_ids(direction_spans, start_id=0)
    
    if not direction_spans_with_ids or not route_spans_with_ids:
        return route_spans_with_ids, direction_spans_with_ids, []
    
    # Create lookup maps for quick ID access
    route_id_map = {(s['start'], s['end']): s['id'] for s in route_spans_with_ids}
    direction_id_map = {(s['start'], s['end']): s['id'] for s in direction_spans_with_ids}
    
    # Merge all entities for sequential processing
    entities = []
    for span in direction_spans_with_ids:
        entities.append({
            'start': span['start'],
            'end': span['end'],
            'id': span['id'],
            'value': span['value'],
            'entity_type': 'DIRECTION'
        })
    for span in route_spans_with_ids:
        entities.append({
            'start': span['start'],
            'end': span['end'],
            'id': span['id'],
            'value': span['value'],
            'entity_type': 'ROUTE'
        })
    
    # Sort by start position
    entities.sort(key=lambda x: x['start'])
    
    relations = []
    active_direction_id = None
    active_direction_end = None
    
    # Track route_id - direction_id pairings
    route_direction_pairs = {}
    
    # Track unpaired routes for second pass
    unpaired_routes = []
    
    for i, entity in enumerate(entities):
        # Check for major break from previous entity
        if i > 0:
            prev_entity = entities[i - 1]
            text_between = header[prev_entity['end']:entity['start']]
            
            if is_major_break(text_between):
                active_direction_id = None
                active_direction_end = None
        
        if entity['entity_type'] == 'DIRECTION':
            # Update active direction
            active_direction_id = entity['id']
            active_direction_end = entity['end']
        
        elif entity['entity_type'] == 'ROUTE':
            route_id = entity['id']
            
            if active_direction_id is not None:
                route_direction_pairs[route_id] = active_direction_id
            else:
                # Track unpaired route for second pass
                unpaired_routes.append({
                    'route_id': route_id,
                    'route_start': entity['start'],
                    'route_end': entity['end']
                })
    
    # Second pass: assign unpaired routes to the next direction in the same segment
    # This handles patterns like "L trains are delayed in both directions"
    for unpaired in unpaired_routes:
        route_id = unpaired['route_id']
        route_end = unpaired['route_end']
        
        # Find the next direction after this route
        for entity in entities:
            if entity['entity_type'] == 'DIRECTION' and entity['start'] > route_end:
                # Check if there's a segment break between route and this direction
                text_between = header[route_end:entity['start']]
                if not is_major_break(text_between):
                    route_direction_pairs[route_id] = entity['id']
                    break
                else:
                    # Hit a segment break, stop looking
                    break
    
    # Convert pairs to relation format
    for route_id, direction_id in route_direction_pairs.items():
        relations.append({
            'route_span_id': route_id,
            'direction_span_id': direction_id,
            'type': 'HAS_DIRECTION'
        })
    
    # Sort relations by route_span_id for consistent output
    relations.sort(key=lambda x: x['route_span_id'])
    
    return route_spans_with_ids, direction_spans_with_ids, relations

## 4. Dataset Processing

In [None]:
# Process the silver dataset and extract relations for each row
# Updates spans to include IDs and adds relations column
def process_dataset(input_path: str, output_path: str) -> pd.DataFrame:
    df = load_silver_data(input_path)
    
    print("Extracting relations...")
    updated_route_spans_list = []
    updated_direction_spans_list = []
    relations_list = []
    relation_names_list = []
    
    for idx, row in df.iterrows():
        if (idx + 1) % 50000 == 0:
            print(f"  Processed {idx + 1:,} / {len(df):,} records...")
        
        header = str(row['header']) if pd.notna(row['header']) else ""
        direction_spans = parse_spans(row.get('direction_spans', '[]'))
        route_spans = parse_spans(row.get('affected_spans', '[]'))
        
        updated_routes, updated_directions, relations = extract_relations(
            header, direction_spans, route_spans
        )
        
        # Create relation names by looking up the actual route and direction values
        relation_names = []
        # Build lookup maps
        route_id_to_value = {s['id']: s['value'] for s in updated_routes}
        direction_id_to_value = {s['id']: s['value'] for s in updated_directions}
        
        for rel in relations:
            route_value = route_id_to_value.get(rel['route_span_id'], 'UNKNOWN')
            direction_value = direction_id_to_value.get(rel['direction_span_id'], 'UNKNOWN')
            relation_names.append({
                'route': route_value,
                'direction': direction_value
            })
        
        updated_route_spans_list.append(json.dumps(updated_routes))
        updated_direction_spans_list.append(json.dumps(updated_directions))
        relation_names_list.append(json.dumps(relation_names))
        relations_list.append(json.dumps(relations))
    
    # Update dataframe with ID enhanced spans and relations
    df['affected_spans'] = updated_route_spans_list
    df['direction_spans'] = updated_direction_spans_list
    df['relation_names'] = relation_names_list
    df['relations'] = relations_list
    
    print(f"\nWriting output to {output_path}...")
    df.to_csv(output_path, index=False)
    print(f"Successfully wrote {len(df):,} records to {output_path}")
    
    return df

## 5. Statistical Analysis

Generate statistics on the extracted relations.

In [None]:
# Print compact EDA statistics for the extracted relations
def print_eda_stats(df: pd.DataFrame):
    print("Stats")
    
    # Parse relations and spans for analysis
    relations_counts = []
    direction_type_counts = defaultdict(int)
    total_relations = 0
    rows_with_relations = 0
    
    for idx, row in df.iterrows():
        relations = json.loads(row['relations']) if row['relations'] else []
        direction_spans = json.loads(row['direction_spans']) if row['direction_spans'] else []
        
        # Build direction_id -> value map
        dir_id_to_value = {s['id']: s['value'] for s in direction_spans}
        
        count = len(relations)
        relations_counts.append(count)
        total_relations += count
        
        if count > 0:
            rows_with_relations += 1
        
        for rel in relations:
            direction_id = rel['direction_span_id']
            direction_value = dir_id_to_value.get(direction_id, 'UNKNOWN')
            direction_type_counts[direction_value] += 1
    
    # Basic stats
    print(f"\nTotal rows: {len(df):,}")
    print(f"Rows with relations: {rows_with_relations:,} ({100*rows_with_relations/len(df):.1f}%)")
    print(f"Total relation pairs: {total_relations:,}")
    
    # Single vs multi relation distribution
    single_relation = sum(1 for c in relations_counts if c == 1)
    multi_relation = sum(1 for c in relations_counts if c > 1)
    print(f"\nSingle relation rows: {single_relation:,}")
    print(f"Multi relation rows: {multi_relation:,}")
    
    # Direction type distribution
    print(f"\nDirection Type Distribution:")
    for direction, count in sorted(direction_type_counts.items(), key=lambda x: -x[1]):
        pct = 100 * count / total_relations if total_relations > 0 else 0
        print(f"  {direction}: {count:,} ({pct:.1f}%)")

## 6. Main Execution

Run the baseline relation extraction pipeline on the silver dataset.

In [9]:
input_path = 'Preprocessed/MTA_Data_silver_directions.csv'
output_path = 'Preprocessed/MTA_Data_silver_relations.csv'

df = process_dataset(input_path, output_path)
print_eda_stats(df)

Loading data from Preprocessed/MTA_Data_silver_directions.csv...
Loaded 226,160 records
Extracting relations...
  Processed 50,000 / 226,160 records...
  Processed 100,000 / 226,160 records...
  Processed 150,000 / 226,160 records...
  Processed 200,000 / 226,160 records...

Writing output to Preprocessed/MTA_Data_silver_relations.csv...
Successfully wrote 226,160 records to Preprocessed/MTA_Data_silver_relations.csv

RELATION EXTRACTION STATISTICS

Total rows: 226,160
Rows with relations: 179,548 (79.4%)
Total relation pairs: 293,474

Single-relation rows: 102,198
Multi-relation rows: 77,350

Direction Type Distribution:
  SOUTHBOUND: 89,350 (30.4%)
  NORTHBOUND: 87,501 (29.8%)
  BOTH_DIRECTIONS: 63,035 (21.5%)
  PLACE_BOUND: 27,173 (9.3%)
  UPTOWN: 4,746 (1.6%)
  DOWNTOWN: 4,668 (1.6%)
  WESTBOUND: 4,091 (1.4%)
  EASTBOUND: 3,978 (1.4%)
  MANHATTAN_BOUND: 3,784 (1.3%)
  BROOKLYN_BOUND: 2,059 (0.7%)
  QUEENS_BOUND: 1,923 (0.7%)
  BRONX_BOUND: 898 (0.3%)
  STATENISLAND_BOUND: 268 (0.1%