In [4]:
from time import time

import torch
from botorch.utils.multi_objective import infer_reference_point
from botorch.utils.transforms import unnormalize
from botorch.utils.multi_objective.pareto import is_non_dominated
from botorch.utils.multi_objective.box_decompositions import NondominatedPartitioning

from robust_mobo.input_transform import InputPerturbation
from robust_mobo.multi_objective_risk_measures import MVaR
from robust_mobo.experiment_utils import get_perturbations, get_problem

import matplotlib.pyplot as plt
%matplotlib inline

func = get_problem("gmm4")
use_mvar: bool = True

tkwargs = {}
n_w = 32
perturbations = get_perturbations(
    n_w=n_w,
    dim=func.dim,
    bounds=func.bounds,
    method="sobol-normal",
    std_dev=0.05,
    tkwargs=tkwargs,
)
if use_mvar:
    obj = MVaR(n_w=n_w, alpha=0.8)
    perturb = InputPerturbation(perturbation_set=perturbations).eval()
else:
    perturb = lambda x: x
    obj = lambda x: x

In [5]:
def get_ref(num_samples: int):
    start = time()
    test_x = unnormalize(torch.rand(num_samples, 1, func.dim, **tkwargs), func.bounds)
    y = obj(func(perturb(test_x))).view(-1, func.num_objectives)
    mask = is_non_dominated(y)
    pareto_y = y[mask]
    ref_pt = infer_reference_point(pareto_y)
    hv = NondominatedPartitioning(ref_pt, pareto_y).compute_hypervolume()
    if pareto_y.shape[-1] == 2:
        plt.scatter(pareto_y[:, 0], pareto_y[:, 1])
        plt.show()
    print("time ", time() - start)
    print("hv", hv)
    print(f"ref pt {ref_pt}")


In [6]:
%time get_ref(10000)

time  5137.311141014099
hv tensor(0.0290)
ref pt tensor([ 0.0322, -0.0398,  0.1168, -0.0023])
CPU times: user 3h 46min 27s, sys: 5min 45s, total: 3h 52min 13s
Wall time: 1h 25min 37s
