Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Update SameDiff training to use new Regularization interface (same L1/L2/WD as DL4J) #7128

Merged
merged 4 commits into from Feb 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -40,12 +40,14 @@
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.*;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.weightinit.impl.XavierInitScheme;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.*;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
Expand All @@ -64,17 +66,21 @@ public void testCompareMlpTrainingIris(){
INDArray f = ds.getFeatures();
INDArray l = ds.getLabels();

//TODO 2019-02-01: SameDiff needs to be updated with same L1/L2/WeightDecay changes
double[] l1 = new double[]{0.0 /*, 0.0, 0.01, 0.01*/};
double[] l2 = new double[]{0.0 /*, 0.02, 0.00, 0.02*/};
double[] l1 = new double[]{0.0, 0.0, 0.01, 0.01, 0.0};
double[] l2 = new double[]{0.0, 0.02, 0.00, 0.02, 0.0};
double[] wd = new double[]{0.0, 0.0, 0.0, 0.0, 0.03};
// double[] l1 = new double[]{0.0};
// double[] l2 = new double[]{0.0};
// double[] wd = new double[]{0.03};

for (String u : new String[]{"sgd", "adam", "nesterov", "adamax", "amsgrad"}) {
for(int i=0; i<l1.length; i++ ) {
Nd4j.getRandom().setSeed(12345);
double l1Val = l1[i];
double l2Val = l2[i];
double wdVal = wd[i];

String testName = u + ", l1=" + l1Val + ", l2=" + l2Val;
String testName = u + ", l1=" + l1Val + ", l2=" + l2Val + ", wd=" + wdVal;

log.info("Starting: {}", testName);
SameDiff sd = SameDiff.create();
Expand Down Expand Up @@ -123,13 +129,21 @@ public void testCompareMlpTrainingIris(){
throw new RuntimeException();
}

List<Regularization> r = new ArrayList<>();
if(l2Val > 0){
r.add(new L2Regularization(l2Val));
}
if(l1Val > 0){
r.add(new L1Regularization(l1Val));
}
if(wdVal > 0){
r.add(new WeightDecay(wdVal, true));
}
TrainingConfig conf = new TrainingConfig.Builder()
.l2(1e-4)
.updater(updater)
.regularization(r)
.dataSetFeatureMapping("input")
.dataSetLabelMapping("label")
.l1(l1Val)
.l2(l2Val)
.build();
sd.setTrainingConfig(conf);

Expand All @@ -139,7 +153,8 @@ public void testCompareMlpTrainingIris(){
.weightInit(WeightInit.XAVIER).seed(12345)
.l1(l1Val).l2(l2Val)
.l1Bias(l1Val).l2Bias(l2Val)
.updater(new Sgd(1.0))
.weightDecay(wdVal, true).weightDecayBias(wdVal, true)
.updater(new Sgd(1.0)) //Exclicitly use SGD(1.0) for comparing PRE-UPDATE GRADIENTS (but with l1/l2/wd component added)
.list()
.layer(new DenseLayer.Builder().nIn(4).nOut(10).activation(Activation.TANH).build())
.layer(new OutputLayer.Builder().nIn(10).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MSE).build())
Expand Down Expand Up @@ -174,24 +189,21 @@ public void testCompareMlpTrainingIris(){

//Check score
double scoreDl4j = net.score();
double scoreSd = lossMse.getArr().getDouble(0) + sd.calculateL1Loss() + sd.calculateL2Loss();
double scoreSd = lossMse.getArr().getDouble(0) + sd.calcRegularizationScore();
assertEquals(testName, scoreDl4j, scoreSd, 1e-6);

double l1Sd = sd.calculateL1Loss();
double l2Sd = sd.calculateL2Loss();

double r = net.calcRegularizationScore(true);
double lossRegScoreSD = sd.calcRegularizationScore();
double lossRegScoreDL4J = net.calcRegularizationScore(true);

// assertEquals(l1Dl4j, l1Sd, 1e-6);
// assertEquals(l2Dl4j, l2Sd, 1e-6);
assertEquals(lossRegScoreDL4J, lossRegScoreSD, 1e-6);

//Check gradients (before updater applied)
Map<String,INDArray> grads = net.gradient().gradientForVariable();
sd.execBackwards(placeholders);

//Note that the SameDiff gradients don't include the L1/L2 terms at present just from execBackwards()... these are added later
//Note that the SameDiff gradients don't include the L1/L2 terms at present just from execBackwards()... these are added in fitting only
//We can check correctness though with training param checks later
if(l1Val == 0 && l2Val == 0) {
if(l1Val == 0 && l2Val == 0 && wdVal == 0) {
assertEquals(testName, grads.get("1_b"), b1.getGradient().getArr());
assertEquals(testName, grads.get("1_W"), w1.getGradient().getArr());
assertEquals(testName, grads.get("0_b"), b0.getGradient().getArr());
Expand All @@ -204,6 +216,7 @@ public void testCompareMlpTrainingIris(){
.weightInit(WeightInit.XAVIER).seed(12345)
.l1(l1Val).l2(l2Val)
.l1Bias(l1Val).l2Bias(l2Val)
.weightDecay(wdVal, true).weightDecayBias(wdVal, true)
.updater(updater.clone())
.list()
.layer(new DenseLayer.Builder().nIn(4).nOut(10).activation(Activation.TANH).build())
Expand All @@ -213,15 +226,14 @@ public void testCompareMlpTrainingIris(){
net.init();
net.setParamTable(oldParams);

// System.out.println("0_W before:\n" + oldParams.get("0_W"));
// System.out.println("0_W grad:\n" + grads.get("0_W"));

for( int j=0; j<3; j++ ) {
net.fit(ds);
sd.fit(ds);

String s = testName + " - " + j;
assertEquals(s, net.getParam("0_W"), w0.getArr());
INDArray dl4j_0W = net.getParam("0_W");
INDArray sd_0W = w0.getArr();
assertEquals(s, dl4j_0W, sd_0W);
assertEquals(s, net.getParam("0_b"), b0.getArr());
assertEquals(s, net.getParam("1_W"), w1.getArr());
assertEquals(s, net.getParam("1_b"), b1.getArr());
Expand All @@ -241,7 +253,5 @@ public void testCompareMlpTrainingIris(){
System.out.println("---------------------------------");
}
}

}

}
Expand Up @@ -16,6 +16,7 @@

package org.deeplearning4j.spark.models.sequencevectors.learning.elements;

import org.deeplearning4j.models.embeddings.learning.impl.elements.BatchSequences;
import org.deeplearning4j.models.embeddings.learning.impl.elements.RandomUtils;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement;
Expand Down Expand Up @@ -45,6 +46,11 @@ public String getCodeName() {
return "Spark-CBOW";
}

@Override
public double learnSequence(Sequence<ShallowSequenceElement> sequence, AtomicLong nextRandom, double learningRate, BatchSequences<ShallowSequenceElement> batchSequences) {
throw new UnsupportedOperationException();
}

@Override
public Frame<? extends TrainingMessage> frameSequence(Sequence<ShallowSequenceElement> sequence,
AtomicLong nextRandom, double learningRate) {
Expand Down
Expand Up @@ -17,6 +17,7 @@
package org.deeplearning4j.spark.models.sequencevectors.learning.elements;

import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.models.embeddings.learning.impl.elements.BatchSequences;
import org.deeplearning4j.models.embeddings.learning.impl.elements.RandomUtils;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement;
Expand All @@ -39,6 +40,11 @@ public String getCodeName() {
return "Spark-SkipGram";
}

@Override
public double learnSequence(Sequence<ShallowSequenceElement> sequence, AtomicLong nextRandom, double learningRate, BatchSequences<ShallowSequenceElement> batchSequences) {
throw new UnsupportedOperationException();
}

protected transient AtomicLong counter;
protected transient ThreadLocal<Frame<SkipGramRequestMessage>> frame;

Expand Down
Expand Up @@ -91,6 +91,7 @@
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.Pair;
Expand Down Expand Up @@ -1534,6 +1535,19 @@ protected synchronized void fit(MultiDataSetIterator iter, int numEpochs, boolea
//Note: don't need to divide by minibatch - that should be handled in loss function and hence loss function gradients,
// which should flow through to here

//Pre-apply regularization (L1, L2)
List<Regularization> r = trainingConfig.getRegularization();
int iterCount = trainingConfig.getIterationCount();
int epochCount = trainingConfig.getEpochCount();
double lr = trainingConfig.getUpdater().hasLearningRate() ? trainingConfig.getUpdater().getLearningRate(iteration, epochCount) : 1.0;
if(r != null && r.size() > 0){
for(Regularization reg : r){
if(reg.applyStep() == Regularization.ApplyStep.BEFORE_UPDATER){
reg.apply(param, grad, lr, iterCount, epochCount);
}
}
}

//Apply updater. Note that we need to reshape to [1,length] for updater
INDArray reshapedView = Shape.newShapeNoCopy(grad, new long[]{1, grad.length()}, grad.ordering() == 'f'); //TODO make sure we always reshape in same order!
Preconditions.checkState(reshapedView != null, "Error reshaping array for parameter \"%s\": array is a view?", s);
Expand All @@ -1545,18 +1559,13 @@ protected synchronized void fit(MultiDataSetIterator iter, int numEpochs, boolea
+ "\": either parameter size is inconsistent between iterations, or \"" + s + "\" should not be a trainable parameter?", t);
}

//L1 and L2 regularization:
if (trainingConfig.getL1() > 0) {
//L1: loss += lambda * sum_i |param_i|
//dL/dp_i: lambda * sgn(param_i)
INDArray signProd = Transforms.sign(param, true).muli(trainingConfig.getL1());
grad.addi(signProd);
}
if (trainingConfig.getL2() > 0) {
//L2: loss += 0.5 * lambda * sum_i param_i^2
//dL/dp_i: lambda * param_i
//TODO axpy optimization = safe/possible?
grad.addi(param.mul(trainingConfig.getL2()));
//Post-apply regularization (weight decay)
if(r != null && r.size() > 0){
for(Regularization reg : r){
if(reg.applyStep() == Regularization.ApplyStep.POST_UPDATER){
reg.apply(param, grad, lr, iterCount, epochCount);
}
}
}

if (trainingConfig.isMinimize()) {
Expand All @@ -1579,59 +1588,32 @@ protected synchronized void fit(MultiDataSetIterator iter, int numEpochs, boolea
}

/**
* Calculate the L2 regularization component of the loss: {@code 0.5 * sum_i (weights_i)}<br>
* Calculate the regularization (L1, L2 and/or WeightDecay) component of the loss function for the current parameters..
* Note that the training configuration must be set (via {@link #setTrainingConfig(TrainingConfig)}) before this
* method can be called
*
* @return The L2 regularization component of the score
* @return The regularization component of the score/loss function
*/
public double calculateL2Loss() {
public double calcRegularizationScore() {
Preconditions.checkState(trainingConfig != null, "No training configuration has been set. A training configuration must " +
"be set before calculating the L2 loss. Use setTrainingConfig(TrainingConfig)");

if(trainingConfig.getL2() == 0){
if(trainingConfig.getRegularization() == null || trainingConfig.getRegularization().isEmpty()){
return 0.0;
}

if(trainingConfig.getTrainableParams() == null || trainingConfig.getTrainableParams().isEmpty())
initializeTraining();

double l2 = trainingConfig.getL2();
double l2Loss = 0.0;
List<Regularization> l = trainingConfig.getRegularization();
double loss = 0.0;
for (String s : trainingConfig.getTrainableParams()) {
//L2: loss += 0.5 * lambda * sum_i param_i^2
double norm2 = getVariable(s).getArr().norm2Number().doubleValue();
l2Loss += 0.5 * l2 * norm2 * norm2;
}
return l2Loss;
}

/**
* Calculate the L1 regularization component of the loss: {@code 0sum_i (abs(weights_i))}<br>
* Note that the training configuration must be set (via {@link #setTrainingConfig(TrainingConfig)}) before this
* method can be called
*
* @return The L1 regularization component of the score
*/
public double calculateL1Loss(){
Preconditions.checkState(trainingConfig != null, "No training configuration has been set. A training configuration must " +
"be set before calculating the L1 loss. Use setTrainingConfig(TrainingConfig)");

if(trainingConfig.getL1() == 0){
return 0.0;
}

if(trainingConfig.getTrainableParams() == null || trainingConfig.getTrainableParams().isEmpty())
initializeTraining();

double l1 = trainingConfig.getL1();
double l1Loss = 0.0;
for (String s : trainingConfig.getTrainableParams()) {
//L1: loss += lambda * sum_i |param_i|
double norm1 = getVariable(s).getArr().norm1Number().doubleValue();
l1Loss += l1 * norm1;
for(Regularization r : l){
INDArray arr = getVariable(s).getArr();
loss += r.score(arr, trainingConfig.getIterationCount(), trainingConfig.getEpochCount());
}
}
return l1Loss;
return loss;
}

/**
Expand Down