In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import functools
import jax
import scipy.io
from gaussians import *
from jax import numpy as jnp
from matplotlib import pyplot as plt
from tueplots import bundles
from tueplots.constants.color import rgb

plt.rcParams.update(bundles.beamer_moml())
plt.rcParams.update({"figure.dpi": 200})


In [3]:
# load some data:
data = scipy.io.loadmat("nlindata.mat")
X = data["X"]  # inputs
Y = data["Y"][:, 0]  # outputs
sigma = data["sigma"][0][0].flatten()


In [4]:
# define the kernel and mean function:
def SE_kernel(x, y, ell=1.0):
    return jnp.exp(-(jnp.sum((x - y) / ell, axis=-1) ** 2) / 2)


def const_mean(x, c=0.0):
    return c * jnp.ones_like(x[:, 0])


In [5]:
# now we can construct a Gaussian process prior
kernel = functools.partial(SE_kernel, ell=0.5)
prior = GaussianProcess(const_mean, kernel)
# and condition it on the data:
posterior = prior.condition(Y, X, sigma)


In [6]:
prior

GaussianProcess(m=<function const_mean at 0x1271af060>, k=functools.partial(<function SE_kernel at 0x116993a60>, ell=0.5))

In [7]:
x = jnp.linspace(-8, 8, 300)[:, None]
key = jax.random.PRNGKey(0)

# Plot prior mean and samples
fitted_prior = prior(x)
prior_mean, prior_std = fitted_prior.mu, fitted_prior.std

In [8]:
num_samples = 10
samples = fitted_prior.sample(key, num_samples=num_samples).T

In [9]:
x.flatten().shape

(300,)

In [10]:
samples[1].shape

(10,)

In [11]:
def plot(prior, posterior):
    import plotly.graph_objs as go

    fig = go.Figure()
    num_samples = 50  # number of samples to plot

    # Plot data points with error bars
    fig.add_trace(
        go.Scatter(
            x=X.flatten(),
            y=Y,
            mode="markers",
            marker=dict(color="rgb(0, 101, 189)", size=6),
            error_y=dict(
                type="data",
                array=sigma[0] * jnp.ones_like(Y),
                visible=True,
                color="rgba(0,101,189,0.5)",
            ),
            name="data",
        )
    )

    x = jnp.linspace(-8, 8, 300)[:, None]
    key = jax.random.PRNGKey(0)

    # Plot prior mean and samples
    fitted_prior = prior(x)
    prior_mean, prior_std = fitted_prior.mu, fitted_prior.std
    fig.add_trace(
        go.Scatter(
            x=x.flatten(),
            y=prior_mean,
            mode="lines",
            line=dict(color="rgb(51, 51, 51)", dash="dash"),
            name="prior mean",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=jnp.concatenate([x[:, 0], x[::-1, 0]]),
            y=jnp.concatenate(
                [prior_mean - 2 * prior_std, (prior_mean + 2 * prior_std)[::-1]]
            ),
            fill="toself",
            fillcolor="rgba(51,51,51,0.1)",
            line=dict(color="rgba(51,51,51,0.3)", dash="dot"),
            name="prior 95% CI",
        )
    )
    if num_samples > 0:
        samples = fitted_prior.sample(key, num_samples=num_samples).T
        for i in range(samples.shape[1]):
            fig.add_trace(
                go.Scatter(
                    x=x[:, 0],
                    y=samples[:, i],
                    mode="lines",
                    line=dict(color="rgba(51,51,51,0.3)"),
                    showlegend=False,
                    name="sample",
                    opacity=0.2,
                )
            )

    # Plot posterior mean and samples
    fitted_posterior = posterior(x)
    post_mean, post_std = fitted_posterior.mu, fitted_posterior.std
    fig.add_trace(
        go.Scatter(
            x=x.flatten(),
            y=post_mean,
            mode="lines",
            line=dict(color="rgb(204, 0, 0)"),
            name="posterior mean",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=jnp.concatenate([x[:, 0], x[::-1, 0]]),
            y=jnp.concatenate(
                [
                    post_mean - 2 * post_std,
                    (post_mean + 2 * post_std)[::-1],
                ]
            ),
            fill="toself",
            fillcolor="rgba(204,0,0,0.05)",  # less opaque
            line=dict(color="rgba(204,0,0,0.3)", dash="dot"),
            name="posterior 95% CI",
        )
    )
    if num_samples > 0:
        samples = fitted_posterior.sample(key, num_samples=num_samples).T
        for i in range(samples.shape[1]):
            fig.add_trace(
                go.Scatter(
                    x=x.flatten(),
                    y=samples[:, i],
                    mode="lines",
                    line=dict(color="rgba(204,0,0,0.3)"),
                    showlegend=False,
                    name="sample",
                    opacity=0.2,
                )
            )

    fig.update_layout(
        xaxis_title="$x$",
        yaxis_title="$f(x)$",
        legend=dict(x=0.01, y=0.99),
        yaxis=dict(range=[-10, 20]),
    )
    fig.show()


In [12]:
plot(prior, posterior)


In [13]:
def polynomial_kernel(x, y, degree=2):
    # an encapsulation of parametric Gaussian regression with finitely many polynomial features
    return jnp.sum(
        (x * y) ** jnp.arange(degree)
        / jnp.exp(jax.scipy.special.gammaln(jnp.arange(degree) + 1)),
        axis=-1,
    ) / jnp.sqrt(degree)


def d(x1, x2):
    # helper function returning ell_2 norm ||x1-x2||_2
    return jnp.sqrt(jnp.sum((x1 - x2) ** 2, axis=-1))


def Wiener(x, y, shift=0.0):
    return jnp.maximum(jnp.prod(jnp.minimum(x, y), axis=-1) - shift, 0)


def Integrated_Wiener(x, y, shift=0.0):
    minxy = jnp.maximum(jnp.minimum(x, y) - shift, 0).prod(axis=-1)
    return minxy**2 * (1 / 3 * minxy + 1 / 2 * jnp.abs(x - y).prod(axis=-1))


def Matern_1(x1, x2, l=1.0):
    return jnp.exp(-d(x1, x2) / l)


def Matern_3(x1, x2, l=1.0):
    r = d(x1, x2) / l
    return (1 + jnp.sqrt(3) * r) * jnp.exp(-jnp.sqrt(3) * r)


def Matern_5(x1, x2, l=1.0):
    r = d(x1, x2) / l
    return (1.0 + jnp.sqrt(5) * r + 5.0 / 3.0 * r**2) * jnp.exp(-jnp.sqrt(5) * r)


def RQ_kernel(x, y, ell=1.0, alpha=1.0, theta=1.0):
    return theta**2 * (1 + jnp.sum((x - y) / ell, axis=-1) ** 2 / (2 * alpha)) ** (
        -alpha
    )


In [14]:
kernel = lambda a, b: Wiener(a, b, shift=-8.0)
prior = GaussianProcess(const_mean, kernel)
posterior = prior.condition(Y, X, sigma)
plot(prior, posterior)


In [17]:
kernel = lambda a, b: Integrated_Wiener(a, b, shift=-8.0)
prior = GaussianProcess(const_mean, kernel)
posterior = prior.condition(Y, X, sigma)
plot(prior, posterior)


In [16]:
kernel = lambda a, b: Matern_3(a, b, l=1)
prior = GaussianProcess(const_mean, kernel)
posterior = prior.condition(Y, X, sigma)
plot(prior, posterior)
