# Modeling Extremes - Numpyro Pt 2 - MAP

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "" # first gpu
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'FALSE'

import jax
jax.config.update('jax_platform_name', 'cpu')

import numpyro
import multiprocessing

num_devices = multiprocessing.cpu_count()
numpyro.set_platform("cpu")
numpyro.set_host_device_count(num_devices)

In [2]:
import autoroot
from pathlib import Path
import numpy as np
import xarray as xr
import pandas as pd
import pint_xarray
import arviz as az

from st_evt.viz import plot_histogram, plot_density
from omegaconf import OmegaConf

import jax
import jax.random as jrandom
import jax.numpy as jnp
import pandas as pd

rng_key = jrandom.PRNGKey(123)

from numpyro.infer import Predictive
import arviz as az

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS
import xarray as xr
import regionmask

import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter, FuncFormatter
import seaborn as sns
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%config InlineBackend.figure_format = 'retina'
plt.style.use(
    "https://raw.githubusercontent.com/ClimateMatchAcademy/course-content/main/cma.mplstyle"
)

from loguru import logger

# num_devices = 5
# numpyro.set_host_device_count(num_devices)


%matplotlib inline
%load_ext autoreload
%autoreload 2

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


## Data

In [3]:
DATA_URL = autoroot.root.joinpath("data/ml_ready/aemet/t2max_stations_bm_summer.zarr")
variable = "t2max"
covariate = "gmst"
spatial_dim_name = "station_id"


# LOAD DATA
with xr.open_dataset(DATA_URL, engine="zarr") as f:
    ds_bm = f.load()
    ds_bm = ds_bm.where(ds_bm.red_feten_mask == 1, drop=True) 

    y = ds_bm[variable].values.squeeze()
    t = ds_bm[covariate].values.squeeze()
    t_pred = jnp.linspace(0.0, 2.5, 100)
    assert len(y.shape) == 2
    assert len(t.shape) == 1

## Model

In [4]:
from st_evt._src.modules.models.aemet.gevd_nonstationary_iid.model import init_t2m_model

In [5]:
from st_evt._src.models.gevd import NonStationaryUnPooledGEVD, CoupledExponentialUnPooledGEVD

t0 = float(t.min())

# Intercept Parameter
loc_init = np.mean(y)
scale_init = np.std(y)
logger.debug(f"Initial Location: Normal({loc_init:.2f}, {scale_init:.2f})")
intercept_prior = dist.Normal(float(loc_init), float(scale_init))

# Slope Prior
slope_prior = dist.Normal(0.0, 1.0)

# Scale Parameter is always positive
loc_init = np.log(scale_init)
logger.debug(f"Initial Scale: LogNormal({loc_init:.2f}, 0.5)")
scale_prior = dist.LogNormal(loc_init, 0.5)

# TEMPERATURE has a negative shape
concentration_prior = dist.TruncatedNormal(-0.3, 0.1, low=-1.0, high=-1e-5)

# initialize model
model = init_t2m_model(
    t_values=t,
    y_values=y,
    spatial_dim_name=spatial_dim_name,
    variable_name=variable, 
    time_dim_name=covariate
)

