Skip to content

Commit

Permalink
Modernize @jaxtyped code
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 590708836
  • Loading branch information
sagipe authored and Copybara-Service committed Dec 13, 2023
1 parent 4b14b9c commit e64328f
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions vizier/_src/algorithms/designers/scalarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ class Scalarization(abc.ABC, eqx.Module):
# Weights shape should be broadcastable with objectives when called.
weights: jt.Float[jax.Array, '#Obj'] = eqx.field(converter=jnp.asarray)

@jt.jaxtyped
@typeguard.typechecked
@jt.jaxtyped(typechecker=typeguard.typechecked)
def __call__(
self, objectives: jt.Float[jax.Array, '*Batch Obj']
) -> jt.Float[jax.Array, '*Batch']:
Expand Down

0 comments on commit e64328f

Please sign in to comment.