In [None]:
import warnings
from itertools import cycle
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as fnn
from botorch.acquisition.monte_carlo import qNoisyExpectedImprovement
from botorch.acquisition.multi_objective.monte_carlo import (
    qExpectedHypervolumeImprovement, qNoisyExpectedHypervolumeImprovement)
from botorch.acquisition.multi_objective.objective import \
    IdentityMCMultiOutputObjective
from botorch.models import SingleTaskGP
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.optim import optimize_acqf
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.test_functions.multi_objective import (
        DTLZ2,
        DTLZ7,
        ZDT3,
        BraninCurrin,
        GMM
)
from botorch.utils.multi_objective.box_decompositions import \
    NondominatedPartitioning
from botorch.utils.multi_objective.hypervolume import (
    Hypervolume, infer_reference_point
)
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from scipy.stats.qmc import LatinHypercube
from sklearn.decomposition import PCA
from torch import Size

from vsd.acquisition import VariationalPreferenceAcquisition
from vsd.condproposals import (ConditionalGaussianProposal,
                               PreferenceSearchDistribution)
from vsd.cpe import PreferenceContinuousCPE, fit_cpe_labels, make_constrastive_alignment_data
from vsd.generation import generate_candidates_iw
from vsd.preferences import EmpiricalPreferences, UnitNormal, MixtureUnitNormal
from vsd.proposals import fit_ml
from vsd.utils import is_non_dominated_strict
from vsd.labellers import ParetoAnnealed

warnings.filterwarnings("ignore")

In [None]:
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# Config
N_INIT = 64
BATCH_SIZE = 5
GP_BOUNDS = 15
N_ITER = 10
N_REPS = 10
USE_EMPIRICAL_PREFERENCES = False
SAMPLES = 128
# FUNCTION = "ZDT3"
# FUNCTION = "BraninCurrin"
# FUNCTION = "DTLZ2"
FUNCTION = "DTLZ2-5"
# FUNCTION = "DTLZ7"
# FUNCTION = "GMM"

obj_mappings = {
    "DTLZ2": dict(f=DTLZ2, args=dict(dim=3, num_objectives=2, negate=True), yscale=None, sigmoid=True, start_scale=3),
    "DTLZ2-5": dict(f=DTLZ2, args=dict(dim=6, num_objectives=5, negate=True), yscale=None, sigmoid=True, start_scale=2),
    "DTLZ7": dict(f=DTLZ7, args=dict(dim=7, num_objectives=6, negate=True), yscale=None, sigmoid=True, start_scale=2),
    "BraninCurrin": dict(f=BraninCurrin, args=dict(negate=False), yscale=[200., 4.], sigmoid=False, start_scale=2.5),
    "ZDT3": dict(f=ZDT3, args=dict(dim=6, num_objectives=2, negate=True), yscale=None, sigmoid=True, start_scale=2.),
    "GMM": dict(f=GMM, args=dict(num_objectives=2, negate=True), yscale=None, sigmoid=True, start_scale=2)
}

plt.rcParams.update(
    {
        "font.size": 12,
        "axes.labelsize": 14,
        "axes.titlesize": 16,
        "legend.fontsize": 10,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
    }
)

In [None]:
class BBWrapper(obj_mappings[FUNCTION]["f"]):
    def __init__(self):
        super().__init__(**obj_mappings[FUNCTION]["args"])
        self.yscale = obj_mappings[FUNCTION]["yscale"]
        if self.yscale is not None:
            self.yscale = torch.tensor(self.yscale, dtype=torch.float)

    def __call__(self, X):
        if obj_mappings[FUNCTION]["sigmoid"]:
            X = torch.sigmoid(X)
        y = super().__call__(X).float()
        if self.yscale is not None:
            return  y / self.yscale
        return y

bb = BBWrapper()
D = bb.dim if hasattr(bb, "dim") else 2
# ref_point = bb.ref_point.float()

