Skip to content

Commit

Permalink
Vectorize the scalarization method instead of producing multiple clas…
Browse files Browse the repository at this point in the history
…s instances. This should run faster + reduce the jitted graph size (prevent RAM blowups)

PiperOrigin-RevId: 648911793
  • Loading branch information
xingyousong authored and Copybara-Service committed Jul 3, 2024
1 parent 94b53ef commit 5d0e992
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 35 deletions.
15 changes: 9 additions & 6 deletions vizier/_src/algorithms/designers/gp/acquisitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,23 +539,26 @@ def __call__(
)()


# TODO: What do we end up jitting? If we end up directly jitting this call
# then we should make it `eqx.Module` and set
# `reduction_fn=eqx.field(static=True)` instead.
@struct.dataclass
class ScalarizedAcquisition(AcquisitionFunction):
"""Wrapper that scalarizes multiple objective before acquisition eval."""

acquisition_fn: AcquisitionFunction
scalarizers: list[scalarization.Scalarization]
scalarizer: scalarization.Scalarization
reduction_fn: Callable[[jax.Array], jax.Array] = struct.field(
pytree_node=False, default=lambda x: x
)

def __call__(
self,
dist: tfd.Distribution,
seed: Optional[jax.Array] = None,
) -> jax.Array:
scores = [
jnp.squeeze(scalarizer(self.acquisition_fn(dist, seed)))
for scalarizer in self.scalarizers
]
return jnp.mean(jnp.stack(scores, axis=0), axis=0)
scalarized = self.scalarizer(self.acquisition_fn(dist, seed).squeeze())
return self.reduction_fn(scalarized)


@struct.dataclass
Expand Down
3 changes: 2 additions & 1 deletion vizier/_src/algorithms/designers/gp/acquisitions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from vizier._src.algorithms.designers.gp import acquisitions
from vizier._src.jax import types
from absl.testing import absltest

