In [134]:
import torch
from botorch import fit_gpytorch_model
from botorch.acquisition import UpperConfidenceBound
from botorch.models import SingleTaskGP
from botorch.optim import optimize_acqf
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.kernels import RBFKernel, ScaleKernel
import numpy as np
import matplotlib.pyplot as plt

In [135]:
def styblinski_tang(x):
    return 0.5 * torch.sum(x ** 4 - 16 * x ** 2 + 5 * x, dim=-1)

In [136]:
def rosenbrock(x):
    return torch.sum(100.0 * (x[..., 1:] - x[..., :-1]**2)**2 + (1 - x[..., :-1])**2, dim=-1)

In [137]:
def generate_initial_points(n_initial, dim, bounds):
    return torch.rand(n_initial, dim) * (bounds[1] - bounds[0]) + bounds[0]

In [138]:
def create_model(train_X, train_Y):
    kernel = ScaleKernel(RBFKernel(ard_nim_dims=train_X.shape[-1]))
    model = SingleTaskGP(train_X, train_Y, covar_module=kernel)
    return model

In [139]:
dim = 25
active_dim = 5
bounds = torch.tensor([[-5.0] * dim, [5.0] * dim])
n_initial = 200
n_iter = 100

In [140]:
X = generate_initial_points(n_initial, dim, bounds)

In [141]:
X

tensor([[-1.4265,  3.6187,  2.8275,  ..., -2.3944,  2.3231, -3.8500],
        [-2.7506, -0.1478, -0.9576,  ...,  2.9441,  4.1308, -1.5695],
        [-0.9605,  3.6784, -4.4435,  ...,  2.3108,  2.1571,  4.5854],
        ...,
        [-1.7726,  0.2586,  1.3661,  ..., -3.7704, -2.5124, -0.3962],
        [-1.5062, -0.3860, -4.7846,  ...,  1.0546, -1.2210,  0.0640],
        [ 4.6166,  4.0198,  2.4448,  ...,  0.2180,  4.1376, -2.1122]])

In [142]:
X.shape

torch.Size([200, 25])

In [143]:
bounds

tensor([[-5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5.,
         -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5.],
        [ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,
          5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.]])

In [144]:
bounds.shape

torch.Size([2, 25])

In [145]:
Y = styblinski_tang(X)

In [146]:
Y2 = rosenbrock(X)

In [147]:
Y

tensor([-2.7195e+02, -8.1539e+01, -3.1334e+01, -2.4960e+02, -1.7238e+02,
        -3.6939e+02, -1.7872e+02, -4.2487e+01, -2.1778e+02, -2.7685e+02,
        -9.6267e+01,  5.6655e+01,  5.9745e+01, -2.7247e+02,  1.2927e+01,
         3.8532e+01, -1.7629e+02, -2.0639e+02, -3.3354e+02, -2.9766e+02,
         2.6744e+01, -1.2718e+02, -1.3038e+02, -2.1177e+02, -7.3960e+01,
        -1.3889e+02, -9.6251e+01,  1.4689e+02,  1.0804e+02, -2.2394e+02,
        -9.5190e+01,  4.5726e+01, -1.4050e+02, -3.8173e+02, -9.6370e+01,
        -3.4558e+02, -1.1655e+02, -2.8305e+02, -3.2747e+00, -2.0060e+02,
        -1.4964e+02,  4.5501e+01, -1.9626e+02, -1.4623e+02,  9.8681e+01,
         1.8886e+02, -3.3398e+02, -2.7807e+02, -1.8587e+02, -1.0392e+02,
        -1.1030e+02, -1.8506e+02, -8.0112e+00, -2.6396e+02, -2.8902e+02,
         1.1684e+01, -1.2868e+02,  5.3657e+01, -2.5322e+02, -5.8316e+01,
         1.1291e+02,  7.4419e+01, -1.4484e+02, -1.4255e+02, -1.1948e+02,
         8.7495e+01,  3.0754e+02,  1.1937e+02, -3.1

In [148]:
Y2

tensor([288442.5938, 291583.5000, 433717.5938, 281725.0312, 337762.3125,
        244112.7344, 346070.3125, 340325.5312, 380272.7500, 242050.9219,
        335612.8438, 355750.6250, 379318.4688, 179369.1562, 278667.2812,
        290055.8438, 406161.0312, 299486.0938, 170728.0938, 176036.7656,
        307158.9375, 238489.9688, 359647.4375, 321689.1250, 319110.5000,
        244025.3438, 302604.9062, 520015.9062, 446038.5625, 446300.1250,
        294985.0625, 529476.6250, 353563.6562, 192454.8438, 378377.2188,
        214549.3125, 361257.5312, 253727.4219, 257486.2344, 347371.1250,
        241830.4062, 486493.3438, 249011.5000, 357671.5000, 369168.1562,
        372039.3125, 186339.6094, 195782.3125, 378082.6250, 278261.9688,
        251298.0781, 221293.1719, 338563.5938, 269071.8750, 210728.6094,
        364686.7500, 370006.9062, 546665.2500, 277252.0000, 281446.5938,
        439795.2500, 322120.0000, 347244.0938, 254032.0469, 195712.9531,
        386384.6250, 434994.2188, 349442.6875, 1956

In [149]:
Y.shape

torch.Size([200])

In [150]:
Y2.shape

torch.Size([200])

In [151]:
best_f = Y.min().item()

In [152]:
best_f

-457.0860290527344

In [153]:
Y.argmin()

tensor(184)

In [154]:
X[Y.argmin()]

tensor([-3.4853e+00, -6.0002e-01,  1.0805e-03,  1.2545e+00, -8.4942e-01,
        -2.0111e+00, -1.5566e+00,  2.2363e+00,  3.0701e+00,  2.6347e+00,
        -4.4847e-01,  5.0135e-01, -2.4653e+00, -2.7062e+00,  2.3795e+00,
        -3.2034e+00,  3.2251e+00, -3.8593e-01,  1.9421e+00, -2.6673e+00,
        -1.5340e+00, -1.9709e+00, -3.3149e+00, -4.6228e+00, -3.6175e+00])