## Step 4:  Fit the SPDE Spatial Model to the data

This notebook:
1. Compiles the SPDE Stan model (`spde_pm25.stan`)
2. Loads the preprocessed data from `json_data/stan_data.json`
3. Runs MCMC sampling
4. Provides basic diagnostics and visualization

**Prerequisites**: Run the FEM preprocessing notebook first to generate `json_data/stan_data.json`

### Setup and Imports

In [None]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

# CmdStanPy for Stan interface
from cmdstanpy import CmdStanModel, from_csv

### Load Preprocessed Data

In [None]:
# Load the preprocessed data
data_path = Path("json_data/stan_data.json")

if not data_path.exists():
    raise FileNotFoundError(
        f"Data file not found at {data_path}. "
        "Please run the FEM preprocessing notebook first."
    )

with open(data_path, 'r') as f:
    stan_data = json.load(f)

print(f"Loaded Stan data with {len(stan_data)} fields")
print(f"\nData dimensions:")
print(f"  Observations: {stan_data['N_obs']}")
print(f"  Mesh vertices: {stan_data['N_vertices']}")
print(f"  A matrix non-zeros: {stan_data['A_nnz']}")
print(f"  Q matrix non-zeros: {stan_data['Q_nnz']}")
print(f"\nPrior mean PM2.5: {stan_data['prior_mean']:.1f}")

### Compile Stan Model

In [None]:
# Compile the Stan model
model_path = "spde_pm25.stan"

if not Path(model_path).exists():
    raise FileNotFoundError(f"Stan model not found at {model_path}")

print("Compiling Stan model...")
model = CmdStanModel(stan_file=model_path)
print("Model compiled successfully!")

# Display model code (first 30 lines)
print("\nModel structure (first 30 lines):")
print("="*60)
with open(model_path, 'r') as f:
    lines = f.readlines()[:30]
    print(''.join(lines))
print("...")

### Configure Sampling Parameters

In [None]:
# Sampling configuration
sampling_config = {
    'refresh': 10,  # Print progress every X iterations
    'adapt_init_phase': 50,
    'adapt_metric_window': 25,
    'save_warmup' : True,
    'output_dir' : 'tmp',
}

### Run MCMC Sampling

**Note**: This may take several minutes depending on your data size and computer.

In [None]:
# Run the sampler
print("Starting MCMC sampling...")
print("This may take several minutes.\n")

fit = model.sample(
    data=stan_data,
    **sampling_config
)

print("\nSampling completed successfully!")

In [None]:
summary = fit.summary()
summary.round(2)

### Parameter Summary

In [None]:
# Get summary statistics for key parameters
# Filter for main parameters (not including the spatial field w)
main_params = ['alpha', 'sigma', 'tau']
param_summary = summary[summary.index.isin(main_params)]

print("Parameter estimates:")
print("="*60)
print(param_summary[['Mean', 'StdDev', '5%', '95%', 'ESS_bulk', 'R_hat']].round(2))

# Check R_hat values
max_rhat = param_summary['R_hat'].max()
print(f"\nMax R_hat for main parameters: {max_rhat:.3f}")
if max_rhat > 1.01:
    print("Warning: Some R_hat values > 1.01, indicating potential convergence issues")
else:
    print("All R_hat values < 1.01, indicating good convergence")

### Visualize Parameter Distributions

In [None]:
# Extract samples for main parameters
draws = fit.draws_pd()

# Create trace plots
fig, axes = plt.subplots(len(main_params), 2, figsize=(12, 10))
fig.suptitle('Parameter Trace Plots and Distributions', fontsize=14)

for idx, param in enumerate(main_params):
    # Trace plot
    ax_trace = axes[idx, 0]
    for chain in range(4):
        chain_draws = draws[draws['chain__'] == chain+1][param]
        ax_trace.plot(chain_draws.values, alpha=0.7, linewidth=0.5)
    ax_trace.set_ylabel(param)
    ax_trace.set_xlabel('Iteration')
    ax_trace.set_title(f'{param} chains')
    
    # Histogram
    ax_hist = axes[idx, 1]
    ax_hist.hist(draws[param], bins=50, density=True, alpha=0.7, color='blue')
    ax_hist.axvline(draws[param].mean(), color='red', linestyle='--', label='Mean')
    ax_hist.axvline(draws[param].median(), color='green', linestyle='--', label='Median')
    ax_hist.set_xlabel(param)
    ax_hist.set_ylabel('Density')
    ax_hist.set_title(f'{param} distribution')
    ax_hist.legend()

