Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solver API (wip) #777

Merged
merged 1 commit into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions optax/experimental/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
As the name suggests, the functions implemented here are subject to
modifications. We are currently developing a new Solver API that could
span more optimizers such as the ones using some linesearches.
58 changes: 58 additions & 0 deletions optax/experimental/api_test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "vfWSk55u5_E-"
},
"source": [
"# Example of Solver API usage\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vzyIF6NW6Dwd"
},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"import optax\n",
"from optax.experimental import gradient_solver\n",
"\n",
"def obj_fun(x):\n",
" return jnp.sum(x**2)\n",
"\n",
"init, step = gradient_solver.gradient_solver(\n",
" obj_fun, optax.adam(learning_rate=1.)\n",
" )\n",
"\n",
"params = jnp.arange(16, dtype=jnp.float32)\n",
"state = init(params)\n",
"for _ in range(10):\n",
" params, state = step(params, state)\n",
" print(f'Objective value: {obj_fun(params)}')"
]
}
],
"metadata": {
"colab": {
"last_runtime": {
"build_target": "//learning/grp/tools/ml_python:ml_notebook",
"kind": "private"
},
"private_outputs": true,
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
64 changes: 64 additions & 0 deletions optax/experimental/gradient_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wraps a GradientTransform into a Solver."""

from typing import Any, NamedTuple, Union

import jax
import optax
import optax.experimental.solver as optax_solver
import optax.experimental.utils as exp_utils


class GradientSolverState(NamedTuple):
gt_state: optax.OptState = None


def gradient_solver(obj_fn, gradient_transform, obj_fun_has_aux=False):
"""Wraps a GradientTransform into a Solver."""

def init(init_params: optax.Params) -> optax_solver.SolverState:
init_gt_state = gradient_transform.init(init_params)
init_opt_state = GradientSolverState(init_gt_state)
return init_opt_state

def step(
params: optax.Params,
state: optax_solver.SolverState,
**extra_kwargs: dict[str, Any]
) -> tuple[
Union[optax.Params, tuple[optax.Params, Any]], optax_solver.SolverState
]:
obj_kwargs, gt_kwargs = exp_utils.split_kwargs(
(obj_fn, gradient_transform.update), extra_kwargs
)
if obj_fun_has_aux:
grad, aux = jax.grad(obj_fn, has_aux=obj_fun_has_aux)(
params, **obj_kwargs
)
else:
grad = jax.grad(obj_fn)(params, **obj_kwargs)
aux = None
update, gt_state = gradient_transform.update(
grad, state.gt_state, params, **gt_kwargs
)
next_params = optax.apply_updates(params, update)
next_state = GradientSolverState(gt_state)
if obj_fun_has_aux:
return (next_params, aux), next_state
else:
return next_params, next_state

return optax_solver.Solver(init, step)
110 changes: 110 additions & 0 deletions optax/experimental/gradient_solver_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `alias.py`."""

from absl.testing import absltest
from absl.testing import parameterized

import chex
import jax
import jax.numpy as jnp

from optax._src import alias
from optax._src import numerics
from optax.experimental import gradient_solver


_GRAD_TRANSFORMS_UNDER_TEST = (
dict(gt_name='sgd', gt_kwargs=dict(learning_rate=1e-3, momentum=0.9)),
dict(gt_name='adafactor', gt_kwargs=dict(learning_rate=5e-3)),
dict(gt_name='adagrad', gt_kwargs=dict(learning_rate=1.0)),
dict(gt_name='adam', gt_kwargs=dict(learning_rate=1e-1)),
dict(gt_name='adamw', gt_kwargs=dict(learning_rate=1e-1)),
dict(gt_name='adamax', gt_kwargs=dict(learning_rate=1e-1)),
dict(gt_name='adamaxw', gt_kwargs=dict(learning_rate=1e-1)),
dict(gt_name='amsgrad', gt_kwargs=dict(learning_rate=1e-1)),
dict(gt_name='lars', gt_kwargs=dict(learning_rate=1.0)),
dict(gt_name='lamb', gt_kwargs=dict(learning_rate=1e-3)),
dict(
gt_name='lion', gt_kwargs=dict(learning_rate=1e-2, weight_decay=1e-4),
),
dict(gt_name='nadam', gt_kwargs=dict(learning_rate=1e-2)),
dict(gt_name='nadamw', gt_kwargs=dict(learning_rate=1e-2)),
dict(gt_name='noisy_sgd', gt_kwargs=dict(learning_rate=1e-3, eta=1e-4)),
dict(gt_name='novograd', gt_kwargs=dict(learning_rate=1e-3)),
dict(
gt_name='optimistic_gradient_descent',
gt_kwargs=dict(learning_rate=2e-3, alpha=0.7, beta=0.1),
),
dict(gt_name='rmsprop', gt_kwargs=dict(learning_rate=5e-3)),
dict(gt_name='rmsprop', gt_kwargs=dict(learning_rate=5e-3, momentum=0.9)),
dict(gt_name='fromage', gt_kwargs=dict(learning_rate=5e-3)),
dict(gt_name='adabelief', gt_kwargs=dict(learning_rate=1e-2)),
dict(gt_name='radam', gt_kwargs=dict(learning_rate=5e-3)),
dict(gt_name='rprop', gt_kwargs=dict(learning_rate=1e-1)),
dict(gt_name='sm3', gt_kwargs=dict(learning_rate=1.0)),
dict(gt_name='yogi', gt_kwargs=dict(learning_rate=1e-1)),
)


def _setup_parabola(dtype):
"""Quadratic function as an optimization target."""
initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype)
final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype)

