# Central Asia AQ/Weather/Mobility Tutorial

This notebook demonstrates spatiotemporal prediction on the Central Asia AQ/Weather/Mobility dataset using BayesNF.

<a target="_blank" href="https://colab.research.google.com/github/enorenio/bayesnf/blob/main/central_asia_bayesnf.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# 1. Install Required Libraries

Install the necessary libraries, including bayesnf, cartopy, contextily, geopandas, and kagglehub.

In [None]:
# Install Required Libraries
!pip install bayesnf
!pip -q install cartopy
!pip -q install contextily
!pip -q install geopandas
!pip install kagglehub[pandas-datasets]

# 2. Download and Prepare Central Asia Dataset

Use kagglehub to download the 'mlbyalex/central-asia-aq-weather-mobility-hourly' dataset.  
Set the `file_path` variable to the desired CSV file and load the data into a pandas DataFrame.

In [None]:
import kagglehub
from kagglehub import KaggleDatasetAdapter

# Set the path to the file you'd like to load
file_path = "central_asia_aq_weather_mobility_hourly.csv"  # Update with actual file name if needed

# Load the latest version
df = kagglehub.load_dataset(
  KaggleDatasetAdapter.PANDAS,
  "mlbyalex/central-asia-aq-weather-mobility-hourly",
  file_path,
  # Provide any additional arguments like 
  # sql_query or pandas_kwargs. See the 
  # documentation for more information:
  # https://github.com/Kaggle/kagglehub/blob/main/README.md#kaggledatasetadapterpandas
)

print("First 5 records:", df.head())

# 3. Import Libraries

Import warnings, contextily, geopandas, jax, matplotlib, numpy, pandas, cartopy.crs, shapely.geometry.Point, and mpl_toolkits.axes_grid1.make_axes_locatable.

In [None]:
import warnings
warnings.simplefilter('ignore')

import contextily as ctx
import geopandas as gpd
import jax
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from cartopy import crs as ccrs
from shapely.geometry import Point
from mpl_toolkits.axes_grid1 import make_axes_locatable

# 4. Load and Inspect Data

Read the dataset into a pandas DataFrame, inspect the first few records, and ensure the data is in long format with appropriate columns for BayesNF.

In [None]:
# Display first 20 rows to inspect format
df.head(20)

BayesNF expects the dataframe to be in "long" format.  
Each row should show a single observation (e.g., `aq_value` or other target) at a given point in time (`datetime` column) and in space (`latitude` and `longitude` columns, which show the centroid of the location).  
The `location` column provides a human-readable name for the measurement site.

# 5. Plot Spatial Snapshots

Use geopandas to plot spatial snapshots of the data at different time points, visualizing the distribution of measurements over the region.

In [None]:
# If you have a shapefile for Central Asia, load it here.
# For demonstration, we'll assume 'central_asia.shp' is available.
region = gpd.read_file('central_asia.shp')  # Update with actual shapefile path

df_plot = df.copy()
df_plot['centroid'] = df_plot[['longitude','latitude']].apply(Point, axis=1)
centroid_to_polygon = {
    c: next((g for g in region.geometry.values if g.contains(c)), None)
    for c in set(df_plot['centroid'])
}
df_plot['boundary'] = df_plot['centroid'].replace(centroid_to_polygon)

def plot_map(date, ax, value_col='aq_value'):
    region.plot(color='none', edgecolor='black', linewidth=1, ax=ax)
    ctx.add_basemap(ax, crs=region.crs.to_string(), attribution='', zorder=-1)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad='2%', axes_class=plt.matplotlib.axes.Axes)
    df_plot_geo = gpd.GeoDataFrame(df_plot, geometry='boundary')
    df_plot_geo_t0 = df_plot_geo[df_plot_geo.datetime==date]
    df_plot_geo_t0.plot(
        column=value_col, alpha=.5, edgecolor='k',
        linewidth=1, legend=True, cmap='jet', cax=cax, ax=ax)
    gl = ax.gridlines(draw_labels=True, alpha=0)
    gl.top_labels = False
    gl.right_labels = False
    ax.set_title(date)

fig, axes = plt.subplots(
    nrows=2, ncols=2, subplot_kw={'projection': ccrs.PlateCarree()},
    figsize=(12.5, 12.5), tight_layout=True)

dates = df_plot['datetime'].drop_duplicates().sort_values().iloc[[0, 10, 20, 30]]
for ax, date in zip(axes.flat, dates):
    plot_map(date, ax)

# 6. Plot Time Series for Each Location

