Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Early stopping score functions #4630

Merged
merged 9 commits into from Feb 8, 2018
@@ -1,24 +1,37 @@
package org.deeplearning4j.earlystopping;

import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.ExistingDataSetIterator;
import org.deeplearning4j.datasets.iterator.MultipleEpochsIterator;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener;
import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
import org.deeplearning4j.earlystopping.scorecalc.ClassificationScoreCalculator;
import org.deeplearning4j.earlystopping.scorecalc.AutoencoderScoreCalculator;
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;
import org.deeplearning4j.earlystopping.scorecalc.RegressionScoreCalculator;
import org.deeplearning4j.earlystopping.scorecalc.VAEReconErrorScoreCalculator;
import org.deeplearning4j.earlystopping.scorecalc.VAEReconProbScoreCalculator;
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
import org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition;
import org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition;
import org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer;
import org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.AutoEncoder;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
Expand All @@ -35,14 +48,12 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.*;
import java.util.concurrent.TimeUnit;

import static org.junit.Assert.*;

@Slf4j
public class TestEarlyStopping extends BaseDL4JTest {

@Test
Expand Down Expand Up @@ -157,7 +168,7 @@ public void testEarlyStoppingIrisMultiEpoch() {
//Check that best score actually matches (returned model vs. manually calculated score)
MultiLayerNetwork bestNetwork = result.getBestModel();
irisIter.reset();
double score = bestNetwork.score(irisIter.next());
double score = bestNetwork.score(irisIter.next(), false);
assertEquals(result.getBestModelScore(), score, 1e-2);
}

Expand Down Expand Up @@ -425,4 +436,225 @@ public void onCompletion(EarlyStoppingResult esResult) {
onCompletionCallCount++;
}
}


@Test
public void testRegressionScoreFunctionSimple() throws Exception {

for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
RegressionEvaluation.Metric.MAE}) {
log.info("Metric: " + metric);

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new DenseLayer.Builder().nIn(784).nOut(32).build())
.layer(new OutputLayer.Builder().nIn(32).nOut(784).activation(Activation.SIGMOID).build())
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);

List<DataSet> l = new ArrayList<>();
for( int i=0; i<10; i++ ){
DataSet ds = iter.next();
l.add(new DataSet(ds.getFeatures(), ds.getFeatures()));
}

iter = new ExistingDataSetIterator(l);

EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(5))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
.scoreCalculator(new RegressionScoreCalculator(metric, iter)).modelSaver(saver)
.build();

EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();

assertNotNull(result.getBestModel());
assertTrue(result.getBestModelScore() > 0.0);
}
}

@Test
public void testAEScoreFunctionSimple() throws Exception {

for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
RegressionEvaluation.Metric.MAE}) {
log.info("Metric: " + metric);

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new AutoEncoder.Builder().nIn(784).nOut(32).build())
.pretrain(true).backprop(false)
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);

List<DataSet> l = new ArrayList<>();
for( int i=0; i<10; i++ ){
DataSet ds = iter.next();
l.add(new DataSet(ds.getFeatures(), ds.getFeatures()));
}

iter = new ExistingDataSetIterator(l);

EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(5))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
.scoreCalculator(new AutoencoderScoreCalculator(metric, iter)).modelSaver(saver)
.build();

EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();

assertNotNull(result.getBestModel());
assertTrue(result.getBestModelScore() > 0.0);
}
}

@Test
public void testVAEScoreFunctionSimple() throws Exception {

for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
RegressionEvaluation.Metric.MAE}) {
log.info("Metric: " + metric);

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new VariationalAutoencoder.Builder()
.nIn(784).nOut(32)
.encoderLayerSizes(64)
.decoderLayerSizes(64)
.build())
.pretrain(true).backprop(false)
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);

List<DataSet> l = new ArrayList<>();
for( int i=0; i<10; i++ ){
DataSet ds = iter.next();
l.add(new DataSet(ds.getFeatures(), ds.getFeatures()));
}

iter = new ExistingDataSetIterator(l);

EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(5))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
.scoreCalculator(new VAEReconErrorScoreCalculator(metric, iter)).modelSaver(saver)
.build();

EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();

assertNotNull(result.getBestModel());
assertTrue(result.getBestModelScore() > 0.0);
}
}

@Test
public void testVAEScoreFunctionReconstructionProbSimple() throws Exception {

for(boolean logProb : new boolean[]{false, true}) {

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new VariationalAutoencoder.Builder()
.nIn(784).nOut(32)
.encoderLayerSizes(64)
.decoderLayerSizes(64)
.reconstructionDistribution(new BernoulliReconstructionDistribution(Activation.SIGMOID))
.build())
.pretrain(true).backprop(false)
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);

List<DataSet> l = new ArrayList<>();
for (int i = 0; i < 10; i++) {
DataSet ds = iter.next();
l.add(new DataSet(ds.getFeatures(), ds.getFeatures()));
}

iter = new ExistingDataSetIterator(l);

EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(5))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
.scoreCalculator(new VAEReconProbScoreCalculator(iter, 20, logProb)).modelSaver(saver)
.build();

EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();

assertNotNull(result.getBestModel());
assertTrue(result.getBestModelScore() > 0.0);
}
}

@Test
public void testClassificationScoreFunctionSimple() throws Exception {

for(Evaluation.Metric metric : Evaluation.Metric.values()) {
log.info("Metric: " + metric);

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new DenseLayer.Builder().nIn(784).nOut(32).build())
.layer(new OutputLayer.Builder().nIn(32).nOut(10).activation(Activation.SOFTMAX).build())
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);

List<DataSet> l = new ArrayList<>();
for( int i=0; i<10; i++ ){
DataSet ds = iter.next();
l.add(ds);
}

iter = new ExistingDataSetIterator(l);

EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(5))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
.scoreCalculator(new ClassificationScoreCalculator(metric, iter)).modelSaver(saver)
.build();

EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();

assertNotNull(result.getBestModel());
}
}
}