Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add AdaDelta optimizer in R
Browse files Browse the repository at this point in the history
  • Loading branch information
ziyeqinghan committed Jul 26, 2016
1 parent 095d742 commit 3b9e0b9
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 0 deletions.
63 changes: 63 additions & 0 deletions R-package/R/optimizer.R
Expand Up @@ -329,6 +329,66 @@ mx.opt.adagrad <- function(learning.rate=0.05,
return(list(create.state=create.state, update=update))
}

#' Create an AdaDelta optimizer with respective parameters.
#'
#' AdaDelta optimizer as described in Zeiler, M. D. (2012).
#' *ADADELTA: An adaptive learning rate method.*
#' http://arxiv.org/abs/1212.5701
#'
#' @param rho float, default=0.90
#' Decay rate for both squared gradients and delta x.
#' @param epsilon float, default=1e-5
#' The constant as described in the thesis.
#' @param wd float, default=0.0
#' L2 regularization coefficient add to all the weights.
#' @param rescale.grad float, default=1.0
#' rescaling factor of gradient.
#' @param clip_gradient float, optional
#' clip gradient in range [-clip_gradient, clip_gradient].
#'
mx.opt.adadelta <- function(rho=0.90,
epsilon=1e-5,
wd=0,
rescale.grad=1,
clip_gradient = NULL) {
adadelta <- new.env()

create.state <- function(index, weight) {
return (list(acc.g=mx.nd.zeros(dim(weight), ctx(weight)), # accumulated g
acc.delta=mx.nd.zeros(dim(weight), ctx(weight)))) # accumulated delta
}

update <- function(index, weight, grad, state) {
# preprocess grad
grad <- grad * rescale.grad
if (!is.null(clip_gradient)){
if(clip_gradient >= 0){
grad_ctx <- ctx(grad)
grad <- as.array(grad)
grad <- pmax(grad, -1 * clip_gradient)
grad <- pmin(grad, clip_gradient)
grad <- mx.nd.array(grad, grad_ctx)
} else {
stop("Error: clip_gradient should be positive number.")
}
}

# accumulated g and delta initlization
acc.g <- state$acc.g
acc.delta <- state$acc.delta

# update g, delta
acc.g <- rho * acc.g + (1 - rho) * (grad * grad)
current.delta <- mx.nd.sqrt(acc.delta + epsilon) / mx.nd.sqrt(acc.g + epsilon) * grad
acc.delta <- rho * acc.delta + (1 - rho) * (current.delta * current.delta)
weight <- weight - current.delta - wd * weight
state <- list(acc.g=acc.g, acc.delta=acc.delta)

return(list(weight=weight, state=state))
}
return(list(create.state=create.state, update=update))
}

#' Create an optimizer by name and parameters
#'
#' @param name The name of the optimizer
Expand All @@ -348,6 +408,9 @@ mx.opt.create <- function(name, ...) {
else if (name == "adagrad") {
return (mx.opt.adagrad(...))
}
else if (name == "adadelta") {
return (mx.opt.adadelta(...))
}
stop(paste("Unknown optimizer ", name))
}

Expand Down
31 changes: 31 additions & 0 deletions R-package/man/mx.opt.adadelta.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 3b9e0b9

Please sign in to comment.