# Modeling Extremes - Numpyro Pt 2 - MAP

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "" # first gpu
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'FALSE'

import jax
jax.config.update('jax_platform_name', 'cpu')

import numpyro
import multiprocessing

num_devices = multiprocessing.cpu_count()
numpyro.set_platform("cpu")
numpyro.set_host_device_count(num_devices)
jax.config.update("jax_enable_x64", True)

In [2]:
import autoroot
from pathlib import Path
import numpy as np
import xarray as xr
import pandas as pd
import pint_xarray
import arviz as az

from st_evt.viz import plot_histogram, plot_density
from omegaconf import OmegaConf

import jax
import jax.random as jrandom
import jax.numpy as jnp
import pandas as pd

rng_key = jrandom.PRNGKey(123)

from numpyro.infer import Predictive
import arviz as az

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS
import xarray as xr
import regionmask

import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter, FuncFormatter
import seaborn as sns
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%config InlineBackend.figure_format = 'retina'
plt.style.use(
    "https://raw.githubusercontent.com/ClimateMatchAcademy/course-content/main/cma.mplstyle"
)

from loguru import logger

# num_devices = 5
# numpyro.set_host_device_count(num_devices)


%matplotlib inline
%load_ext autoreload
%autoreload 2

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


## Paths

In [3]:
results_root_path = "/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/nonstationary_iid_mcmc_redfeten/results"
results_file_name = "nonstationary_iid_mcmc_redfeten.zarr"
results_data_path = Path(results_root_path).joinpath(results_file_name)

figures_path = Path(results_root_path).joinpath("figures/stations")

In [4]:
results_data_path

PosixPath('/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/nonstationary_iid_mcmc_redfeten/results/nonstationary_iid_mcmc_redfeten.zarr')

In [5]:
figures_path

PosixPath('/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/nonstationary_iid_mcmc_redfeten/results/figures/stations')

## Data

In [6]:
DATA_URL = autoroot.root.joinpath("data/ml_ready/aemet/t2max_stations_bm_summer.zarr")
variable = "t2max"
covariate = "gmst"
spatial_dim_name = "station_id"


# LOAD DATA
with xr.open_dataset(DATA_URL, engine="zarr") as f:
    ds_bm = f.load()
    # ds_bm = ds_bm.where(ds_bm.red_feten_mask == 1, drop=True)

### Likelihood Statistics

There are some useful statistics that we can use to evaluate how well our model does.

$$
\begin{aligned}
\text{ELPD WAIC}: && && \\
\text{ELPD WAIC SE}: && && \\
\text{P-Value WAIC}: && && \\
\end{aligned}
$$

In [7]:
variable = "t2max"

## Station Selection

### a - Predetermined Station

In [8]:
from st_evt import CANDIDATE_STATIONS
CANDIDATE_STATIONS

{'madrid': ['3129A', 'Madrid (Barajas)'],
 'valencia': ['8414A', 'Valencia (Aeropuerto)'],
 'zaragoza': ['9434', 'Zaragoza (Aeropuerto)'],
 'santiago': ['1475X', 'Santiago De Compostela'],
 'murcia': ['7178I', 'Murcia'],
 'cordoba': ['9434', 'Viallanueva de Cordoba (Sea)']}

In [9]:
candidate_station = CANDIDATE_STATIONS["valencia"][0]
candidate_station = '3129A'

In [10]:
figures_path = figures_path.joinpath(f"{candidate_station}")
figures_path.mkdir(parents=True, exist_ok=True)

### MCMC Results

In [11]:
az_ds = az.from_zarr(str(results_data_path))
az_ds_station = az_ds.sel(station_id = candidate_station)
ds_station = ds_bm.sel(station_id = candidate_station)

### EDA Stuff

In [12]:
from st_evt._src.modules.models.aemet import utils_station

In [13]:
utils_station.plot_eda(
    da=ds_station[variable].squeeze(),
    variable_label="2m Max Temperature [°C]",
    # figures_path="./", 
    figures_path=figures_path, 
    figure_dpi=300,
)

