Skip to content

Commit

Permalink
Refactor multiobjective setup
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 651135860
  • Loading branch information
xingyousong authored and Copybara-Service committed Jul 10, 2024
1 parent 3bc353a commit 47aa1c8
Showing 1 changed file with 56 additions and 75 deletions.
131 changes: 56 additions & 75 deletions vizier/_src/algorithms/designers/gp_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ class VizierGPBandit(vza.Designer, vza.Predictor):

_problem: vz.ProblemStatement = attr.field(kw_only=False)
_acquisition_optimizer_factory: vb.VectorizedOptimizerFactory = attr.field(
default=default_acquisition_optimizer_factory,
kw_only=True,
factory=lambda: default_acquisition_optimizer_factory,
)
_ard_optimizer: optimizers.Optimizer[types.ParameterDict] = attr.field(
factory=optimizers.default_optimizer,
Expand All @@ -123,7 +123,7 @@ class VizierGPBandit(vza.Designer, vza.Predictor):
# still tunes its amplitude. Only used for single-objective.
_linear_coef: Optional[float] = attr.field(default=None, kw_only=True)
_scoring_function_factory: acq_lib.ScoringFunctionFactory = attr.field(
factory=lambda: default_scoring_function_factory,
default=default_scoring_function_factory,
kw_only=True,
)
_scoring_function_is_parallel: bool = attr.field(default=False, kw_only=True)
Expand All @@ -144,6 +144,10 @@ class VizierGPBandit(vza.Designer, vza.Predictor):
factory=output_warpers.create_default_warper, kw_only=True
)

# Multi-objective parameters.
_num_scalarizations: int = attr.field(default=1000, kw_only=True)
_ref_scaling: float = attr.field(default=0.01, kw_only=True)
_num_ehvi_samples: Optional[int] = attr.field(default=None, kw_only=True)
# ------------------------------------------------------------------
# Internal attributes which should not be set by callers.
# ------------------------------------------------------------------
Expand Down Expand Up @@ -195,6 +199,55 @@ def __attrs_post_init__(self):
),
)

# Multi-objective overrides.
m_info = self._problem.metric_information
if not m_info.is_single_objective:
num_obj = len(m_info.of_type(vz.MetricType.OBJECTIVE))

# Create scalarization weights.
self._rng, weights_rng = jax.random.split(self._rng)
weights = jax.random.normal(
weights_rng, shape=(self._num_scalarizations, num_obj)
)
weights = jnp.abs(weights)
weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True)

if self._num_ehvi_samples: # Sampled EHVI.
reduction_fn = lambda x: jnp.mean(jax.nn.relu(x), axis=[0, 1])
acquisition_fn = acq_lib.Sample(self._num_ehvi_samples)
else: # Scalarized UCB.
reduction_fn = lambda x: jnp.mean(x, axis=0)
acquisition_fn = acq_lib.UCB()

def acq_fn_factory(
data: types.ModelData,
) -> acq_lib.AcquisitionFunction:
# Scalarized UCB.
labels_array = data.labels.padded_array
has_labels = labels_array.shape[0] > 0
scalarizer = scalarization.HyperVolumeScalarization(
weights,
acq_lib.get_reference_point(data.labels, self._ref_scaling)
if has_labels
else None,
)

max_scalarized = (
jnp.max(scalarizer(labels_array), axis=-1) if has_labels else None
)
return acq_lib.ScalarizedAcquisition(
acquisition_fn,
scalarizer,
reduction_fn=reduction_fn,
max_scalarized=max_scalarized,
)

self._scoring_function_factory = (
acq_lib.bayesian_scoring_function_factory(acq_fn_factory),
)
self._scoring_function_is_parallel = True
self._use_trust_region = False

# Additional validations
coroutine = gp_models.get_vizier_gp_coroutine(empty_data)
params = sp.CoroutineWithData(coroutine, empty_data).setup(self._rng)
Expand Down Expand Up @@ -580,79 +633,7 @@ def from_problem(
cls,
problem: vz.ProblemStatement,
seed: Optional[int] = None,
*, # Below are multi-objective options for acquisition function.
num_scalarizations: int = 1000,
reference_scaling: float = 0.01,
num_samples: int | None = None,
**kwargs,
) -> 'VizierGPBandit':
rng = jax.random.PRNGKey(seed or 0)
if problem.is_single_objective:
return cls(problem, rng=rng, **kwargs)

# Multi-objective.
num_obj = len(problem.metric_information.of_type(vz.MetricType.OBJECTIVE))
rng, weights_rng = jax.random.split(rng)
weights = jnp.abs(
jax.random.normal(weights_rng, shape=(num_scalarizations, num_obj))
)
weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True)

if num_samples is None:

def acq_fn_factory(data: types.ModelData) -> acq_lib.AcquisitionFunction:
# Scalarized UCB.
labels_array = data.labels.padded_array
has_labels = labels_array.shape[0] > 0
scalarizer = scalarization.HyperVolumeScalarization(
weights,
acq_lib.get_reference_point(data.labels, reference_scaling)
if has_labels
else None,
)

max_scalarized = (
jnp.max(scalarizer(labels_array), axis=-1) if has_labels else None
)
return acq_lib.ScalarizedAcquisition(
acq_lib.UCB(),
scalarizer,
reduction_fn=lambda x: jnp.mean(x, axis=0),
max_scalarized=max_scalarized,
)

else:

def acq_fn_factory(data: types.ModelData) -> acq_lib.AcquisitionFunction:
# Sampled EHVI.
labels_array = data.labels.padded_array
has_labels = labels_array.shape[0] > 0
scalarizer = scalarization.HyperVolumeScalarization(
weights,
acq_lib.get_reference_point(data.labels, reference_scaling)
if has_labels
else None,
)

max_scalarized = (
jnp.max(scalarizer(labels_array), axis=-1) if has_labels else None
)
return acq_lib.ScalarizedAcquisition(
acq_lib.Sample(num_samples),
scalarizer,
# We need to reduce across the scalarization and sample axes.
reduction_fn=lambda x: jnp.mean(jax.nn.relu(x), axis=[0, 1]),
max_scalarized=max_scalarized,
)

scoring_function_factory = acq_lib.bayesian_scoring_function_factory(
acq_fn_factory
)
return cls(
problem,
scoring_function_factory=scoring_function_factory,
scoring_function_is_parallel=True,
use_trust_region=False,
rng=rng,
**kwargs,
)
return cls(problem, rng=rng, **kwargs)

0 comments on commit 47aa1c8

Please sign in to comment.