Skip to content

Commit

Permalink
Add schedule free Adam optimizer.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634460555
  • Loading branch information
ColCarroll authored and The bayeux Authors committed May 16, 2024
1 parent 316ccd0 commit 8192095
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 7 deletions.
17 changes: 12 additions & 5 deletions bayeux/_src/mcmc/blackjax.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@
}


def _convert_algorithm(algorithm):
# Remove this after blackjax is stable
if hasattr(algorithm, "differentiable"):
return algorithm.differentiable
return algorithm


def get_extra_kwargs(kwargs):
defaults = {
"chain_method": "vectorized",
Expand All @@ -64,8 +71,8 @@ def get_kwargs(self, **kwargs):
adapt_fn, algorithm, constrained_log_density, extra_parameters | kwargs)
return {adapt_fn: adaptation_kwargs,
"adapt.run": run_kwargs,
algorithm: get_algorithm_kwargs(
algorithm, constrained_log_density, kwargs),
_convert_algorithm(algorithm): get_algorithm_kwargs(
_convert_algorithm(algorithm), constrained_log_density, kwargs),
"extra_parameters": extra_parameters}

def __call__(self, seed, **kwargs):
Expand Down Expand Up @@ -171,7 +178,7 @@ def _blackjax_inference(
(states, infos), adaptation_parameters
"""

algorithm_kwargs = kwargs[algorithm] | adapt_parameters
algorithm_kwargs = kwargs[_convert_algorithm(algorithm)] | adapt_parameters
inference_algorithm = algorithm(**algorithm_kwargs)
_, states, infos = blackjax.util.run_inference_algorithm(
rng_key=seed,
Expand Down Expand Up @@ -257,8 +264,8 @@ def get_adaptation_kwargs(adaptation_algorithm, algorithm, log_density, kwargs):
adaptation_required.remove("algorithm")
adaptation_kwargs["algorithm"] = algorithm
adaptation_kwargs = (
get_algorithm_kwargs(algorithm, log_density, kwargs) | adaptation_kwargs
)
get_algorithm_kwargs(_convert_algorithm(algorithm), log_density, kwargs)
| adaptation_kwargs)

adaptation_required = adaptation_required - adaptation_kwargs.keys()

Expand Down
19 changes: 17 additions & 2 deletions bayeux/_src/optimize/optax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,24 @@
from bayeux._src.optimize import shared
import jax
import optax
import optax.contrib


class _OptaxOptimizer(shared.Optimizer):
"""Base class for optax optimizers."""
_base = optax

def get_kwargs(self, **kwargs):
kwargs = self.default_kwargs() | kwargs
optimizer = getattr(optax, self.optimizer)
optimizer = getattr(self._base, self.optimizer)
return {optimizer: shared.get_optimizer_kwargs(optimizer, kwargs),
"extra_parameters": shared.get_extra_kwargs(kwargs)}

def __call__(self, seed, **kwargs):
kwargs = self.get_kwargs(**kwargs)
fun, initial_state, apply_transform = self._prep_args(seed, kwargs)

optimizer_fn = getattr(optax, self.optimizer)
optimizer_fn = getattr(self._base, self.optimizer)
optimizer = optimizer_fn(**kwargs[optimizer_fn])
num_iters = kwargs["extra_parameters"]["num_iters"]
optimizer = functools.partial(
Expand Down Expand Up @@ -183,6 +185,19 @@ def default_kwargs(self) -> dict[str, float]:
return kwargs


class ScheduleFree(_OptaxOptimizer):
_base = optax.contrib
name = "optax_schedule_free"
optimizer = "schedule_free"

def default_kwargs(self) -> dict[str, float]:
kwargs = super().default_kwargs()
base_optimizer = optax.adam(
**shared.get_optimizer_kwargs(optax.adam, kwargs))
kwargs["base_optimizer"] = base_optimizer
return kwargs


class Sgd(_OptaxOptimizer):
name = "optax_sgd"
optimizer = "sgd"
Expand Down
2 changes: 2 additions & 0 deletions bayeux/optimize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
# from bayeux._src.optimize.optax import OptimisticGradientDescent # pylint: disable=line-too-long
from bayeux._src.optimize.optax import Radam
from bayeux._src.optimize.optax import Rmsprop
from bayeux._src.optimize.optax import ScheduleFree
from bayeux._src.optimize.optax import Sgd
from bayeux._src.optimize.optax import Sm3
from bayeux._src.optimize.optax import Yogi
Expand All @@ -66,6 +67,7 @@
# "Dpsgd",
"Radam",
"Rmsprop",
"ScheduleFree",
"Sgd",
"Sm3",
"Yogi",
Expand Down

0 comments on commit 8192095

Please sign in to comment.