In [None]:
# X_init = torch.rand(N_INIT, D)
s = obj_mappings[FUNCTION]["start_scale"] * 2
b = obj_mappings[FUNCTION]["start_scale"]
X_init = torch.tensor(LatinHypercube(d=D).random(n=N_INIT) * s - b).float()
y_init = bb(X_init)
M = y_init.shape[1]
ref_point = infer_reference_point(y_init)

In [None]:
def sanitize_inputs(X, Y):
    mask = ~torch.isnan(Y).any(dim=1)
    return X[mask], Y[mask]

def fit_gpytorch_model(mll, num_iter=500):
    mll.train()
    optimizer = torch.optim.Adam(mll.parameters(), lr=1e-3)

    for _ in range(num_iter):
        optimizer.zero_grad()
        output = mll.model(*mll.model.train_inputs)
        loss = -mll(output, mll.model.train_targets)
        loss.backward()
        optimizer.step()

        for param in mll.model.parameters():
            if param.requires_grad:
                param.data.clamp_(min=1e-6)

    mll.eval()


def fit_gp_moo_model(X, Y):
    models = [
        SingleTaskGP(X, Y[:, i:i+1])
        for i in range(Y.shape[-1])
    ]
    model = ModelListGP(*models)
    mll = SumMarginalLogLikelihood(model.likelihood, model)
    fit_gpytorch_model(mll)
    return model


def fit_gp_sca_model(train_X, train_Y):
    weights = torch.rand(M); weights /= weights.sum()
    scalar_Y = (train_Y * weights).sum(dim=-1, keepdim=True)
    scalar_gp = SingleTaskGP(train_X, scalar_Y)
    mll = ExactMarginalLogLikelihood(scalar_gp.likelihood, scalar_gp)
    fit_gpytorch_model(mll)
    return scalar_gp

In [None]:
def generate_candidates_qehvi(model, train_X, train_Y):
    sampler = SobolQMCNormalSampler(sample_shape=Size([SAMPLES]))
    partitioning = NondominatedPartitioning(ref_point=ref_point, Y=train_Y)
    acq_func = qExpectedHypervolumeImprovement(
        model=model, ref_point=ref_point.tolist(),
        partitioning=partitioning, sampler=sampler,
        objective=IdentityMCMultiOutputObjective(outcomes=list(range(train_Y.shape[-1])))
    )

    candidates, _ = optimize_acqf(
        acq_function=acq_func,
        bounds=torch.stack([-GP_BOUNDS * torch.ones(D), GP_BOUNDS * torch.ones(D)]),
        q=BATCH_SIZE,
        num_restarts=10,
        raw_samples=SAMPLES,
    )
    return candidates

def generate_candidates_qnehvi(model, train_X, train_Y):
    sampler = SobolQMCNormalSampler(sample_shape=torch.Size([SAMPLES]))
    acq_func = qNoisyExpectedHypervolumeImprovement(
        model=model, X_baseline=train_X, ref_point=ref_point.tolist(),
        sampler=sampler,
        objective=IdentityMCMultiOutputObjective(outcomes=list(range(train_Y.shape[-1])))
    )

    candidates, _ = optimize_acqf(
        acq_function=acq_func,
        bounds=torch.stack([-GP_BOUNDS * torch.ones(D), GP_BOUNDS * torch.ones(D)]),
        q=BATCH_SIZE,
        num_restarts=10,
        raw_samples=SAMPLES,
    )
    return candidates

def generate_candidates_nparego(model, train_X, train_Y):
    acq_func = qNoisyExpectedImprovement(
        model=model,
        X_baseline=train_X,
        sampler=SobolQMCNormalSampler(sample_shape=torch.Size([SAMPLES]))
    )

    candidates, _ = optimize_acqf(
        acq_function=acq_func,
        bounds=torch.stack([-GP_BOUNDS * torch.ones(D), GP_BOUNDS * torch.ones(D)]),
        q=BATCH_SIZE,
        num_restarts=10,
        raw_samples=SAMPLES,
    )
    return candidates

