In [2]:
import jax
import numpyro
import xarray as xr
from jax import numpy as jnp
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
from sklearn.pipeline import Pipeline, FunctionTransformer
from sklearn.model_selection import train_test_split, TimeSeriesSplit
from sklearn.preprocessing import StandardScaler

from src.data_loading import load_data
from src.preprocessing.preprocessing import XArrayStandardScaler, XArrayFeatureUnion, SeasonalFeatures
from src.utils import flatten_array

numpyro.set_host_device_count(4)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
X = (
    load_data(["precip", "temp", "evap"])
    .sel(type=["Basin", "Land", "Water"])
    .dropna("Date")
    .to_array()
    .transpose("Date", "lake", ...)
    .sel(Date=slice("1980", "1990"))
)
X

In [4]:
y, x = xr.align(load_data("rnbs").sel(lake="sup"), X)
y = jnp.array(y)

preprocessor = Pipeline(
    [
        (
            "features",
            XArrayFeatureUnion(
                [
                    (
                        "preprocess",
                        Pipeline(
                            steps=[
                                ("flatten_array", FunctionTransformer(flatten_array)),
                                ("scale", XArrayStandardScaler()),
                            ]
                        ),
                    ),
                    ("seasonal", SeasonalFeatures()),
                ]
            ),
        ),
        ("array", FunctionTransformer(jnp.array)),
    ]
)

train_x, test_x, train_y, test_y = train_test_split(X, y, test_size=0.2)

y_scaler = StandardScaler()
train_y = y_scaler.fit_transform(train_y.reshape(-1, 1)).squeeze()

scaled_x = preprocessor.fit_transform(train_x)
scaled_x, train_y

(Array([[-8.8181239e-01, -1.2943277e+00, -3.2870948e-01, ...,
          3.5071561e-01,  5.0000000e-01,  8.6602539e-01],
        [-3.6393335e-01, -6.1166090e-01, -8.4454484e-02, ...,
         -1.5951632e-01,  8.6602539e-01,  5.0000000e-01],
        [ 2.3567970e+00,  2.3335588e+00,  2.2822921e+00, ...,
         -1.7870008e+00,  1.0000000e+00,  6.1232343e-17],
        ...,
        [ 2.6358845e+00,  2.9742155e+00,  2.0380373e+00, ...,
          1.3054587e+00, -1.0000000e+00, -1.8369701e-16],
        [ 8.8972187e-01,  9.0370947e-01,  8.1676221e-01, ...,
          3.3936384e-01, -5.0000000e-01,  8.6602539e-01],
        [-8.3554649e-01, -7.8270268e-01, -8.5090983e-01, ...,
          1.5952274e+00, -2.4492937e-16,  1.0000000e+00]], dtype=float32),
 array([-1.1611121 , -1.3931339 ,  0.03214244, -1.0237932 , -0.06592076,
        -0.4792523 , -1.3647231 , -0.12549183,  1.0934243 , -0.39371437,
         0.23101808, -1.1800526 ,  1.7761997 ,  0.2973101 ,  1.3087968 ,
         0.6098288 , -0.6260416

In [7]:
# squared exponential kernel with diagonal noise term
def rbf_kernel(X, Z, var, length, noise, jitter=1.0e-6, include_noise=True):
    # deltaXsq = jnp.power((X[:, None] - Z) / length, 2.0)
    deltaXsq = jnp.power(jnp.linalg.norm(X[:, None] - Z, axis=-1)/length, 2)
    k = var * jnp.exp(-0.5 * deltaXsq)
    if include_noise:
        k += (noise + jitter) * jnp.eye(X.shape[0])
    return k


rbf_kernel(scaled_x, scaled_x, 1.0, 1.0, 1.0)

Array([[2.0000010e+00, 6.6748007e-13, 3.0386267e-20, ..., 7.5107534e-35,
        9.7776908e-12, 6.1262499e-06],
       [6.6748007e-13, 2.0000010e+00, 5.1026266e-18, ..., 0.0000000e+00,
        2.2463520e-14, 3.0344793e-09],
       [3.0386267e-20, 5.1026266e-18, 2.0000010e+00, ..., 7.4600902e-22,
        2.2501119e-11, 1.0754462e-21],
       ...,
       [7.5107534e-35, 0.0000000e+00, 7.4600902e-22, ..., 2.0000010e+00,
        3.7958928e-11, 6.8224276e-23],
       [9.7776908e-12, 2.2463520e-14, 2.2501119e-11, ..., 3.7958928e-11,
        2.0000010e+00, 1.7577025e-06],
       [6.1262499e-06, 3.0344793e-09, 1.0754462e-21, ..., 6.8224276e-23,
        1.7577025e-06, 2.0000010e+00]], dtype=float32)

In [8]:
def model(X, Y):
    # set uninformative log-normal priors on our three kernel hyperparameters
    var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
    noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
    length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))

    # compute kernel
    k = rbf_kernel(X, X, var, length, noise)

    # sample Y according to the standard gaussian process formula
    numpyro.sample(
        "Y",
        dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=k),
        obs=Y,
    )

In [9]:
# demonstrate how to use different HMC initialization strategies
kernel = NUTS(model)
mcmc = MCMC(
    kernel,
    num_warmup=1000,
    num_samples=1000,
    num_chains=4,
)

rng_key = jax.random.PRNGKey(0)
mcmc.run(rng_key, X=scaled_x, Y=train_y)
mcmc.get_samples()

Compiling.. :   0%|          | 0/2000 [00:00<?, ?it/s]
[A
[A

[A[A

0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.


[A[A
Running chain 0:   0%|          | 0/2000 [00:01<?, ?it/s]