In [1]:
import numpy as np
import pandas as pd
import os
import imageio
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.stats import norm
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter

from bayesian_opt import BO

In [2]:
def one_dim_function_simple(burnin):
    '''
    Maximum is attained at 12000
    '''
    f = -((burnin-12000)/10000)**2
    return f


def one_dim_function_complex(burnin):
    '''
    Maximum is attained at 12000
    '''
    f = norm.pdf((burnin-12000)/1000)+np.sin((burnin-12000)/500)/5
    return f


def two_dim_function_simple(burnin, learning_rate):
    '''
    Maximum is attained at 12000, 1.5
    '''
    f = (-((burnin-12000)/10000)**2)+(-(learning_rate-1.25)**2)
    return f


def two_dim_function_complex(burnin, learning_rate):
    '''
    Maximum is attained at 12000, 1.5
    '''
    
    x = (burnin-12000)/1000
    y = (learning_rate-1)*2-1.5
    
    f = (1-(x**2+y**3))*np.exp(-(x**2+y**2)/2)

    return f

In [None]:
def generate_images_1D(path_to_images,data,loss_predicted,sigma,search_space,iteration):
    if not os.path.isdir(path_to_images):
        os.mkdir(path_to_images)

    mean_loss = data['loss'].mean()
    std_loss = data['loss'].std()
    # Plot and save the fit
    x = np.linspace(*search_space['burnin_period'][0])    
    plt.figure()
    plt.plot(x, one_dim_function_complex(x),'b',label="True function");
    plt.plot(data['burnin_period'].values,data['loss'].values,'*r',label="Collected measurements")
    plt.plot(data['burnin_period'].values[-1],data['loss'].values[-1],'og',label="New location")
        
    if ii>0:
        if ii>1:
            loss_predicted= loss_predicted*std_loss+mean_loss
            sigma= sigma*std_loss+mean_loss        

        plt.plot(x,loss_predicted,'r',label="Predicted function")
        plt.fill_between(x, loss_predicted - sigma,
                 loss_predicted + sigma,
                 alpha=0.5, color='c',label="Standard deviation")  
    plt.ylim([-.5, .8])
    plt.grid()
    plt.ylabel('Loss', fontsize=16)
    plt.xlabel('Burn-in period', fontsize=16)
    plt.tight_layout()
    plt.legend(loc = 'upper left')
    plt.savefig(path_to_images+str(iteration)+'.png', bbox_inches='tight',dpi=300)
    

def generate_images_2D(path_to_images,data,loss_predicted,sigma,search_space,iteration):
    if not os.path.isdir(path_to_images):
        os.mkdir(path_to_images)
        
    
    mean_loss = data['loss'].mean()
    std_loss = data['loss'].std()
    
    # Plot the data
    fig = plt.figure()
    ax = fig.add_subplot(1,2,1, projection='3d')
    # Make data.
    burnin_period =  np.linspace(*search_space['burnin_period'][0])
    learning_rate =  np.linspace(*search_space['learning_rate'][0]) 
    X, Y = np.meshgrid(burnin_period, learning_rate,indexing='ij')

    Z = two_dim_function_complex(X, Y)

    # Customize the z axis.
    ax.set_zlim(-.5, 1.5)
    ax.set_ylim(search_space['learning_rate'][0][0], search_space['learning_rate'][0][1])
    ax.set_xlim(search_space['burnin_period'][0][0], search_space['burnin_period'][0][1])
    ax.set_xlabel('Burn-in period')
    ax.set_ylabel('Learning rate')
    ax.set_zlabel('Loss')

    ax.scatter(data['burnin_period'].values, data['learning_rate'].values, data['loss'].values, color='b', linewidth=0.5);
    ax.scatter(data['burnin_period'].values[-1], data['learning_rate'].values[-1], data['loss'].values[-1], color='r', linewidth=0.5);

        
    if ii>0:
        if ii>1:
            loss_predicted= loss_predicted*(std_loss+1e-6)+mean_loss        
        
        surf_2 = ax.plot_surface(X,Y, loss_predicted, cmap=cm.coolwarm, linewidth=0, antialiased=False)
        
    
    ax = fig.add_subplot(1,2,2)
    if ii>0:
        #fig = plt.figure()
        
        ax.contour(X,Y,loss_predicted)
        ax.contourf(X,Y,loss_predicted, cmap='viridis')
        ax.plot(data['burnin_period'].values, data['learning_rate'].values, 'ko')
        
        ax.plot(data['burnin_period'].values[-1], data['learning_rate'].values[-1], 'ro')
        ax.set_xlabel('Burn-in period')
        ax.set_ylabel('Learning rate')
    else:
        ax.contourf(X,Y,np.zeros((np.shape(X))), cmap='viridis')
        
    ax.set_aspect(1.0/ax.get_data_ratio(), adjustable='box')
    fig.tight_layout(pad=6.0)
    fig.savefig(path_to_images+str(iteration)+'.png', bbox_inches='tight',dpi=300)



    
