Skip to content

Commit

Permalink
[MODELS] Changed the adaptive AMSGrad optimizer.
Browse files Browse the repository at this point in the history
  • Loading branch information
eaplatanios committed Dec 13, 2018
1 parent 9dfccf3 commit 72134b2
Showing 1 changed file with 67 additions and 29 deletions.
Expand Up @@ -17,7 +17,7 @@ package org.platanios.symphony.mt.models.helpers

import org.platanios.tensorflow.api._
import org.platanios.tensorflow.api.ops.training.optimizers.schedules.{FixedSchedule, Schedule}
import org.platanios.tensorflow.api.ops.variables.ZerosInitializer
import org.platanios.tensorflow.api.ops.variables.{ConstantInitializer, ZerosInitializer}

/** Optimizer that implements the AMSGrad optimization algorithm, presented in
* [On the Convergence of Adam and Beyond](https://openreview.net/pdf?id=ryQu7f-RZ).
Expand Down Expand Up @@ -124,6 +124,8 @@ class AdaptiveAMSGrad protected (
zerosSlot("V", v, name)(TF.fromDataType(v.dataType))
zerosSlot("Vhat", v, name)(TF.fromDataType(v.dataType))
getSlot("NonZerosCount", v, INT64, ZerosInitializer, Shape(v.shape(0)), name)
getSlot("Beta1", v, FLOAT32, ConstantInitializer(beta1), Shape(v.shape(0)), name)
getSlot("Beta2", v, FLOAT32, ConstantInitializer(beta2), Shape(v.shape(0)), name)
getSlot("Beta1Power", v, FLOAT32, ZerosInitializer, Shape(v.shape(0)), name)
getSlot("Beta2Power", v, FLOAT32, ZerosInitializer, Shape(v.shape(0)), name)
})
Expand Down Expand Up @@ -153,23 +155,8 @@ class AdaptiveAMSGrad protected (
val beta2Power = getSlot[T, Float]("Beta2Power", variable)

val betaShape = Shape(variable.shape(0)) ++ Shape(Array.fill(variable.rank - 1)(1))

// val nonZerosCount = getSlot[T, Long]("NonZerosCount", variable)
// val nonZerosCountValue = nonZerosCount.assignAdd(tf.notEqual(gradient, tf.zeros[T](Shape())).toLong)

// val initialBeta1 = getBeta1(variable).toFloat
// val initialBeta2 = getBeta2(variable).toFloat
// val aBeta1 = 1.0f - initialBeta1
// val aBeta2 = 1.0f - initialBeta2
// val bBeta1 = initialBeta1
// val bBeta2 = initialBeta2
// val rate = 1.0f - nonZerosCountValue.toFloat / (iteration.get.toFloat + 1.0f)
// val beta1 = (aBeta1 * rate + bBeta1).castTo[T]
// val beta2 = (aBeta2 * rate + bBeta2).castTo[T]

val beta1 = getBeta1(variable)
val beta2 = getBeta2(variable)

val epsilon = getEpsilon(variable)

var learningRate = getLearningRate(variable, iteration)
Expand Down Expand Up @@ -210,20 +197,32 @@ class AdaptiveAMSGrad protected (
val beta2Power = getSlot[T, Float]("Beta2Power", variable)

val nonZerosCount = getSlot[T, Long]("NonZerosCount", variable)
val nonZerosCountValue = nonZerosCount.assignScatterAdd(gradient.indices, 1L).toFloat
// nonZerosCountValue = tf.print(nonZerosCountValue, Seq(nonZerosCountValue), s"Sparse ${variable.name}: ", 10000, 10000)
val nonZerosCountValue = nonZerosCount.assignScatterAdd(gradient.indices, 1L)

val betaShape = Shape(variable.shape(0)) ++ Shape(Array.fill(variable.rank - 1)(1))

val initialBeta1 = getBeta1(variable).toFloat
val initialBeta2 = getBeta2(variable).toFloat
val aBeta1 = 1.0f - initialBeta1
val aBeta2 = 1.0f - initialBeta2
val bBeta1 = initialBeta1
val bBeta2 = initialBeta2
val rate = 1.0f - nonZerosCountValue.toFloat / (iteration.get.toFloat + 1.0f)
val beta1 = (aBeta1 * rate + bBeta1).castTo[T]
val beta2 = (aBeta2 * rate + bBeta2).castTo[T]
val nz = nonZerosCountValue.gather(gradient.indices).toFloat
val beta1 = computeBeta(
step = iteration.get.toFloat,
initialBeta = getBeta1(variable).toFloat,
previousBeta = getSlot[T, Float]("Beta1", variable),
nonZeroIndices = gradient.indices,
nonZerosCount = nz,
g = gradient.values.toFloat,
a = m.gather(gradient.indices).toFloat,
epsilon = epsilonTensor
).castTo[T]

val gradientValuesSquare = gradient.values * gradient.values
val beta2 = computeBeta(
step = iteration.get.toFloat,
initialBeta = getBeta2(variable).toFloat,
previousBeta = getSlot[T, Float]("Beta2", variable),
nonZeroIndices = gradient.indices,
nonZerosCount = nz,
g = gradientValuesSquare.toFloat,
a = v.gather(gradient.indices).toFloat,
epsilon = epsilonTensor
).castTo[T]

val epsilon = getEpsilon(variable)

Expand All @@ -240,7 +239,7 @@ class AdaptiveAMSGrad protected (
}

// v_t = beta2 * v + (1 - beta2) * gradient * gradient
val vScaledGradient = gradient.values * gradient.values * (one - beta2.reshape(betaShape).gather(gradient.indices))
val vScaledGradient = gradientValuesSquare * (one - beta2.reshape(betaShape).gather(gradient.indices))
var vT = v.assign(v.value * beta2.reshape(betaShape))
vT = tf.createWith(controlDependencies = Set(vT.op)) {
v.assignScatterAdd(gradient.indices, vScaledGradient)
Expand All @@ -257,6 +256,45 @@ class AdaptiveAMSGrad protected (
tf.group(Set(update.op, mT.op, vT.op, updateBeta1Power.op, updateBeta2Power.op))
}
}

def computeBetaOld(
step: Output[Float],
initialBeta: Output[Float],
nonZerosCount: Output[Float]
): Output[Float] = {
val aBeta = 1.0f - initialBeta
val bBeta = initialBeta
val rate = 1.0f - nonZerosCount / (step + 1.0f)
aBeta * rate + bBeta
}

def computeBeta(
step: Output[Float],
initialBeta: Output[Float],
previousBeta: Variable[Float],
nonZeroIndices: Output[Int],
nonZerosCount: Output[Float],
g: Output[Float],
a: Output[Float],
epsilon: Output[Float]
): Output[Float] = {
// initialBeta is a scalar
// previousBeta has shape [numRows]
// g has shape [numNonZeroRows, ...]
// a has shape [numNonZeroRows, ...]
val stepsTillNextNonZero = (step + 1.0f) / nonZerosCount
val logInitialBeta = tf.log(initialBeta)
val gNorm = tf.sum(g * g, 1 until g.rank) // has shape [numNonZeroRows]
val aNorm = tf.sum(a * a, 1 until a.rank) // has shape [numNonZeroRows]
val logGNorm = tf.log(gNorm + epsilon)
val logANorm = tf.log(aNorm + epsilon)
val logBeta = (logGNorm - logANorm + logInitialBeta) / stepsTillNextNonZero
val beta = tf.maximum(0.0f, tf.minimum(1.0f, tf.exp(logBeta)))
// beta = tf.print(beta, Seq(gNorm), "GNorm: ", 1000, 1000)
// beta = tf.print(beta, Seq(logANorm), "LogANorm: ", 1000, 1000)
// beta = tf.print(beta, Seq(beta), "Beta: ", 1000, 1000)
previousBeta.assignScatter(nonZeroIndices, beta)
}
}

object AdaptiveAMSGrad {
Expand Down

0 comments on commit 72134b2

Please sign in to comment.