[32m2025-01-10 06:33:06.278[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_eda[0m:[36m114[0m - [1mPlotting BM Data Time Series...[0m
[32m2025-01-10 06:33:06.535[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_eda[0m:[36m130[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/nonstationary_iid_mcmc_redfeten/results/figures/stations/3129A/eda/ts_bm_data.pdf[0m
[32m2025-01-10 06:33:06.536[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_eda[0m:[36m132[0m - [1mPlotting BM Data Histogram...[0m
[32m2025-01-10 06:33:06.807[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_eda[0m:[36m146[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/nonstationary_iid_mcmc_redfeten/results/figures/stations/3129A/eda/h

:::{figure}
:label: my-figure
:align: left
:width: 20px

(my-figure-fruit)=
![Here is some fruit 🍏](eda/ts_bm_data.png)

![My vacation pics! 🏝](eda/hist_bm_data.png)

![My vacation pics! 🏝](eda/density_bm_data.png)

Some pictures of fruit and the ocean!
:::

## Posterior Calculations

## Model Inspection

### Trace Plot

In [14]:
variables = [
    "concentration",
    "scale",
    "location_slope",
    "location_intercept",
    ]

utils_station.plot_model_params_critique(
    ds=az_ds_station.posterior,
    variables=variables,
    # figures_path="./", 
    figures_path=figures_path, 
    
)

[32m2025-01-10 06:33:10.029[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_model_params_critique[0m:[36m39[0m - [1mPlotting Parameter Traces...[0m
[32m2025-01-10 06:33:11.455[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_model_params_critique[0m:[36m53[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/nonstationary_iid_mcmc_redfeten/results/figures/stations/3129A/params/trace.pdf[0m
[32m2025-01-10 06:33:11.456[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_model_params_critique[0m:[36m55[0m - [1mPlotting Parameter Jonts...[0m
[32m2025-01-10 06:33:12.840[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_model_params_critique[0m:[36m76[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/nonstat

:::{figure}
:label: my-figure
:align: left
:width: 20px

(my-figure-fruit)=
![Here is some fruit 🍏](params/trace.png)

![My vacation pics! 🏝](params/joint.png)


Some pictures of fruit and the ocean!
:::

:::{figure}
:label: my-figure
:align: left
:width: 20px

(my-figure-fruit)=
![Here is some fruit 🍏](posterior_predictive/regression.png)


Some pictures of fruit and the ocean!
:::

## Model Critique

In [15]:
idata = az.extract(az_ds_station, group="posterior_predictive", num_samples=10_000)


y_pred = az_ds_station.posterior_predictive[variable].rename("y_pred")
y_true = az_ds_station.observed_data[variable]

In [16]:
utils_station.plot_residual_error_metric(
    y_pred=y_pred,
    y_true=y_true,
    figures_dpi=300,
    # figures_path="./", 
    figures_path=figures_path, 
    units="[°C]"
)
utils_station.plot_residual_abs_error_metric(
    y_pred=y_pred,
    y_true=y_true,
    figures_dpi=300,
    # figures_path="./", 
    figures_path=figures_path, 
    units="[°C]"
)

[32m2025-01-10 06:33:28.645[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_residual_error_metric[0m:[36m235[0m - [1mCalculating residual error...[0m
[32m2025-01-10 06:33:29.291[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_residual_error_metric[0m:[36m252[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/nonstationary_iid_mcmc_redfeten/results/figures/stations/3129A/metrics/density_residuals.pdf[0m
[32m2025-01-10 06:33:29.292[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_residual_abs_error_metric[0m:[36m266[0m - [1mCalculating residual error...[0m
[32m2025-01-10 06:33:29.505[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_residual_abs_error_metric[0m:[36m283[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/wal

:::{figure}
:label: my-figure
:align: left
:width: 20px

(my-figure-fruit)=
![Here is some fruit 🍏](metrics/density_residuals.png)

![Here is some fruit 🍏](metrics/density_residuals_abs.png)


Some pictures of fruit and the ocean!
:::

### QQ-Plot

In [17]:
y_pred_median = y_pred.mean(dim=["draw", "chain"])
# fig, ax = plot_qq(
#     y_true=y_true,
#     y_pred=y_pred_median,
#     figures_dpi=300,
# )
# plt.show()

utils_station.plot_qq(
    y_true=y_true,
    y_pred=y_pred_median,
    # figures_path="./", 
    figures_path=figures_path, 
    figures_dpi=300,
)

[32m2025-01-10 06:33:29.620[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_qq[0m:[36m301[0m - [1mCalculating Metrics (RMSE, MAE, MAPD)...[0m
[32m2025-01-10 06:33:29.806[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_qq[0m:[36m314[0m - [1mPlotting QQ-Plot...[0m
[32m2025-01-10 06:33:29.990[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_qq[0m:[36m340[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/nonstationary_iid_mcmc_redfeten/results/figures/stations/3129A/metrics/qq_plot.pdf[0m


## Regression Plot

In [18]:
x = az_ds_station.posterior.gmst
y = az_ds_station.observed_data[variable]
y_hat = az_ds_station.posterior_predictive[variable]
y_model = az_ds_station.posterior["location_slope"] * x + az_ds_station.posterior["location_intercept"]

utils_station.plot_regression_posterior(
    x=x,
    y=y,
    y_hat=y_hat,
    y_model=y_model,
    # figures_path="./", 
    figures_path=figures_path, 
    figure_dpi=300,
    covariate_label="Global Mean Surface Temperature Anomaly [°C]",
    y_label="2m Max Temperature [°C]"
)

[32m2025-01-10 06:33:30.312[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_regression_posterior[0m:[36m983[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/nonstationary_iid_mcmc_redfeten/results/figures/stations/3129A/regression/regression.pdf[0m


:::{figure}
:label: my-figure
:align: left
:width: 20px

(my-figure-fruit)=
![Here is some fruit 🍏](regression/regression.png)



Some pictures of fruit and the ocean!
:::

## Return Levels

#### Empirical Return Levels

In [19]:
def calculate_empirical_return_level_gevd_ds(
    da: xr.DataArray,
    covariate: str,
    num_samples: int = 1_000,
    seed: int = 123,
):
    variable = da.name
    # # resample array
    # rng = np.random.RandomState(seed)
    # y_samples = list(
    #     rng.choice(
    #         y_clean.values, size=(y_clean.size), replace=True,
    #     ) for seed in np.arange(0, num_samples)
    # )
    # y_samples = np.stack(y_samples, axis=0)

    # # expand dims
    # da = da.expand_dims(dim={"samples": np.arange(0, num_samples)})
    # # assign coordinates
    
    # da[variable] = (("samples", covariate), y_samples)
    
    from st_evt.extremes import calculate_exceedence_probs
    # add as coordinate
    da["return_level"] = 1/xr.apply_ufunc(
        calculate_exceedence_probs,
        da,
        input_core_dims=[[covariate]],
        output_core_dims=[[covariate]],
        vectorize=True
    )
    # swap dimensions
    # 
    # da_rl = da_rl.rename("return_level")
    # print(da_rl)
    # da_rl = da_rl.rename({covariate: "t2max"})
    # da_rl[variable] = da[variable]
    da = da.swap_dims({covariate: variable})
    return da

In [20]:
# select clean data
y_clean = az_ds_station.observed_data.dropna(dim=covariate)[variable]

# calculate return period
y_clean = utils_station.calculate_empirical_return_level_gevd_ds(y_clean, covariate=covariate)

# calculate model return periods
az_ds_station.posterior_predictive = utils_station.calculate_ds_return_periods(az_ds_station.posterior_predictive)

# Calculate Quantiles
rl_model_quantiles = az_ds_station.posterior_predictive["return_level"].quantile(q=[0.025, 0.5, 0.975], dim=["chain", "draw"])

[32m2025-01-10 06:33:34.096[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mcalculate_empirical_return_level_gevd_ds[0m:[36m354[0m - [1mCalculating Return Level...[0m
[32m2025-01-10 06:33:34.098[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mcalculate_empirical_return_level_gevd_ds[0m:[36m362[0m - [1mSwapping Dims...[0m


### Viz - Return Level 

In [21]:
# fig, ax = model_eval_station.plot_return_periods_ds(
#     rl_model_quantiles=rl_model_quantiles,
#     y=y_clean,
#     covariate=covariate,
#     y_label="2m Max Temperature, $R_a$ [°C]"
# )
# plt.show()

utils_station.plot_return_periods_dyn_ds(
    rl_model_quantiles=rl_model_quantiles,
    y=y_clean,
    covariate=covariate,
    # figures_path="./", 
    figures_path=figures_path, 
    y_label="2m Max Temperature, $R_a$ [°C]"
)

[32m2025-01-10 06:33:37.350[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_dyn_ds[0m:[36m421[0m - [1mGetting Appropriate Periods...[0m
[32m2025-01-10 06:33:37.351[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_dyn_ds[0m:[36m425[0m - [1mIntialize Returns...[0m
[32m2025-01-10 06:33:37.352[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_dyn_ds[0m:[36m429[0m - [1mCreating Data structures...[0m
[32m2025-01-10 06:33:37.353[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_dyn_ds[0m:[36m445[0m - [1mPlotting...[0m
[32m2025-01-10 06:33:38.250[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_dyn_ds[0m:[36m461[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experimen

:::{figure}
:label: my-figure
:align: left
:width: 20px

(my-figure-fruit)=
![Here is some fruit 🍏](returns/returns_prob_posterior_vs_empirical.png)



Some pictures of fruit and the ocean!
:::

### Viz - 100-Year Return Period

In [22]:
# calculate model return periods
az_ds_station.posterior_predictive = utils_station.calculate_ds_return_periods(az_ds_station.posterior_predictive)

# Calculate Quantiles
rl_model_quantiles = az_ds_station.posterior_predictive["return_level_100"]

In [23]:
# fig, ax = plot_return_periods_100_ds(
#     rl_model_quantiles=rl_model_quantiles,
#     covariate=covariate,
#     x_label="2m Max Temperature, $R_a$ [°C]"
# )
# plt.show()

utils_station.plot_return_periods_100_dyn_ds(
    rl_model=rl_model_quantiles,
    covariate=covariate,
    # figures_path="./", 
    figures_path=figures_path, 
    x_label="2m Max Temperature, $R_{100}$ [°C]"
)

[32m2025-01-10 06:33:43.712[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_dyn_ds[0m:[36m575[0m - [1mGetting Appropriate Periods...[0m
[32m2025-01-10 06:33:43.713[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_dyn_ds[0m:[36m579[0m - [1mIntialize Returns...[0m
[32m2025-01-10 06:33:43.714[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_dyn_ds[0m:[36m583[0m - [1mCreating Data structures...[0m
[32m2025-01-10 06:33:43.715[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_dyn_ds[0m:[36m601[0m - [1mPlotting...[0m
[32m2025-01-10 06:33:43.970[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_dyn_ds[0m:[36m617[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_dat

:::{figure}
:label: my-figure
:align: left
:width: 20px

(my-figure-fruit)=
![Here is some fruit 🍏](returns/returns_100years_density.png)



Some pictures of fruit and the ocean!
:::

### Viz - 100-Year Return Period Difference

In [25]:
# calculate model return periods
az_ds_station.posterior_predictive = utils_station.calculate_ds_return_periods(az_ds_station.posterior_predictive)

# Calculate Quantiles
rl_model = az_ds_station.posterior_predictive["return_level_100"]

In [26]:
# fig, ax = plot_return_periods_100_difference_ds(
#     rl_model=rl_model,
#     covariate=covariate,
#     x_label="2m Max Temperature, $R_a$ [°C]",
#     units="[°C]",
#     color="black"
# )
# plt.show()

utils_station.plot_return_periods_100_difference_dyn_ds(
    rl_model=rl_model,
    covariate=covariate,
    # figures_path="./", 
    figures_path=figures_path, 
    x_label="2m Max Temperature, $R_{100}$ [°C]",
    units="[°C]",
    color="black"
)

utils_station.plot_return_periods_100_difference_prct_dyn_ds(
    rl_model=rl_model,
    covariate=covariate,
    # figures_path="./", 
    figures_path=figures_path, 
    x_label="2m Max Temperature, $R_{100}$ [%]",
    color="black"
)

[32m2025-01-10 06:33:54.169[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_difference_dyn_ds[0m:[36m693[0m - [1mGetting Appropriate Periods...[0m
[32m2025-01-10 06:33:54.170[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_difference_dyn_ds[0m:[36m697[0m - [1mCalculating Difference...[0m
[32m2025-01-10 06:33:54.173[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_difference_dyn_ds[0m:[36m703[0m - [1mIntialize Returns...[0m
[32m2025-01-10 06:33:54.174[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_difference_dyn_ds[0m:[36m707[0m - [1mCreating Data structures...[0m
[32m2025-01-10 06:33:54.174[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_difference_dyn_ds[