Plot the observed time series for each location to visualize temporal patterns and spatial differences.

In [None]:
locations = df.location.unique()
fig, axes = plt.subplots(ncols=4, nrows=5, tight_layout=True, figsize=(25,20))
for ax, location in zip(axes.flat, locations):
    df_location = df[df.location==location]
    latitude, longitude = df_location.iloc[0][['latitude', 'longitude']]
    ax.plot(df_location.datetime, df_location.aq_value, marker='.', color='k', linewidth=1)
    ax.set_title(f'Location: {location} ({longitude:.2f}, {latitude:.2f})')
    ax.set_xlabel('Time')
    ax.set_ylabel('AQ Value')

# 7. Build BayesNF Estimator

Construct a BayesNF model (e.g., BayesianNeuralFieldMAP) using relevant feature columns from the Central Asia dataset.

In [None]:
from bayesnf.spatiotemporal import BayesianNeuralFieldMAP

model = BayesianNeuralFieldMAP(
  width=256,
  depth=2,
  freq='H',  # hourly data
  seasonality_periods=['D', 'Y'], # daily and yearly
  num_seasonal_harmonics=[2, 10],
  feature_cols=['datetime', 'latitude', 'longitude'], # time, spatial 1, ..., spatial n
  target_col='aq_value',  # update to your target column
  observation_model='NORMAL',
  timetype='index',
  standardize=['latitude', 'longitude'],
  interactions=[(0, 1), (0, 2), (1, 2)],
)

# 8. Fit the Estimator

Train the BayesNF estimator on the dataset using the `.fit` method, specifying ensemble size and number of epochs.

In [None]:
# Train MAP ensemble
model = model.fit(
    df,
    seed=jax.random.PRNGKey(0),
    ensemble_size=64,
    num_epochs=5000,
)

# 9. Plot Training Loss

Plot the training loss for each particle in the ensemble to assess convergence.

In [None]:
# Inspect the training loss for each particle.
losses = np.row_stack(model.losses_)
fig, ax = plt.subplots(figsize=(5, 3), tight_layout=True)
ax.plot(losses.T)
ax.plot(np.mean(losses, axis=0), color='k', linewidth=3)
ax.set_xlabel('Epoch')
ax.set_ylabel('Negative Joint Probability')
ax.set_yscale('log', base=10)

# 10. Make Predictions

Use the model's predict method on a test split of the data, obtaining mean predictions and quantiles.

In [None]:
# Split data into train/test if not already done
# For demonstration, let's assume df_test is available
# Otherwise, create a split:
from sklearn.model_selection import train_test_split
df_train, df_test = train_test_split(df, test_size=0.2, shuffle=False)

yhat, yhat_quantiles = model.predict(df_test.drop(columns=['aq_value']), quantiles=(0.025, 0.5, 0.975))

# 11. Scatter Plot: True vs Predicted

Plot a scatter plot comparing true values to predicted values on the test data.

In [None]:
fig, ax = plt.subplots(figsize=(5,3), tight_layout=True)
ax.scatter(df_test.aq_value, yhat_quantiles[1], marker='.', color='k')
ax.plot([df_test.aq_value.min(), df_test.aq_value.max()], [df_test.aq_value.min(), df_test.aq_value.max()], color='red')
ax.set_xlabel('True Value')
ax.set_ylabel('Predicted Value')

# 12. Forecasts for Held-Out Locations

Plot forecasts for held-out locations, showing observed, predicted, and prediction intervals for each location.

In [None]:
locations = df_test.location.unique()
fig, axes = plt.subplots(nrows=2, ncols=len(locations)//2, tight_layout=True, figsize=(16,8))
for ax, location in zip(axes.flat, locations):
    y_train = df_train[df_train.location==location]
    y_test = df_test[df_test.location==location]
    ax.scatter(y_train.datetime[-100:], y_train.aq_value[-100:], marker='o', color='k', label='Observations')
    ax.scatter(y_test.datetime, y_test.aq_value, marker='o', edgecolor='k', facecolor='w', label='Test Data')
    mask = df_test.location.to_numpy() == location
    ax.plot(y_test.datetime, yhat_quantiles[1][mask], color='red', label='Median Prediction')
    ax.fill_between(y_test.datetime, yhat_quantiles[0][mask], yhat_quantiles[2][mask], alpha=0.5, label='95% Prediction Interval')
    ax.set_title('Test Location: %s' % (location,))
    ax.set_xlabel('Time')
    ax.set_ylabel('AQ Value')
axes.flat[0].legend(loc='upper left')