Skip to content

Commit

Permalink
Use first-order optimality conditions to measure convergnece of stoch…
Browse files Browse the repository at this point in the history
…astic gradient descent reasoning. (#327)

Use first-order optimality conditions to measure convergence of stochastic gradient descent reasoning.

This should make SGDReasoner tests more consistent as the change in the objective is not always an honest measurement of optimality.
  • Loading branch information
dickensc committed Dec 2, 2021
1 parent d990a83 commit edc9621
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 24 deletions.
16 changes: 16 additions & 0 deletions psl-core/src/main/java/org/linqs/psl/config/Options.java
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,22 @@ public class Options {
+ " ADAM: Update the learning rate using the Adaptive Moment Estimation (Adam) algorithm."
);

public static final Option SGD_FIRST_ORDER_NORM = new Option(
"sgd.firstordernorm",
Float.POSITIVE_INFINITY,
"The p-norm used to measure the first order optimality condition."
+ " Default is the infinity-norm which is the absolute value of the maximum component of the gradient vector."
+ " Note that the infinity-norm can be explicitly set with the string literal: 'Infinity'.",
Option.FLAG_NON_NEGATIVE
);

public static final Option SGD_FIRST_ORDER_THRESHOLD = new Option(
"sgd.firstorderthreshold",
0.01f,
"Stochastic gradient descent stops when the norm of the gradient is less than this threshold.",
Option.FLAG_NON_NEGATIVE
);

public static final Option SGD_INVERSE_TIME_EXP = new Option(
"sgd.inversescaleexp",
1.0f,
Expand Down
60 changes: 53 additions & 7 deletions psl-core/src/main/java/org/linqs/psl/reasoner/sgd/SGDReasoner.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ public static enum SGDLearningSchedule {

private int maxIterations;

private float firstOrderTolerance;
private float firstOrderNorm;

private boolean watchMovement;
private float movementThreshold;

Expand All @@ -85,18 +88,22 @@ public static enum SGDLearningSchedule {
public SGDReasoner() {
maxIterations = Options.SGD_MAX_ITER.getInt();

firstOrderTolerance = Options.SGD_FIRST_ORDER_THRESHOLD.getFloat();
firstOrderNorm = Options.SGD_FIRST_ORDER_NORM.getFloat();

watchMovement = Options.SGD_MOVEMENT.getBoolean();
movementThreshold = Options.SGD_MOVEMENT_THRESHOLD.getFloat();

initialLearningRate = Options.SGD_LEARNING_RATE.getFloat();
learningRateInverseScaleExp = Options.SGD_INVERSE_TIME_EXP.getFloat();
learningSchedule = SGDLearningSchedule.valueOf(Options.SGD_LEARNING_SCHEDULE.getString().toUpperCase());

adamBeta1 = Options.SGD_ADAM_BETA_1.getFloat();
adamBeta2 = Options.SGD_ADAM_BETA_2.getFloat();
accumulatedGradientSquares = null;
accumulatedGradientMean = null;
accumulatedGradientVariance = null;
coordinateStep = Options.SGD_COORDINATE_STEP.getBoolean();
learningSchedule = SGDLearningSchedule.valueOf(Options.SGD_LEARNING_SCHEDULE.getString().toUpperCase());
sgdExtension = SGDExtension.valueOf(Options.SGD_EXTENSION.getString().toUpperCase());
}

Expand All @@ -123,6 +130,7 @@ public double optimize(TermStore baseTermStore,
// optimization pass because they are being updated in the variableUpdate() method.
// Note that the number of variables may change in the first iteration (since grounding may happen then).
double oldObjective = Double.POSITIVE_INFINITY;
float[] prevGradient = null;
float[] prevVariableValues = null;
// Save and use the variable values with the lowest computed objective.
double lowestObjective = Double.POSITIVE_INFINITY;
Expand All @@ -146,9 +154,15 @@ public double optimize(TermStore baseTermStore,
useNonConvex = true;
}

if (iteration > 1) {
// Reset gradients for next round.
Arrays.fill(prevGradient, 0.0f);
}

for (SGDObjectiveTerm term : termStore) {
if (iteration > 1) {
objective += term.evaluate(prevVariableValues);
addTermGradient(term, prevGradient, prevVariableValues, termStore.getVariableAtoms());
}

termCount++;
Expand All @@ -163,13 +177,15 @@ public double optimize(TermStore baseTermStore,
meanMovement /= termCount;
}

breakSGD = breakOptimization(iteration, objective, oldObjective, meanMovement, termCount);

if (iteration == 1) {
// Initialize old variables values.
// Initialize old variables values and gradient.
prevGradient = new float[termStore.getVariableValues().length];
prevVariableValues = Arrays.copyOf(termStore.getVariableValues(), termStore.getVariableValues().length);
lowestVariableValues = Arrays.copyOf(termStore.getVariableValues(), termStore.getVariableValues().length);
} else {
clipGradient(prevGradient, prevVariableValues);
breakSGD = breakOptimization(iteration, objective, oldObjective, prevGradient, meanMovement, termCount);

// Update lowest objective and variable values.
if (objective < lowestObjective) {
lowestIteration = iteration - 1;
Expand All @@ -186,8 +202,8 @@ public double optimize(TermStore baseTermStore,
totalTime += end - start;

if (iteration > 1 && log.isTraceEnabled()) {
log.trace("Iteration {} -- Objective: {}, Normalized Objective: {}, Iteration Time: {}, Total Optimization Time: {}",
iteration - 1, objective, objective / termCount, (end - start), totalTime);
log.trace("Iteration {} -- Objective: {}, Normalized Objective: {}, Gradient Norm: {}, Iteration Time: {}, Total Optimization Time: {}",
iteration - 1, objective, objective / termCount, MathUtils.pNorm(prevGradient, firstOrderNorm), (end - start), totalTime);
}

iteration++;
Expand Down Expand Up @@ -237,7 +253,7 @@ private void optimizationComplete() {
accumulatedGradientVariance = null;
}

private boolean breakOptimization(int iteration, double objective, double oldObjective, float movement, long termCount) {
private boolean breakOptimization(int iteration, double objective, double oldObjective, float[] gradient, float movement, long termCount) {
// Always break when the allocated iterations is up.
if (iteration > (int)(maxIterations * budget)) {
return true;
Expand All @@ -253,6 +269,11 @@ private boolean breakOptimization(int iteration, double objective, double oldObj
return false;
}

// Break if the norm of the gradient is zero.
if (MathUtils.equals(MathUtils.pNorm(gradient, firstOrderNorm), 0.0f, firstOrderTolerance)) {
return true;
}

// Break if the objective has not changed.
if (objectiveBreak && MathUtils.equals(objective / termCount, oldObjective / termCount, tolerance)) {
return true;
Expand All @@ -261,6 +282,31 @@ private boolean breakOptimization(int iteration, double objective, double oldObj
return false;
}

private void clipGradient(float[] gradient, float[] variableValues) {
for(int i = 0; i < gradient.length; i++) {
if (MathUtils.equals(variableValues[i], 0.0f) && gradient[i] > 0.0f) {
gradient[i] = 0.0f;
} else if (MathUtils.equals(variableValues[i], 1.0f) && gradient[i] < 0.0f) {
gradient[i] = 0.0f;
}
}
}

private void addTermGradient(SGDObjectiveTerm term, float[] gradient, float[] variableValues, GroundAtom[] variableAtoms) {
int size = term.size();
WeightedRule rule = term.getRule();
int[] variableIndexes = term.getVariableIndexes();
float dot = term.dot(variableValues);

for (int i = 0 ; i < size; i++) {
if (variableAtoms[variableIndexes[i]] instanceof ObservedAtom) {
continue;
}

gradient[variableIndexes[i]] += term.computePartial(i, dot, rule.getWeight());
}
}

private double computeObjective(VariableTermStore<SGDObjectiveTerm, GroundAtom> termStore) {
double objective = 0.0;

Expand Down
70 changes: 56 additions & 14 deletions psl-core/src/main/java/org/linqs/psl/util/MathUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,11 @@ public static void toUnit(float[] vector) {
* Scale n-dimensional double array to vector with the specified magnitude.
*/
public static void toMagnitude(double[] vector, double magnitude) {
double norm = 0.0;

if (magnitude <= 0.0) {
throw new ArithmeticException("Cannot scale a vector to a non-positive magnitude.");
}

for (int i = 0; i < vector.length; i++) {
norm += Math.pow(vector[i], 2);
}

norm = Math.sqrt(norm);
double norm = pNorm(vector, 2.0f);
if (!((norm != 0.0) || (vector.length == 0))) {
throw new ArithmeticException("Cannot scale a zero vector to a non-zero magnitude.");
}
Expand All @@ -198,17 +192,11 @@ public static void toMagnitude(double[] vector, double magnitude) {
* Scale n-dimensional float array to vector with the specified magnitude.
*/
public static void toMagnitude(float[] vector, double magnitude) {
double norm = 0.0;

if (magnitude <= 0.0) {
throw new ArithmeticException("Cannot scale a vector to a non-positive magnitude.");
}

for (int i = 0; i < vector.length; i++) {
norm += Math.pow(vector[i], 2);
}

norm = Math.sqrt(norm);
float norm = pNorm(vector, 2.0f);
if (!((norm != 0.0) || (vector.length == 0))) {
throw new ArithmeticException("Cannot scale a zero vector to a non-zero magnitude.");
}
Expand All @@ -217,4 +205,58 @@ public static void toMagnitude(float[] vector, double magnitude) {
vector[i] = (float)(magnitude * (vector[i] / norm));
}
}

/**
* Compute the p-norm of the provided vector.
*/
public static float pNorm(float[] vector, float p) {
float norm = 0.0f;

if (p <= 0.0f) {
throw new ArithmeticException("The p-norm for p <= 0.0 is not defined.");
}

if (p == Float.POSITIVE_INFINITY) {
for (float v : vector) {
if (norm < Math.abs(v)) {
norm = Math.abs(v);
}
}
return norm;
}

for (float v : vector) {
norm += Math.pow(v, p);
}
norm = (float)Math.pow(norm, 1.0f / p);

return norm;
}

/**
* Compute the p-norm of the provided vector.
*/
public static double pNorm(double[] vector, double p) {
double norm = 0.0;

if (p <= 0.0) {
throw new ArithmeticException("The p-norm for p <= 0.0 is not defined.");
}

if (p == Double.POSITIVE_INFINITY) {
for (double v : vector) {
if (norm < Math.abs(v)) {
norm = Math.abs(v);
}
}
return norm;
}

for (double v : vector) {
norm += Math.pow(v, p);
}
norm = Math.pow(norm, 1.0f / p);

return norm;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,22 @@
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.reasoner.sgd.SGDReasoner;

import org.junit.After;
import org.linqs.psl.reasoner.sgd.SGDReasoner;
import org.junit.Before;

import java.util.List;

public class SGDInferenceTest extends InferenceTest {
@Before
public void setup() {
Options.REASONER_OBJECTIVE_BREAK.set(false);
}

@After
public void cleanup() {
Options.REASONER_OBJECTIVE_BREAK.clear();
Options.SGD_LEARNING_RATE.clear();
Options.SGD_COORDINATE_STEP.clear();
Options.SGD_EXTENSION.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,19 @@
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.reasoner.sgd.SGDReasoner;

import org.junit.After;
import org.junit.Before;

import java.util.List;

public class SGDStreamingInferenceTest extends InferenceTest {
@Before
public void setup() {
Options.REASONER_OBJECTIVE_BREAK.set(false);
}

@After
public void cleanup() {
Options.SGD_LEARNING_RATE.clear();
Expand Down Expand Up @@ -61,7 +68,7 @@ public void initialValueTest() {
cleanup();

// Adam.
Options.SGD_EXTENSION.set("ADAM");
Options.SGD_EXTENSION.set(SGDReasoner.SGDExtension.ADAM);
// Non-coordinate step.
Options.SGD_LEARNING_RATE.set(1.0);
Options.SGD_INVERSE_TIME_EXP.set(0.5);
Expand All @@ -75,7 +82,7 @@ public void initialValueTest() {
cleanup();

// AdaGrad.
Options.SGD_EXTENSION.set("ADAGRAD");
Options.SGD_EXTENSION.set(SGDReasoner.SGDExtension.ADAGRAD);
// Non-coordinate step.
Options.SGD_LEARNING_RATE.set(1.0);
Options.SGD_INVERSE_TIME_EXP.set(0.5);
Expand Down

0 comments on commit edc9621

Please sign in to comment.