# AI 4 StoryLines - Bayesian LR

In this example, we are going to showcase how we can do sensitivity analysis using a simple linear regression model.

In [1]:
import autoroot
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro.infer import Predictive, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta
from numpyro import handlers
from numpyro.infer import MCMC, NUTS
import jax
import jax.random as jrandom
import jax.numpy as jnp
from jaxtyping import Array, Float
from pathlib import Path
from dataclasses import dataclass
import xarray as xr
import pandas as pd
import einops
import matplotlib.pyplot as plt
import seaborn as sns
import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
from bayesevt._src.utils.io import get_list_filenames
numpyro.set_platform("cpu")
numpyro.set_host_device_count(64)
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
rng_key = jrandom.PRNGKey(123)

We can see that there is some explained variance that is missing.
Potentially we can explain this with the regression model.

## Load Data

We have a clean analysis-ready dataset available from the previous notebook.
We will save it and revisit it later.

First, we will make sure that the models in the covariates and the QoI are the same.

In [3]:
save_dir = "/pool/usuarios/juanjohn/data/ai4storylines/analysis/"

In [48]:
# load covariates
df = pd.read_csv(Path(save_dir).joinpath("covariates.csv"), index_col=0)
# load qoi
ds = xr.open_dataset(Path(save_dir).joinpath("qoi.nc"))
# quick check
assert df.model.values.sort() == ds.model.values.sort()

## Baseline Model - Linear Regression

#### Input Data

In [49]:
covariate_names = ["sst", "sm_sur", "t2m", "z500_zonal"]
qoi_names = ["tasmax"]
x = df.sort_values(by=["model"])[covariate_names].values
u = ds.sortby("model").stack(spatial=["lat", "lon"]).tasmax.values

In [13]:
x.shape, u.shape

((28, 4), (28, 399))

#### Model

In [14]:
num_spatial = u.shape[1]
num_models = u.shape[0]
num_covariates = x.shape[1]

In [24]:
from models import BayesianLinearRegression, ModelPredictorMCMC

model = BayesianLinearRegression(num_spatial=num_spatial)

In [25]:
rng_key, rng_prior = jrandom.split(rng_key, 2)
prior_predictive = Predictive(model.model, num_samples=1000)
prior_samples = prior_predictive(rng_prior, x=x)

### Sampling: MCMC

In [26]:
# initialize kernel
nuts_kernel = NUTS(model.model)

# initial mcmc scheme
num_warmup = 500
num_samples = 2_000
num_chains = 5
mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, chain_method="parallel")

In [27]:
%%time
rng_key, rng_mcmc = jrandom.split(rng_key, 2)

mcmc.run(rng_key=rng_mcmc, x=x, y=u)

  0%|          | 0/2500 [00:00<?, ?it/s]

  0%|          | 0/2500 [00:00<?, ?it/s]

  0%|          | 0/2500 [00:00<?, ?it/s]

  0%|          | 0/2500 [00:00<?, ?it/s]

  0%|          | 0/2500 [00:00<?, ?it/s]

CPU times: user 2min 45s, sys: 838 ms, total: 2min 46s
Wall time: 36.5 s


### Posterior Samples

In [39]:
rng_key, rng_predict = jrandom.split(rng_key)

In [40]:
posterior_predictive = Predictive(model.model, posterior_samples=mcmc.get_samples(), parallel=True)
posterior_samples = posterior_predictive(rng_predict, x=x,)

In [63]:
posterior_samples["obs"].shape

(10000, 28, 399)

### Predictions

In [42]:
model_inference = ModelPredictorMCMC(model=model.model, posterior_params=mcmc.get_samples())

$$
\begin{aligned}
\text{Sample Posterior Parameters}: && &&
\boldsymbol{\theta}_n &\sim p(\boldsymbol{\theta}|\mathcal{D}) \\
\text{Sample Covariates}: && &&
\mathbf{x}_m &\sim p(\mathbf{x}) \\
\text{Data Likelihood}: && &&
\mathbf{u}_{nm} &= p(\mathbf{u}_{nm}|\mathbf{z}_{nm})
p(\mathbf{z}_{nm}|\mathbf{x}_m,\boldsymbol{\theta}_n)
\end{aligned}
$$

In [43]:
# pred = model_inference.predict(x, rng_key=jrandom.PRNGKey(10))
# pred.shape

It can be faster to predict using vectorized operations

In [44]:
pred = jax.vmap(model_inference.predict, in_axes=(0,None), out_axes=1)(x, rng_predict).squeeze()
pred.shape

(10000, 28, 399)

### Gradients

$$
\begin{aligned}
\text{Sample Posterior Parameters}: && &&
\boldsymbol{\theta}_n &\sim p(\boldsymbol{\theta}|\mathcal{D}) \\
\text{Sample Covariates}: && &&
\mathbf{x}_m &\sim p(\mathbf{x}) \\
\text{Data Likelihood (Gradient)}: && &&
\partial_{\mathbf{x}_m}\mathbf{u}_{nm} &= p(\mathbf{u}_{nm}|\mathbf{z}_{nm})
p(\mathbf{z}_{nm}|\mathbf{x}_m,\boldsymbol{\theta}_n)
\end{aligned}
$$

In [45]:
# grads = model_inference.gradient(x, jrandom.PRNGKey(10))
# grads.shape

In [46]:
grads = jax.vmap(model_inference.gradient, in_axes=(0,None), out_axes=2)(x, rng_predict)
grads.shape

(10000, 399, 28, 4)

In [64]:
import arviz as az
import numpy as np

mcmc_arviz = az.from_numpyro(
    posterior=mcmc, 
    coords = {"chain": np.arange(0, num_chains),
              "draw": np.arange(0, num_samples),
              # "spatial": pd.MultiIndex.from_arrays([lat, lon], names=["lat", "lon"]),
              "covariate": ["sst", "sm", "t2m", "z500"],
              "model": ds.model,
             },
    dims = {
        "bias": ["chain", "draw", "spatial"],
        "loc": ["chain", "draw", "spatial", "covariate"],
        "obs": ["model", "spatial"],
        
    },
    posterior_predictive=posterior_samples,
    predictions={"grads": grads, "pred": pred},
    pred_dims = {
        "grads": ["spatial", "model", "covariate"],
        "pred": ["model", "spatial"], 
    },
)


# # assign the lat-lon coordinates to data structure
# mcmc_arviz = mcmc_arviz.assign_coords({
#     "spatial": ds.stack(spatial=["lat","lon"])[["lat", "lon"]].spatial,
#     "lon": ds.stack(spatial=["lat","lon"])[["lat", "lon"]].lon,
#     "lat": ds.stack(spatial=["lat","lon"])[["lat", "lon"]].lat,
# })

# mcmc_arviz = mcmc_arviz.unstack()

In [65]:
mcmc_arviz

In [61]:
results_save_dir = "/pool/usuarios/juanjohn/data/ai4storylines/results/"

In [62]:
mcmc_arviz.unstack().to_netcdf(Path(results_save_dir).joinpath("blr_mcmc.nc"), engine="netcdf4")

PosixPath('/pool/usuarios/juanjohn/data/ai4storylines/results/blr_mcmc.nc')