def generate_gif(path_to_images,number_of_iterations):
    # Write out images to a gif
    images = []
    for ii in range(0,number_of_iterations):
        filename = path_to_images+str(ii)+'.png'
        image = imageio.imread(filename)
        images.append(image)
    imageio.mimsave((path_to_images+'movie.gif'), images, duration = 1) # modify duration as needed




# 1D Example

In [None]:
   
number_of_iterations = 15

# Dictionary:
# Key: parameter name
# Entry: tupple with (i) a list for the search space interval (start_val, end_val, num_points) and (ii) type of parameter
search_space = {'burnin_period':([8000, 16000, 100],int)}
maximize=True
list_of_parameters_names = ['burnin_period']


generate_images_flag = True
path_to_images = 'figures_1D/'
generate_gif_flag = True

bayes_optimizer = BO(search_space, list_of_parameters_names, maximize)


for ii in range(0, number_of_iterations):
    print("Iteration: ", ii)
    
    # Fit and get the next point to sample at
    next_parameter_values, loss_predicted, sigma, expected_improvement = bayes_optimizer.bayes_opt()
    
    # Generate the predition from the true function
    res = one_dim_function_complex(next_parameter_values['burnin_period'])
    
    bayes_optimizer.update_loss(res)
    

    data_trace = pd.DataFrame.from_dict(bayes_optimizer.parameters_and_loss_dict,orient='index').transpose()  
    if generate_images_flag:
        generate_images_1D(path_to_images,data_trace,loss_predicted,sigma,search_space)
    
if generate_gif_flag:
    generate_gif(path_to_images,number_of_iterations)

    
print(data_trace)

# 2D Example

In [None]:

number_of_iterations = 15


# Dictionary:
# Key: parameter name
# Entry: tupple with (i) a list for the search space interval (start_val, end_val, num_points) and (ii) type of parameter
search_space = {'burnin_period':([8000, 16000, 100],int),
                'learning_rate':([0.5, 2.0, 100],float)}
maximize=True
list_of_parameters_names = ['burnin_period','learning_rate']



generate_images_flag = True
path_to_images = 'figures_2D/'
generate_gif_flag = True

bayes_optimizer = BO(search_space, list_of_parameters_names, maximize)


for ii in range(0, number_of_iterations):
    print("Iteration: ", ii)
    
    # Fit and get the next point to sample at
    next_parameter_values, loss_predicted, sigma, expected_improvement = bayes_optimizer.bayes_opt()
    
    # Generate the predition from the true function
    res = one_dim_function_complex(next_parameter_values['burnin_period'])
    
    
    bayes_optimizer.update_loss(res)
    
    data_trace = pd.DataFrame.from_dict(bayes_optimizer.parameters_and_loss_dict,orient='index').transpose()
    
    if generate_images_flag:
          
        generate_images_2D(path_to_images,data_trace,loss_predicted,sigma,search_space,ii)
    
if generate_gif_flag:
    generate_gif(path_to_images,number_of_iterations)

print(data_trace)


Iteration:  0
Iteration:  1
Iteration:  2
Iteration:  3
Iteration:  4
Iteration:  5
Iteration:  6
Iteration:  7
Iteration:  8
Iteration:  9


ABNORMAL_TERMINATION_IN_LNSRCH.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  _check_optimize_result("lbfgs", opt_res)


Iteration:  10


ABNORMAL_TERMINATION_IN_LNSRCH.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  _check_optimize_result("lbfgs", opt_res)


Iteration:  11


ABNORMAL_TERMINATION_IN_LNSRCH.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  _check_optimize_result("lbfgs", opt_res)


Iteration:  12