def obj_fun(params):
return jnp.sum(numerics.abs_sq(params - final_params))

return initial_params, final_params, obj_fun


def _setup_rosenbrock(dtype):
"""Rosenbrock function as an optimization target."""
a = 1.0
b = 100.0

initial_params = jnp.array([0.0, 0.0], dtype=dtype)
final_params = jnp.array([a, a**2], dtype=dtype)

def obj_fun(params):
return (numerics.abs_sq(a - params[0]) +
b * numerics.abs_sq(params[1] - params[0]**2))

return initial_params, final_params, obj_fun


class SolverWrapperTest(chex.TestCase):

@parameterized.product(
_GRAD_TRANSFORMS_UNDER_TEST,
target=(_setup_parabola, _setup_rosenbrock),
dtype=(jnp.float32,),
)
def test_optimization(self, gt_name, gt_kwargs, target, dtype):
opt = getattr(alias, gt_name)(**gt_kwargs)
initial_params, final_params, obj_fun = target(dtype)

init, step = gradient_solver.gradient_solver(obj_fun, opt)

params = initial_params
state = init(params)
step = jax.jit(step)
for _ in range(10_000):
params, state = step(params, state)

chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2)

if __name__ == '__main__':
absltest.main()
89 changes: 89 additions & 0 deletions optax/experimental/solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Solver API."""

from typing import Any, NamedTuple, Protocol, Union

import optax

Params = optax.Params
SolverState = Any


class SolverInitFn(Protocol):
"""A callable type for the `init` function of a `Solver`.

The `init` function takes a tree of `params` and uses these to construct an
arbitrary structured initial `state` for the solver. This
may hold statistics of the past updates or any other non static information.
"""

def __call__(self, params: Params) -> SolverState:
"""Initialize the solver.

Args:
params: The initial value of the parameters.

Returns:
The initial state of the solver.
"""


class SolverStepFn(Protocol):
"""A callable type for the `step` function of a `Solver`.

The `step` function takes a tree of candidate parameters `params`, and an
arbitrary structured `state` to return a new tree of candidate parameters,
and a new state. Additional arguments can be fed in a keyword format.
"""

def __call__(
self, params: Params, state: SolverState, **extra_kwargs: dict[str, Any]
) -> tuple[Union[Params, tuple[Params, Any]], SolverState]:
"""Performs a step of the solver.

Args:
params: A tree of candidate parameters.
state: The state of the solver.
**extra_kwargs: Additional arguments for the function or the solver in
keyword format.

Returns:
The updated parameters, eventually with an auxiliary output,
and updated state.
"""


class Solver(NamedTuple):
"""A pair of pure functions implementing a solver.

The init function initializes the state of the solver given an initial tree of
parameters. The step function updates the parameters and the state of the
solver given current parameters and state.
Contrarily to GradientTransformation, this API accesses the function to be
optimized directly to compute gradients, then update directions and finally
updated parameters.

Attributes:
init: A pure function which, when called with an example instance of the
parameters, returns an arbitrary structured initial `state`.
step: A pure function which takes as input a tree of parameters, the
previous solver state (which may have been initialized using the init
function). The step function then returns the updated parameters,
and a new solver state.
"""

init_fn: SolverInitFn
step_fn: SolverStepFn
65 changes: 65 additions & 0 deletions optax/experimental/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2023 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for solvers.
"""
import functools
import inspect
import operator

from typing import Any, Callable, Sequence


def split_kwargs(
funs: Sequence[Callable[..., Any]],
fun_kwargs: dict[str, Any],
) -> Sequence[dict[str, Any]]:
"""Split fun_kwargs into kwargs of the input functions funs.

Raises an error in one keyword argument of fun_kwargs does not match any
argument name of funs.

Args:
funs: sequence of functions to feed fun_kwargs to
fun_kwargs: dictionary of keyword variables to be fed to funs

Returns:
(fun_1_kwargs, ..., fun_n_kwargs): keyword arguments for each function taken
from fun_kwargs.

Examples:
>>> def fun1(a, b): return a+b
>>> def fun2(c, d): return c+d
>>> fun_kwargs = {'b':1., 'd':2.}
>>> funs_kwargs = split_kwargs((fun1, fun2), fun_kwargs)
>>> print(funs_kwargs)
[{'b': 1.0}, {'d': 2.0}]
"""
funs_arg_names = [
list(inspect.signature(fun).parameters.keys()) for fun in funs
]
funs_kwargs = [
{k: v for k, v in fun_kwargs.items() if k in fun_arg_names}
for fun_arg_names in funs_arg_names
]
all_possible_arg_names = functools.reduce(operator.add, funs_arg_names)
remaining_keys = [
k for k in fun_kwargs.keys() if k not in all_possible_arg_names
]
if remaining_keys:
raise ValueError(
f'{remaining_keys} are not valid arguments for any of the functions'
f' {funs}'
)
return funs_kwargs
Loading