# Modeling Extremes - Numpyro Pt 2 - MAP

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "" # first gpu
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'FALSE'

import jax
jax.config.update('jax_platform_name', 'cpu')

import numpyro
import multiprocessing

num_devices = multiprocessing.cpu_count()
numpyro.set_platform("cpu")
numpyro.set_host_device_count(num_devices)
jax.config.update("jax_enable_x64", True)

In [2]:
import autoroot
from pathlib import Path
import numpy as np
import xarray as xr
import pandas as pd
import pint_xarray
import arviz as az

from st_evt.viz import plot_histogram, plot_density
from st_evt._src.modules.models.aemet import utils_station
from omegaconf import OmegaConf

import jax
import jax.random as jrandom
import jax.numpy as jnp
import pandas as pd

rng_key = jrandom.PRNGKey(123)

from numpyro.infer import Predictive
import arviz as az

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS
import xarray as xr
import regionmask

import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter, FuncFormatter
import seaborn as sns
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%config InlineBackend.figure_format = 'retina'
plt.style.use(
    "https://raw.githubusercontent.com/ClimateMatchAcademy/course-content/main/cma.mplstyle"
)

from loguru import logger

# num_devices = 5
# numpyro.set_host_device_count(num_devices)


%matplotlib inline
%load_ext autoreload
%autoreload 2

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


## Paths

In [3]:
results_root_path = "/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/nonstationary_iid_mcmc_redfeten/results"
results_file_name = "nonstationary_iid_mcmc_redfeten.zarr"
results_data_path = Path(results_root_path).joinpath(results_file_name)

figures_path = Path(results_root_path).joinpath("figures/stations")

In [4]:
# results_data_path = "/home/juanjohn/pool_data/dynev4eo/temp/results/scratch_pipelines/results/az_nonstationary_gp_lap_redfeten.zarr"
# figures_path = Path("/home/juanjohn/pool_data/dynev4eo/temp/scratch_pipelines/figures/nonstationary_gp_lap_redfeten")

## Data

In [5]:
DATA_URL = autoroot.root.joinpath("data/ml_ready/aemet/t2max_stations_bm_summer.zarr")
variable = "t2max"
covariate = "gmst"
spatial_dim_name = "station_id"


# LOAD DATA
with xr.open_dataset(DATA_URL, engine="zarr") as f:
    ds_bm = f.load()
    # ds_bm = ds_bm.where(ds_bm.red_feten_mask == 1, drop=True)

### Likelihood Statistics

There are some useful statistics that we can use to evaluate how well our model does.

$$
\begin{aligned}
\text{ELPD WAIC}: && && \\
\text{ELPD WAIC SE}: && && \\
\text{P-Value WAIC}: && && \\
\end{aligned}
$$

In [6]:
variable = "t2max"

## Station Selection

### a - Predetermined Station

In [34]:
from st_evt import AEMET_GOOD_STATIONS, AEMET_BAD_STATIONS

In [35]:
AEMET_BAD_STATIONS["red_feten"]

['3407Y', '9677', 'C018J', 'C426R']

In [36]:
candidate_station = '3407Y' #'3129A' # 

In [37]:
figures_path = figures_path.joinpath(f"{candidate_station}")
figures_path.mkdir(parents=True, exist_ok=True)

### MCMC Results

In [38]:
az_ds = az.from_zarr(store=str(results_data_path))
az_ds_station_pred = az_ds.predictions.sel(station_id = candidate_station)
# num_samples = 5_000
# az_ds_station_pred = az.extract(az_ds, group="predictions", num_samples=num_samples).sel(station_id = candidate_station)
y_data = az_ds.posterior_predictive.sel(station_id = candidate_station)[f"{variable}_true"]
ds_station = ds_bm.sel(station_id = candidate_station)

### EDA Stuff

In [39]:
from st_evt.viz import plot_scatter_ts, plot_histogram, plot_density

In [40]:
utils_station.plot_eda(
    da=ds_station[variable].squeeze(),
    variable_label="2m Max Temperature [°C]",
    # figures_path="./", 
    figures_path=figures_path, 
    figure_dpi=300,
)