tfd = tfp.distributions
tfpk = tfp.math.psd_kernels
tfpke = tfp.experimental.psd_kernels
Expand Down Expand Up @@ -79,7 +80,7 @@ def test_scalarized_ucb(self):
scalarizer = scalarization.HyperVolumeScalarization(
weights=jnp.array([0.1, 0.2]), reference_point=reference_point
)
acq = acquisitions.ScalarizedAcquisition(ucb, [scalarizer, scalarizer])
acq = acquisitions.ScalarizedAcquisition(ucb, scalarizer)
self.assertAlmostEqual(
acq(tfd.Normal([0.1, 0.2], [1, 2])), jnp.array(20.9), delta=1e-2
)
Expand Down
25 changes: 14 additions & 11 deletions vizier/_src/algorithms/designers/gp_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import attr
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from vizier import algorithms as vza
from vizier import pyvizier as vz
Expand Down Expand Up @@ -586,19 +587,21 @@ def from_problem(
if problem.is_single_objective:
return cls(problem, linear_coef=1.0, rng=rng, **kwargs)
else:
objectives = problem.metric_information.of_type(vz.MetricType.OBJECTIVE)
random_weights = [
np.abs(np.random.normal(size=len(objectives)))
for _ in range(num_scalarizations)
]
num_obj = len(problem.metric_information.of_type(vz.MetricType.OBJECTIVE))
rng, weights_rng = jax.random.split(rng)
weights = jnp.abs(
jax.random.normal(weights_rng, shape=(num_scalarizations, num_obj))
)

def _scalarized_ucb(data: types.ModelData) -> acq_lib.AcquisitionFunction:
reference_point = acq_lib.get_worst_labels(data.labels)
scalarizers = [
scalarization.HyperVolumeScalarization(weights, reference_point)
for weights in random_weights
]
return acq_lib.ScalarizedAcquisition(acq_lib.UCB(), scalarizers)
scalarizer = scalarization.HyperVolumeScalarization(
weights, acq_lib.get_worst_labels(data.labels)
)
return acq_lib.ScalarizedAcquisition(
acq_lib.UCB(),
scalarizer,
reduction_fn=lambda x: jnp.mean(x, axis=0),
)

scoring_function_factory = acq_lib.bayesian_scoring_function_factory(
_scalarized_ucb
Expand Down
50 changes: 33 additions & 17 deletions vizier/_src/algorithms/designers/scalarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@
import typeguard


def _broadcast_multiply(
weights: jt.Float[jax.Array, '*Num #Obj'],
objs: jt.Float[jax.Array, '*Batch #Obj'],
) -> jt.Float[jax.Array, '*NumBatch #Obj']:
# [*Num, #Obj] -> [*Num, 1, ..., 1, #Obj]
broadcasted_weights = jnp.expand_dims(
weights, axis=range(-2, -1 - len(objs.shape), -1)
)
return broadcasted_weights * objs


class Scalarization(abc.ABC, eqx.Module):
"""Reduces an array of objectives to a single float.
Expand All @@ -35,67 +46,72 @@ class Scalarization(abc.ABC, eqx.Module):
@abc.abstractmethod
def __call__(
self, objectives: jt.Float[jax.Array, '*Batch Obj']
) -> jt.Float[jax.Array, '*Batch']:
) -> jt.Float[jax.Array, '*NumBatch']:
"""Computes the scalarization."""


# Scalarization factory from weights.
ScalarizationFromWeights = Callable[
[jt.Float[jax.Array, '#Obj']], Scalarization
[jt.Float[jax.Array, '*Num #Obj']], Scalarization
]


class LinearScalarization(Scalarization):
"""Linear Scalarization."""
weights: jt.Float[jax.Array, '#Obj'] = eqx.field(converter=jnp.asarray)

weights: jt.Float[jax.Array, '*Num #Obj'] = eqx.field(converter=jnp.asarray)

@jt.jaxtyped(typechecker=typeguard.typechecked)
def __call__(
self, objectives: jt.Float[jax.Array, '*Batch Obj']
) -> jt.Float[jax.Array, '*Batch']:
return jnp.sum(self.weights * objectives, axis=-1)
) -> jt.Float[jax.Array, '*NumBatch']:
product = _broadcast_multiply(self.weights, objectives)
return jnp.sum(product, axis=-1)


class ChebyshevScalarization(Scalarization):
"""Chebyshev Scalarization."""
weights: jt.Float[jax.Array, '#Obj'] = eqx.field(converter=jnp.asarray)

weights: jt.Float[jax.Array, '*Num #Obj'] = eqx.field(converter=jnp.asarray)

@jt.jaxtyped(typechecker=typeguard.typechecked)
def __call__(
self, objectives: jt.Float[jax.Array, '*Batch Obj']
) -> jt.Float[jax.Array, '*Batch']:
return jnp.min(objectives * self.weights, axis=-1)
) -> jt.Float[jax.Array, '*NumBatch']:
product = _broadcast_multiply(self.weights, objectives)
return jnp.min(product, axis=-1)


class HyperVolumeScalarization(Scalarization):
"""HyperVolume Scalarization."""
weights: jt.Float[jax.Array, '#Obj'] = eqx.field(converter=jnp.asarray)

weights: jt.Float[jax.Array, '*Num #Obj'] = eqx.field(converter=jnp.asarray)
reference_point: Optional[jt.Float[jax.Array, '* #Obj']] = eqx.field(
default=None
)

@jt.jaxtyped(typechecker=typeguard.typechecked)
def __call__(
self, objectives: jt.Float[jax.Array, '*Batch Obj']
) -> jt.Float[jax.Array, '*Batch']:
) -> jt.Float[jax.Array, '*NumBatch']:
# Uses scalarizations in https://arxiv.org/abs/2006.04655 for
# non-convex multiobjective optimization. Removes the exponentiation
# factor in number of objectives as it is a monotone transformation and
# removes the non-negativity for easier gradients.
if self.reference_point is not None:
return jnp.min(
(objectives - self.reference_point) / self.weights, axis=-1
)
else:
return jnp.min(objectives / self.weights, axis=-1)
objectives = objectives - self.reference_point

product = _broadcast_multiply(1.0 / self.weights, objectives)
return jnp.min(product, axis=-1)


class LinearAugmentedScalarization(Scalarization):
"""Scalarization augmented with a linear sum.
See https://arxiv.org/pdf/1904.05760.pdf.
"""
weights: jt.Float[jax.Array, '#Obj'] = eqx.field(converter=jnp.asarray)

weights: jt.Float[jax.Array, '*Num #Obj'] = eqx.field(converter=jnp.asarray)

scalarization_factory: ScalarizationFromWeights = eqx.field(static=True)
augment_weight: jt.Float[jax.Array, ''] = eqx.field(
Expand All @@ -105,7 +121,7 @@ class LinearAugmentedScalarization(Scalarization):
@jt.jaxtyped(typechecker=typeguard.typechecked)
def __call__(
self, objectives: jt.Float[jax.Array, '*Batch Obj']
) -> jt.Float[jax.Array, '*Batch']:
) -> jt.Float[jax.Array, '*NumBatch']:
return self.scalarization_factory(self.weights)(
objectives
) + self.augment_weight * LinearScalarization(weights=self.weights)(
Expand Down

0 comments on commit 5d0e992

Please sign in to comment.