In [None]:
def vsd_callback(i, loss, grads):
    if (i % 100) == 0:
        gmean = sum([g.detach().mean() for g in grads if g is not None]) / len(grads)
        print(f"  {i}: loss = {loss:.3f}, mean grad = {gmean:.3f}")

def callback(i, loss, _):
    if (i % 100) == 0:
        print(f"  {i}: loss = {loss:.3f}")

class AGPS():

    def __init__(self):
        if USE_EMPIRICAL_PREFERENCES:
            self.preferences = EmpiricalPreferences()
        else:
            mu = torch.rand((5, M))
            self.preferences = MixtureUnitNormal(locs=mu)
            # self.preferences = UnitNormal(dim=M)
        self.pareto_cpe = PreferenceContinuousCPE(x_dim=D, u_dims=M, latent_dim=32, dropoutp=0.2, hidden_layers=2)
        self.preference_cpe = PreferenceContinuousCPE(x_dim=D, u_dims=M, latent_dim=32, dropoutp=0.2, hidden_layers=2)
        prior = torch.distributions.MultivariateNormal(loc=torch.zeros((1, D)), precision_matrix=torch.eye(D) * 0.01)
        cproposal = ConditionalGaussianProposal(x_dims=D, u_dims=M, latent_dim=64, hidden_layers=4, bias=True)
        self.proposal = PreferenceSearchDistribution(cproposal=cproposal, preference=self.preferences)
        self.acq = VariationalPreferenceAcquisition(pareto_model=self.pareto_cpe, pref_model=self.preference_cpe, prior_dist=prior)
        self.labeller = ParetoAnnealed(percentile=0.75, T=N_ITER)

    def fit(self, X, y, round):
        print("Running A-GPS...")
        z = torch.tensor(self.labeller(y), dtype=torch.float)

        U = fnn.normalize(y - ref_point, p=2, dim=1)
        # Augment dataset with misalignments
        Xa, Ua, za = make_constrastive_alignment_data(X, U)
        if USE_EMPIRICAL_PREFERENCES:
            self.preferences.set_preferences(U[z==1, :] if round > 0 else U)
        else:
            print("Fitting preferences.")
            fit_ml(
                self.preferences,
                U[z==1, :] if round > 0 else U,
                optimizer_options=dict(lr=1e-3, weight_decay=1e-8),
                stop_options=dict(n_window=500, maxiter=10000),
                callback=callback
            )
        cpe_opt_options = dict(lr=1e-3, weight_decay=1e-6)
        cpe_stop_options = dict(n_window=1000, maxiter=10000)
        print("Fitting Pareto CPE.")
        fit_cpe_labels(
            self.pareto_cpe,
            X,
            z,
            U,
            optimizer_options=cpe_opt_options,
            stop_options=cpe_stop_options,
            callback=callback,
            batch_size=32
        )
        print("Fitting Alignment CPE.")
        fit_cpe_labels(
            self.preference_cpe,
            Xa,
            za,
            Ua,
            optimizer_options=cpe_opt_options,
            stop_options=cpe_stop_options,
            callback=callback,
            batch_size=32
        )
        print("Fitting AGPS.")
        generate_candidates_iw(
            self.acq,
            self.proposal,
            optimizer_options=dict(lr=1e-3),
            stop_options=dict(n_window=4000, maxiter=10000),
            gradient_samples=SAMPLES,
            callback=vsd_callback
        )
        return self.proposal.sample(torch.Size([BATCH_SIZE]))

In [None]:
X_vsd, y_vsd = X_init.clone(), y_init.clone()
hv_vsd, hv_comp = [], Hypervolume(ref_point=ref_point)
agps = AGPS()
for t in range(N_ITER):
    Xc, _ = agps.fit(X_vsd, y_vsd, round=t)
    yc = bb(Xc)
    X_vsd = torch.cat([X_vsd, Xc], dim=0)
    y_vsd = torch.cat([y_vsd, yc], dim=0)
    hv_vsd.append(hv_comp.compute(y_vsd[is_non_dominated_strict(y_vsd)]))

