Skip to content

Commit

Permalink
Merge pull request #328 from dirkschumacher/optimrefactor
Browse files Browse the repository at this point in the history
Refactor most of the optimizer step functions.
  • Loading branch information
dfalbel committed Oct 26, 2020
2 parents 7e28f8e + a04f9fd commit 75e58d2
Show file tree
Hide file tree
Showing 15 changed files with 288 additions and 398 deletions.
85 changes: 30 additions & 55 deletions R/optim-adadelta.R
Expand Up @@ -29,63 +29,38 @@ optim_Adadelta <- R6::R6Class(
},

step = function(closure = NULL){

with_no_grad({

loss <- NULL

if (!is.null(closure)) {
with_enable_grad({
loss <- closure()
})
}

for (g in seq_along(self$param_groups)) {

group <- self$param_groups[[g]]

for (p in seq_along(group$params)) {

param <- group$params[[p]]

if (is.null(param$grad) || is_undefined_tensor(param$grad))
next

grad <- param$grad

# if (grad$is_sparse) {
# runtime_error("Adadelta does not support sparse gradients")
# }

# state initialization
if (length(param$state) == 0) {
param$state <- list()
param$state[["step"]] <- 0
param$state[["square_avg"]] <- torch_zeros_like(param, memory_format=torch_preserve_format())
param$state[["acc_delta"]] <- torch_zeros_like(param, memory_format=torch_preserve_format())
}

square_avg <- param$state[["square_avg"]]
acc_delta <- param$state[["acc_delta"]]

rho <- group[["rho"]]
eps <- group[["eps"]]

param$state[["step"]] <- param$state[["step"]] + 1

if (group[["weight_decay"]] != 0)
grad <- grad$add(param, alpha=group[["weight_decay"]])

square_avg$mul_(rho)$addcmul_(grad, grad, value=1 - rho)
std <- square_avg$add(eps)$sqrt_()
delta <- acc_delta$add(eps)$sqrt_()$div_(std)$mul_(grad)
param$add_(delta, alpha=-group[["lr"]])
acc_delta$mul_(rho)$addcmul_(delta, delta, value=1 - rho)

}
private$step_helper(closure, function(group, param, g, p) {
grad <- param$grad

# if (grad$is_sparse) {
# runtime_error("Adadelta does not support sparse gradients")
# }

# state initialization
if (length(param$state) == 0) {
param$state <- list()
param$state[["step"]] <- 0
param$state[["square_avg"]] <- torch_zeros_like(param, memory_format=torch_preserve_format())
param$state[["acc_delta"]] <- torch_zeros_like(param, memory_format=torch_preserve_format())
}

square_avg <- param$state[["square_avg"]]
acc_delta <- param$state[["acc_delta"]]

rho <- group[["rho"]]
eps <- group[["eps"]]

param$state[["step"]] <- param$state[["step"]] + 1

if (group[["weight_decay"]] != 0)
grad <- grad$add(param, alpha=group[["weight_decay"]])

square_avg$mul_(rho)$addcmul_(grad, grad, value=1 - rho)
std <- square_avg$add(eps)$sqrt_()
delta <- acc_delta$add(eps)$sqrt_()$div_(std)$mul_(grad)
param$add_(delta, alpha=-group[["lr"]])
acc_delta$mul_(rho)$addcmul_(delta, delta, value=1 - rho)
})
loss
}
)
)
Expand Down
85 changes: 32 additions & 53 deletions R/optim-adagrad.R
Expand Up @@ -41,7 +41,7 @@ optim_Adagrad <- R6::R6Class(
}
}
},