[32m2024-12-16 13:19:09.192[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m8[0m - [34m[1mInitial Location: Normal(36.18, 4.07)[0m
[32m2024-12-16 13:19:09.193[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m16[0m - [34m[1mInitial Scale: LogNormal(1.40, 0.5)[0m
[32m2024-12-16 13:19:09.194[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.gevd_nonstationary_iid.model[0m:[36minit_t2m_model[0m:[36m80[0m - [34m[1mInitial Location: Normal(36.18, 4.07)[0m
[32m2024-12-16 13:19:09.194[0m | [34m[1mDEBUG   [0m | [36mst_evt._src.modules.models.aemet.gevd_nonstationary_iid.model[0m:[36minit_t2m_model[0m:[36m88[0m - [34m[1mInitial Scale: LogNormal(1.40, 0.5)[0m


## Inference

In [6]:
from st_evt._src.models.inference import SVILearner

num_steps = 200_000
num_warmup_steps = int(0.1 * num_steps)

init_lr = 1e-10
peak_lr = 1e-3
end_lr = 1e-4
method = "laplace"
svi_learner = SVILearner(model, peak_lr=peak_lr, end_lr=end_lr, init_lr=init_lr, num_steps=num_steps, num_warmup_steps=num_warmup_steps, method=method)

svi_posterior = svi_learner(t=t, y=y)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200000/200000 [01:43<00:00, 1935.50it/s, init loss: 26929.7051, avg. loss [190001-200000]: 18374.4551]


In [7]:
# grab median params
init_params = svi_posterior.median_params

In [8]:
from st_evt._src.models.inference import MCMCLearner

num_samples = 1_000
num_warmup = 1_000
num_chains = 8


mcmc_learner = MCMCLearner(
    model=model, 
    num_warmup=num_warmup,
    num_samples=num_samples,
    num_chains=num_chains,
    init_params=init_params,
)

In [9]:
mcmc_posterior = mcmc_learner(t=t, y=y)

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

## Posterior

In [10]:
logger.info("Grabbing MCMC Samples...")
posterior_samples = mcmc_posterior.mcmc.get_samples()

[32m2024-12-16 13:27:27.013[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mGrabbing MCMC Samples...[0m


In [11]:
logger.info("Creating MCMC Data Structure...")
az_ds = mcmc_posterior.init_arviz_summary()

# correct coordinates
logger.info("Correcting Coordinates...")
az_ds = az_ds.assign_coords({covariate: ds_bm[covariate]})
az_ds = az_ds.assign_coords({spatial_dim_name: ds_bm[spatial_dim_name]})
az_ds

[32m2024-12-16 13:30:13.056[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mCreating MCMC Data Structure...[0m
[32m2024-12-16 13:30:13.899[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1mCorrecting Coordinates...[0m


### Posterior Predictive

In [12]:
logger.info("Calculating Posterior Predictive Samples...")
# Posterior predictive samples
rng_key, rng_subkey = jrandom.split(rng_key)


posterior_predictive_samples = mcmc_posterior.posterior_predictive_samples(rng_subkey, t=t)


[32m2024-12-16 13:30:14.325[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mCalculating Posterior Predictive Samples...[0m


In [13]:
logger.info("Creating Posterior Predictive Datastructure...")
az_ds_postpred = az.from_numpyro(
    posterior_predictive=posterior_predictive_samples,
    # log_likelihood=nll_postpred_samples,
    dims=model.dimensions,
    num_chains=num_chains,
)
# correct coordinates
logger.info("Correcting Coordinates...")
az_ds_postpred = az_ds_postpred.assign_coords({covariate: ds_bm[covariate]})
az_ds_postpred = az_ds_postpred.assign_coords({spatial_dim_name: ds_bm[spatial_dim_name]})

az_ds_postpred

[32m2024-12-16 13:30:14.940[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mCreating Posterior Predictive Datastructure...[0m
[32m2024-12-16 13:30:16.836[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mCorrecting Coordinates...[0m


#### Log-Likelihood

In [14]:
from numpyro.infer import log_likelihood

logger.info("Calculating Log-Likelihood for Posterior Predictive...")
nll_postpred_samples = log_likelihood(
    model=model,
    posterior_samples=posterior_predictive_samples,
    parallel=False,
    batch_ndim=1,
    t=t,
    y=y,
)

[32m2024-12-16 13:30:17.168[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mCalculating Log-Likelihood for Posterior Predictive...[0m


In [15]:
az_ds_postpred_nll = az.from_numpyro(
    posterior_predictive=nll_postpred_samples,
    dims=model.dimensions,
    num_chains=num_chains,
).posterior_predictive[variable].rename("nll")

# correct coordinates
az_ds_postpred_nll = az_ds_postpred_nll.assign_coords({covariate: ds_bm[covariate]})
az_ds_postpred_nll = az_ds_postpred_nll.assign_coords({spatial_dim_name: ds_bm[spatial_dim_name]})

In [16]:
nll_postpred_samples[variable].shape

(8000, 59, 154)

In [17]:
logger.info("Adding extra coordinates")
az_ds_postpred.posterior_predictive["nll"] = az_ds_postpred_nll
az_ds_postpred.posterior_predictive[f"{variable}_true"] = (
    (covariate, spatial_dim_name),
    np.asarray(y),
)

[32m2024-12-16 13:30:17.552[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mAdding extra coordinates[0m


## Predictions

In [18]:
t_pred = np.linspace(0.0, 2.5, 100)

In [19]:
# PREDICTIVE POSTERIOR
return_sites = [
    "location", "location_slope", "location_intercept",
    "scale",
    "concentration",
    variable,
]
prediction_samples = mcmc_posterior.posterior_predictive_samples(
    rng_subkey, 
    return_sites=return_sites,
    t=t_pred,
)

In [20]:
logger.info("Constructing Posterior...")
az_ds_preds = az.from_dict(
    predictions={k: np.expand_dims(v, 0) for k, v in prediction_samples.items()},
    pred_dims=model.dimensions,
)

# correct coordinates
logger.info("Correcting Coordinates...")
az_ds_preds = az_ds_preds.assign_coords({covariate: t_pred})
az_ds_preds = az_ds_preds.assign_coords({spatial_dim_name: ds_bm[spatial_dim_name]})


az_ds_preds

[32m2024-12-16 13:30:17.876[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mConstructing Posterior...[0m
[32m2024-12-16 13:30:20.355[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m8[0m - [1mCorrecting Coordinates...[0m


In [21]:
az_ds.add_groups(az_ds_postpred)
az_ds.add_groups(az_ds_preds)
az_ds

#### MCMC Statistics

In [22]:
stats = az.waic(az_ds)
stats

See http://arxiv.org/abs/1507.04544 for details


Computed from 8000 posterior samples and 9086 observations log-likelihood matrix.

          Estimate       SE
elpd_waic -17794.94    69.99
p_waic      399.34        -


In [23]:

az_ds.log_likelihood.attrs["elpd_waic"] = stats.elpd_waic
az_ds.log_likelihood.attrs["se"] = stats.se
az_ds.log_likelihood.attrs["p_waic"] = stats.p_waic

### Save Data

In [24]:
mcmc_results_path = Path("/home/juanjohn/pool_data/dynev4eo/experiments/walkthrough/aemet/t2max/nonstationary_iid_mcmc_redfeten/results")
mcmc_results_path.mkdir(parents=True, exist_ok=True)
mcmc_results_path = mcmc_results_path.joinpath("nonstationary_iid_mcmc_redfeten.zarr")
az_ds.to_zarr(store=str(mcmc_results_path))

<zarr.hierarchy.Group '/'>