In [309]:
!pip install botorch




[notice] A new release of pip is available: 23.3.1 -> 23.3.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [310]:
import time
import torch
import numpy as np
import botorch
from botorch import fit_gpytorch_model
from botorch.models import HeteroskedasticSingleTaskGP
from gpytorch.mlls.sum_marginal_log_likelihood import ExactMarginalLogLikelihood

torch.set_default_dtype(torch.float64)

In [311]:
lower_bounds = [-2, -2]
upper_bounds = [2, 2]
bounds = [lower_bounds, upper_bounds]
bounds = torch.tensor(bounds, dtype=torch.float)


def blackbox_func(x):
    res = x[0]**3 + x[1]**3
    var = torch.rand(1).item()/2 # var in [0.5, 1.5]
    res_noise = np.random.normal(loc=res, scale=np.sqrt(var))
    
    return res_noise, var

blackbox_func([2.0, 2.0])

(16.034727959706622, 0.04083764369885279)

In [312]:
def generate_initial_data(n=10):
    # generate training data
    train_x = torch.rand(n, 2)
    train_y = []
    train_y_var = []

    for i in range(n):
        t_y, t_y_var = blackbox_func(train_x[i])
        train_y.append(t_y)
        train_y_var.append(t_y_var)

    train_y = torch.tensor(train_y).reshape(-1,1)
    train_y_var = torch.tensor(train_y_var).reshape(-1,1)


    return train_x, train_y, train_y_var
    
    
def initialize_model(train_x, train_y, train_y_var, state_dict=None):
    # define models for objective and constraint
    model = HeteroskedasticSingleTaskGP(train_x, train_y, train_y_var)
    mll = ExactMarginalLogLikelihood(model.likelihood, model)

    # load state dict if it is passed
    if state_dict is not None:
        model.load_state_dict(state_dict)
        
    return mll, model

In [324]:
from botorch import fit_gpytorch_model
from botorch.optim import optimize_acqf
from botorch.acquisition.monte_carlo import qNoisyExpectedImprovement
from botorch.exceptions import BadInitialCandidatesWarning
from botorch.sampling import IIDNormalSampler, SobolQMCNormalSampler

def get_next_points(train_x, train_y, train_y_var, best_y, bounds, n_points):    
    train_x_mean = torch.mean(train_x, dim=0)
    train_x_std = torch.std(train_x, dim=0)
    train_x = (train_x-train_x_mean)/train_x_std


    train_y_mean = torch.mean(train_y, dim=0)
    train_y_std = torch.std(train_y, dim=0)
    train_y = (train_y-train_y_mean)/train_y_std

    
    mll, model = initialize_model(train_x, train_y, train_y_var)

    fit_gpytorch_model(mll, retain_graph=True)

    sampler = SobolQMCNormalSampler(2048)
    qNEI = qNoisyExpectedImprovement(model, train_x, sampler)

    candidates = optimize_acqf(
        acq_function=qNEI,
        bounds=bounds,
        q=n_points,
        num_restarts=20,
        raw_samples=500
    )

    return candidates

In [325]:
from botorch.exceptions import InputDataWarning
import warnings
warnings.filterwarnings("ignore", category=InputDataWarning)

NUM_ITERATIONS = 100

train_x, train_y, train_y_var = generate_initial_data(n=3)
best_observed = torch.max(train_y)

for iteration in range(1, NUM_ITERATIONS + 1):    
        print("="*30)
        print("Iteration number = ", iteration)
        t0 = time.time()

        
        candidate = get_next_points(train_x, train_y, train_y_var, 0, bounds, 1)
        candidate_x = candidate[0]
        candidate_y, candidate_y_var = blackbox_func(candidate_x[0])

        train_x = torch.cat([train_x, candidate_x])
        train_y = torch.cat([train_y, torch.tensor(candidate_y).reshape(-1,1)])
        train_y_var = torch.cat([train_y_var, torch.tensor(candidate_y_var).reshape(-1,1)])
        best_observed = torch.max(train_y)

        i = torch.argmax(train_y)
        print(f"{train_x[i]=}")
        print(f"{train_y[i]=}")

        t = time.time()
        print(f"Got in {t-t0} seconds")

Iteration number =  1
train_x[i]=tensor([0.3235, 0.4979])
train_y[i]=tensor([0.8011])
Got in 2.372742176055908 seconds
Iteration number =  2
train_x[i]=tensor([0.3235, 0.4979])
train_y[i]=tensor([0.8011])
Got in 1.9167790412902832 seconds
Iteration number =  3
train_x[i]=tensor([0.3235, 0.4979])
train_y[i]=tensor([0.8011])
Got in 2.078418016433716 seconds
Iteration number =  4
train_x[i]=tensor([0.3235, 0.4979])
train_y[i]=tensor([0.8011])
Got in 1.6053411960601807 seconds
Iteration number =  5
train_x[i]=tensor([0.3235, 0.4979])
train_y[i]=tensor([0.8011])
Got in 1.3969309329986572 seconds
Iteration number =  6
train_x[i]=tensor([0.3235, 0.4979])
train_y[i]=tensor([0.8011])
Got in 1.9166312217712402 seconds
Iteration number =  7
train_x[i]=tensor([0.3235, 0.4979])
train_y[i]=tensor([0.8011])
Got in 1.3122286796569824 seconds
Iteration number =  8
train_x[i]=tensor([0.3235, 0.4979])
train_y[i]=tensor([0.8011])
Got in 1.615483045578003 seconds
Iteration number =  9
train_x[i]=tensor([0.

  warn(


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

ValueError: Input arguments must all be instances of numbers.Number, torch.Tensor or objects implementing __torch_function__.

In [343]:

n=10
lower_bounds = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
upper_bounds = [10, 10, 10, 2, 2, 2]
bounds = [lower_bounds, upper_bounds]
bounds = torch.tensor(bounds)


tensor([[ 0.1000,  0.1000,  0.1000,  0.1000,  0.1000,  0.1000],
        [10.0000, 10.0000, 10.0000,  2.0000,  2.0000,  2.0000]]) 
 [[4.74082044 6.20255942 5.4856321  1.33095741 0.56330245 0.75392889]
 [1.80759191 0.73742934 2.4142889  0.15867303 1.93549639 0.99400486]
 [7.28613889 8.22578667 5.26090643 0.65862728 0.82150666 0.74441587]
 [1.62281353 4.39224778 7.81284271 1.20721258 0.78087423 1.0131707 ]
 [2.40532128 0.84965026 3.62264728 0.67921092 1.54616542 0.40263723]
 [7.74759362 9.76143089 9.2648164  1.52301949 0.63265076 1.99992286]
 [5.88876042 5.51805398 2.13171238 1.0111712  0.80893225 1.94361677]
 [9.17569992 1.20464232 1.07626523 0.98658706 1.05661379 0.82653648]
 [5.03644826 0.178317   2.5595831  0.27262388 0.14766582 0.20054643]
 [6.20438433 9.12564433 9.36366279 1.8085673  1.93668059 0.39850698]] 6