In [None]:
# Baselines
results = {}
if M <= 4:
    methods = ["qEHVI", "qNEHVI", "qNParEGO"]
elif M <= 5:
    methods = ["qNEHVI", "qNParEGO"]
else:
    methods = ["qNParEGO"]

for method in methods:
    X, Y, hvs = X_init.clone(), y_init.clone(), []
    hv_comp = Hypervolume(ref_point=ref_point)
    for i in range(N_ITER):
        print(f"Training {method}: {i} ... ", end="")
        if method == "qEHVI":
            moo_model = fit_gp_moo_model(X, Y)
            Xc = generate_candidates_qehvi(moo_model, X, Y)
        elif method == "qNEHVI":
            moo_model = fit_gp_moo_model(X, Y)
            Xc = generate_candidates_qnehvi(moo_model, X, Y)
        else:
            sca_model = fit_gp_sca_model(X, Y)
            Xc = generate_candidates_nparego(sca_model, X, Y)
        yc = bb(Xc)
        X = torch.cat([X, Xc], dim=0)
        Y = torch.cat([Y, yc], dim=0)
        hvs.append(hv_comp.compute(Y[is_non_dominated_strict(Y)]))
        print("Done!")
    results[method] = hvs

# Plot
plt.figure(dpi=150, figsize=(8, 5))
plt.plot(hv_vsd, label="A-GPS", marker='o')
for method in results:
    plt.plot(results[method], label=method, marker='x')
plt.xlabel("Iteration")
plt.ylabel("Hypervolume")
plt.title(f"{FUNCTION}: Hypervolume Comparison")
plt.legend()
plt.grid()
plt.show()

In [None]:
# Demonstrate conditioning
np.set_printoptions(suppress=True, precision=2) # Set the precision of the output to 3
condsamples = 100
yplt = y_vsd[is_non_dominated_strict(y_vsd)]
yplt -= ref_point

if M > 2:
    pca = PCA(n_components=2)
    yplt = torch.tensor(pca.fit_transform(yplt.numpy()))
    pref = pca.transform(ref_point.numpy()[np.newaxis, :]).flatten()
else:
    pref = ref_point

uf1 = torch.tensor([torch.quantile(yplt[:, 0], q=0.9), torch.quantile(yplt[:, 1], q=0.1)])
uf2 = torch.tensor([yplt[:, 0].mean(), yplt[:, 1].mean()])
uf3 = torch.tensor([torch.quantile(yplt[:, 0], q=0.1), torch.quantile(yplt[:, 1], q=0.9)])

start = (pref[0].item(), pref[1].item())
cols = plt.cm.inferno([0.1, 0.5, 0.9])
plt.figure(dpi=150, figsize=(8, 5))
plt.plot(*start, "ks", label="Reference point")
for uf, marker, col in zip([uf1, uf2, uf3], ["x", '+', '.'], cols):
    if M > 2:
        uf = torch.tensor(pca.inverse_transform(uf.numpy()))
    uf = fnn.normalize(uf, p=2, dim=-1)
    ufs = torch.tile(uf, (condsamples, 1))
    xcand, _ = agps.proposal.cproposal(ufs)
    ycand = bb(xcand)
    if M > 2:
        ycand = torch.tensor(pca.transform(ycand.numpy()))
    plt.plot(*ycand.T, marker, label=f"u = {uf.numpy()}", c=col, markersize=10,
             alpha=0.7, mew=2)
    if M > 2:
        uf = pca.transform(uf.numpy()[np.newaxis, :]).squeeze()
    dx, dy = uf[0].item() / 2, uf[1].item() / 2
    end = (start[0] + dx, start[1] + dy)
    plt.annotate(
        "",  # no text
        xy=end,  # arrow head at 'end'
        xytext=start,  # arrow tail at 'start'
        arrowprops={
            "arrowstyle": "->",
            "color": col,
            "lw": 2,
            "shrinkA": 0,
            "shrinkB": 0,
        },
    )
