# Delay and Weather Correlation Analysis
Analyzing the relationship between precipitation and train delays using nearest weather stations to the train stations using Meteostat data over two months.

In [1]:
import pandas as pd
import sqlite3
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import warnings
import ast
from sklearn.neighbors import BallTree

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Set plot style
sns.set_theme(style="whitegrid")

In [2]:
# Database connection
db_path = '../data/amse.sqlite'
conn = sqlite3.connect(db_path)

print("Loading Train Station Data...")
# Load train station coordinates and weather stations for matching
try:
    # Load geo data for stations
    stations_df = pd.read_sql_query("SELECT * FROM geo_data_train_stations", conn)
    # Ensure eva is string
    if 'evaNumbers' in stations_df.columns:
        stations_df['eva'] = stations_df['evaNumbers'].astype(str)
    
    # Drop invalid coordinates
    stations_df = stations_df.dropna(subset=['latitude', 'longitude'])
    
    # Load weather stations
    print("Loading Weather Station Data...")
    weather_stations = pd.read_sql_query("SELECT * FROM weather_stations", conn)
    
    # Parse location column if it exists (JSON/string dict)
    if 'location' in weather_stations.columns:
        weather_stations['location'] = weather_stations['location'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
        weather_stations['longitude'] = weather_stations['location'].apply(lambda x: x.get('longitude'))
        weather_stations['latitude'] = weather_stations['location'].apply(lambda x: x.get('latitude'))
    
    weather_stations = weather_stations.dropna(subset=['latitude', 'longitude'])
    
    # --- Nearest Neighbor Matching ---
    print("Matching Train Stations to Nearest Weather Stations...")
    
    # Convert to radians for BallTree
    train_rad = np.deg2rad(stations_df[['latitude', 'longitude']].values)
    weather_rad = np.deg2rad(weather_stations[['latitude', 'longitude']].values)
    
    # Build Tree
    tree = BallTree(weather_rad, metric='haversine')
    
    # Query nearest neighbor
    dist, ind = tree.query(train_rad, k=1)
    
    # Assign nearest weather station
    stations_df['weather_station_id'] = weather_stations.iloc[ind.flatten()]['id'].values
    # Check if distance is reasonable (e.g., < 50km) - can be used for filtering
    stations_df['distance_km'] = dist.flatten() * 6371 # Earth radius
    
    print(f"Matched {len(stations_df)} stations. Avg distance: {stations_df['distance_km'].mean():.2f} km")
    
except Exception as e:
    print(f"Error in station matching: {e}")
    # Create empty DF to prevent crashing
    stations_df = pd.DataFrame(columns=['eva', 'name', 'latitude', 'longitude', 'weather_station_id'])

# Load Train Datasets
train_stats_list = []


Loading Train Station Data...
Loading Weather Station Data...
Matching Train Stations to Nearest Weather Stations...
Matched 146 stations. Avg distance: 5.65 km


In [3]:
print("Loading Train Timetables...")
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'timetables%'")
timetable_tables = [row[0] for row in cursor.fetchall()]

for table in timetable_tables:
    try:
        # Load data
        query = f'SELECT "eva", "pt", "ct" FROM "{table}"' 
        df = pd.read_sql_query(query, conn)
        
        # Parse timestamps
        df['pt'] = pd.to_datetime(df['pt'], format='%d.%m.%Y, %H:%M', errors='coerce')
        df['ct'] = pd.to_datetime(df['ct'], format='%d.%m.%Y, %H:%M', errors='coerce')
        df = df.dropna(subset=['pt'])
        df['ct'] = df['ct'].fillna(df['pt'])
        
        # Delay
        df['delay_minutes'] = (df['ct'] - df['pt']).dt.total_seconds() / 60.0
        
        # Date Hour
        df['date_hour'] = df['pt'].dt.floor('H')
        
        # Aggregate per station-hour
        # We need to keep eva to join with station info
        station_hourly = df.groupby(['date_hour', 'eva']).agg(
            total_trains=('eva', 'count'),
            total_delay=('delay_minutes', 'sum'),
            avg_delay=('delay_minutes', 'mean')
        ).reset_index()
        
        train_stats_list.append(station_hourly)
        
    except Exception as e:
        print(f"Error processing table {table}: {e}")

# Combine all
if train_stats_list:
    train_data = pd.concat(train_stats_list, ignore_index=True)
    
    # Aggregate again in case of splits across tables
    train_data = train_data.groupby(['date_hour', 'eva']).agg(
        total_trains=('total_trains', 'sum'),
        total_delay=('total_delay', 'sum')
    ).reset_index()
    train_data['avg_delay'] = train_data['total_delay'] / train_data['total_trains']
    
    # Convert eva
    train_data['eva'] = train_data['eva'].astype(str)
    
    # Merge with Stations (including matched weather station)
    # Using inner join to strictly keep only stations where we have weather mapping
    train_merged = pd.merge(train_data, stations_df[['eva', 'name', 'latitude', 'longitude', 'weather_station_id', 'distance_km']], on='eva', how='inner')
    
    print(f"Train data with station info: {len(train_merged)} records")
else:
    print("No train data loaded.")
    train_merged = pd.DataFrame()


Loading Train Timetables...
Train data with station info: 72907 records


In [5]:
# Load Weather Data
print("Loading Weather Data to match...")
try:
    weather_df = pd.read_sql_query("SELECT * FROM weather", conn)
    
    # Parse date
    weather_df['date'] = pd.to_datetime(weather_df['date'])
    
    # Create date_hour
    weather_df['date_hour'] = weather_df.apply(
        lambda row: row['date'] + pd.Timedelta(hours=row['hour']), axis=1
    )
    
    # Ensure station is string/key matches
    # weather table has 'station' column corresponding to ID
    weather_df['station'] = weather_df['station'].astype(str)
    
    # We rename to facilitate merge or just merge on correct columns
    # We want to keep relevant weather metrics
    # Ensure columns exist (handling potential variations in schema)
    cols = ['date_hour', 'station', 'prcp', 'temp', 'wspd']
    available_cols = [c for c in cols if c in weather_df.columns]
    weather_relevant = weather_df[available_cols]
    
    print(f"Weather data loaded: {len(weather_relevant)} records")

    # Merge Train and Weather on PROPER (Nearest) Station
    # train_merged has 'weather_station_id'
    # weather_relevant has 'station'
    
    full_data = pd.merge(
        train_merged, 
        weather_relevant, 
        left_on=['date_hour', 'weather_station_id'], 
        right_on=['date_hour', 'station'], 
        how='inner'
    )
    
    # Rename for clarity
    if 'prcp' in full_data.columns:
        full_data.rename(columns={'prcp': 'hourly_prcp'}, inplace=True)
    
    print(f"Merged Dataset (Train + Nearest Weather): {len(full_data)} records")

    if not full_data.empty:
        # --- Pre-calculate Metrics for Plots ---
        
        # 1. Spatial Correlation Data
        # valid stations (enough data)
        valid_stations = full_data.groupby('name').filter(lambda x: len(x) > 10)
        valid_stations = valid_stations.groupby('name').filter(lambda x: x['hourly_prcp'].std() > 0 and x['avg_delay'].std() > 0)
        
        corr_series = valid_stations.groupby('name').apply(lambda x: x['avg_delay'].corr(x['hourly_prcp']))
        corr_df = corr_series.reset_index(name='correlation')
        
        counts = valid_stations['name'].value_counts().reset_index()
        counts.columns = ['name', 'count']
        
        tr_stations = train_merged[['name', 'latitude', 'longitude']].drop_duplicates()
        corr_df = pd.merge(corr_df, tr_stations, on='name')
        corr_df = pd.merge(corr_df, counts, on='name')
        
        # 2. Daily Time Series Data
        daily_agg = full_data.set_index('date_hour').resample('D').agg({
            'avg_delay': 'mean',
            'hourly_prcp': 'mean'
        }).reset_index()
                
        # --- Visualizations ---
        
        # A. Maps (Smaller & Portrait)
        print("Generating Maps...")
        
        # Combined Map Prep
        tr_stations['type'] = 'Train Station'
        used_weather_ids = full_data['weather_station_id'].unique()
        w_stations_all = pd.read_sql_query("SELECT * FROM weather_stations", conn)
        if 'location' in w_stations_all.columns:
            w_stations_all['location'] = w_stations_all['location'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
            w_stations_all['longitude'] = w_stations_all['location'].apply(lambda x: x.get('longitude'))
            w_stations_all['latitude'] = w_stations_all['location'].apply(lambda x: x.get('latitude'))
        w_st_used = w_stations_all[w_stations_all['id'].astype(str).isin(used_weather_ids.astype(str))].copy()
        w_st_used['name'] = w_st_used['id'].astype(str)
        w_st_used['type'] = 'Weather Station'
        map_df = pd.concat([tr_stations[['name', 'latitude', 'longitude', 'type']], w_st_used[['name', 'latitude', 'longitude', 'type']]], ignore_index=True)
        
        # Fig 1: Combined Map
        fig_map = px.scatter_mapbox(
            map_df, lat="latitude", lon="longitude", color="type", zoom=5, 
            title="Train & Weather Stations",
            color_discrete_map={'Train Station': 'blue', 'Weather Station': 'red'},
            height=500, width=400 # Portrait
        )
        fig_map.update_layout(mapbox_style="open-street-map", margin={"r":5,"t":40,"l":5,"b":5})
        fig_map.show()
        
        # Fig 2: Spatial Correlation Map
        limit = max(abs(corr_df['correlation'].min()), abs(corr_df['correlation'].max())) if not corr_df.empty else 0.5
        corr_df["abs_correlation"] = corr_df["correlation"].abs()
        fig_corr = px.scatter_mapbox(
            corr_df, lat="latitude", lon="longitude", color="correlation", size="abs_correlation",
            color_continuous_scale="RdBu_r", range_color=[-limit, limit], zoom=5,
            hover_data={'correlation': ':.2f', 'count': True},
            title="Spatial Correlation",
            height=500, width=400 # Portrait
        )
        fig_corr.update_layout(mapbox_style="open-street-map", margin={"r":5,"t":40,"l":5,"b":5})
        fig_corr.show()
        
        # Fig 3: Daily Animation Map
        print("Generating Daily Animation...")
        full_data['date_str'] = full_data['date_hour'].dt.date.astype(str)
        def calc_corr_daily(g):
            if len(g) > 10 and g['hourly_prcp'].std() > 0 and g['avg_delay'].std() > 0:
                return g['avg_delay'].corr(g['hourly_prcp'])
            return np.nan
        daily_corr_map = full_data.groupby(['date_str', 'name']).apply(calc_corr_daily).reset_index(name='correlation')
        daily_corr_map = daily_corr_map.dropna(subset=['correlation'])
        daily_corr_map["abs_correlation"] = daily_corr_map["correlation"].abs()
        daily_counts_map = full_data.groupby(['date_str', 'name']).size().reset_index(name='count')
        daily_corr_map = pd.merge(daily_corr_map, tr_stations[['name', 'latitude', 'longitude']], on='name')
        daily_corr_map = pd.merge(daily_corr_map, daily_counts_map, on=['date_str', 'name'])
        daily_corr_map = daily_corr_map.sort_values('date_str')
        
        if not daily_corr_map.empty:
            limit_anim = max(abs(daily_corr_map['correlation'].min()), abs(daily_corr_map['correlation'].max()))
            fig_anim = px.scatter_mapbox(
                daily_corr_map, lat="latitude", lon="longitude", color="correlation", size="abs_correlation",
                animation_frame="date_str", animation_group="name",
                color_continuous_scale="RdBu_r", range_color=[-limit_anim, limit_anim], zoom=4,
                title="Daily Correlation",
                height=500, width=400 # Portrait
            )
            fig_anim.update_layout(mapbox_style="open-street-map", margin={"r":5,"t":40,"l":5,"b":5})
            fig_anim.layout.updatemenus[0].buttons[0].args[1]['frame']['duration'] = 800
            fig_anim.show()
        
        # B. Combined Plots (Distribution, Scatter, Time Series)
        print("Generating Analysis Plots...")
        fig_plots = make_subplots(
            rows=1, cols=3, 
            subplot_titles=("Correlation Distribution", "Rain vs Delay", "Daily Trends"),
            specs=[[{}, {}, {"secondary_y": True}]]
        )
        
        # 1. Distribution
        fig_plots.add_trace(go.Histogram(x=corr_df['correlation'], name='Corr', marker_color='purple'), row=1, col=1)
        
        # 2. Scatter
        fig_plots.add_trace(go.Scatter(x=full_data['hourly_prcp'], y=full_data['avg_delay'], mode='markers', marker=dict(opacity=0.3, size=3), name='Points'), row=1, col=2)
               
        # 3. Time Series
        fig_plots.add_trace(go.Scatter(x=daily_agg['date_hour'], y=daily_agg['avg_delay'], name="Delay", line=dict(color='red')), row=1, col=3, secondary_y=False)
        fig_plots.add_trace(go.Bar(x=daily_agg['date_hour'], y=daily_agg['hourly_prcp'], name="Precipitation", marker_color='blue', opacity=0.3), row=1, col=3, secondary_y=True)
        
        fig_plots.update_layout(height=400, showlegend=False, title_text="Detailed Correlation Analysis",template="plotly_white")
        # Update axes titles
        fig_plots.update_xaxes(title_text="Correlation", row=1, col=1)
        fig_plots.update_xaxes(title_text="Precipitation (mm)", row=1, col=2)
        fig_plots.update_yaxes(title_text="Count", row=1, col=1)
        fig_plots.update_yaxes(title_text="Delay (min)", row=1, col=2)
        
        fig_plots.show()

    else:
        print("No intersecting data found between trains and their nearest weather stations.")

except Exception as e:
    print(f"Error in weather processing: {e}")
    import traceback
    traceback.print_exc()


Loading Weather Data to match...
Weather data loaded: 598108 records
Merged Dataset (Train + Nearest Weather): 28704 records
Generating Maps...


Generating Daily Animation...


Generating Analysis Plots...
