# Modeling Extremes - Numpyro Pt 2 - MAP

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "" # first gpu
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'FALSE'

import jax
jax.config.update('jax_platform_name', 'cpu')

import numpyro
import multiprocessing

num_devices = multiprocessing.cpu_count()
numpyro.set_platform("cpu")
numpyro.set_host_device_count(num_devices)
jax.config.update("jax_enable_x64", True)

In [2]:
import autoroot
from pathlib import Path
import numpy as np
import xarray as xr
import pandas as pd
import pint_xarray
import arviz as az

from st_evt.viz import plot_histogram, plot_density
from omegaconf import OmegaConf

import jax
import jax.random as jrandom
import jax.numpy as jnp
import pandas as pd

rng_key = jrandom.PRNGKey(123)

from numpyro.infer import Predictive
import arviz as az

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS
import xarray as xr
import regionmask

import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter, FuncFormatter
import seaborn as sns
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%config InlineBackend.figure_format = 'retina'
plt.style.use(
    "https://raw.githubusercontent.com/ClimateMatchAcademy/course-content/main/cma.mplstyle"
)

from loguru import logger

# num_devices = 5
# numpyro.set_host_device_count(num_devices)


%matplotlib inline
%load_ext autoreload
%autoreload 2

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


## Paths

In [3]:
results_root_path = "/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/stationary_iid_mcmc_redfeten/results"
results_file_name = "stationary_iid_mcmc_redfeten.zarr"
results_data_path = Path(results_root_path).joinpath(results_file_name)

figures_path = Path(results_root_path).joinpath("figures/stations")

## Data

In [4]:
DATA_URL = autoroot.root.joinpath("data/ml_ready/aemet/t2max_stations_bm_summer.zarr")
variable = "t2max"
covariate = "gmst"
spatial_dim_name = "station_id"


# LOAD DATA
with xr.open_dataset(DATA_URL, engine="zarr") as f:
    ds_bm = f.load()
    # ds_bm = ds_bm.where(ds_bm.red_feten_mask == 1, drop=True)

### Likelihood Statistics

There are some useful statistics that we can use to evaluate how well our model does.

$$
\begin{aligned}
\text{ELPD WAIC}: && && \\
\text{ELPD WAIC SE}: && && \\
\text{P-Value WAIC}: && && \\
\end{aligned}
$$

In [5]:
variable = "t2max"

## Station Selection

### a - Predetermined Station

In [6]:
from st_evt import CANDIDATE_STATIONS
CANDIDATE_STATIONS

{'madrid': ['3129A', 'Madrid (Barajas)'],
 'valencia': ['8414A', 'Valencia (Aeropuerto)'],
 'zaragoza': ['9434', 'Zaragoza (Aeropuerto)'],
 'santiago': ['1475X', 'Santiago De Compostela'],
 'murcia': ['7178I', 'Murcia'],
 'cordoba': ['9434', 'Viallanueva de Cordoba (Sea)']}

In [7]:
candidate_station = CANDIDATE_STATIONS["valencia"][0]
candidate_station = '3129A'

In [8]:
figures_path = figures_path.joinpath(f"{candidate_station}")
figures_path.mkdir(parents=True, exist_ok=True)

### MCMC Results

In [9]:
az_ds = az.from_zarr(str(results_data_path))
az_ds_station = az_ds.sel(station_id = candidate_station)
ds_station = ds_bm.sel(station_id = candidate_station)

In [10]:
az.waic(az_ds_station)

Computed from 8000 posterior samples and 59 observations log-likelihood matrix.

          Estimate       SE
elpd_waic   -95.83     4.57
p_waic        1.77        -

### EDA Stuff

In [11]:
from st_evt._src.modules.models.aemet import utils_station

In [12]:
utils_station.plot_eda(
    da=ds_station[variable].squeeze(),
    variable_label="2m Max Temperature [°C]",
    # figures_path="./", 
    figures_path=figures_path, 
    figure_dpi=300,
)