plt.legend()
plt.grid()
plt.tight_layout()
plt.title(FUNCTION)
plt.xlabel("$f_1$")
plt.ylabel("$f_2$")
plt.show()

In [None]:
results = {m: [] for m in methods + ["A-GPS"]}

# Run A-GPS multiple times
for run in range(N_REPS):
    X_vsd, y_vsd = X_init.clone(), y_init.clone()
    hv_vsd = []
    hv_comp = Hypervolume(ref_point=ref_point)
    agps = AGPS()
    hv_init = hv_comp.compute(y_vsd[is_non_dominated_strict(y_vsd)])
    hv_vsd.append(1.)
    for t in range(N_ITER):
        Xc, _ = agps.fit(X_vsd, y_vsd, t)
        yc = bb(Xc)
        X_vsd = torch.cat([X_vsd, Xc], dim=0)
        y_vsd = torch.cat([y_vsd, yc], dim=0)
        hv = hv_comp.compute(y_vsd[is_non_dominated_strict(y_vsd)]) / hv_init
        hv_vsd.append(hv)
    results["A-GPS"].append(hv_vsd)

# Run baselines
for method in methods:
    for run in range(N_REPS):
        X, Y, hvs = X_init.clone(), y_init.clone(), []
        hv_comp = Hypervolume(ref_point=ref_point)
        hv_init = hv_comp.compute(Y[is_non_dominated_strict(Y)])
        hvs.append(1.)
        for i in range(N_ITER):
            print(f"Training {method}: {i} ... ", end="")
            if method == "qEHVI":
                moo_model = fit_gp_moo_model(X, Y)
                Xc = generate_candidates_qehvi(moo_model, X, Y)
            elif method == "qNEHVI":
                moo_model = fit_gp_moo_model(X, Y)
                Xc = generate_candidates_qnehvi(moo_model, X, Y)
            else:
                sca_model = fit_gp_sca_model(X, Y)
                Xc = generate_candidates_nparego(sca_model, X, Y)
            yc = bb(Xc)
            X = torch.cat([X, Xc], dim=0)
            Y = torch.cat([Y, yc], dim=0)
            print("Done!")
            hv = hv_comp.compute(Y[is_non_dominated_strict(Y)]) / hv_init
            hvs.append(hv)
        results[method].append(hvs)

In [None]:
LINECYCLE = cycle(
    [
        "-",
        "--",
        ":",
        "-.",
        (0, (3, 1, 1, 1)),
        (0, (5, 1)),
        (0, (3, 1, 1, 1, 1, 1)),
        (0, (1, 1)),
    ]
)
cycler = plt.cycler(
        color=plt.cm.viridis(np.linspace(0.05, 0.95, 4))
    ) + plt.cycler(linestyle=[next(LINECYCLE) for _ in range(4)])
plt.rc("axes", prop_cycle=cycler)

plt.figure(dpi=150, figsize=(8, 5))

for method, hvs in results.items():
    hvs_array = np.array(hvs)
    mean = hvs_array.mean(axis=0)
    std = hvs_array.std(axis=0)
    lower = mean - std
    upper = mean + std
    print(f"{method} :")
    for i, (m, s) in enumerate(zip(mean, std)):
        print(f"  {i}: {m:.3f} ({s:.3f})")

    plt.plot(
        mean,
        label=method,
        marker='o' if method == 'A-GPS' else 'x',
        markersize=10,
        linewidth=3,
    )
    plt.fill_between(range(N_ITER+1), lower, upper, alpha=0.2)

plt.xlabel("Round")
plt.ylabel("Relative Hyper-volume")
plt.title(FUNCTION)
plt.legend()
plt.grid()
plt.show()