# AI 4 StoryLines - Bayesian LR

In this example, we are going to showcase how we can do sensitivity analysis using a simple linear regression model.

In [1]:
import autoroot
import numpy as np
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro.infer import Predictive, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta
from numpyro import handlers
from numpyro.infer import MCMC, NUTS
import jax
import jax.random as jrandom
import jax.numpy as jnp
from jaxtyping import Array, Float
from pathlib import Path
from dataclasses import dataclass
import xarray as xr
import pandas as pd
import einops
import matplotlib.pyplot as plt
import seaborn as sns
import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
from bayesevt._src.utils.io import get_list_filenames
# from utils import plot_weights


numpyro.set_platform("cpu")
numpyro.set_host_device_count(64)
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
rng_key = jrandom.PRNGKey(123)

We can see that there is some explained variance that is missing.
Potentially we can explain this with the regression model.

## Load Data

We have a clean analysis-ready dataset available from the previous notebook.
We will save it and revisit it later.

First, we will make sure that the models in the covariates and the QoI are the same.

In [3]:
save_dir = "/pool/usuarios/juanjohn/data/ai4storylines/analysis/"

In [4]:
# load covariates
df = pd.read_csv(Path(save_dir).joinpath("covariates.csv"), index_col=0)
# load qoi
ds = xr.open_dataset(Path(save_dir).joinpath("qoi.nc"))
models = ['access_cm2', 'access_esm1_5', 'bcc_csm2_mr', 'cams_csm1_0', 'canesm5',
       'cmcc_esm2', 'cnrm_cm6_1', 'cnrm_cm6_1_hr', 'cnrm_esm2_1', 'e3sm_1_0',
       'ec_earth3', 'ec_earth3_cc', 'ec_earth3_veg', 'ec_earth3_veg_lr',
       'fgoals_g3', 'gfdl_cm4', 'gfdl_esm4', 'ipsl_cm6a_lr', 'kace_1_0_g',
       'miroc_es2l', 'mpi_esm1_2_hr', 'mpi_esm1_2_lr', 'mri_esm2_0',
       'noresm2_lm', 'noresm2_mm', 'taiesm1', 'ukesm1_0_ll']

# only select valid models
ds = ds.sel(model=models)
df = df[df["model"].map(lambda x: x in models)]

# quick check
assert df.model.values.sort() == ds.model.values.sort()

## Baseline Model - Linear Regression

#### Input Data

In [5]:
covariate_names = ["sst_norm", "sm_norm", "z500_norm"]
qoi_names = ["tasmax_norm"]
x = df.sort_values(by=["model"])[covariate_names].values
u = ds.sortby("model").tasmax_norm.stack(spatial=["lat", "lon"]).values

In [6]:
x.shape, u.shape

((27, 3), (27, 399))

#### Model

In [7]:
num_spatial = u.shape[1]
num_models = u.shape[0]
num_covariates = x.shape[1]

In [8]:
from models import BayesianHierachicalRegression, ModelPredictorMCMC

In [9]:
model = BayesianHierachicalRegression(num_spatial=num_spatial)

In [10]:
rng_key, rng_prior = jrandom.split(rng_key, 2)
prior_predictive = Predictive(model.model, num_samples=1000)
prior_samples = prior_predictive(rng_prior, x=x)

### Sampling: MCMC

In [11]:
# initialize kernel
nuts_kernel = NUTS(model.model)

# initial mcmc scheme
num_warmup = 2_000
num_samples = 2_000
num_chains = 5
mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, chain_method="parallel")

In [12]:
%%time
rng_key, rng_mcmc = jrandom.split(rng_key, 2)

mcmc.run(rng_key=rng_mcmc, x=x, y=u)

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:00<?, ?it/s]

CPU times: user 4min 55s, sys: 2.03 s, total: 4min 57s
Wall time: 1min 39s


### Posterior Samples

In [13]:
rng_key, rng_predict = jrandom.split(rng_key)

