In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
from tqdm import tqdm
import numpy as np
from attrs import define, field

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl

# Convenience for making pretty plots
mpl.rcParams['mathtext.fontset'] = 'stix'
mpl.rcParams['font.family'] = 'STIXGeneral'
plt.rc('xtick', labelsize=12)
plt.rc('ytick', labelsize=12)
plt.rc('axes', labelsize=12)
mpl.rcParams['figure.dpi'] = 300

In [None]:
from gpax.gp import ExactGP
from gpax.kernels import MaternKernel

In [None]:
def piecewise1(x, params):
    return np.piecewise(
        x,
        [x < params["t"], x >= params["t"]],
        [lambda x: x**params["beta1"], lambda x: x**params["beta2"]])

NUM_INIT_POINTS = 15 # number of observation points
NOISE_LEVEL = 0.1
PARAMS = {"t": 1.7, "beta1": 4.5, "beta2": 2.5}

np.random.seed(1)
x = np.random.uniform(0, 3, NUM_INIT_POINTS)
y = piecewise1(x, PARAMS) + np.random.normal(0., NOISE_LEVEL, NUM_INIT_POINTS)

x_grid = np.linspace(0, 3, 200)

In [None]:
_, ax = plt.subplots(1, 1, figsize=(6, 2))
ax.scatter(x, y, alpha=0.5, c='k', marker='x', label="Noisy observations")
ax.legend()
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.set_xlim(0, 3)

plt.show()

In [None]:
gp = ExactGP(kernel=MaternKernel(), x=x, y=y, y_std=None, hp_samples=500, observation_noise=False)
gp.fit()
mu, sd = gp.predict(x_grid)
ci = [mu - 2*sd, mu + 2*sd]

In [None]:
samples = gp.sample(x_grid)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 2))
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.scatter(x, y, marker='x', c='k', zorder=1, label="Noisy observations", alpha=0.7)
for y1 in samples.y:
    ax.plot(x_grid, y1.mean(0), lw=.1, zorder=0, c='r', alpha=.1)
ax.plot(x_grid, piecewise1(x_grid, PARAMS), c='k', linestyle='--', label='True function', alpha=0.5)
ax.legend(loc='upper left')

In [None]:
import jax.numpy as jnp
from gpax.means import Mean
from gpax.utils.prior_utils import Parameter
import numpyro

In [None]:
@define
class PiecewiseMeanPrior(Mean):
    m_t = field(default=Parameter(numpyro.distributions.Uniform(0.5, 2.5)))
    m_beta1 = field(default=Parameter(numpyro.distributions.LogNormal(0, 1)))
    m_beta2 = field(default=Parameter(numpyro.distributions.LogNormal(0, 1)))
    
    def _mean_function(self, x, **params):
        return jnp.piecewise(
            x,
            [x < params["m_t"], x >= params["m_t"]],
            [lambda x: x**params["m_beta1"], lambda x: x**params["m_beta2"]]
        )

In [None]:
gp = ExactGP(kernel=MaternKernel(), mean=PiecewiseMeanPrior(), input_transform=None, output_transform=None, x=x, y=y, y_std=None, hp_samples=500, observation_noise=False)
gp.fit()

In [None]:
mu, sd = gp.predict(x_grid)
ci = [mu - 2*sd, mu + 2*sd]

In [None]:
samples = gp.sample(x_grid)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 2))
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
ax.scatter(x, y, marker='x', c='k', zorder=1, label="Noisy observations", alpha=0.7)
for y1 in samples.y:
    ax.plot(x_grid, y1.mean(0), lw=.1, zorder=0, c='r', alpha=.1)
ax.plot(x_grid, piecewise1(x_grid, PARAMS), c='k', linestyle='--', label='True function', alpha=0.5)
ax.legend(loc='upper left')