# Modeling Extremes - Numpyro Pt 2 - MAP

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "" # first gpu
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'FALSE'

import jax
jax.config.update('jax_platform_name', 'cpu')

import numpyro
import multiprocessing

num_devices = multiprocessing.cpu_count()
numpyro.set_platform("cpu")
numpyro.set_host_device_count(num_devices)
jax.config.update("jax_enable_x64", True)

In [2]:
import autoroot
from pathlib import Path
import numpy as np
import xarray as xr
import pandas as pd
import pint_xarray
import arviz as az

from st_evt.viz import plot_histogram, plot_density
from omegaconf import OmegaConf

import jax
import jax.random as jrandom
import jax.numpy as jnp
import pandas as pd

rng_key = jrandom.PRNGKey(123)

import arviz as az

import xarray as xr
import regionmask

import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter, FuncFormatter
import seaborn as sns
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%config InlineBackend.figure_format = 'retina'
plt.style.use(
    "https://raw.githubusercontent.com/ClimateMatchAcademy/course-content/main/cma.mplstyle"
)

from loguru import logger

# num_devices = 5
# numpyro.set_host_device_count(num_devices)


%matplotlib inline
%load_ext autoreload
%autoreload 2

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


## Paths

In [3]:
results_root_path = "/home/juanjohn/pool_projects/scratch/stevt_tutorial/models/nonstationary_gp_lap_demo/"
results_file_name = "results.nc"
results_data_path = Path(results_root_path).joinpath(results_file_name)

figures_path = Path(results_root_path).joinpath("figures/stations/posterior")

In [4]:
figures_path

PosixPath('/home/juanjohn/pool_projects/scratch/stevt_tutorial/models/nonstationary_gp_lap_demo/figures/stations/posterior')

## Data

In [5]:
DATA_URL = autoroot.root.joinpath("data/ml_ready/aemet/t2max_stations_bm_summer.zarr")
variable = "t2max"
covariate = "gmst"
spatial_dim_name = "station_id"


# LOAD DATA
with xr.open_dataset(DATA_URL, engine="zarr") as f:
    ds_bm = f.load()
    # ds_bm = ds_bm.where(ds_bm.red_feten_mask == 1, drop=True)

### Likelihood Statistics

There are some useful statistics that we can use to evaluate how well our model does.

$$
\begin{aligned}
\text{ELPD WAIC}: && && \\
\text{ELPD WAIC SE}: && && \\
\text{P-Value WAIC}: && && \\
\end{aligned}
$$

In [6]:
variable = "t2max"

## Station Selection

In [7]:
from st_evt import AEMET_GOOD_STATIONS, AEMET_BAD_STATIONS

In [8]:
AEMET_BAD_STATIONS["red_feten"]

['3407Y', '9677', 'C018J', 'C426R']

### a - Predetermined Station

In [9]:
from st_evt import CANDIDATE_STATIONS
CANDIDATE_STATIONS

{'madrid': ['3129A', 'Madrid (Barajas)'],
 'valencia': ['8414A', 'Valencia (Aeropuerto)'],
 'zaragoza': ['9434', 'Zaragoza (Aeropuerto)'],
 'santiago': ['1475X', 'Santiago De Compostela'],
 'murcia': ['7178I', 'Murcia'],
 'cordoba': ['9434', 'Viallanueva de Cordoba (Sea)']}

In [10]:
candidate_station = CANDIDATE_STATIONS["valencia"][0]
candidate_station = '8354X'

In [11]:
figures_path = figures_path.joinpath(f"{candidate_station}")
figures_path.mkdir(parents=True, exist_ok=True)

### MCMC Results

In [12]:
az_ds = az.from_netcdf(str(results_data_path))
az_ds_station = az_ds.sel(station_id = candidate_station)
ds_station = ds_bm.sel(station_id = candidate_station)
# az_ds

### EDA Stuff

In [13]:
from st_evt._src.modules.models.aemet import utils_station

