Skip to content

Commit

Permalink
add adamw to torch backend (#532)
Browse files Browse the repository at this point in the history
Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
  • Loading branch information
2 people authored and fchollet committed Jul 18, 2023
1 parent 2481069 commit a0d7776
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 84 deletions.
6 changes: 6 additions & 0 deletions keras_core/backend/torch/optimizers/torch_adamw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from keras_core import optimizers
from keras_core.backend.torch.optimizers import torch_adam


class AdamW(torch_adam.Adam, optimizers.AdamW):
pass
2 changes: 2 additions & 0 deletions keras_core/backend/torch/optimizers/torch_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ class TorchOptimizer(BaseOptimizer):
def __new__(cls, *args, **kwargs):
# Import locally to avoid circular imports.
from keras_core.backend.torch.optimizers import torch_adam
from keras_core.backend.torch.optimizers import torch_adamw
from keras_core.backend.torch.optimizers import torch_sgd

OPTIMIZERS = {
optimizers.SGD: torch_sgd.SGD,
optimizers.Adam: torch_adam.Adam,
optimizers.AdamW: torch_adamw.AdamW,
}
if cls in OPTIMIZERS:
return OPTIMIZERS[cls](*args, **kwargs)
Expand Down
87 changes: 7 additions & 80 deletions keras_core/optimizers/adamw.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from keras_core import ops
from keras_core.api_export import keras_core_export
from keras_core.optimizers import adam
from keras_core.optimizers import optimizer


@keras_core_export(["keras_core.optimizers.AdamW"])
class AdamW(optimizer.Optimizer):
class AdamW(adam.Adam):
"""Optimizer that implements the AdamW algorithm.
AdamW optimization is a stochastic gradient descent method that is based on
Expand Down Expand Up @@ -68,99 +68,26 @@ def __init__(
):
super().__init__(
learning_rate=learning_rate,
beta_1=beta_1,
beta_2=beta_2,
epsilon=epsilon,
amsgrad=amsgrad,
name=name,
weight_decay=weight_decay,
clipnorm=clipnorm,
clipvalue=clipvalue,
global_clipnorm=global_clipnorm,
use_ema=use_ema,
ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency,
)
self.weight_decay = weight_decay
self.beta_1 = beta_1
self.beta_2 = beta_2
self.epsilon = epsilon
self.amsgrad = amsgrad

if self.weight_decay is None:
raise ValueError(
"Argument `weight_decay` must be a float. Received: "
"weight_decay=None"
)

def build(self, var_list):
"""Initialize optimizer variables.
AdamW optimizer has 3 types of variables: momentums, velocities and
velocity_hat (only set when amsgrad is applied),
Args:
var_list: list of model variables to build AdamW variables on.
"""
if self.built:
return
super().build(var_list)
self._momentums = []
self._velocities = []
for var in var_list:
self._momentums.append(
self.add_variable_from_reference(
reference_variable=var, name="m"
)
)
self._velocities.append(
self.add_variable_from_reference(
reference_variable=var, name="v"
)
)
if self.amsgrad:
self._velocity_hats = []
for var in var_list:
self._velocity_hats.append(
self.add_variable_from_reference(
reference_variable=var, name="vhat"
)
)

def update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
lr = ops.cast(learning_rate, variable.dtype)
gradient = ops.cast(gradient, variable.dtype)
local_step = ops.cast(self.iterations + 1, variable.dtype)
beta_1_power = ops.power(
ops.cast(self.beta_1, variable.dtype), local_step
)
beta_2_power = ops.power(
ops.cast(self.beta_2, variable.dtype), local_step
)

m = self._momentums[self._get_variable_index(variable)]
v = self._velocities[self._get_variable_index(variable)]

alpha = lr * ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)

m.assign(m + (gradient - m) * (1 - self.beta_1))
v.assign(v + (ops.square(gradient) - v) * (1 - self.beta_2))
if self.amsgrad:
v_hat = self._velocity_hats[self._get_variable_index(variable)]
v_hat.assign(ops.maximum(v_hat, v))
v = v_hat
variable.assign(variable - (m * alpha) / (ops.sqrt(v) + self.epsilon))

def get_config(self):
config = super().get_config()

config.update(
{
"weight_decay": self.weight_decay,
"beta_1": self.beta_1,
"beta_2": self.beta_2,
"epsilon": self.epsilon,
"amsgrad": self.amsgrad,
}
)
return config


AdamW.__doc__ = AdamW.__doc__.replace(
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
Expand Down
9 changes: 5 additions & 4 deletions keras_core/optimizers/adamw_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np

from keras_core import backend
from keras_core import ops
from keras_core import testing
from keras_core.optimizers.adamw import AdamW

Expand All @@ -22,7 +23,7 @@ def test_config(self):

def test_single_step(self):
optimizer = AdamW(learning_rate=0.5)
grads = np.array([1.0, 6.0, 7.0, 2.0])
grads = ops.array([1.0, 6.0, 7.0, 2.0])
vars = backend.Variable([1.0, 2.0, 3.0, 4.0])
optimizer.apply_gradients(zip([grads], [vars]))
self.assertAllClose(
Expand All @@ -31,7 +32,7 @@ def test_single_step(self):

def test_weight_decay(self):
grads, var1, var2, var3 = (
np.zeros(()),
ops.zeros(()),
backend.Variable(2.0),
backend.Variable(2.0, name="exclude"),
backend.Variable(2.0),
Expand All @@ -55,8 +56,8 @@ def test_correctness_with_golden(self):
optimizer = AdamW(learning_rate=1.0, weight_decay=0.5, epsilon=2)

x = backend.Variable(np.ones([10]))
grads = np.arange(0.1, 1.1, 0.1)
first_grads = np.full((10,), 0.01)
grads = ops.arange(0.1, 1.1, 0.1)
first_grads = ops.full((10,), 0.01)

# fmt: off
golden = np.array(
Expand Down

0 comments on commit a0d7776

Please sign in to comment.