Skip to content

Commit

Permalink
Update SameDiff training to use new Regularization interface (same L1…
Browse files Browse the repository at this point in the history
…/L2/WD as DL4J)
  • Loading branch information
AlexDBlack committed Feb 8, 2019
1 parent b9d7184 commit 7a05134
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 100 deletions.
Expand Up @@ -40,12 +40,14 @@
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.*;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.weightinit.impl.XavierInitScheme;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.*;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
Expand All @@ -64,17 +66,18 @@ public void testCompareMlpTrainingIris(){
INDArray f = ds.getFeatures();
INDArray l = ds.getLabels();

//TODO 2019-02-01: SameDiff needs to be updated with same L1/L2/WeightDecay changes
double[] l1 = new double[]{0.0 /*, 0.0, 0.01, 0.01*/};
double[] l2 = new double[]{0.0 /*, 0.02, 0.00, 0.02*/};
double[] l1 = new double[]{0.0, 0.0, 0.01, 0.01, 0.0};
double[] l2 = new double[]{0.0, 0.02, 0.00, 0.02, 0.0};
double[] wd = new double[]{0.0, 0.0, 0.0, 0.0, 0.03};

for (String u : new String[]{"sgd", "adam", "nesterov", "adamax", "amsgrad"}) {
for(int i=0; i<l1.length; i++ ) {
Nd4j.getRandom().setSeed(12345);
double l1Val = l1[i];
double l2Val = l2[i];
double wdVal = wd[i];

String testName = u + ", l1=" + l1Val + ", l2=" + l2Val;
String testName = u + ", l1=" + l1Val + ", l2=" + l2Val + ", wd=" + wdVal;

log.info("Starting: {}", testName);
SameDiff sd = SameDiff.create();
Expand Down Expand Up @@ -123,9 +126,19 @@ public void testCompareMlpTrainingIris(){
throw new RuntimeException();
}

List<Regularization> r = new ArrayList<>();
if(l2Val > 0){
r.add(new L2Regularization(l2Val));
}
if(l1Val > 0){
r.add(new L1Regularization(l1Val));
}
if(wdVal > 0){
r.add(new WeightDecay(wdVal, true));
}
TrainingConfig conf = new TrainingConfig.Builder()
.l2(1e-4)
.updater(updater)
.regularization(r)
.dataSetFeatureMapping("input")
.dataSetLabelMapping("label")
.l1(l1Val)
Expand All @@ -139,6 +152,7 @@ public void testCompareMlpTrainingIris(){
.weightInit(WeightInit.XAVIER).seed(12345)
.l1(l1Val).l2(l2Val)
.l1Bias(l1Val).l2Bias(l2Val)
.weightDecay(wdVal, true).weightDecayBias(wdVal, true)
.updater(new Sgd(1.0))
.list()
.layer(new DenseLayer.Builder().nIn(4).nOut(10).activation(Activation.TANH).build())
Expand Down Expand Up @@ -174,24 +188,21 @@ public void testCompareMlpTrainingIris(){

//Check score
double scoreDl4j = net.score();
double scoreSd = lossMse.getArr().getDouble(0) + sd.calculateL1Loss() + sd.calculateL2Loss();
double scoreSd = lossMse.getArr().getDouble(0) + sd.calcRegularizationScore();
assertEquals(testName, scoreDl4j, scoreSd, 1e-6);

double l1Sd = sd.calculateL1Loss();
double l2Sd = sd.calculateL2Loss();

double r = net.calcRegularizationScore(true);
double lossRegScoreSD = sd.calcRegularizationScore();
double lossRegScoreDL4J = net.calcRegularizationScore(true);

// assertEquals(l1Dl4j, l1Sd, 1e-6);
// assertEquals(l2Dl4j, l2Sd, 1e-6);
assertEquals(lossRegScoreDL4J, lossRegScoreSD, 1e-6);

//Check gradients (before updater applied)
Map<String,INDArray> grads = net.gradient().gradientForVariable();
sd.execBackwards(placeholders);

//Note that the SameDiff gradients don't include the L1/L2 terms at present just from execBackwards()... these are added later
//Note that the SameDiff gradients don't include the L1/L2 terms at present just from execBackwards()... these are added in fitting only
//We can check correctness though with training param checks later
if(l1Val == 0 && l2Val == 0) {
if(l1Val == 0 && l2Val == 0 && wdVal == 0) {
assertEquals(testName, grads.get("1_b"), b1.getGradient().getArr());
assertEquals(testName, grads.get("1_W"), w1.getGradient().getArr());
assertEquals(testName, grads.get("0_b"), b0.getGradient().getArr());
Expand All @@ -213,9 +224,6 @@ public void testCompareMlpTrainingIris(){
net.init();
net.setParamTable(oldParams);

// System.out.println("0_W before:\n" + oldParams.get("0_W"));
// System.out.println("0_W grad:\n" + grads.get("0_W"));

for( int j=0; j<3; j++ ) {
net.fit(ds);
sd.fit(ds);
Expand All @@ -241,7 +249,5 @@ public void testCompareMlpTrainingIris(){
System.out.println("---------------------------------");
}
}

}

}
Expand Up @@ -91,6 +91,7 @@
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.Pair;
Expand Down Expand Up @@ -1534,6 +1535,19 @@ protected synchronized void fit(MultiDataSetIterator iter, int numEpochs, boolea
//Note: don't need to divide by minibatch - that should be handled in loss function and hence loss function gradients,
// which should flow through to here

//Pre-apply regularization (L1, L2)
List<Regularization> r = trainingConfig.getRegularization();
int iterCount = trainingConfig.getIterationCount();
int epochCount = trainingConfig.getEpochCount();
double lr = trainingConfig.getUpdater().hasLearningRate() ? trainingConfig.getUpdater().getLearningRate(iteration, epochCount) : 1.0;
if(r != null && r.size() > 0){
for(Regularization reg : r){
if(reg.applyStep() == Regularization.ApplyStep.BEFORE_UPDATER){
reg.apply(param, grad, lr, iterCount, epochCount);
}
}
}

//Apply updater. Note that we need to reshape to [1,length] for updater
INDArray reshapedView = Shape.newShapeNoCopy(grad, new long[]{1, grad.length()}, grad.ordering() == 'f'); //TODO make sure we always reshape in same order!
Preconditions.checkState(reshapedView != null, "Error reshaping array for parameter \"%s\": array is a view?", s);
Expand All @@ -1545,18 +1559,13 @@ protected synchronized void fit(MultiDataSetIterator iter, int numEpochs, boolea
+ "\": either parameter size is inconsistent between iterations, or \"" + s + "\" should not be a trainable parameter?", t);
}

//L1 and L2 regularization:
if (trainingConfig.getL1() > 0) {
//L1: loss += lambda * sum_i |param_i|
//dL/dp_i: lambda * sgn(param_i)
INDArray signProd = Transforms.sign(param, true).muli(trainingConfig.getL1());
grad.addi(signProd);
}
if (trainingConfig.getL2() > 0) {
//L2: loss += 0.5 * lambda * sum_i param_i^2
//dL/dp_i: lambda * param_i
//TODO axpy optimization = safe/possible?
grad.addi(param.mul(trainingConfig.getL2()));
//Post-apply regularization (weight decay)
if(r != null && r.size() > 0){
for(Regularization reg : r){
if(reg.applyStep() == Regularization.ApplyStep.POST_UPDATER){
reg.apply(param, grad, lr, iterCount, epochCount);
}
}
}

if (trainingConfig.isMinimize()) {
Expand All @@ -1579,59 +1588,32 @@ protected synchronized void fit(MultiDataSetIterator iter, int numEpochs, boolea
}

/**
* Calculate the L2 regularization component of the loss: {@code 0.5 * sum_i (weights_i)}<br>
* Calculate the regularization (L1, L2 and/or WeightDecay) component of the loss function for the current parameters..
* Note that the training configuration must be set (via {@link #setTrainingConfig(TrainingConfig)}) before this
* method can be called
*
* @return The L2 regularization component of the score
* @return The regularization component of the score/loss function
*/
public double calculateL2Loss() {
public double calcRegularizationScore() {
Preconditions.checkState(trainingConfig != null, "No training configuration has been set. A training configuration must " +
"be set before calculating the L2 loss. Use setTrainingConfig(TrainingConfig)");

if(trainingConfig.getL2() == 0){
if(trainingConfig.getRegularization() == null || trainingConfig.getRegularization().isEmpty()){
return 0.0;
}

if(trainingConfig.getTrainableParams() == null || trainingConfig.getTrainableParams().isEmpty())
initializeTraining();

double l2 = trainingConfig.getL2();
double l2Loss = 0.0;
List<Regularization> l = trainingConfig.getRegularization();
double loss = 0.0;
for (String s : trainingConfig.getTrainableParams()) {
//L2: loss += 0.5 * lambda * sum_i param_i^2
double norm2 = getVariable(s).getArr().norm2Number().doubleValue();
l2Loss += 0.5 * l2 * norm2 * norm2;
}
return l2Loss;
}

/**
* Calculate the L1 regularization component of the loss: {@code 0sum_i (abs(weights_i))}<br>
* Note that the training configuration must be set (via {@link #setTrainingConfig(TrainingConfig)}) before this
* method can be called
*
* @return The L1 regularization component of the score
*/
public double calculateL1Loss(){
Preconditions.checkState(trainingConfig != null, "No training configuration has been set. A training configuration must " +
"be set before calculating the L1 loss. Use setTrainingConfig(TrainingConfig)");

if(trainingConfig.getL1() == 0){
return 0.0;
}

if(trainingConfig.getTrainableParams() == null || trainingConfig.getTrainableParams().isEmpty())
initializeTraining();

double l1 = trainingConfig.getL1();
double l1Loss = 0.0;
for (String s : trainingConfig.getTrainableParams()) {
//L1: loss += lambda * sum_i |param_i|
double norm1 = getVariable(s).getArr().norm1Number().doubleValue();
l1Loss += l1 * norm1;
for(Regularization r : l){
INDArray arr = getVariable(s).getArr();
loss += r.score(arr, trainingConfig.getIterationCount(), trainingConfig.getEpochCount());
}
}
return l1Loss;
return loss;
}

/**
Expand Down

0 comments on commit 7a05134

Please sign in to comment.