[32m2025-01-10 06:45:45.710[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_eda[0m:[36m114[0m - [1mPlotting BM Data Time Series...[0m
[32m2025-01-10 06:45:45.982[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_eda[0m:[36m130[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/stationary_iid_mcmc_redfeten/results/figures/stations/3129A/eda/ts_bm_data.pdf[0m
[32m2025-01-10 06:45:45.982[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_eda[0m:[36m132[0m - [1mPlotting BM Data Histogram...[0m
[32m2025-01-10 06:45:46.274[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_eda[0m:[36m146[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/stationary_iid_mcmc_redfeten/results/figures/stations/3129A/eda/hist_bm

## Posterior Calculations

## Model Inspection

In [13]:
# calculate model return periods
az_ds_station.posterior = utils_station.calculate_ds_return_periods(az_ds_station.posterior)

### Trace Plot

In [14]:
variables = [
    "concentration",
    "scale",
    "location",
    "return_level_100"
    ]

utils_station.plot_model_params_critique(
    ds=az_ds_station.posterior,
    variables=variables,
    # figures_path="./", 
    figures_path=figures_path, 
    
)

[32m2025-01-10 06:45:52.439[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_model_params_critique[0m:[36m39[0m - [1mPlotting Parameter Traces...[0m
[32m2025-01-10 06:45:53.312[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_model_params_critique[0m:[36m53[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/stationary_iid_mcmc_redfeten/results/figures/stations/3129A/params/trace.pdf[0m
[32m2025-01-10 06:45:53.312[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_model_params_critique[0m:[36m55[0m - [1mPlotting Parameter Jonts...[0m
[32m2025-01-10 06:45:54.814[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_model_params_critique[0m:[36m76[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/stationary

## Model Critique

In [15]:
idata = az.extract(az_ds_station, group="posterior_predictive", num_samples=10_000)


y_pred = az_ds_station.posterior_predictive[variable].rename("y_pred")
y_true = az_ds_station.observed_data[variable]

In [16]:
utils_station.plot_residual_error_metric(
    y_pred=y_pred,
    y_true=y_true,
    figures_dpi=300,
    # figures_path="./", 
    figures_path=figures_path, 
    units="[°C]"
)
utils_station.plot_residual_abs_error_metric(
    y_pred=y_pred,
    y_true=y_true,
    figures_dpi=300,
    # figures_path="./", 
    figures_path=figures_path, 
    units="[°C]"
)

[32m2025-01-10 06:46:02.901[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_residual_error_metric[0m:[36m235[0m - [1mCalculating residual error...[0m
[32m2025-01-10 06:46:03.364[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_residual_error_metric[0m:[36m252[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/stationary_iid_mcmc_redfeten/results/figures/stations/3129A/metrics/density_residuals.pdf[0m
[32m2025-01-10 06:46:03.365[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_residual_abs_error_metric[0m:[36m266[0m - [1mCalculating residual error...[0m
[32m2025-01-10 06:46:03.591[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_residual_abs_error_metric[0m:[36m283[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkth

### QQ-Plot

In [17]:
y_pred_median = y_pred.mean(dim=["draw", "chain"])
# fig, ax = plot_qq(
#     y_true=y_true,
#     y_pred=y_pred_median,
#     figures_dpi=300,
# )
# plt.show()

utils_station.plot_qq(
    y_true=y_true,
    y_pred=y_pred_median,
    # figures_path="./", 
    figures_path=figures_path, 
    figures_dpi=300,
)

[32m2025-01-10 06:46:09.119[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_qq[0m:[36m301[0m - [1mCalculating Metrics (RMSE, MAE, MAPD)...[0m
[32m2025-01-10 06:46:09.321[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_qq[0m:[36m314[0m - [1mPlotting QQ-Plot...[0m
[32m2025-01-10 06:46:09.457[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_qq[0m:[36m340[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/stationary_iid_mcmc_redfeten/results/figures/stations/3129A/metrics/qq_plot.pdf[0m


## Return Levels

#### Empirical Return Levels

In [18]:
# select clean data
y_clean = az_ds_station.observed_data.dropna(dim=covariate)[variable]

# calculate return period
y_clean = utils_station.calculate_empirical_return_level_gevd_ds(y_clean, covariate=covariate)

# calculate model return periods
az_ds_station.posterior_predictive = utils_station.calculate_ds_return_periods(az_ds_station.posterior_predictive)

# Calculate Quantiles
rl_model_quantiles = az_ds_station.posterior_predictive["return_level"].quantile(q=[0.025, 0.5, 0.975], dim=["chain", "draw"])

[32m2025-01-10 06:46:18.145[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mcalculate_empirical_return_level_gevd_ds[0m:[36m354[0m - [1mCalculating Return Level...[0m
[32m2025-01-10 06:46:18.146[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mcalculate_empirical_return_level_gevd_ds[0m:[36m362[0m - [1mSwapping Dims...[0m


### Viz - Return Level 

In [19]:
# fig, ax = model_eval_station.plot_return_periods_ds(
#     rl_model_quantiles=rl_model_quantiles,
#     y=y_clean,
#     covariate=covariate,
#     y_label="2m Max Temperature, $R_a$ [°C]"
# )
# plt.show()

utils_station.plot_return_periods_ds(
    rl_model_quantiles=rl_model_quantiles,
    y=y_clean,
    # figures_path="./", 
    figures_path=figures_path, 
    y_label="2m Max Temperature, $R_a$ [°C]"
)

[32m2025-01-10 06:46:21.672[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_ds[0m:[36m375[0m - [1mIntialize Returns...[0m
[32m2025-01-10 06:46:21.672[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_ds[0m:[36m379[0m - [1mCreating Data structures...[0m
[32m2025-01-10 06:46:21.673[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_ds[0m:[36m389[0m - [1mPlotting...[0m
[32m2025-01-10 06:46:22.653[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_ds[0m:[36m406[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/stationary_iid_mcmc_redfeten/results/figures/stations/3129A/returns/returns_prob_posterior_vs_empirical.pdf[0m


### Viz - 100-Year Return Period

In [20]:
# calculate model return periods
az_ds_station.posterior_predictive = utils_station.calculate_ds_return_periods(az_ds_station.posterior_predictive)

# Calculate Quantiles
rl_model_quantiles = az_ds_station.posterior_predictive["return_level_100"]

In [21]:
# fig, ax = plot_return_periods_100_ds(
#     rl_model_quantiles=rl_model_quantiles,
#     covariate=covariate,
#     x_label="2m Max Temperature, $R_a$ [°C]"
# )
# plt.show()

utils_station.plot_return_periods_100_ds(
    rl_model=rl_model_quantiles,
    # figures_path="./", 
    figures_path=figures_path, 
    x_label="2m Max Temperature, $R_{100}$ [°C]"
)

[32m2025-01-10 06:46:28.777[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_ds[0m:[36m532[0m - [1mPlotting...[0m
[32m2025-01-10 06:46:28.777[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_ds[0m:[36m542[0m - [1mPlotting...[0m
[32m2025-01-10 06:46:29.067[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_ds[0m:[36m559[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/stationary_iid_mcmc_redfeten/results/figures/stations/3129A/returns/returns_100years_density.pdf[0m