plt.tight_layout()
plt.show()

### Posterior Predictive Checks

In [None]:
# Extract posterior predictive samples
y_rep_cols = [col for col in draws.columns if col.startswith('y_rep[')]
y_rep = draws[y_rep_cols].values

# Get observed data
y_obs = np.array(stan_data['y'])

# Calculate posterior predictive mean and intervals
y_rep_mean = y_rep.mean(axis=0)
y_rep_lower = np.percentile(y_rep, 5, axis=0)
y_rep_upper = np.percentile(y_rep, 95, axis=0)

# Plot observed vs predicted
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Observed vs predicted means
ax1 = axes[0]
ax1.scatter(y_obs, y_rep_mean, alpha=0.5, s=10)
ax1.plot([y_obs.min(), y_obs.max()], [y_obs.min(), y_obs.max()], 'r--', label='1:1 line')
ax1.set_xlabel('Observed PM2.5')
ax1.set_ylabel('Predicted PM2.5 (mean)')
ax1.set_title('Observed vs Predicted')
ax1.legend()

# Coverage plot
ax2 = axes[1]
in_interval = (y_obs >= y_rep_lower) & (y_obs <= y_rep_upper)
coverage = in_interval.mean() * 100

ax2.errorbar(y_obs[::10], y_rep_mean[::10], 
             yerr=[y_rep_mean[::10] - y_rep_lower[::10], 
                   y_rep_upper[::10] - y_rep_mean[::10]],
             fmt='o', alpha=0.3, markersize=3, elinewidth=0.5)
ax2.plot([y_obs.min(), y_obs.max()], [y_obs.min(), y_obs.max()], 'r--')
ax2.set_xlabel('Observed PM2.5')
ax2.set_ylabel('Predicted PM2.5 (90% CI)')
ax2.set_title(f'90% Prediction Intervals\n(Coverage: {coverage:.1f}%)')

plt.tight_layout()
plt.show()

# Print summary statistics
residuals = y_obs - y_rep_mean
print(f"Residual statistics:")
print(f"  Mean: {residuals.mean():.3f}")
print(f"  Std: {residuals.std():.3f}")
print(f"  RMSE: {np.sqrt((residuals**2).mean()):.3f}")
print(f"  MAE: {np.abs(residuals).mean():.3f}")

### Extract and Visualize Spatial Field

The spatial field `w` at mesh vertices captures the spatial correlation structure.

In [None]:
# Extract posterior mean of spatial field
w_cols = [col for col in draws.columns if col.startswith('w[')]
w_samples = draws[w_cols].values
w_mean = w_samples.mean(axis=0)
w_std = w_samples.std(axis=0)

print(f"Spatial field statistics:")
print(f"  Number of vertices: {len(w_mean)}")
print(f"  Mean value: {w_mean.mean():.3f}")
print(f"  Std of means: {w_mean.std():.3f}")
print(f"  Range: [{w_mean.min():.3f}, {w_mean.max():.3f}]")

# Plot distribution of spatial field values
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Distribution of posterior means
axes[0].hist(w_mean, bins=50, density=True, alpha=0.7, color='blue')
axes[0].axvline(0, color='red', linestyle='--', label='Zero')
axes[0].set_xlabel('Spatial field value')
axes[0].set_ylabel('Density')
axes[0].set_title('Distribution of Spatial Field (posterior means)')
axes[0].legend()

# Uncertainty in spatial field
axes[1].hist(w_std, bins=50, density=True, alpha=0.7, color='orange')
axes[1].set_xlabel('Posterior std deviation')
axes[1].set_ylabel('Density')
axes[1].set_title('Uncertainty in Spatial Field')

plt.tight_layout()
plt.show()

### Geographic Visualization of Spatial Patterns

Visualize the spatial field and residuals overlaid on the geographic locations.

In [None]:
# Load required libraries for geographic visualization
import geopandas as gpd
from plotnine import *
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

# Load original coordinate data
pm25_data = pd.read_csv('north_america_pm25.csv')

# We need to reload the preprocessed coordinate data
import sys
sys.path.insert(0, '..')
from geo_spde.coords import preprocess_coords

