This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
optimizer.R
452 lines (415 loc) · 14.2 KB
/
optimizer.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
#' Create an SGD optimizer with respective parameters.
#' Perform SGD with momentum update
#'
mx.opt.sgd <- function(learning.rate,
momentum=0,
wd=0,
rescale.grad=1,
clip_gradient = NULL,
lr_scheduler = NULL) {
# use lr as short for learing rate.
lr <- learning.rate
count <- 0
num_update <- 0
sgd <- new.env()
sgd$lr <- lr
sgd$count <- 0
sgd$num_update <- 0
create.state <- function(index, weight) {
if (momentum == 0) {
return(NULL)
} else {
ret <- (mx.nd.zeros(dim(weight), ctx(weight)))
return(ret)
}
}
update <- function(index, weight, grad, state) {
if (!is.null(lr_scheduler)){
lr_scheduler(sgd) ## changing lr
lr <- sgd$lr
## update count
indexKey <- paste0('ik', index)
if (!exists(envir = sgd, x = indexKey)){
assign(x = indexKey, value = 0, envir = sgd)
} else {
indexValue <- get(envir = sgd, x = indexKey)
assign(x = indexKey, value = indexValue + 1, envir = sgd)
sgd$num_update <- max(sgd$num_update, get(envir = sgd, x = indexKey))
}
}
grad <- grad * rescale.grad
if (!is.null(clip_gradient)){
if(clip_gradient >= 0){
grad_ctx <- ctx(grad)
grad <- as.array(grad)
grad <- pmax(grad, -1 * clip_gradient)
grad <- pmin(grad, clip_gradient)
grad <- mx.nd.array(grad, grad_ctx)
} else {
stop("Error: clip_gradient should be positive number.")
}
}
if (is.null(state)) {
weight <- weight - lr * (grad + wd * weight)
} else {
mom <- state
mom <- mom * momentum
mom <- mom - lr * (grad + wd * weight)
weight <- weight + mom
state <- mom
}
return(list(weight=weight, state=state))
}
return(list(create.state=create.state, update=update))
}
#' Create an RMSProp optimizer with respective parameters.
#' Reference: Tieleman T, Hinton G. Lecture 6.5-rmsprop: Divide the gradient by a running average of its recent magnitude[J]. COURSERA: Neural Networks for Machine Learning, 2012, 4(2).
#' The code follows: http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45) by Alex Graves, 2013.
#'
#' @param learning.rate float, default=0.002
#' Step size.
#' @param gamma1 float, default=0.95
#' decay factor of moving average for gradient, gradient^2.
#' @param gamm2 float, default=0.9
#' "momentum" factor.
#' @param wd float, default=0.0
#' L2 regularization coefficient add to all the weights.
#' @param rescale.grad float, default=1.0
#' rescaling factor of gradient.
#' @param clip_gradient float, optional
#' clip gradient in range [-clip_gradient, clip_gradient].
#' @param lr_scheduler function, optional
#' The learning rate scheduler.
#'
mx.opt.rmsprop <- function(learning.rate=0.002,
gamma1=0.95,
gamma2=0.9,
wd=0,
rescale.grad=1,
clip_gradient = NULL,
lr_scheduler = NULL) {
# use lr as short for learing rate.
lr <- learning.rate
count <- 0
num_update <- 0
rmsprop <- new.env()
rmsprop$lr <- lr
rmsprop$count <- 0
rmsprop$num_update <- 0
create.state <- function(index, weight) {
return (list(n=mx.nd.zeros(dim(weight), ctx(weight)),
g=mx.nd.zeros(dim(weight), ctx(weight)),
delta=mx.nd.zeros(dim(weight), ctx(weight))))
}
update <- function(index, weight, grad, state) {
if (!is.null(lr_scheduler)){
lr_scheduler(rmsprop) ## changing lr
lr <- rmsprop$lr
## update count
indexKey <- paste0('ik', index)
if (!exists(envir = rmsprop, x = indexKey)){
assign(x = indexKey, value = 0, envir = rmsprop)
} else {
indexValue <- get(envir = rmsprop, x = indexKey)
assign(x = indexKey, value = indexValue + 1, envir = rmsprop)
rmsprop$num_update <- max(rmsprop$num_update, get(envir = rmsprop, x = indexKey))
}
}
grad <- grad * rescale.grad
if (!is.null(clip_gradient)){
if(clip_gradient >= 0){
grad_ctx <- ctx(grad)
grad <- as.array(grad)
grad <- pmax(grad, -1 * clip_gradient)
grad <- pmin(grad, clip_gradient)
grad <- mx.nd.array(grad, grad_ctx)
} else {
stop("Error: clip_gradient should be positive number.")
}
}
n <- state$n
g <- state$g
delta <- state$delta
n <- gamma1 * n + (1 - gamma1) * (grad * grad)
g <- gamma1 * g + (1 - gamma1) * grad
delta <- gamma2 * delta - lr * (grad / mx.nd.sqrt(n - g*g + 1e-4) + wd * weight)
weight <- weight + delta
state <- list(n=n, g=g, delta=delta)
return(list(weight=weight, state=state))
}
return(list(create.state=create.state, update=update))
}
#' Create an Adam optimizer with respective parameters.
#' Adam optimizer as described in [King2014].
#'
#' [King2014] Diederik Kingma, Jimmy Ba,
#' Adam: A Method for Stochastic Optimization,
#' http://arxiv.org/abs/1412.6980
#'
#' @param learning.rate float, default=0.001
#' Step size.
#' @param beta1 float, default=0.9
#' Exponential decay rate for the first moment estimates.
#' @param beta2 float, default=0.999
#' Exponential decay rate for the second moment estimates.
#' @param epsilon float, default=1e-8
#' @param wd float, default=0.0
#' L2 regularization coefficient add to all the weights.
#' @param rescale.grad float, default=1.0
#' rescaling factor of gradient.
#' @param clip_gradient float, optional
#' clip gradient in range [-clip_gradient, clip_gradient].
#' @param lr_scheduler function, optional
#' The learning rate scheduler.
#'
mx.opt.adam <- function(learning.rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
wd=0,
rescale.grad=1,
clip_gradient = NULL,
lr_scheduler = NULL) {
# use lr as short for learing rate.
lr <- learning.rate
count <- 0
num_update <- 0
adam <- new.env()
adam$lr <- lr
adam$count <- 0
adam$num_update <- 0
create.state <- function(index, weight) {
return (list(mean=mx.nd.zeros(dim(weight), ctx(weight)),
variance=mx.nd.zeros(dim(weight), ctx(weight))))
}
update <- function(index, weight, grad, state) {
if (!is.null(lr_scheduler)){
lr_scheduler(adam) ## changing lr
lr <- adam$lr
## update count
indexKey <- paste0('ik', index)
if (!exists(envir = adam, x = indexKey)){
assign(x = indexKey, value = 0, envir = adam)
} else {
indexValue <- get(envir = adam, x = indexKey)
assign(x = indexKey, value = indexValue + 1, envir = adam)
adam$num_update <- max(adam$num_update, get(envir = adam, x = indexKey))
}
}
# increment time
time.key <- paste0('t', index)
if (!exists(envir = adam, x = time.key)){
assign(x = time.key, value = 0, envir = adam)
}
t <- get(envir = adam, x = time.key)
t <- t + 1
assign(x = time.key, value = t, envir = adam)
mean <- state$mean
variance <- state$variance
grad <- grad * rescale.grad
if (!is.null(clip_gradient)){
if(clip_gradient >= 0){
grad_ctx <- ctx(grad)
grad <- as.array(grad)
grad <- pmax(grad, -1 * clip_gradient)
grad <- pmin(grad, clip_gradient)
grad <- mx.nd.array(grad, grad_ctx)
} else {
stop("Error: clip_gradient should be positive number.")
}
}
mean <- beta1 * mean + (1 - beta1) * grad
variance <- beta2 * variance + (1 - beta2) * (grad * grad)
coef1 <- 1 - beta1^t
coef2 <- 1 - beta2^t
lr <- lr * sqrt(coef2)/coef1
weight <- weight - lr * mean / (mx.nd.sqrt(variance) + epsilon)
weight <- weight - lr * wd * weight
state <- list(mean=mean, variance=variance)
return(list(weight=weight, state=state))
}
return(list(create.state=create.state, update=update))
}
#' Create an AdaGrad optimizer with respective parameters.
#' AdaGrad optimizer of Duchi et al., 2011,
#'
#' This code follows the version in http://arxiv.org/pdf/1212.5701v1.pdf Eq(5)
#' by Matthew D. Zeiler, 2012. AdaGrad will help the network to converge faster
#' in some cases.
#'
#' @param learning.rate float, default=0.05
#' Step size.
#' @param epsilon float, default=1e-8
#' @param wd float, default=0.0
#' L2 regularization coefficient add to all the weights.
#' @param rescale.grad float, default=1.0
#' rescaling factor of gradient.
#' @param clip_gradient float, optional
#' clip gradient in range [-clip_gradient, clip_gradient].
#' @param lr_scheduler function, optional
#' The learning rate scheduler.
#'
mx.opt.adagrad <- function(learning.rate=0.05,
epsilon=1e-8,
wd=0,
rescale.grad=1,
clip_gradient = NULL,
lr_scheduler = NULL) {
# use lr as short for learing rate.
lr <- learning.rate
count <- 0
num_update <- 0
adagrad <- new.env()
adagrad$lr <- lr
adagrad$count <- 0
adagrad$num_update <- 0
create.state <- function(index, weight) {
return (mx.nd.zeros(dim(weight), ctx(weight))) #history
}
update <- function(index, weight, grad, state) {
if (!is.null(lr_scheduler)){
lr_scheduler(adagrad) ## changing lr
lr <- adagrad$lr
## update count
indexKey <- paste0('ik', index)
if (!exists(envir = adagrad, x = indexKey)){
assign(x = indexKey, value = 0, envir = adagrad)
} else {
indexValue <- get(envir = adagrad, x = indexKey)
assign(x = indexKey, value = indexValue + 1, envir = adagrad)
adagrad$num_update <- max(adagrad$num_update, get(envir = adagrad, x = indexKey))
}
}
grad <- grad * rescale.grad
if (!is.null(clip_gradient)){
if(clip_gradient >= 0){
grad_ctx <- ctx(grad)
grad <- as.array(grad)
grad <- pmax(grad, -1 * clip_gradient)
grad <- pmin(grad, clip_gradient)
grad <- mx.nd.array(grad, grad_ctx)
} else {
stop("Error: clip_gradient should be positive number.")
}
}
history <- state
history <- history + (grad * grad)
weight <- weight - lr * (grad / mx.nd.sqrt(history + epsilon) + wd * weight)
state <- history
return(list(weight=weight, state=state))
}
return(list(create.state=create.state, update=update))
}
#' Create an AdaDelta optimizer with respective parameters.
#'
#' AdaDelta optimizer as described in Zeiler, M. D. (2012).
#' *ADADELTA: An adaptive learning rate method.*
#' http://arxiv.org/abs/1212.5701
#'
#' @param rho float, default=0.90
#' Decay rate for both squared gradients and delta x.
#' @param epsilon float, default=1e-5
#' The constant as described in the thesis.
#' @param wd float, default=0.0
#' L2 regularization coefficient add to all the weights.
#' @param rescale.grad float, default=1.0
#' rescaling factor of gradient.
#' @param clip_gradient float, optional
#' clip gradient in range [-clip_gradient, clip_gradient].
#'
mx.opt.adadelta <- function(rho=0.90,
epsilon=1e-5,
wd=0,
rescale.grad=1,
clip_gradient = NULL) {
adadelta <- new.env()
create.state <- function(index, weight) {
return (list(acc.g=mx.nd.zeros(dim(weight), ctx(weight)), # accumulated g
acc.delta=mx.nd.zeros(dim(weight), ctx(weight)))) # accumulated delta
}
update <- function(index, weight, grad, state) {
# preprocess grad
grad <- grad * rescale.grad
if (!is.null(clip_gradient)){
if(clip_gradient >= 0){
grad_ctx <- ctx(grad)
grad <- as.array(grad)
grad <- pmax(grad, -1 * clip_gradient)
grad <- pmin(grad, clip_gradient)
grad <- mx.nd.array(grad, grad_ctx)
} else {
stop("Error: clip_gradient should be positive number.")
}
}
# accumulated g and delta initlization
acc.g <- state$acc.g
acc.delta <- state$acc.delta
# update g, delta
acc.g <- rho * acc.g + (1 - rho) * (grad * grad)
current.delta <- mx.nd.sqrt(acc.delta + epsilon) / mx.nd.sqrt(acc.g + epsilon) * grad
acc.delta <- rho * acc.delta + (1 - rho) * (current.delta * current.delta)
weight <- weight - current.delta - wd * weight
state <- list(acc.g=acc.g, acc.delta=acc.delta)
return(list(weight=weight, state=state))
}
return(list(create.state=create.state, update=update))
}
#' Create an optimizer by name and parameters
#'
#' @param name The name of the optimizer
#' @param ... Additional arguments
#'
#' @export
mx.opt.create <- function(name, ...) {
if (name == "sgd") {
return(mx.opt.sgd(...))
}
else if (name == "rmsprop") {
return (mx.opt.rmsprop(...))
}
else if (name == "adam") {
return (mx.opt.adam(...))
}
else if (name == "adagrad") {
return (mx.opt.adagrad(...))
}
else if (name == "adadelta") {
return (mx.opt.adadelta(...))
}
stop(paste("Unknown optimizer ", name))
}
#' Get an updater closure that can take list of weight and gradient
#' and return updated list of weight.
#'
#' @param optimizer The optimizer
#' @param weights The weights to be optimized
#'
#' @export
mx.opt.get.updater <- function(optimizer, weights) {
n <- length(weights)
# This is the list to keep track of internal states of optimzer
state.list <- lapply(1:n, function(i) {
if (is.null(weights[[i]])) return(NULL)
optimizer$create.state(i, weights[[i]])
})
update <- optimizer$update
update.closure <- function(weight, grad) {
ulist <- lapply(1:n, function(i) {
if (!is.null(grad[[i]])) {
update(i, weight[[i]], grad[[i]], state.list[[i]])
} else {
return(NULL)
}
})
# update state list, use mutate assignment
state.list <<- lapply(ulist, function(x) {
x$state
})
# return updated weight list
weight.list <- lapply(ulist, function(x) {
x$weight
})
return(weight.list)
}
return(update.closure)
}