In [None]:
from __future__ import print_function, division, absolute_import
  
import GPy
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import ipywidgets as widgets # Widget definitions
from IPython import display # Used to display widgets in the notebook

# Import safeopt from system, alternatively add main folder to path
try:
    import safeopt
except ImportError:
    import sys
    import os
    module_path = os.path.abspath('..')
    sys.path.append(module_path)
    import safeopt

## Define a kernel and function

Here we define a kernel. The function is drawn at random from the GP and is corrupted my Gaussian noise

In [None]:
# Measurement noise
noise_var = 0.05 ** 2

# Set fixed Gaussian measurement noise
likelihood = GPy.likelihoods.gaussian.Gaussian(variance=noise_var)

# Bounds on the inputs variable
bounds = [(-5., 5.), (-5., 5.)]

# Define Kernel
# works on the first column of X, index=0
k_parameters = GPy.kern.RBF(input_dim=1, variance=2., lengthscale=1.0, active_dims=[0])
# works on the second column of X, index=1
k_context = GPy.kern.RBF(input_dim=1, variance=2., lengthscale=1.0, active_dims=[1], name='context')
kernel = k_parameters * k_context

# set of parameters
parameter_set = safeopt.linearly_spaced_combinations([bounds[0]], 1000)
num_contexts = 1

# Initial safe point
x0 = np.array([[0]])

# Generate function with safe initial point at x=0
def sample_safe_fun(context=0):
    context = np.atleast_2d(context)
    while True:
        # Joint function over parameters and contexts
        sampled_fun = safeopt.sample_gp_function(kernel.copy(), bounds, noise_var, 10)
        
        if sampled_fun(np.hstack([x0, context]), noise=False) > 0.5:
            break
        
    return sampled_fun

## Interactive run of the algorithm

In [None]:
button_add = widgets.Button(description='Add sample', tooltip='Sample a new data point for the optimization')
button_reset = widgets.Button(description='Reset', tooltip='Restart the safe optimization algorithm')
button_new = widgets.Button(description='New function', tooltip='Add a new function and reset')
slider_safety = widgets.FloatSlider(description='Safety threshold', value=0, min=-4, max=4, step=0.1)
slider_context = widgets.FloatSlider(description='Context', value=0, min=-1, max=1, step=0.05)

ms = 20
mew = 7
lw = 5

context = np.array([[0]])
inp = np.hstack([parameter_set, np.zeros((parameter_set.shape[0], num_contexts))])

def plot_gp():
    # Make points transparent when they belong to a different context
    relevance = gp_opt.gp.kern.context.K(np.hstack([[[0]], context]), gp_opt.gp.X)
    relevance /= gp_opt.gp.kern.context.variance
    relevance = np.exp(100 * relevance) / np.exp(100)
    relevance[relevance < 0.25] = 0.25
    point_color = np.zeros((gp_opt.gp.X.shape[0], 4))
    point_color[:, 3] = relevance
            
    # Plot GP
    gp_opt.plot(figure=plt.figure(figsize=(15, 8)), lw=lw, point_color=point_color)
    
    # Plot last point red
    plt.plot(gp_opt.gp.X[-1, 0], gp_opt.gp.Y[-1, :], 'rx', ms=ms, mew=mew, label='Last Point')
        
    # Plot safe line
    plt.plot([bounds[0][0], bounds[0][1]], [gp_opt.fmin[0], gp_opt.fmin[0]], 'k--', label='Minimum', lw=lw)
        
    plt.plot(true_values[0], true_values[1],
             alpha=0.15, label='True function', lw=lw)
    
    plt.ylim([-4, 4])
    plt.legend(loc=3)
    
    # Ensure we only get one plot
    display.clear_output(wait=True)

def new_sample(b=None):
    """Draw a new gp sample"""
    x = gp_opt.optimize(context=context)
    y = fun(np.hstack([x, gp_opt.context]))
    gp_opt.add_new_data_point(x, y, context)
    plot_gp()
button_add.on_click(new_sample)

def reset(b=None):
    """Reset the SafeOpt algorithm"""
    global gp_opt
    x = np.hstack([x0, context])
    gp = GPy.core.GP(x, fun(x), kernel, likelihood)
    gp_opt = safeopt.SafeOpt(gp, parameter_set, 0., num_contexts=1, threshold=0.5)
    gp_opt.context = context
    plot_gp()
button_reset.on_click(reset)

def new_fun(b=None):
    """Draw a new function from the GP"""
    global fun
    fun = sample_safe_fun(context)
    get_true_fun()
    reset(b)
button_new.on_click(new_fun)

def get_true_fun():
    global true_values
    inp[:, -num_contexts:] = context
    true_values = (parameter_set, fun(inp, noise=False))

def change_safety(b=None):
    if b == 'value':
        gp_opt.fmin[0] = slider_safety.get_state()['value']
        plot_gp()
slider_safety.on_trait_change(change_safety)

def change_context(b=None):
    if b == 'value':
        global context
        context = np.array([[slider_context.get_state()['value']]])
        get_true_fun()
        gp_opt.context = context
        plot_gp()
slider_context.on_trait_change(change_context)

display.display(button_add, button_new, button_reset, slider_safety, slider_context)
new_fun()