In [14]:
posterior_predictive = Predictive(model.model, posterior_samples=mcmc.get_samples(), parallel=True)
posterior_samples = posterior_predictive(rng_predict, x=x)

### Predictions

In [15]:
model_inference = ModelPredictorMCMC(model=model.model, posterior_params=mcmc.get_samples(), )

$$
\begin{aligned}
\text{Sample Posterior Parameters}: && &&
\boldsymbol{\theta}_n &\sim p(\boldsymbol{\theta}|\mathcal{D}) \\
\text{Sample Covariates}: && &&
\mathbf{x}_m &\sim p(\mathbf{x}) \\
\text{Data Likelihood}: && &&
\mathbf{u}_{nm} &= p(\mathbf{u}_{nm}|\mathbf{z}_{nm})
p(\mathbf{z}_{nm}|\mathbf{x}_m,\boldsymbol{\theta}_n)
\end{aligned}
$$

In [16]:
# pred = model_inference.predict(x, rng_key=jrandom.PRNGKey(10))
# pred.shape

It can be faster to predict using vectorized operations

In [17]:
pred = jax.vmap(model_inference.predict, in_axes=(0,None), out_axes=1)(x, rng_predict).squeeze()
pred.shape

(10000, 27, 399)

### Gradients

$$
\begin{aligned}
\text{Sample Posterior Parameters}: && &&
\boldsymbol{\theta}_n &\sim p(\boldsymbol{\theta}|\mathcal{D}) \\
\text{Sample Covariates}: && &&
\mathbf{x}_m &\sim p(\mathbf{x}) \\
\text{Data Likelihood (Gradient)}: && &&
\partial_{\mathbf{x}_m}\mathbf{u}_{nm} &= p(\mathbf{u}_{nm}|\mathbf{z}_{nm})
p(\mathbf{z}_{nm}|\mathbf{x}_m,\boldsymbol{\theta}_n)
\end{aligned}
$$

In [18]:
# grads = model_inference.gradient(x, jrandom.PRNGKey(10))
# grads.shape

In [19]:
grads = jax.vmap(model_inference.gradient, in_axes=(0,None), out_axes=2)(x, rng_predict)
grads.shape

(10000, 399, 27, 3)

In [20]:
import arviz as az
import numpy as np

mcmc_arviz = az.from_numpyro(
    posterior=mcmc, 
    coords = {"chain": np.arange(0, num_chains),
              "draw": np.arange(0, num_samples),
              # "spatial": pd.MultiIndex.from_arrays([lat, lon], names=["lat", "lon"]),
              "covariate": ["sst_norm", "sm_norm", "z500_norm"],
              "model": ds.model,
             },
    dims = {
        "bias": ["chain", "draw", "spatial"],
        "weight": ["chain", "draw", "spatial", "covariate"],
        "z_scale": ["chain", "draw", "spatial"],
        "obs": ["model", "spatial"],
        
    },
    posterior_predictive=posterior_samples,
    predictions={"grads": grads, "pred": pred},
    pred_dims = {
        "grads": ["spatial", "model", "covariate"],
        "pred": ["model", "spatial"], 
    },
)


# assign the lat-lon coordinates to data structure
mcmc_arviz = mcmc_arviz.assign_coords({
    "spatial": ds.stack(spatial=["lat","lon"])[["lat", "lon"]].spatial,
    "lon": ds.stack(spatial=["lat","lon"])[["lat", "lon"]].lon,
    "lat": ds.stack(spatial=["lat","lon"])[["lat", "lon"]].lat,
})

mcmc_arviz

In [21]:
results_save_dir = "/pool/usuarios/juanjohn/data/ai4storylines/results/"

In [22]:
mcmc_arviz.unstack().to_netcdf(Path(results_save_dir).joinpath("bhm_mcmc_norm.nc"), engine="netcdf4")

PosixPath('/pool/usuarios/juanjohn/data/ai4storylines/results/bhm_mcmc_norm.nc')