# It's implemeneted in PyTorch, but it's not necessary at the moment
# share_memory = function(){
# for (group in self$param_groups){
Expand All @@ -52,62 +52,41 @@ optim_Adagrad <- R6::R6Class(
# }
# },

step = function(closure = NULL){
with_no_grad({
step = function(closure = NULL) {
private$step_helper(closure, function(group, param, g, p) {
param$state[['step']] <- param$state[['step']] + 1

loss <- NULL
if (!is.null(closure)) {
with_enable_grad({
loss <- closure()
})
}
grad <- param$grad
state_sum <- param$state[['sum']]
state_step <- param$state[['step']]

for (group in self$param_groups){

for (p in seq_along(group$params)) {
param <- group$params[[p]]

if (is.null(param$grad) || is_undefined_tensor(param$grad))
next

param$state[['step']] <- param$state[['step']] + 1

grad <- param$grad
state_sum <- param$state[['sum']]
state_step <- param$state[['step']]

if (group$weight_decay != 0) {
# if (grad$is_sparse) {
# runtime_error("weight_decay option is not compatible with sparse gradients")
# }
grad <- grad$add(param, alpha = group$weight_decay)
}

clr <- group$lr / (1 + (param$state[['step']] - 1) * group$lr_decay)

# Sparse tensors handling will be added in future
# if (grad$is_sparse) {
# grad <- grad$coalesce()
# grad_indices <- grad$`_indices`()
# grad_values <- grad$`_values`()
# size <- grad$size()

# state_sum$add_(`_make_sparse`(grad, grad_indices, grad_values.pow(2)))
# std <- param$state[['sum']]$sparse_mask(grad)
# std_values <- std$`_values()`$sqrt_()$add_(group$eps)
# param$add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
#} else {

param$state[['sum']]$addcmul_(grad, grad, value = 1)
std <- param$state[['sum']]$sqrt()$add_(group$eps)
param$addcdiv_(grad, std, value =-clr)

#}

}
if (group$weight_decay != 0) {
# if (grad$is_sparse) {
# runtime_error("weight_decay option is not compatible with sparse gradients")
# }
grad <- grad$add(param, alpha = group$weight_decay)
}

clr <- group$lr / (1 + (param$state[['step']] - 1) * group$lr_decay)

# Sparse tensors handling will be added in future
# if (grad$is_sparse) {
# grad <- grad$coalesce()
# grad_indices <- grad$`_indices`()
# grad_values <- grad$`_values`()
# size <- grad$size()

# state_sum$add_(`_make_sparse`(grad, grad_indices, grad_values.pow(2)))
# std <- param$state[['sum']]$sparse_mask(grad)
# std_values <- std$`_values()`$sqrt_()$add_(group$eps)
# param$add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
#} else {

param$state[['sum']]$addcmul_(grad, grad, value = 1)
std <- param$state[['sum']]$sqrt()$add_(group$eps)
param$addcdiv_(grad, std, value =-clr)

})
loss
}
)
)
Expand Down
119 changes: 49 additions & 70 deletions R/optim-adam.R
Expand Up @@ -32,82 +32,61 @@ optim_Adam <- R6::R6Class(
},

step = function(closure = NULL) {
with_no_grad({
private$step_helper(closure, function(group, param, g, p) {

loss <- NULL
if (!is.null(closure)) {
with_enable_grad({
loss <- closure()
})
grad <- param$grad

# if (grad$is_sparse) {
# runtime_error("Adam does not support sparse gradients, please consider",
# "SparseAdam instead")
# }
amsgrad <- group$amsgrad

# state initialization
if (length(param$state) == 0) {
param$state <- list()
param$state[["step"]] <- 0
param$state[["exp_avg"]] <- torch_zeros_like(param, memory_format=torch_preserve_format())
param$state[["exp_avg_sq"]] <- torch_zeros_like(param, memory_format=torch_preserve_format())
if (amsgrad) {
param$state[['max_exp_avg_sq']] <- torch_zeros_like(param, memory_format=torch_preserve_format())
}
}

for (g in seq_along(self$param_groups)) {

group <- self$param_groups[[g]]
exp_avg <- param$state[["exp_avg"]]
exp_avg_sq <- param$state[["exp_avg_sq"]]
if (amsgrad) {
max_exp_avg_sq <- param$state[['max_exp_avg_sq']]
}
beta1 <- group$betas[[1]]
beta2 <- group$betas[[2]]

param$state[["step"]] <- param$state[["step"]] + 1
bias_correction1 <- 1 - beta1 ^ param$state[['step']]
bias_correction2 <- 1 - beta2 ^ param$state[['step']]

if (group$weight_decay != 0) {
grad$add_(p, alpha=group$weight_decay)
}

# Decay the first and second moment running average coefficient
exp_avg$mul_(beta1)$add_(grad, alpha=1 - beta1)
exp_avg_sq$mul_(beta2)$addcmul_(grad, grad, value=1 - beta2)

if (amsgrad) {

for (p in seq_along(group$params)) {

param <- group$params[[p]]

if (is.null(param$grad) || is_undefined_tensor(param$grad))
next

grad <- param$grad

# if (grad$is_sparse) {
# runtime_error("Adam does not support sparse gradients, please consider",
# "SparseAdam instead")
# }
amsgrad <- group$amsgrad

# state initialization
if (length(param$state) == 0) {
param$state <- list()
param$state[["step"]] <- 0
param$state[["exp_avg"]] <- torch_zeros_like(param, memory_format=torch_preserve_format())
param$state[["exp_avg_sq"]] <- torch_zeros_like(param, memory_format=torch_preserve_format())
if (amsgrad) {
param$state[['max_exp_avg_sq']] <- torch_zeros_like(param, memory_format=torch_preserve_format())
}
}

exp_avg <- param$state[["exp_avg"]]
exp_avg_sq <- param$state[["exp_avg_sq"]]
if (amsgrad) {
max_exp_avg_sq <- param$state[['max_exp_avg_sq']]
}
beta1 <- group$betas[[1]]
beta2 <- group$betas[[2]]

param$state[["step"]] <- param$state[["step"]] + 1
bias_correction1 <- 1 - beta1 ^ param$state[['step']]
bias_correction2 <- 1 - beta2 ^ param$state[['step']]

if (group$weight_decay != 0) {
grad$add_(p, alpha=group$weight_decay)
}

# Decay the first and second moment running average coefficient
exp_avg$mul_(beta1)$add_(grad, alpha=1 - beta1)
exp_avg_sq$mul_(beta2)$addcmul_(grad, grad, value=1 - beta2)

if (amsgrad) {

# Maintains the maximum of all 2nd moment running avg. till now
max_exp_avg_sq$set_data(max_exp_avg_sq$max(other = exp_avg_sq))
# Use the max. for normalizing running avg. of gradient
denom <- (max_exp_avg_sq$sqrt() / sqrt(bias_correction2))$add_(group$eps)
} else {
denom <- (exp_avg_sq$sqrt() / sqrt(bias_correction2))$add_(group$eps)
}

step_size <- group$lr / bias_correction1

param$addcdiv_(exp_avg, denom, value=-step_size)
}
# Maintains the maximum of all 2nd moment running avg. till now
max_exp_avg_sq$set_data(max_exp_avg_sq$max(other = exp_avg_sq))
# Use the max. for normalizing running avg. of gradient
denom <- (max_exp_avg_sq$sqrt() / sqrt(bias_correction2))$add_(group$eps)
} else {
denom <- (exp_avg_sq$sqrt() / sqrt(bias_correction2))$add_(group$eps)
}

step_size <- group$lr / bias_correction1

param$addcdiv_(exp_avg, denom, value=-step_size)
})
loss
}
)
)
Expand Down

0 comments on commit 75e58d2

Please sign in to comment.