# Get coordinates in original lon/lat
lon_lat_coords = pm25_data[['Longitude', 'Latitude']].to_numpy()

# Process to get the same cleaned coordinates used in modeling
clean_coords, indices, proj_info = preprocess_coords(
    lon_lat_coords,
    remove_duplicates=True
)

# Get the cleaned data
pm25_clean = pm25_data.iloc[indices].copy()
lon_lat_clean = lon_lat_coords[indices]

### Save Results

In [None]:
# Compute spatial field at observation locations using A matrix
# Extract A matrix structure from stan_data
A_w = np.array(stan_data['A_w'])
A_v = np.array(stan_data['A_v']) - 1  # Convert to 0-indexed
A_u = np.array(stan_data['A_u']) - 1  # Convert to 0-indexed

# Function to multiply sparse CSR matrix with vector
def csr_matvec(w_vals, v_indices, u_pointers, x):
    """Multiply CSR matrix with vector"""
    n_rows = len(u_pointers) - 1
    result = np.zeros(n_rows)
    
    for i in range(n_rows):
        for j in range(u_pointers[i], u_pointers[i+1]):
            result[i] += w_vals[j] * x[v_indices[j]]
    
    return result

# Compute spatial field at observation locations
w_at_obs = csr_matvec(A_w, A_v, A_u, w_mean)

# Compute model predictions and residuals
alpha_mean = draws['alpha'].mean()

# Fixed effects prediction (without spatial field)
fixed_pred = alpha_mean

# Full model prediction at observations
y_pred_full = fixed_pred + w_at_obs

# Residuals
residuals = y_obs - y_pred_full

print(f"Spatial field contribution at observations:")
print(f"  Mean: {w_at_obs.mean():.3f}")
print(f"  Std: {w_at_obs.std():.3f}")
print(f"  Range: [{w_at_obs.min():.3f}, {w_at_obs.max():.3f}]")

In [None]:
# Load USA map from naturalearth
url = "https://naciscdn.org/naturalearth/110m/cultural/ne_110m_admin_0_countries.zip"
world = gpd.read_file(url)
north_america_map = world[world['ISO_A3'].isin(['CAN', 'USA', 'MEX'])]

# Project map to same coordinate system as our data
north_america_proj = north_america_map.to_crs(proj_info['proj4_string'])

# Create dataframe with all visualization data
viz_df = pd.DataFrame({
    'Longitude': lon_lat_clean[:, 0],
    'Latitude': lon_lat_clean[:, 1],
    'x_proj': clean_coords[:, 0],
    'y_proj': clean_coords[:, 1],
    'PM25_observed': y_obs,
    'PM25_predicted': y_pred_full,
    'spatial_field': w_at_obs,
    'residual': residuals,
})

In [None]:
# Additional visualization: Urban vs Rural effects
# First, create a proper categorical variable for urban/rural

p_map = (ggplot() +
    geom_map(north_america_proj, fill='white', color='black', size=0.2) +
    geom_point(data=viz_df,
               mapping=aes(x='x_proj', y='y_proj', 
                          color='spatial_field'),
               size=0.5, alpha=0.8) +
    scale_color_gradient2(low='blue', mid='white', high='red',
                          midpoint=0, name='Spatial\nField') +
    theme_minimal() +
    labs(title="Spatial Field") +
    coord_fixed())
display(p_map)

# Alternative: Show residuals by urban/rural
p_map_residuals = (ggplot() +
    geom_map(north_america_proj, fill='white', color='black', size=0.2) +
    geom_point(data=viz_df,
               mapping=aes(x='x_proj', y='y_proj', 
                          color='residual'),
               size=0.5, alpha=0.8) +
    scale_color_gradient2(low='blue', mid='white', high='red',
                          midpoint=0, name='Residual\n(μg/m³)') +
    theme_minimal() +
    labs(title="Model Residuals") +
    coord_fixed())

display(p_map_residuals)

### Summary

This notebook successfully:
1. Compiled the SPDE Stan model
2. Loaded preprocessed data from the FEM pipeline
3. Ran MCMC sampling with diagnostics
4. Visualized parameter distributions and convergence
5. Performed posterior predictive checks
6. Extracted and analyzed the spatial field

The model captures spatial correlation through the latent field `w`, providing better predictions than a non-spatial model would.