In [14]:
utils_station.plot_eda(
    da=ds_station[variable].squeeze(),
    variable_label="2m Max Temperature [°C]",
    # figures_path="./", 
    figures_path=figures_path, 
    figure_dpi=300,
)

[32m2025-01-09 20:17:27.194[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_eda[0m:[36m114[0m - [1mPlotting BM Data Time Series...[0m
[32m2025-01-09 20:17:27.807[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_eda[0m:[36m130[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_projects/scratch/stevt_tutorial/models/nonstationary_gp_lap_demo/figures/stations/posterior/8354X/eda/ts_bm_data.pdf[0m
[32m2025-01-09 20:17:27.808[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_eda[0m:[36m132[0m - [1mPlotting BM Data Histogram...[0m
[32m2025-01-09 20:17:27.936[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_eda[0m:[36m146[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_projects/scratch/stevt_tutorial/models/nonstationary_gp_lap_demo/figures/stations/posterior/8354X/eda/hist_bm_data.pdf[0m
[32m2025-

## Posterior Calculations

## Model Inspection

In [15]:
az.waic(az_ds_station)

Computed from 1000 posterior samples and 59 observations log-likelihood matrix.

          Estimate       SE
elpd_waic  -104.25     5.01
p_waic        1.07        -

### Trace Plot

In [16]:
variables = [
    "concentration",
    "scale",
    "location_slope",
    "location_intercept",
    ]

utils_station.plot_model_params_critique(
    ds=az_ds_station.posterior,
    variables=variables,
    # figures_path="./", 
    figures_path=figures_path, 
    
)

[32m2025-01-09 20:17:28.311[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_model_params_critique[0m:[36m39[0m - [1mPlotting Parameter Traces...[0m
[32m2025-01-09 20:17:28.826[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_model_params_critique[0m:[36m53[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_projects/scratch/stevt_tutorial/models/nonstationary_gp_lap_demo/figures/stations/posterior/8354X/params/trace.pdf[0m
[32m2025-01-09 20:17:28.827[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_model_params_critique[0m:[36m55[0m - [1mPlotting Parameter Jonts...[0m
[32m2025-01-09 20:17:29.913[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_model_params_critique[0m:[36m76[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_projects/scratch/stevt_tutorial/models/nonstationary_gp_lap_demo/figures

## Model Critique

In [17]:
idata = az.extract(az_ds_station, group="posterior_predictive", num_samples=10_000)


y_pred = az_ds_station.posterior_predictive["obs"].sel(variable="t2max").rename("y_pred")
y_true = az_ds_station.observed_data["obs"].sel(variable="t2max")

In [18]:
utils_station.plot_residual_error_metric(
    y_pred=y_pred,
    y_true=y_true,
    figures_dpi=300,
    # figures_path="./", 
    figures_path=figures_path, 
    units="[°C]"
)
utils_station.plot_residual_abs_error_metric(
    y_pred=y_pred,
    y_true=y_true,
    figures_dpi=300,
    # figures_path="./", 
    figures_path=figures_path, 
    units="[°C]"
)

[32m2025-01-09 20:17:32.113[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_residual_error_metric[0m:[36m235[0m - [1mCalculating residual error...[0m
[32m2025-01-09 20:17:32.354[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_residual_error_metric[0m:[36m252[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_projects/scratch/stevt_tutorial/models/nonstationary_gp_lap_demo/figures/stations/posterior/8354X/metrics/density_residuals.pdf[0m
[32m2025-01-09 20:17:32.354[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_residual_abs_error_metric[0m:[36m266[0m - [1mCalculating residual error...[0m
[32m2025-01-09 20:17:32.577[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_residual_abs_error_metric[0m:[36m283[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_projects/scratch/stevt_tutorial/models/nonst

### QQ-Plot

In [19]:
y_pred_median = y_pred.mean(dim=["draw", "chain"])
# fig, ax = plot_qq(
#     y_true=y_true,
#     y_pred=y_pred_median,
#     figures_dpi=300,
# )
# plt.show()

utils_station.plot_qq(
    y_true=y_true,
    y_pred=y_pred_median,
    # figures_path="./", 
    figures_path=figures_path, 
    figures_dpi=300,
)

[32m2025-01-09 20:17:32.689[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_qq[0m:[36m301[0m - [1mCalculating Metrics (RMSE, MAE, MAPD)...[0m
[32m2025-01-09 20:17:32.887[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_qq[0m:[36m314[0m - [1mPlotting QQ-Plot...[0m
[32m2025-01-09 20:17:33.073[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_qq[0m:[36m340[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_projects/scratch/stevt_tutorial/models/nonstationary_gp_lap_demo/figures/stations/posterior/8354X/metrics/qq_plot.pdf[0m


## Regression Plot

In [20]:
x = az_ds_station.posterior.gmst
y = az_ds_station.observed_data["obs"].sel(variable=variable)
y_hat = az_ds_station.posterior_predictive["obs"].sel(variable=variable)
slope = az_ds_station.posterior["location_slope"].sel(variable=variable)
intercept = az_ds_station.posterior["location_intercept"].sel(variable=variable)
y_model =  slope * x + intercept

utils_station.plot_regression_posterior(
    x=x,
    y=y,
    y_hat=y_hat,
    y_model=y_model,
    # figures_path="./", 
    figures_path=figures_path, 
    figure_dpi=300,
    covariate_label="Global Mean Surface Temperature Anomaly [°C]",
    y_label="2m Max Temperature [°C]"
)

[32m2025-01-09 20:17:33.376[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_regression_posterior[0m:[36m983[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_projects/scratch/stevt_tutorial/models/nonstationary_gp_lap_demo/figures/stations/posterior/8354X/regression/regression.pdf[0m


## Return Levels

#### Empirical Return Levels

In [21]:
# select clean data
# y_clean = az_ds_station.observed_data.dropna(dim=covariate)["obs"].sel(variable=variable)
y_clean = az_ds_station.observed_data["obs"].sel(variable=variable)
y_clean = y_clean.assign_coords({covariate: az_ds_station.observed_data[covariate]})
y_clean = y_clean.swap_dims({"time": "gmst"})

# calculate return period
y_clean = utils_station.calculate_empirical_return_level_gevd_ds(y_clean, covariate=covariate)


[32m2025-01-09 20:17:33.422[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mcalculate_empirical_return_level_gevd_ds[0m:[36m354[0m - [1mCalculating Return Level...[0m
[32m2025-01-09 20:17:33.426[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mcalculate_empirical_return_level_gevd_ds[0m:[36m362[0m - [1mSwapping Dims...[0m


In [22]:
# calculate model return periods
az_ds_station.posterior_predictive = utils_station.calculate_ds_return_periods(az_ds_station.posterior_predictive)

# Calculate Quantiles
rl_model_quantiles = (
    az_ds_station.posterior_predictive["return_level"]
    .sel(variable=variable)
    .quantile(q=[0.025, 0.5, 0.975], dim=["chain", "draw"])
)

rl_model_quantiles = rl_model_quantiles.assign_coords({covariate: az_ds_station.observed_data[covariate]})
rl_model_quantiles = rl_model_quantiles.swap_dims({"time": "gmst"})

### Viz - Return Level 

In [23]:
# fig, ax = model_eval_station.plot_return_periods_ds(
#     rl_model_quantiles=rl_model_quantiles,
#     y=y_clean,
#     covariate=covariate,
#     y_label="2m Max Temperature, $R_a$ [°C]"
# )
# plt.show()

utils_station.plot_return_periods_dyn_ds(
    rl_model_quantiles=rl_model_quantiles,
    y=y_clean,
    covariate=covariate,
    # figures_path="./", 
    figures_path=figures_path, 
    y_label="2m Max Temperature, $R_a$ [°C]"
)

[32m2025-01-09 20:17:34.108[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_dyn_ds[0m:[36m421[0m - [1mGetting Appropriate Periods...[0m
[32m2025-01-09 20:17:34.109[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_dyn_ds[0m:[36m425[0m - [1mIntialize Returns...[0m
[32m2025-01-09 20:17:34.110[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_dyn_ds[0m:[36m429[0m - [1mCreating Data structures...[0m
[32m2025-01-09 20:17:34.111[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_dyn_ds[0m:[36m445[0m - [1mPlotting...[0m
[32m2025-01-09 20:17:35.027[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_dyn_ds[0m:[36m461[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_projects/scratch/stevt_

### Viz - 100-Year Return Period

In [24]:
# calculate model return periods
az_ds_station.posterior_predictive = utils_station.calculate_ds_return_periods(az_ds_station.posterior_predictive)

# Calculate Quantiles
rl_model_quantiles = az_ds_station.posterior_predictive["return_level_100"].sel(variable=variable)

rl_model_quantiles = rl_model_quantiles.assign_coords({covariate: az_ds_station.observed_data[covariate]})
rl_model_quantiles = rl_model_quantiles.swap_dims({"time": "gmst"})

In [25]:
# fig, ax = plot_return_periods_100_ds(
#     rl_model_quantiles=rl_model_quantiles,
#     covariate=covariate,
#     x_label="2m Max Temperature, $R_a$ [°C]"
# )
# plt.show()

utils_station.plot_return_periods_100_dyn_ds(
    rl_model=rl_model_quantiles,
    covariate=covariate,
    # figures_path="./", 
    figures_path=figures_path, 
    x_label="2m Max Temperature, $R_{100}$ [°C]"
)

[32m2025-01-09 20:17:35.235[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_dyn_ds[0m:[36m575[0m - [1mGetting Appropriate Periods...[0m
[32m2025-01-09 20:17:35.236[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_dyn_ds[0m:[36m579[0m - [1mIntialize Returns...[0m
[32m2025-01-09 20:17:35.237[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_dyn_ds[0m:[36m583[0m - [1mCreating Data structures...[0m
[32m2025-01-09 20:17:35.238[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_dyn_ds[0m:[36m601[0m - [1mPlotting...[0m
[32m2025-01-09 20:17:35.542[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_dyn_ds[0m:[36m617[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_pro

### Viz - 100-Year Return Period Difference

In [26]:
# calculate model return periods
az_ds_station.posterior_predictive = utils_station.calculate_ds_return_periods(az_ds_station.posterior_predictive)

# Calculate Quantiles
rl_model = az_ds_station.posterior_predictive["return_level_100"].sel(variable=variable)
rl_model = rl_model.assign_coords({covariate: az_ds_station.observed_data[covariate]})
rl_model = rl_model.swap_dims({"time": "gmst"})

In [27]:
# fig, ax = plot_return_periods_100_difference_ds(
#     rl_model=rl_model,
#     covariate=covariate,
#     x_label="2m Max Temperature, $R_a$ [°C]",
#     units="[°C]",
#     color="black"
# )
# plt.show()

utils_station.plot_return_periods_100_difference_dyn_ds(
    rl_model=rl_model,
    covariate=covariate,
    # figures_path="./", 
    figures_path=figures_path, 
    x_label="2m Max Temperature, $R_{100}$ [°C]",
    units="[°C]",
    color="black"
)

utils_station.plot_return_periods_100_difference_prct_dyn_ds(
    rl_model=rl_model,
    covariate=covariate,
    # figures_path="./", 
    figures_path=figures_path, 
    x_label="2m Max Temperature, $R_{100}$ [%]",
    color="black"
)

[32m2025-01-09 20:17:35.812[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_difference_dyn_ds[0m:[36m693[0m - [1mGetting Appropriate Periods...[0m
[32m2025-01-09 20:17:35.813[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_difference_dyn_ds[0m:[36m697[0m - [1mCalculating Difference...[0m
[32m2025-01-09 20:17:35.822[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_difference_dyn_ds[0m:[36m703[0m - [1mIntialize Returns...[0m
[32m2025-01-09 20:17:35.823[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_difference_dyn_ds[0m:[36m707[0m - [1mCreating Data structures...[0m
[32m2025-01-09 20:17:35.824[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_difference_dyn_ds[