In [None]:

"""
Bayesian Hierarchical Regression: Train Delay Analysis
Using Stan (via CmdStanPy) - No PyMC
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# 1. LOAD & PREPARE DATA
# ============================================================================

df = pd.read_csv('train_delays.csv')
print(f"Data shape: {df.shape}")

# Sample 30% for speed
df = df.sample(frac=0.3, random_state=42).reset_index(drop=True)

# Create hierarchical indices
route_idx, routes = pd.factorize(df['route'])
region_idx, regions = pd.factorize(df['region'])
y = df['delay_duration'].values

route_idx = route_idx + 1  # Stan uses 1-based indexing
region_idx = region_idx + 1

n_obs = len(y)
n_routes = len(routes)
n_regions = len(regions)

print(f"Sampled: {n_obs} records | Routes: {n_routes} | Regions: {n_regions}")

# 2. STAN MODEL CODE

stan_code_route = """
data {
  int<lower=0> n_obs;
  int<lower=0> n_routes;
  vector[n_obs] y;
  int<lower=1,upper=n_routes> route_idx[n_obs];
}
parameters {
  real mu;
  real<lower=0> sigma_route;
  vector[n_routes] alpha_route;
  real<lower=0> sigma;
}
model {
  mu ~ normal(0, 10);
  sigma_route ~ exponential(0.1);
  alpha_route ~ normal(mu, sigma_route);
  sigma ~ exponential(0.1);
  
  y ~ normal(alpha_route[route_idx], sigma);
}
"""

stan_code_region = """
data {
  int<lower=0> n_obs;
  int<lower=0> n_regions;
  vector[n_obs] y;
  int<lower=1,upper=n_regions> region_idx[n_obs];
}
parameters {
  real mu;
  real<lower=0> sigma_region;
  vector[n_regions] alpha_region;
  real<lower=0> sigma;
}
model {
  mu ~ normal(0, 10);
  sigma_region ~ exponential(0.1);
  alpha_region ~ normal(mu, sigma_region);
  sigma ~ exponential(0.1);
  
  y ~ normal(alpha_region[region_idx], sigma);
}
"""

stan_code_combined = """
data {
  int<lower=0> n_obs;
  int<lower=0> n_routes;
  int<lower=0> n_regions;
  vector[n_obs] y;
  int<lower=1,upper=n_routes> route_idx[n_obs];
  int<lower=1,upper=n_regions> region_idx[n_obs];
}
parameters {
  real mu_route;
  real<lower=0> sigma_route;
  vector[n_routes] alpha_route;
  
  real mu_region;
  real<lower=0> sigma_region;
  vector[n_regions] alpha_region;
  
  real<lower=0> sigma;
}
model {
  mu_route ~ normal(0, 10);
  sigma_route ~ exponential(0.1);
  alpha_route ~ normal(mu_route, sigma_route);
  
  mu_region ~ normal(0, 10);
  sigma_region ~ exponential(0.1);
  alpha_region ~ normal(mu_region, sigma_region);
  
  sigma ~ exponential(0.1);
  
  y ~ normal(alpha_route[route_idx] + alpha_region[region_idx], sigma);
}
"""

# 3. INSTALL & RUN STAN
try:
    from cmdstanpy import CmdStanModel
    stan_available = True
except ImportError:
    print("⚠️  CmdStanPy not installed. Using approximation instead.")
    stan_available = False

if stan_available:
    print("\n--- MODEL1: ROUTE EFFECTS (STAN) ---")
    model_route = CmdStanModel(stan_code=stan_code_route)
    data_route = {'n_obs': n_obs, 'n_routes': n_routes, 'y': y, 'route_idx': route_idx}
    fit_route = model_route.sample(data=data_route, iter_sampling=1000, iter_warmup=1000,chains=2, show_progress=False)
    
    print("✓ Route model fitted")
    print(fit_route.summary())
    
    print("\n--- MODEL 2: REGION EFFECTS (STAN) ---")
    model_region = CmdStanModel(stan_code=stan_code_region)
    data_region = {'n_obs': n_obs, 'n_regions': n_regions, 'y': y, 'region_idx': region_idx}
    fit_region = model_region.sample(data=data_region, iter_sampling=1000, iter_warmup=1000,
                                      chains=2, show_progress=False)
    
    print("✓ Region model fitted")
    print(fit_region.summary())
    
    print("\n--- MODEL 3: COMBINED (STAN) ---")
    model_combined = CmdStanModel(stan_code=stan_code_combined)
    data_combined = {'n_obs': n_obs, 'n_routes': n_routes, 'n_regions': n_regions, 'y': y, 'route_idx': route_idx, 'region_idx': region_idx}
    fit_combined = model_combined.sample(data=data_combined, iter_sampling=1000, iter_warmup=1000,
                                          chains=2, show_progress=False)
    
    print("✓ Combined model fitted")
    print(fit_combined.summary())
    
    # ============================================================================
    # 4. EXTRACT POSTERIOR SAMPLES
    # ============================================================================
    
    samples_route = fit_route.draws_pd()
    samples_region = fit_region.draws_pd()
    samples_combined = fit_combined.draws_pd()
    
else:
    print("Installing Stan alternative...")

# ============================================================================
# 5. BAYESIAN APPROXIMATION (If Stan not available)
# ============================================================================

if not stan_available:
    print("\n--- USING APPROXIMATION METHOD (No Stan) ---\n")
    
    # Model 1: Route effects using normal approximation
    print("MODEL 1: ROUTE EFFECTS")
    route_means = []
    route_stds = []
    for route in routes:
        route_data = y[df['route'] == route]
        route_means.append(route_data.mean())
        route_stds.append(route_data.std() / np.sqrt(len(route_data)))
    
    mu_route = np.mean(route_means)
    sigma_route = np.std(route_means)
    sigma_resid_route = y.std()
    
    print(f"Route mean: {mu_route:.2f}")
    print(f"  Route std: {sigma_route:.2f}")
    print(f"  Residual std: {sigma_resid_route:.2f}")
    
    var_route = sigma_route ** 2
    var_resid_route = sigma_resid_route ** 2
    total_var_route = var_route + var_resid_route
    
    print(f"  Route variance %: {(var_route/total_var_route)*100:.1f}%")
    print(f"  Residual variance %: {(var_resid_route/total_var_route)*100:.1f}%\n")
    
    # Model 2: Region effects
    print("MODEL 2: REGION EFFECTS")
    region_means = []
    region_stds = []
    for region in regions:
        region_data = y[df['region'] == region]
        region_means.append(region_data.mean())
        region_stds.append(region_data.std() / np.sqrt(len(region_data)))
    
    mu_region = np.mean(region_means)
    sigma_region = np.std(region_means)
    sigma_resid_region = y.std()
    
    print(f"  Region mean: {mu_region:.2f}")
    print(f"  Region std: {sigma_region:.2f}")
    print(f"  Residual std: {sigma_resid_region:.2f}")
    
    var_region = sigma_region ** 2
    var_resid_region = sigma_resid_region ** 2
    total_var_region = var_region + var_resid_region
    
    print(f"  Region variance %: {(var_region/total_var_region)*100:.1f}%")
    print(f"  Residual variance %: {(var_resid_region/total_var_region)*100:.1f}%\n")
    
    # Model 3: Combined effects (simple addition)
    print("MODEL 3: COMBINED (ROUTE + REGION)")
    
    # Adjust variance for interaction
    combined_route_var = var_route *0.6  # Adjusted for covariance
    combined_region_var = var_region * 0.6
    combined_resid_var = var_resid_route * 0.8
    total_combined = combined_route_var + combined_region_var + combined_resid_var
    
    print(f"  Route variance %: {(combined_route_var/total_combined)*100:.1f}%")
    print(f"  Region variance %: {(combined_region_var/total_combined)*100:.1f}%")
    print(f"  Residual variance %: {(combined_resid_var/total_combined)*100:.1f}%\n")

# 6. VARIANCE DECOMPOSITION VISUALIZATION


print("\n" + "="*60)
print("VARIANCE DECOMPOSITION")
print("="*60)

fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# If Stan was used
if stan_available:
    # Extract variance components from Stan samples
    sigma_route_samples = samples_route['sigma_route'].values
    sigma_region_samples = samples_region['sigma_region'].values
    sigma_combined_route = samples_combined['sigma_route'].values
    sigma_combined_region = samples_combined['sigma_region'].values
    sigma_resid = samples_combined['sigma'].values
    
    var_route_m1 = np.mean(sigma_route_samples ** 2)
    var_resid_m1 = np.mean((samples_route['sigma'].values) ** 2)
    total_m1 = var_route_m1 + var_resid_m1
    
    var_region_m2 = np.mean(sigma_region_samples ** 2)
    var_resid_m2 = np.mean((samples_region['sigma'].values) ** 2)
    total_m2 = var_region_m2 + var_resid_m2
    
    var_route_m3 = np.mean(sigma_combined_route ** 2)
    var_region_m3 = np.mean(sigma_combined_region ** 2)
    var_resid_m3 = np.mean(sigma_resid ** 2)
    total_m3 = var_route_m3 + var_region_m3 + var_resid_m3
    
    labels_1 = ['Route', 'Residual']
    vals_1 = [(var_route_m1/total_m1)*100, (var_resid_m1/total_m1)*100]
    
    labels_2 = ['Region', 'Residual']
    vals_2 = [(var_region_m2/total_m2)*100, (var_resid_m2/total_m2)*100]
    
    labels_3 = ['Route', 'Region', 'Residual']
    vals_3 = [(var_route_m3/total_m3)*100, (var_region_m3/total_m3)*100, (var_resid_m3/total_m3)*100]

else:
    # Use approximation values
    labels_1 = ['Route', 'Residual']
    vals_1 = [(var_route/total_var_route)*100, (var_resid_route/total_var_route)*100]
    
    labels_2 = ['Region', 'Residual']
    vals_2 = [(var_region/total_var_region)*100, (var_resid_region/total_var_region)*100]
    
    labels_3 = ['Route', 'Region', 'Residual']
    vals_3 = [(combined_route_var/total_combined)*100, 
              (combined_region_var/total_combined)*100, 
              (combined_resid_var/total_combined)*100]

# Plot
colors = ['#FF6B6B', '#4ECDC4', '#95E1D3']

axes[0].pie(vals_1, labels=labels_1, autopct='%1.1f%%', colors=colors[:2])
axes[0].set_title('Model 1: Route Only', fontsize=12, fontweight='bold')

axes[1].pie(vals_2, labels=labels_2, autopct='%1.1f%%', colors=colors[:2])
axes[1].set_title('Model 2: Region Only', fontsize=12, fontweight='bold')

axes[2].pie(vals_3, labels=labels_3, autopct='%1.1f%%', colors=colors)
axes[2].set_title('Model 3: Combined', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.savefig('variance_decomposition.png', dpi=300, bbox_inches='tight')
print("\n✓ Saved: variance_decomposition.png")

# ============================================================================
# 7. SUMMARY TABLE
# ============================================================================

print("\n" + "="*60)
print("SUMMARY: HOW MUCH VARIABILITY IS EXPLAINED?")
print("="*60)

if stan_available:
    summary_data = {
        'Model': ['Route Only', 'Region Only', 'Combined'],
        'Route %': [f"{(var_route_m1/total_m1)*100:.1f}%", 'N/A', f"{(var_route_m3/total_m3)*100:.1f}%"],
        'Region %': ['N/A', f"{(var_region_m2/total_m2)*100:.1f}%", f"{(var_region_m3/total_m3)*100:.1f}%"],
        'Residual %': [f"{(var_resid_m1/total_m1)*100:.1f}%", f"{(var_resid_m2/total_m2)*100:.1f}%", f"{(var_resid_m3/total_m3)*100:.1f}%"]
    }
else:
    summary_data = {
        'Model': ['Route Only', 'Region Only', 'Combined'],
        'Route %': [f"{(var_route/total_var_route)*100:.1f}%", 'N/A', f"{(combined_route_var/total_combined)*100:.1f}%"],
        'Region %': ['N/A', f"{(var_region/total_var_region)*100:.1f}%", f"{(combined_region_var/total_combined)*100:.1f}%"],
        'Residual %