Skip to content

Commit

Permalink
fix: decay defaults for optimizers
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente committed Apr 26, 2023
1 parent bcafdef commit 4236cb9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
16 changes: 8 additions & 8 deletions lib/axon/optimizers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ defmodule Axon.Optimizers do
* `:b2` - second moment decay. Defaults to `0.999`
* `:eps` - numerical stability term. Defaults to `1.0e-8`
* `:eps_root` - numerical stability term. Defaults to `0.0`
* `:decay` - weight decay. Defaults to `0.0`
* `:decay` - weight decay. Defaults to `0.95`
"""
def adamw(learning_rate \\ 1.0e-3, opts \\ []) do
{decay, opts} = Keyword.pop(opts, :decay, 0.0)
{decay, opts} = Keyword.pop(opts, :decay, 0.95)

Updates.scale_by_adam(opts)
|> Updates.add_decayed_weights(decay: decay)
Expand All @@ -137,15 +137,15 @@ defmodule Axon.Optimizers do
* `:b2` - second moment decay. Defaults to `0.999`
* `:eps` - numerical stability term. Defaults to `1.0e-8`
* `:eps_root` - numerical stability term. Defaults to `0.0`
* `:decay` - weight decay. Defaults to `0.0`
* `:decay` - weight decay. Defaults to `0.95`
* `:min_norm` - minimum norm value. Defaults to `0.0`
## References
* [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962)
"""
def lamb(learning_rate \\ 1.0e-2, opts \\ []) do
{decay, opts} = Keyword.pop(opts, :decay, 0.0)
{decay, opts} = Keyword.pop(opts, :decay, 0.95)
{min_norm, opts} = Keyword.pop(opts, :min_norm, 0.0)

Updates.scale_by_adam(opts)
Expand Down Expand Up @@ -197,13 +197,13 @@ defmodule Axon.Optimizers do
to value of this term.
* `:nesterov` - whether or not to use nesterov momentum. Defaults to `false`
* `:initial_scale` - initial value of EMA. Defaults to `0.0`
* `:decay` - EMA decay rate. Defaults to `0.9`
* `:eps` - numerical stability term. Defaults to `1.0e-8`
"""
def rmsprop(learning_rate \\ 1.0e-2, opts \\ []) do
{centered, opts} = Keyword.pop(opts, :centered, false)
{nesterov?, opts} = Keyword.pop(opts, :nesterov, false)
{momentum, opts} = Keyword.pop(opts, :momentum, nil)
opts = Keyword.validate!(opts, [:momentum, centered: false, nesterov: false, decay: 0.95])
centered = opts[:centered]
nesterov? = opts[:nesterov]
momentum = opts[:momentum]

combinator =
if centered do
Expand Down
2 changes: 1 addition & 1 deletion lib/axon/updates.ex
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ defmodule Axon.Updates do
Adds random Gaussian noise to the input.
## Options
* `:seed` - Random seed to use. Defaults to the
current system time.
Expand Down

0 comments on commit 4236cb9

Please sign in to comment.