[32m2024-12-16 19:47:14.153[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_eda[0m:[36m114[0m - [1mPlotting BM Data Time Series...[0m
[32m2024-12-16 19:47:14.465[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_eda[0m:[36m130[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/nonstationary_iid_mcmc_redfeten/results/figures/stations/3129A/1391/3407Y/eda/ts_bm_data.pdf[0m
[32m2024-12-16 19:47:14.472[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_eda[0m:[36m132[0m - [1mPlotting BM Data Histogram...[0m
[32m2024-12-16 19:47:14.710[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_eda[0m:[36m146[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/nonstationary_iid_mcmc_redfeten/results/figures/stations/

:::{figure}
:label: my-figure
:align: left
:width: 20px

(my-figure-fruit)=
![Here is some fruit 🍏](eda/ts_bm_data.png)

![My vacation pics! 🏝](eda/hist_bm_data.png)

![My vacation pics! 🏝](eda/density_bm_data.png)


Some pictures of fruit and the ocean!
:::

## Posterior Calculations

## Model Inspection

### Trace Plot

In [41]:
variables = [
    "concentration",
    "scale",
    "location_slope",
    "location_intercept",
    ]

utils_station.plot_model_params_critique(
    ds=az_ds_station_pred,
    variables=variables,
    # figures_path="./", 
    figures_path=figures_path, 
    
)

[32m2024-12-16 19:47:26.595[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_model_params_critique[0m:[36m39[0m - [1mPlotting Parameter Traces...[0m
[32m2024-12-16 19:47:28.187[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_model_params_critique[0m:[36m53[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/nonstationary_iid_mcmc_redfeten/results/figures/stations/3129A/1391/3407Y/params/trace.pdf[0m
[32m2024-12-16 19:47:28.189[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_model_params_critique[0m:[36m55[0m - [1mPlotting Parameter Jonts...[0m
[32m2024-12-16 19:47:30.764[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_model_params_critique[0m:[36m76[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2

:::{figure}
:label: my-figure
:align: left
:width: 20px

(my-figure-fruit)=
![Here is some fruit 🍏](params/trace.png)

![My vacation pics! 🏝](params/joint.png)


Some pictures of fruit and the ocean!
:::

## Regression Plot

In [42]:
# calculate model return periods
az_ds_station_pred = utils_station.calculate_ds_return_periods(az_ds_station_pred)

az_ds_quantiles = az_ds_station_pred.quantile(q=[0.025, 0.5, 0.975], dim=["chain", "draw"]).squeeze()
locations = az_ds_station_pred["location"].quantile(q=[0.025, 0.5, 0.975], dim=["chain", "draw"]).squeeze()
scales = az_ds_station_pred["scale"].quantile(q=[0.025, 0.5, 0.975], dim=["chain", "draw"]).squeeze()
return_level_100 = az_ds_station_pred["return_level_100"].quantile(q=[0.025, 0.5, 0.975], dim=["chain", "draw"]).squeeze()
observations = ds_station[variable].squeeze()

In [43]:
utils_station.plot_regression_prediction(
    ds_quantiles=az_ds_quantiles,
    observations=observations,
    # figures_path="./", 
    figures_path=figures_path, 
    covariate=covariate,
    figure_dpi=300,
    y_label="2m Max Temperature, $R_a$ [°C]",
    covariate_label="Global Mean Surface Temperature Anomaly [°C]",
    location_only=True
)

utils_station.plot_regression_prediction(
    ds_quantiles=az_ds_quantiles,
    observations=observations,
    # figures_path="./", 
    figures_path=figures_path, 
    covariate=covariate,
    figure_dpi=300,
    y_label="2m Max Temperature, $R_a$ [°C]",
    covariate_label="Global Mean Surface Temperature Anomaly [°C]",
    location_only=False
)

[32m2024-12-16 19:48:07.583[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_regression_prediction[0m:[36m1021[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/nonstationary_iid_mcmc_redfeten/results/figures/stations/3129A/1391/3407Y/regression/regression_pred_location.png[0m
[32m2024-12-16 19:48:08.131[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_regression_prediction[0m:[36m1021[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/nonstationary_iid_mcmc_redfeten/results/figures/stations/3129A/1391/3407Y/regression/regression_pred.png[0m


## Return Levels

#### Empirical Return Levels

In [44]:
# def calculate_empirical_return_level_gevd_ds(
#     da: xr.DataArray,
#     covariate: str,
#     num_samples: int = 1_000,
#     seed: int = 123,
# ):
#     variable = da.name
#     # # resample array
#     # rng = np.random.RandomState(seed)
#     # y_samples = list(
#     #     rng.choice(
#     #         y_clean.values, size=(y_clean.size), replace=True,
#     #     ) for seed in np.arange(0, num_samples)
#     # )
#     # y_samples = np.stack(y_samples, axis=0)

#     # # expand dims
#     # da = da.expand_dims(dim={"samples": np.arange(0, num_samples)})
#     # # assign coordinates
    
#     # da[variable] = (("samples", covariate), y_samples)
    
#     from st_evt.extremes import calculate_exceedence_probs
#     # add as coordinate
#     da["return_level"] = 1/xr.apply_ufunc(
#         calculate_exceedence_probs,
#         da,
#         input_core_dims=[[covariate]],
#         output_core_dims=[[covariate]],
#         vectorize=True
#     )
#     # swap dimensions
#     # 
#     # da_rl = da_rl.rename("return_level")
#     # print(da_rl)
#     # da_rl = da_rl.rename({covariate: "t2max"})
#     # da_rl[variable] = da[variable]
#     da = da.swap_dims({covariate: variable})
#     return da

In [45]:
# select clean data
y_clean = ds_station[variable].squeeze()

# calculate return period
y_clean = utils_station.calculate_empirical_return_level_gevd_ds(y_clean, covariate=covariate)

# calculate model return periods
az_ds_station_pred = utils_station.calculate_ds_return_periods(az_ds_station_pred)

# Calculate Quantiles
rl_model_quantiles = az_ds_station_pred["return_level"].quantile(q=[0.025, 0.5, 0.975], dim=["chain", "draw"])

[32m2024-12-16 19:48:18.431[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mcalculate_empirical_return_level_gevd_ds[0m:[36m331[0m - [1mCalculating Return Level...[0m
[32m2024-12-16 19:48:18.436[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mcalculate_empirical_return_level_gevd_ds[0m:[36m339[0m - [1mSwapping Dims...[0m


### Viz - Return Level 

In [46]:
# fig, ax = model_eval_station.plot_return_periods_ds(
#     rl_model_quantiles=rl_model_quantiles,
#     y=y_clean,
#     covariate=covariate,
#     y_label="2m Max Temperature, $R_a$ [°C]"
# )
# plt.show()

utils_station.plot_return_periods_gmst_ds(
    rl_model_quantiles=rl_model_quantiles,
    y=y_clean,
    covariate=covariate,
    # figures_path="./", 
    figures_path=figures_path, 
    y_label="2m Max Temperature, $R_a$ [°C]"
)

[32m2024-12-16 19:48:22.535[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_gmst_ds[0m:[36m453[0m - [1mGetting Appropriate Periods...[0m
[32m2024-12-16 19:48:22.537[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_gmst_ds[0m:[36m455[0m - [1mIntialize Returns...[0m
[32m2024-12-16 19:48:22.537[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_gmst_ds[0m:[36m459[0m - [1mCreating Data structures...[0m
[32m2024-12-16 19:48:22.539[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_gmst_ds[0m:[36m482[0m - [1mPlotting...[0m
[32m2024-12-16 19:48:24.158[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_gmst_ds[0m:[36m498[0m - [34m[1mSaved Figure:
/home/juanjohn/pool_data/dynev4eo/expe

:::{figure}
:label: my-figure
:align: left
:width: 20px

(my-figure-fruit)=
![Here is some fruit 🍏](returns/returns_gmst_prob_posterior_vs_empirical.png)

Some pictures of fruit and the ocean!
:::

### Viz - 100-Year Return Period

In [47]:
# calculate model return periods
az_ds_station_pred = utils_station.calculate_ds_return_periods(az_ds_station_pred)

# Calculate Quantiles
rl_model_quantiles = az_ds_station_pred["return_level_100"]

In [48]:
# fig, ax = plot_return_periods_100_ds(
#     rl_model_quantiles=rl_model_quantiles,
#     covariate=covariate,
#     x_label="2m Max Temperature, $R_a$ [°C]"
# )
# plt.show()

utils_station.plot_return_periods_100_gmst_ds(
    rl_model=rl_model_quantiles,
    covariate=covariate,
    # figures_path="./", 
    figures_path=figures_path, 
    x_label=r"2m Max Temperature, $R_{100}$ [°C]"
)

[32m2024-12-16 19:48:25.875[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_gmst_ds[0m:[36m609[0m - [1mGetting Appropriate Periods...[0m
[32m2024-12-16 19:48:25.876[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_gmst_ds[0m:[36m611[0m - [1mIntialize Returns...[0m
[32m2024-12-16 19:48:25.877[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_gmst_ds[0m:[36m615[0m - [1mCreating Data structures...[0m
[32m2024-12-16 19:48:25.880[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_gmst_ds[0m:[36m638[0m - [1mPlotting...[0m
[32m2024-12-16 19:48:26.360[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_gmst_ds[0m:[36m654[0m - [34m[1mSaved Figure:
/home/juanjohn/poo

:::{figure}
:label: my-figure
:align: left
:width: 20px

(my-figure-fruit)=
![Here is some fruit 🍏](returns/returns_100years_gmst_density.png)



Some pictures of fruit and the ocean!
:::

### Viz - 100-Year Return Period Difference

In [49]:
# calculate model return periods
az_ds_station_pred = utils_station.calculate_ds_return_periods(az_ds_station_pred)

# Calculate Quantiles
rl_model = az_ds_station_pred["return_level_100"]
rl_model

In [None]:
# fig, ax = plot_return_periods_100_difference_ds(
#     rl_model=rl_model,
#     covariate=covariate,
#     x_label="2m Max Temperature, $R_a$ [°C]",
#     units="[°C]",
#     color="black"
# )
# plt.show()

utils_station.plot_return_periods_100_difference_gmst_ds(
    rl_model=rl_model,
    covariate=covariate,
    # figures_path="./", 
    figures_path=figures_path, 
    x_label="2m Max Temperature, $\Delta R_{100}$ [°C]",
    units="[°C]",
)

[32m2024-12-16 19:48:27.934[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_difference_gmst_ds[0m:[36m784[0m - [1mGetting Appropriate Periods...[0m
[32m2024-12-16 19:48:27.935[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_difference_gmst_ds[0m:[36m786[0m - [1mCalculating Difference...[0m
[32m2024-12-16 19:48:27.942[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_difference_gmst_ds[0m:[36m794[0m - [1mIntialize Returns...[0m
[32m2024-12-16 19:48:27.942[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_difference_gmst_ds[0m:[36m797[0m - [1mCreating Data structures...[0m
[32m2024-12-16 19:48:27.943[0m | [1mINFO    [0m | [36mst_evt._src.modules.models.aemet.utils_station[0m:[36mplot_return_periods_100_difference_gmst

:::{figure}
:label: my-figure
:align: left
:width: 20px

(my-figure-fruit)=
![Here is some fruit 🍏](returns/returns_100years_difference_gmst_density.png)

Some pictures of fruit and the ocean!
:::

In [None]:
# fig, ax = plot_return_periods_100_difference_ds(
#     rl_model=rl_model,
#     covariate=covariate,
#     x_label="2m Max Temperature, $R_a$ [°C]",
#     units="[°C]",
#     color="black"
# )
# plt.show()

utils_station.plot_return_periods_100_difference_prct_gmst_ds(
    rl_model=rl_model,
    covariate=covariate,
    # figures_path="./", 
    figures_path=figures_path, 
    x_label="2m Max Temperature, $\Delta R_{100}$ [%]",
)

:::{figure}
:label: my-figure
:align: left
:width: 20px

(my-figure-fruit)=
![Here is some fruit 🍏](returns/returns_100years_difference_prcnt_gmst_density.png)


Some pictures of fruit and the ocean!
:::