Skip to content

Commit

Permalink
Optimizer methods can now (optionally) have their stepSize parameter …
Browse files Browse the repository at this point in the history
…be a function (which takes in the step index and return the step size to use). This allows for programmatically-defined learning rate schedules.
  • Loading branch information
dritchie committed May 19, 2017
1 parent 641d779 commit c636abb
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
18 changes: 13 additions & 5 deletions opt/methods.js
Expand Up @@ -11,13 +11,15 @@ var EPS = 1e-8;
function sgd(options) {
options = utils.mergeDefaults(options, { stepSize: 0.1, stepSizeDecay: 1, mu: 0 });
var stepSize = options.stepSize;
var stepSizeIsFunction = (typeof stepSize === 'function');
var decay = options.stepSizeDecay;
var mu = options.mu; // mu > 0 yields gradient descent with momentum

// State
var vStruct;

return function(grad, param, step) {
var stepSize_ = stepSizeIsFunction ? stepSize(step) : stepSize;
if (!vStruct) vStruct = tstruct.emptyLike(grad);
tstruct.foreach(
grad,
Expand All @@ -27,7 +29,7 @@ function sgd(options) {
],
function(g, p, v) {
// v = v * mu - g * stepSize;
v.muleq(mu).subeq(g.mul(stepSize));
v.muleq(mu).subeq(g.mul(stepSize_));
// p = p + v
p.addeq(v);
}
Expand All @@ -39,11 +41,13 @@ function sgd(options) {
function adagrad(options) {
options = utils.mergeDefaults(options, { stepSize: 0.1 });
var stepSize = options.stepSize;
var stepSizeIsFunction = (typeof stepSize === 'function');

// State
var g2Struct;

return function(grad, param, step) {
var stepSize_ = stepSizeIsFunction ? stepSize(step) : stepSize;
if (!g2Struct) g2Struct = tstruct.emptyLike(grad);
tstruct.foreach(
grad,
Expand All @@ -55,7 +59,7 @@ function adagrad(options) {
// g2 = g2 + g*g;
g2.addeq(g.mul(g));
// p = p - stepSize * (g / (sqrt(g2) + 1e-8))
p.subeq(g.div(g2.sqrt().addeq(EPS)).muleq(stepSize));
p.subeq(g.div(g2.sqrt().addeq(EPS)).muleq(stepSize_));
}
);
};
Expand All @@ -64,12 +68,14 @@ function adagrad(options) {
function rmsprop(options) {
options = utils.mergeDefaults(options, {stepSize: 0.1, decayRate: 0.9});
var stepSize = options.stepSize;
var stepSizeIsFunction = (typeof stepSize === 'function');
var decayRate = options.decayRate;

// State
var g2Struct;

return function(grad, param, step) {
var stepSize_ = stepSizeIsFunction ? stepSize(step) : stepSize;
if (!g2Struct) g2Struct = tstruct.emptyLike(grad);
tstruct.foreach(
grad,
Expand All @@ -81,7 +87,7 @@ function rmsprop(options) {
// g2 = decayRate*g2 + (1-decayRate)*(g*g)
g2.muleq(decayRate).addeq(g.mul(g).muleq(1-decayRate));
// p = p - stepSize * (g / (sqrt(g2) + 1e-8))
p.subeq(g.div(g2.sqrt().addeq(EPS)).muleq(stepSize));
p.subeq(g.div(g2.sqrt().addeq(EPS)).muleq(stepSize_));
}
);
};
Expand All @@ -91,17 +97,19 @@ function adam(options) {
options = utils.mergeDefaults(options, {
stepSize: 0.001, // alpha
decayRate1: 0.9, // beta1
decayRate2: 0.999, // beta2
decayRate2: 0.999, // beta2,
});

var stepSize = options.stepSize;
var stepSizeIsFunction = (typeof stepSize === 'function');
var decayRate1 = options.decayRate1;
var decayRate2 = options.decayRate2;

var mStruct;
var vStruct;

return function(grad, param, step) {
var stepSize_ = stepSizeIsFunction ? stepSize(step) : stepSize;
var t = step + 1;
if (!mStruct) mStruct = tstruct.emptyLike(grad);
if (!vStruct) vStruct = tstruct.emptyLike(grad);
Expand All @@ -118,7 +126,7 @@ function adam(options) {
// v = decayRate2*v + (1-decayRate2)*g*g
v.muleq(decayRate2).addeq(g.mul(g).muleq(1-decayRate2));

var alpha_t = stepSize * Math.sqrt(1 - Math.pow(decayRate2, t)) / (1 - Math.pow(decayRate1, t));
var alpha_t = stepSize_ * Math.sqrt(1 - Math.pow(decayRate2, t)) / (1 - Math.pow(decayRate1, t));

// p = p - alpha_t * (m / (sqrt(v) + 1e-8))
p.subeq(m.div(v.sqrt().addeq(EPS)).muleq(alpha_t));
Expand Down
2 changes: 1 addition & 1 deletion package.json
@@ -1,6 +1,6 @@
{
"name": "adnn",
"version": "2.0.6",
"version": "2.0.7",
"description": "Javascript neural networks on top of general scalar/tensor reverse-mode automatic differentiation.",
"author": "dritchie",
"license": "MIT",
Expand Down

0 comments on commit c636abb

Please sign in to comment.