/
optimizers.ex
273 lines (206 loc) · 8.25 KB
/
optimizers.ex
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
defmodule Axon.Optimizers do
@moduledoc """
Implementations of common gradient-based optimization algorithms.
All of the methods in this module are written in terms of
the update methods defined in `Axon.Updates`. Axon treats
optimizers as the tuple:
{init_fn, update_fn}
where `init_fn` returns an initial optimizer state and `update_fn`
scales input gradients. `init_fn` accepts a model's parameters
and attaches state to each parameter. `update_fn` accepts
gradients, optimizer state, and current model parameters and
returns updated optimizer state and gradients.
Custom optimizers are often created via the `Axon.Updates` API.
## Example
Consider the following usage of the Adam optimizer in a basic
update function (assuming `objective` and the `dataset` are
defined elsewhere):
defmodule Learning do
import Nx.Defn
defn init(params, init_fn) do
init_fn.(params)
end
defn update(params, optimizer_state, inputs, targets, update_fn) do
{loss, gradient} = value_and_grad(params, &objective(&1, inputs, targets))
{scaled_updates, new_optimizer_state} = update_fn.(gradient, optimizer_state, params)
{Axon.Updates.apply_updates(params, scaled_updates), new_optimizer_state, loss}
end
end
model_params = Nx.random_uniform({784, 10})
{init_fn, update_fn} = Axon.Optimizers.adam(0.005)
optimizer_state =
Learning.init(params, init_fn)
{new_params, new_optimizer_state, loss} =
Learning.update(params, optimizer_state, inputs, targets, update_fn)
For a simpler approach, you can also use optimizers with the training API:
model
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.005))
|> Axon.Loop.run(data, epochs: 10, compiler: EXLA)
"""
alias Axon.Updates
@doc """
Adabelief optimizer.
## Options
* `:b1` - first moment decay. Defaults to `0.9`
* `:b2` - second moment decay. Defaults to `0.999`
* `:eps` - numerical stability term. Defaults to `0.0`
* `:eps_root` - numerical stability term. Defaults to `1.0e-16`
## References
* [AdaBelief Optimizer: Adapting Stepsizes by the Belief in Observed Gradients](https://arxiv.org/abs/2010.07468)
"""
def adabelief(learning_rate \\ 1.0e-3, opts \\ []) do
Updates.scale_by_belief(opts)
|> scale_by_learning_rate(learning_rate)
end
@doc """
Adagrad optimizer.
## Options
* `:eps` - numerical stability term. Defaults to `1.0e-7`
## References
* [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
"""
def adagrad(learning_rate \\ 1.0e-3, opts \\ []) do
Updates.scale_by_rss(opts)
|> scale_by_learning_rate(learning_rate)
end
@doc """
Adam optimizer.
## Options
* `:b1` - first moment decay. Defaults to `0.9`
* `:b2` - second moment decay. Defaults to `0.999`
* `:eps` - numerical stability term. Defaults to `1.0e-8`
* `:eps_root` - numerical stability term. Defaults to `1.0e-15`
## References
* [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)
"""
def adam(learning_rate \\ 1.0e-3, opts \\ []) do
Updates.scale_by_adam(opts)
|> scale_by_learning_rate(learning_rate)
end
@doc """
Adam with weight decay optimizer.
## Options
* `:b1` - first moment decay. Defaults to `0.9`
* `:b2` - second moment decay. Defaults to `0.999`
* `:eps` - numerical stability term. Defaults to `1.0e-8`
* `:eps_root` - numerical stability term. Defaults to `0.0`
* `:decay` - weight decay. Defaults to `0.0`
"""
def adamw(learning_rate \\ 1.0e-3, opts \\ []) do
{decay, opts} = Keyword.pop(opts, :decay, 0.0)
Updates.scale_by_adam(opts)
|> Updates.add_decayed_weights(decay: decay)
|> scale_by_learning_rate(learning_rate)
end
@doc """
Lamb optimizer.
## Options
* `:b1` - first moment decay. Defaults to `0.9`
* `:b2` - second moment decay. Defaults to `0.999`
* `:eps` - numerical stability term. Defaults to `1.0e-8`
* `:eps_root` - numerical stability term. Defaults to `0.0`
* `:decay` - weight decay. Defaults to `0.0`
* `:min_norm` - minimum norm value. Defaults to `0.0`
## References
* [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962)
"""
def lamb(learning_rate \\ 1.0e-2, opts \\ []) do
{decay, opts} = Keyword.pop(opts, :decay, 0.0)
{min_norm, opts} = Keyword.pop(opts, :min_norm, 0.0)
Updates.scale_by_adam(opts)
|> Updates.add_decayed_weights(decay: decay)
|> Updates.scale_by_trust_ratio(min_norm: min_norm)
|> scale_by_learning_rate(learning_rate)
end
@doc """
Noisy SGD optimizer.
## Options
* `:eta` - used to compute variance of noise distribution. Defaults to `0.1`
* `:gamma` - used to compute variance of noise distribution. Defaults to `0.55`
"""
def noisy_sgd(learning_rate \\ 1.0e-2, opts \\ []) do
scale_by_learning_rate(learning_rate)
|> Updates.add_noise(opts)
end
@doc """
Rectified Adam optimizer.
## Options
* `:b1` - first moment decay. Defaults to `0.9`
* `:b2` - second moment decay. Defaults to `0.999`
* `:eps` - numerical stability term. Defaults to `1.0e-8`
* `:eps_root` - numerical stability term. Defaults to `0.0`
* `:threshold` - threshold term. Defaults to `5.0`
## References
* [On the Variance of Adaptive Learning Rate and Beyond](https://arxiv.org/pdf/1908.03265.pdf)
"""
def radam(learning_rate \\ 1.0e-3, opts \\ []) do
Updates.scale_by_radam(opts)
|> scale_by_learning_rate(learning_rate)
end
@doc """
RMSProp optimizer.
## Options
* `:centered` - whether to scale by centered root of EMA of squares. Defaults to `false`
* `:momentum` - momentum term. If set, uses SGD with momentum and decay set
to value of this term.
* `:nesterov` - whether or not to use nesterov momentum. Defaults to `false`
* `:initial_scale` - initial value of EMA. Defaults to `0.0`
* `:decay` - EMA decay rate. Defaults to `0.9`
* `:eps` - numerical stability term. Defaults to `1.0e-8`
"""
def rmsprop(learning_rate \\ 1.0e-2, opts \\ []) do
{centered, opts} = Keyword.pop(opts, :centered, false)
{nesterov?, opts} = Keyword.pop(opts, :nesterov, false)
{momentum, opts} = Keyword.pop(opts, :momentum, nil)
combinator =
if centered do
Updates.scale_by_stddev(opts)
else
Updates.scale_by_rms(opts)
end
|> scale_by_learning_rate(learning_rate)
if momentum,
do: Updates.trace(combinator, decay: momentum, nesterov: nesterov?),
else: combinator
end
@doc """
SGD optimizer.
## Options
* `:momentum` - momentum term. If set, uses SGD with momentum and decay set
to value of this term.
* `:nesterov` - whether or not to use nesterov momentum. Defaults to `false`
"""
def sgd(learning_rate \\ 1.0e-2, opts \\ []) do
momentum = opts[:momentum]
nesterov? = opts[:nesterov] || false
if momentum do
Updates.trace(decay: momentum, nesterov: nesterov?)
|> scale_by_learning_rate(learning_rate)
else
scale_by_learning_rate(learning_rate)
end
end
@doc """
Yogi optimizer.
## Options
* `:initial_accumulator_value` - initial value for first and second moment. Defaults to `0.0`
* `:b1` - first moment decay. Defaults to `0.9`
* `:b2` - second moment decay. Defaults to `0.999`
* `:eps` - numerical stability term. Defaults to `1.0e-8`
* `:eps_root` - numerical stability term. Defaults to `0.0`
## References
* [Adaptive Methods for Nonconvex Optimization](https://papers.nips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf)
"""
def yogi(learning_rate \\ 1.0e-2, opts \\ []) do
Updates.scale_by_yogi(opts)
|> scale_by_learning_rate(learning_rate)
end
## Helpers
defp scale_by_learning_rate(combinator \\ Updates.identity(), lr)
defp scale_by_learning_rate(combinator, schedule) when is_function(schedule, 1) do
Updates.scale_by_schedule(combinator, fn count -> Nx.negate(schedule.(count)) end)
end
defp scale_by_learning_rate(combinator, lr) do
Updates.scale_by_state(combinator, -lr)
end
end