In [1]:
import matplotlib.pyplot as plt
import numpy as np

In [2]:
from IPython.display import display
import GPy

In [3]:
# Initialize random number generator
np.random.seed(123)

# Initialise data

In [4]:
# Graph parameters
lb = -5
ub = 5
y_min = -3
y_max = 3

In [5]:
def initFirstPoint():
    firstPoint = True
    x_new = 0
    mu_new = 0 
    s_new = 1
    return firstPoint, x_new, mu_new, s_new

In [6]:
def initData():
    X_train = np.array([])
    y_train = np.array([])
    X_train = X_train.reshape(-1,1)
    y_train = y_train.reshape(-1,1)
    
    n_test = 100
    X_test = np.linspace(lb,ub,n_test)
    X_test = X_test.reshape(-1,1)
    mu_test = np.zeros((n_test,1))
    s2_test = np.ones((n_test,1))
    s_test = np.sqrt(s2_test)
    return X_train, y_train, X_test, mu_test, s_test

# Initialise GP kernel

In [7]:
def initGP(X_train,y_train):
    # Define a squared exponential kernel with the following hyperparameters
    mu = 0
    ell = 1
    sn = 0.5
    sy = np.exp(-6)

    k = GPy.kern.RBF(1, variance=sn, lengthscale=ell)
    m = GPy.models.GPRegression(X_train,y_train,k, noise_var=sy)
    return m

In [8]:
def addTrainingPoint(X_train,y_train,x_new,y_new):    
    X_train = np.append(X_train,x_new)
    y_train = np.append(y_train,y_new)
    return X_train.reshape(-1,1), y_train.reshape(-1,1)

In [9]:
def addConfidenceBounds(ax,x,mu,s):
    ax.fill_between(np.squeeze(x), np.squeeze(mu-2*s), np.squeeze(mu+2*s), facecolor='grey', alpha=0.2)

def removeConfidenceBounds(ax):
    for coll in (ax.collections): ax.collections.remove(coll)

In [10]:
import matplotlib.animation
plt.rcParams["animation.html"] = "html5"

# Reset data
firstPoint, x_new, mu_new, s_new = initFirstPoint()
X_train, y_train, X_test, mu_test, s_test = initData()

# Initialise plot
fig,ax = plt.subplots(1,1, figsize=(8, 6), dpi=100)
ax.set_xlim(lb,ub)
ax.set_ylim(y_min,y_max)
plt.title('Sampling from GP kernel')
plt.ylabel('y'); plt.xlabel('x');

line_post, = ax.plot(X_test,mu_test)
line_next = ax.axvline(x_new,y_min,y_max,color='r', linestyle="dashed")
pts_train, = ax.plot(X_train,y_train,'rx', label="training points")
addConfidenceBounds(ax,X_test,mu_test,s_test)

def addAndRetrain(b=None):
    global X_train,y_train,mu_test,s_test,x_new,mu_new,s_new,firstPoint,m
    # Add new training point based on x_new
    y_new = np.random.normal(mu_new, s_new)
    X_train,y_train = addTrainingPoint(X_train,y_train,x_new,y_new) 
    if firstPoint:
        firstPoint = False
        m = initGP(X_train,y_train)   
    
    # Update training data
    m.set_XY(X_train,y_train)
    
    # Generate predictions
    mu_test, s2_test = m.predict(X_test)
    s_test = np.sqrt(s2_test)

    # Sample next training point
    x_new = np.random.uniform(low=lb, high=ub, size=1) 
    mu_new, s2_new = m.predict(x_new.reshape(-1,1))
    s_new = np.sqrt(s2_new)
    
    updatePlot()
    
def resetPlot(b=None):
    global X_train,y_train,mu_test,s_test,x_new,mu_new,s_new,firstPoint,m
    firstPoint, x_new, mu_new, s_new = initFirstPoint()
    X_train, y_train, X_test, mu_test, s_test = initData()
    updatePlot()
    
def updatePlot():
    line_post.set_data(X_test,mu_test)
    pts_train.set_data(X_train,y_train)
    line_next.set_xdata(x_new)
    removeConfidenceBounds(ax)
    addConfidenceBounds(ax,X_test,mu_test,s_test)    
    fig.canvas.draw()

def animate(i):
    if i == 0: resetPlot()
    else: addAndRetrain()

ani = matplotlib.animation.FuncAnimation(fig, animate, frames=10, interval = 1000)

In [11]:
ani