Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: decay defaults for optimizers #492

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 42 additions & 12 deletions lib/axon/optimizers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ defmodule Axon.Optimizers do
* [AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients](https://arxiv.org/abs/2010.07468)
"""
def adabelief(learning_rate \\ 1.0e-3, opts \\ []) do
opts = Keyword.validate!(opts, b1: 0.9, b2: 0.999, eps: 0.0, eps_root: 1.0e-16)

Updates.scale_by_belief(opts)
|> scale_by_learning_rate(learning_rate)
end
Expand Down Expand Up @@ -105,6 +107,8 @@ defmodule Axon.Optimizers do
* [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)
"""
def adam(learning_rate \\ 1.0e-3, opts \\ []) do
opts = Keyword.validate!(opts, b1: 0.9, b2: 0.999, eps: 1.0e-8, eps_root: 1.0e-15)

Updates.scale_by_adam(opts)
|> scale_by_learning_rate(learning_rate)
end
Expand All @@ -118,10 +122,10 @@ defmodule Axon.Optimizers do
* `:b2` - second moment decay. Defaults to `0.999`
* `:eps` - numerical stability term. Defaults to `1.0e-8`
* `:eps_root` - numerical stability term. Defaults to `0.0`
* `:decay` - weight decay. Defaults to `0.0`
* `:decay` - weight decay. Defaults to `0.95`
"""
def adamw(learning_rate \\ 1.0e-3, opts \\ []) do
{decay, opts} = Keyword.pop(opts, :decay, 0.0)
{decay, opts} = Keyword.pop(opts, :decay, 0.95)

Updates.scale_by_adam(opts)
|> Updates.add_decayed_weights(decay: decay)
Expand All @@ -137,15 +141,25 @@ defmodule Axon.Optimizers do
* `:b2` - second moment decay. Defaults to `0.999`
* `:eps` - numerical stability term. Defaults to `1.0e-8`
* `:eps_root` - numerical stability term. Defaults to `0.0`
* `:decay` - weight decay. Defaults to `0.0`
* `:decay` - weight decay. Defaults to `0.95`
* `:min_norm` - minimum norm value. Defaults to `0.0`

## References

* [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962)
"""
def lamb(learning_rate \\ 1.0e-2, opts \\ []) do
{decay, opts} = Keyword.pop(opts, :decay, 0.0)
opts =
Keyword.validate!(opts,
decay: 0.95,
min_norm: 0.0,
b1: 0.9,
b2: 0.999,
eps: 0.0,
eps_root: 1.0e-16
)

{decay, opts} = Keyword.pop(opts, :decay, 0.95)
{min_norm, opts} = Keyword.pop(opts, :min_norm, 0.0)

Updates.scale_by_adam(opts)
Expand All @@ -163,6 +177,8 @@ defmodule Axon.Optimizers do
* `:gamma` - used to compute variance of noise distribution. Defaults to `0.55`
"""
def noisy_sgd(learning_rate \\ 1.0e-2, opts \\ []) do
opts = Keyword.validate!(opts, eta: 0.1, gamma: 0.55)

scale_by_learning_rate(learning_rate)
|> Updates.add_noise(opts)
end
Expand All @@ -183,6 +199,9 @@ defmodule Axon.Optimizers do
* [On the Variance of Adaptive Learning Rate and Beyond](https://arxiv.org/pdf/1908.03265.pdf)
"""
def radam(learning_rate \\ 1.0e-3, opts \\ []) do
opts =
Keyword.validate!(opts, b1: 0.9, b2: 0.999, eps: 1.0e-8, eps_root: 1.0e-16, threshold: 5.0)

Updates.scale_by_radam(opts)
|> scale_by_learning_rate(learning_rate)
end
Expand All @@ -196,20 +215,20 @@ defmodule Axon.Optimizers do
* `:momentum` - momentum term. If set, uses SGD with momentum and decay set
to value of this term.
* `:nesterov` - whether or not to use nesterov momentum. Defaults to `false`
* `:initial_scale` - initial value of EMA. Defaults to `0.0`
* `:decay` - EMA decay rate. Defaults to `0.9`
* `:eps` - numerical stability term. Defaults to `1.0e-8`
"""
def rmsprop(learning_rate \\ 1.0e-2, opts \\ []) do
{centered, opts} = Keyword.pop(opts, :centered, false)
{nesterov?, opts} = Keyword.pop(opts, :nesterov, false)
{momentum, opts} = Keyword.pop(opts, :momentum, nil)
opts = Keyword.validate!(opts, [:momentum, centered: false, nesterov: false, decay: 0.95])

centered = opts[:centered]
nesterov? = opts[:nesterov]
momentum = opts[:momentum]

combinator =
if centered do
Updates.scale_by_stddev(opts)
Updates.scale_by_stddev(decay: opts[:decay])
else
Updates.scale_by_rms(opts)
Updates.scale_by_rms(decay: opts[:decay])
end
|> scale_by_learning_rate(learning_rate)

Expand All @@ -228,8 +247,10 @@ defmodule Axon.Optimizers do
* `:nesterov` - whether or not to use nesterov momentum. Defaults to `false`
"""
def sgd(learning_rate \\ 1.0e-2, opts \\ []) do
opts = Keyword.validate!(opts, [:momentum, nesterov: false])

momentum = opts[:momentum]
nesterov? = opts[:nesterov] || false
nesterov? = opts[:nesterov]

if momentum do
Updates.trace(decay: momentum, nesterov: nesterov?)
Expand All @@ -255,6 +276,15 @@ defmodule Axon.Optimizers do
* [Adaptive Methods for Nonconvex Optimization](https://papers.nips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf)
"""
def yogi(learning_rate \\ 1.0e-2, opts \\ []) do
opts =
Keyword.validate!(opts,
initial_accumulator_value: 0.0,
b1: 0.9,
b2: 0.999,
eps: 1.0e-8,
eps_root: 0.0
)

Updates.scale_by_yogi(opts)
|> scale_by_learning_rate(learning_rate)
end
Expand Down
18 changes: 9 additions & 9 deletions lib/axon/updates.ex
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ defmodule Axon.Updates do

## Options

* `:decay` - EMA decay rate. Defaults to `0.9`.
* `:decay` - EMA decay rate. Defaults to `0.95`.

* `:eps` - numerical stability term. Defaults to `1.0e-8`.

Expand Down Expand Up @@ -276,7 +276,7 @@ defmodule Axon.Updates do
end

defnp apply_scale_by_rms(x, %{nu: nu}, _params, opts \\ []) do
opts = keyword!(opts, decay: 0.9, eps: 1.0e-8)
opts = keyword!(opts, decay: 0.95, eps: 1.0e-8)
decay = opts[:decay]
eps = opts[:eps]

Expand Down Expand Up @@ -355,7 +355,7 @@ defmodule Axon.Updates do

## Options

* `:decay` - EMA decay rate. Defaults to `0.9`.
* `:decay` - EMA decay rate. Defaults to `0.95`.

* `:eps` - numerical stability term. Defaults to `1.0e-8`.

Expand Down Expand Up @@ -393,7 +393,7 @@ defmodule Axon.Updates do
end

defnp apply_scale_by_stddev(x, %{mu: mu, nu: nu}, _params, opts \\ []) do
opts = keyword!(opts, decay: 0.9, eps: 1.0e-8)
opts = keyword!(opts, decay: 0.95, eps: 1.0e-8)
decay = opts[:decay]
eps = opts[:eps]

Expand Down Expand Up @@ -529,7 +529,7 @@ defmodule Axon.Updates do
## Options

* `:decay` - decay rate for tracing past updates. Defaults
to `0.9`
to `0.95`
* `:nesterov` - whether to use Nesterov momentum. Defaults
to `false`

Expand Down Expand Up @@ -560,7 +560,7 @@ defmodule Axon.Updates do
end

defnp apply_trace(x, %{trace: trace}, _params, opts \\ []) do
opts = keyword!(opts, decay: 0.9, nesterov: false)
opts = keyword!(opts, decay: 0.95, nesterov: false)
decay = opts[:decay]

update_trace = deep_merge(x, trace, fn g, t -> t * decay + g end)
Expand Down Expand Up @@ -688,7 +688,7 @@ defmodule Axon.Updates do

## Options

* `:decay` - Rate of decay. Defaults to `0.0`.
* `:decay` - Rate of decay. Defaults to `0.95`.
"""
def add_decayed_weights(combinator_or_opts \\ [])

Expand All @@ -704,7 +704,7 @@ defmodule Axon.Updates do
def add_decayed_weights({init_fn, apply_fn} = combinator, opts)
when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do
stateless(combinator, fn updates, params ->
opts = Nx.Defn.Kernel.keyword!(opts, decay: 0.0)
opts = Nx.Defn.Kernel.keyword!(opts, decay: 0.95)
# Decay can be a tensor, that's why we preprocess it before-hand
# and pass it as argument to defn instead of as an option.
apply_weight_decay(updates, params, opts[:decay])
Expand Down Expand Up @@ -771,7 +771,7 @@ defmodule Axon.Updates do
Adds random Gaussian noise to the input.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trace function in this file uses a decay-like algorithm with :decay defaulting to 0.9. Should we use this default throughout?


## Options

* `:seed` - Random seed to use. Defaults to the
current system time.

Expand Down
8 changes: 3 additions & 5 deletions test/axon/optimizers_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,7 @@ defmodule OptimizersTest do
end

test "correctly optimizes simple loss centered case" do
optimizer =
Axon.Optimizers.rmsprop(@learning_rate, centered: true, initial_scale: 0.1, decay: 0.8)
optimizer = Axon.Optimizers.rmsprop(@learning_rate, centered: true, decay: 0.8)

loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end
num_steps = @iterations
Expand All @@ -229,7 +228,7 @@ defmodule OptimizersTest do
end

test "correctly optimizes simple loss rms case" do
optimizer = Axon.Optimizers.rmsprop(@learning_rate, initial_scale: 0.1, decay: 0.8)
optimizer = Axon.Optimizers.rmsprop(@learning_rate, decay: 0.8)
loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end
num_steps = @iterations
x0 = %{"x0" => Nx.tensor([1.0])}
Expand All @@ -238,8 +237,7 @@ defmodule OptimizersTest do
end

test "correctly optimizes simple loss with momentum" do
optimizer =
Axon.Optimizers.rmsprop(@learning_rate, initial_scale: 0.1, decay: 0.8, momentum: 0.9)
optimizer = Axon.Optimizers.rmsprop(@learning_rate, decay: 0.8, momentum: 0.9)

loss_fn = fn %{"x0" => x} -> Nx.multiply(x, x) end
num_steps = @iterations
Expand Down
Loading