From f2d19a45e744b0210fc1602cc8a9e46be751c21e Mon Sep 17 00:00:00 2001 From: Robert Altena Date: Mon, 23 Sep 2019 09:05:37 +0900 Subject: [PATCH] wip updating examples. Signed-off-by: Robert Altena --- .gitignore | 1 + ...asicHyperparameterOptimizationExample.java | 10 ++-- ...eticHyperparameterOptimizationExample.java | 46 +++++++++------- ...eticHyperparameterOptimizationExample.java | 53 +++---------------- .../GeneticSearchExampleConfiguration.java | 12 ++--- .../MultiDigitNumberRecognition.java | 34 ++++++------ .../MultiRecordDataSetIterator.java | 16 +++--- 7 files changed, 70 insertions(+), 102 deletions(-) diff --git a/.gitignore b/.gitignore index acff98bbda..969eb63582 100644 --- a/.gitignore +++ b/.gitignore @@ -61,3 +61,4 @@ Word2vec-index/ !tutorials/*.json *end.model +arbiterExample/ diff --git a/dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/BasicHyperparameterOptimizationExample.java b/dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/BasicHyperparameterOptimizationExample.java index d3b4adc0c4..20767c166d 100644 --- a/dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/BasicHyperparameterOptimizationExample.java +++ b/dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/BasicHyperparameterOptimizationExample.java @@ -1,4 +1,4 @@ -/******************************************************************************* +/* ***************************************************************************** * Copyright (c) 2015-2019 Skymind, Inc. * * This program and the accompanying materials are made available under the @@ -42,11 +42,11 @@ import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator; import org.deeplearning4j.arbiter.ui.listener.ArbiterStatusListener; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.storage.FileStatsStorage; +import org.nd4j.evaluation.classification.Evaluation.Metric; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.lossfunctions.LossFunctions; @@ -114,14 +114,16 @@ public static void main(String[] args) throws Exception { // This will result in examples being saved to arbiterExample/0/, arbiterExample/1/, arbiterExample/2/, ... String baseSaveDirectory = "arbiterExample/"; File f = new File(baseSaveDirectory); - if (f.exists()) f.delete(); + if (f.exists()) //noinspection ResultOfMethodCallIgnored + f.delete(); + //noinspection ResultOfMethodCallIgnored f.mkdir(); ResultSaver modelSaver = new FileModelSaver(baseSaveDirectory); // (d) What are we actually trying to optimize? // In this example, let's use classification accuracy on the test set // See also ScoreFunctions.testSetF1(), ScoreFunctions.testSetRegression(regressionValue) etc - ScoreFunction scoreFunction = new EvaluationScoreFunction(Evaluation.Metric.ACCURACY); + ScoreFunction scoreFunction = new EvaluationScoreFunction(Metric.ACCURACY); // (e) When should we stop searching? Specify this with termination conditions diff --git a/dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/genetic/BaseGeneticHyperparameterOptimizationExample.java b/dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/genetic/BaseGeneticHyperparameterOptimizationExample.java index 418979e5ac..4ac5cc6f48 100644 --- a/dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/genetic/BaseGeneticHyperparameterOptimizationExample.java +++ b/dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/genetic/BaseGeneticHyperparameterOptimizationExample.java @@ -1,4 +1,4 @@ -/******************************************************************************* +/* ***************************************************************************** * Copyright (c) 2015-2019 Skymind, Inc. * * This program and the accompanying materials are made available under the @@ -25,9 +25,10 @@ import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; import org.deeplearning4j.arbiter.scoring.impl.EvaluationScoreFunction; -import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.evaluation.classification.Evaluation.Metric; +import java.io.IOException; import java.util.List; /** @@ -38,23 +39,13 @@ * * @author Alexandre Boulanger */ - public class BaseGeneticHyperparameterOptimizationExample { - public static void main(String[] args) throws Exception { - - ComputationGraphSpace cgs = GeneticSearchExampleConfiguration.GetGraphConfiguration(); - - EvaluationScoreFunction scoreFunction = new EvaluationScoreFunction(Evaluation.Metric.F1); - - // This is where we create the GeneticSearchCandidateGenerator with its default behavior: - // - a population that fits 30 candidates and is culled back to 20 when it overflows - // - new candidates are generated with a probability of 85% of being the result of breeding (a k-point crossover with 1 to 4 points) - // - the new candidate have a probability of 0.5% of sustaining a random mutation on one of its genes. - GeneticSearchCandidateGenerator candidateGenerator = new GeneticSearchCandidateGenerator.Builder(cgs, scoreFunction).build(); - - // Let's have a listener to print the population size after each evaluation. - PopulationModel populationModel = candidateGenerator.getPopulationModel(); + /** + * Common code used by two Arbiter examples. + */ + public static void run(PopulationModel populationModel, GeneticSearchCandidateGenerator candidateGenerator, + EvaluationScoreFunction scoreFunction) throws IOException { populationModel.addListener(new ExamplePopulationListener()); IOptimizationRunner runner = GeneticSearchExampleConfiguration.BuildRunner(candidateGenerator, scoreFunction); @@ -80,13 +71,32 @@ public static void main(String[] args) throws Exception { System.out.println(bestModel.getConfiguration().toJson()); } + public static void main(String[] args) throws Exception { + + ComputationGraphSpace cgs = GeneticSearchExampleConfiguration.GetGraphConfiguration(); + + EvaluationScoreFunction scoreFunction = new EvaluationScoreFunction(Metric.F1); + + // This is where we create the GeneticSearchCandidateGenerator with its default behavior: + // - a population that fits 30 candidates and is culled back to 20 when it overflows + // - new candidates are generated with a probability of 85% of being the result of breeding (a k-point crossover with 1 to 4 points) + // - the new candidate have a probability of 0.5% of sustaining a random mutation on one of its genes. + GeneticSearchCandidateGenerator candidateGenerator = new GeneticSearchCandidateGenerator.Builder(cgs, scoreFunction).build(); + + // Let's have a listener to print the population size after each evaluation. + PopulationModel populationModel = candidateGenerator.getPopulationModel(); + populationModel.addListener(new ExamplePopulationListener()); + run(populationModel, candidateGenerator, scoreFunction); + } + public static class ExamplePopulationListener implements PopulationListener { + @SuppressWarnings("OptionalGetWithoutIsPresent") @Override public void onChanged(List population) { double best = population.get(0).getFitness(); double average = population.stream() - .mapToDouble(c -> c.getFitness()) + .mapToDouble(Chromosome::getFitness) .average() .getAsDouble(); System.out.println(String.format("\nPopulation size is %1$s, best score is %2$s, average score is %3$s", population.size(), best, average)); diff --git a/dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/genetic/CustomGeneticHyperparameterOptimizationExample.java b/dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/genetic/CustomGeneticHyperparameterOptimizationExample.java index 3f5a0aee8e..957ea89684 100644 --- a/dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/genetic/CustomGeneticHyperparameterOptimizationExample.java +++ b/dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/genetic/CustomGeneticHyperparameterOptimizationExample.java @@ -1,4 +1,4 @@ -/******************************************************************************* +/* ***************************************************************************** * Copyright (c) 2015-2019 Skymind, Inc. * * This program and the accompanying materials are made available under the @@ -20,25 +20,19 @@ import org.apache.commons.math3.random.RandomGenerator; import org.apache.commons.math3.random.SynchronizedRandomGenerator; import org.deeplearning4j.arbiter.ComputationGraphSpace; -import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; -import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference; import org.deeplearning4j.arbiter.optimize.generator.GeneticSearchCandidateGenerator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover; import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection; import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator; import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationListener; import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.GeneticSelectionOperator; import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.SelectionOperator; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; import org.deeplearning4j.arbiter.scoring.impl.EvaluationScoreFunction; -import org.deeplearning4j.eval.Evaluation; -import org.deeplearning4j.nn.graph.ComputationGraph; +import org.nd4j.evaluation.classification.Evaluation.Metric; -import java.util.List; +import static org.deeplearning4j.examples.arbiter.genetic.BaseGeneticHyperparameterOptimizationExample.run; /** * In this hyperparameter optimization example, we change the default behavior of the genetic candidate generator. @@ -55,7 +49,7 @@ public static void main(String[] args) throws Exception { ComputationGraphSpace cgs = GeneticSearchExampleConfiguration.GetGraphConfiguration(); - EvaluationScoreFunction scoreFunction = new EvaluationScoreFunction(Evaluation.Metric.F1); + EvaluationScoreFunction scoreFunction = new EvaluationScoreFunction(Metric.F1); // The ExampleCullOperator extends the default cull operator (least fit) to include an artificial predator. CullOperator cullOperator = new ExampleCullOperator(); @@ -85,30 +79,8 @@ public static void main(String[] args) throws Exception { .build(); // Let's have a listener to print the population size after each evaluation. - populationModel.addListener(new ExamplePopulationListener()); - - IOptimizationRunner runner = GeneticSearchExampleConfiguration.BuildRunner(candidateGenerator, scoreFunction); - - //Start the hyperparameter optimization - runner.execute(); - - //Print out some basic stats regarding the optimization procedure - String s = "Best score: " + runner.bestScore() + "\n" + - "Index of model with best score: " + runner.bestScoreCandidateIndex() + "\n" + - "Number of configurations evaluated: " + runner.numCandidatesCompleted() + "\n"; - System.out.println(s); - - - //Get all results, and print out details of the best result: - int indexOfBestResult = runner.bestScoreCandidateIndex(); - List allResults = runner.getResults(); - - OptimizationResult bestResult = allResults.get(indexOfBestResult).getResult(); - ComputationGraph bestModel = (ComputationGraph) bestResult.getResultReference().getResultModel(); - - System.out.println("\n\nConfiguration of best model:\n"); - System.out.println(bestModel.getConfiguration().toJson()); - + populationModel.addListener(new BaseGeneticHyperparameterOptimizationExample.ExamplePopulationListener()); + run(populationModel, candidateGenerator, scoreFunction); } // This is an example of a custom behavior for the genetic algorithm. We force one of the parent to be one of the @@ -158,17 +130,4 @@ public void cullPopulation() { System.out.println(String.format("Randomly removed %1$s candidate(s).", preyCount)); } } - - public static class ExamplePopulationListener implements PopulationListener { - - @Override - public void onChanged(List population) { - double best = population.get(0).getFitness(); - double average = population.stream() - .mapToDouble(Chromosome::getFitness) - .average() - .getAsDouble(); - System.out.println(String.format("\nPopulation size is %1$s, best score is %2$s, average score is %3$s", population.size(), best, average)); - } - } } diff --git a/dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/genetic/GeneticSearchExampleConfiguration.java b/dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/genetic/GeneticSearchExampleConfiguration.java index 694b0994e1..939ab22181 100644 --- a/dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/genetic/GeneticSearchExampleConfiguration.java +++ b/dl4j-examples/src/main/java/org/deeplearning4j/examples/arbiter/genetic/GeneticSearchExampleConfiguration.java @@ -1,4 +1,4 @@ -/******************************************************************************* +/* ***************************************************************************** * Copyright (c) 2015-2019 Skymind, Inc. * * This program and the accompanying materials are made available under the @@ -49,22 +49,20 @@ import java.io.File; import java.util.Properties; -public class GeneticSearchExampleConfiguration { +class GeneticSearchExampleConfiguration { - public static ComputationGraphSpace GetGraphConfiguration() { + static ComputationGraphSpace GetGraphConfiguration() { int inputSize = 784; int outputSize = 47; // First, we setup the hyperspace parameters. These are the values which will change, breed and mutate // while attempting to find the best candidate. - DiscreteParameterSpace activationSpace = new DiscreteParameterSpace(new Activation[] { - Activation.ELU, + DiscreteParameterSpace activationSpace = new DiscreteParameterSpace<>(Activation.ELU, Activation.RELU, Activation.LEAKYRELU, Activation.TANH, Activation.SELU, - Activation.HARDSIGMOID - }); + Activation.HARDSIGMOID); IntegerParameterSpace[] layersParametersSpace = new IntegerParameterSpace[] { new IntegerParameterSpace(outputSize, inputSize), new IntegerParameterSpace(outputSize, inputSize), diff --git a/dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/captcharecognition/MultiDigitNumberRecognition.java b/dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/captcharecognition/MultiDigitNumberRecognition.java index 8379e54848..03f11284bf 100644 --- a/dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/captcharecognition/MultiDigitNumberRecognition.java +++ b/dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/captcharecognition/MultiDigitNumberRecognition.java @@ -1,4 +1,4 @@ -/******************************************************************************* +/* ***************************************************************************** * Copyright (c) 2015-2019 Skymind, Inc. * * This program and the accompanying materials are made available under the @@ -44,7 +44,7 @@ import java.io.File; /** - * @Description This is a demo that multi-digit number recognition. The maximum length is 6 digits. + * Description This is a demo that multi-digit number recognition. The maximum length is 6 digits. * If it is less than 6 digits, then zero is added to last * Training set: There were 14108 images, and they were used to train a model. * Testing set: in total 108 images,they copied from the training set,mainly to determine whether it's good that the model fited training data @@ -57,8 +57,6 @@ public class MultiDigitNumberRecognition { private static final Logger log = LoggerFactory.getLogger(MultiDigitNumberRecognition.class); - private static long seed = 123; - private static int epochs = 4; private static int batchSize = 15; private static String rootPath = System.getProperty("user.dir"); @@ -72,7 +70,8 @@ public static void main(String[] args) throws Exception { File modelDir = new File(modelDirPath); // create directory - boolean hasDir = modelDir.exists() || modelDir.mkdirs(); + if (!modelDir.exists()) { //noinspection ResultOfMethodCallIgnored + modelDir.mkdirs(); } log.info( modelPath ); //create model ComputationGraph model = createModel(); @@ -88,6 +87,7 @@ public static void main(String[] args) throws Exception { //fit model.setListeners(new ScoreIterationListener(10), new StatsListener( statsStorage), new EvaluativeListener(testMulIterator, 1, InvocationType.EPOCH_END)); + int epochs = 4; model.fit(trainMulIterator, epochs); //save @@ -104,8 +104,9 @@ public static void main(String[] args) throws Exception { } - public static ComputationGraph createModel() { + private static ComputationGraph createModel() { + long seed = 123; ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() .seed(seed) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) @@ -157,7 +158,7 @@ public static ComputationGraph createModel() { return model; } - public static void modelPredict(ComputationGraph model, MultiDataSetIterator iterator) { + private static void modelPredict(ComputationGraph model, MultiDataSetIterator iterator) { int sumCount = 0; int correctCount = 0; @@ -165,23 +166,24 @@ public static void modelPredict(ComputationGraph model, MultiDataSetIterator ite MultiDataSet mds = iterator.next(); INDArray[] output = model.output(mds.getFeatures()); INDArray[] labels = mds.getLabels(); - int dataNum = batchSize > output[0].rows() ? output[0].rows() : batchSize; + int dataNum = Math.min(batchSize, output[0].rows()); for (int dataIndex = 0; dataIndex < dataNum; dataIndex ++) { - String reLabel = ""; - String peLabel = ""; - INDArray preOutput = null; - INDArray realLabel = null; + StringBuilder reLabel = new StringBuilder(); + StringBuilder peLabel = new StringBuilder(); + INDArray preOutput; + INDArray realLabel; for (int digit = 0; digit < 6; digit ++) { preOutput = output[digit].getRow(dataIndex); - peLabel += Nd4j.argMax(preOutput, 1).getInt(0); + peLabel.append(Nd4j.argMax(preOutput, 1).getInt(0)); realLabel = labels[digit].getRow(dataIndex); - reLabel += Nd4j.argMax(realLabel, 1).getInt(0); + reLabel.append(Nd4j.argMax(realLabel, 1).getInt(0)); } - if (peLabel.equals(reLabel)) { + boolean equals = peLabel.toString().equals(reLabel.toString()); + if (equals) { correctCount ++; } sumCount ++; - log.info("real image {} prediction {} status {}", reLabel,peLabel, peLabel.equals(reLabel)); + log.info("real image {} prediction {} status {}", reLabel.toString(), peLabel.toString(), equals); } } iterator.reset(); diff --git a/dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/captcharecognition/MultiRecordDataSetIterator.java b/dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/captcharecognition/MultiRecordDataSetIterator.java index 5f415c694e..273361b833 100644 --- a/dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/captcharecognition/MultiRecordDataSetIterator.java +++ b/dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/captcharecognition/MultiRecordDataSetIterator.java @@ -1,4 +1,4 @@ -/******************************************************************************* +/* ***************************************************************************** * Copyright (c) 2015-2019 Skymind, Inc. * * This program and the accompanying materials are made available under the @@ -26,16 +26,16 @@ */ public class MultiRecordDataSetIterator implements MultiDataSetIterator { - private int batchSize = 0; + private int batchSize; private int batchNum = 0; - private int numExample = 0; + private int numExample; private MulRecordDataLoader load; private MultiDataSetPreProcessor preProcessor; - public MultiRecordDataSetIterator(int batchSize, String dataSetType) throws Exception { + MultiRecordDataSetIterator(int batchSize, String dataSetType) throws Exception { this(batchSize, null, dataSetType); } - public MultiRecordDataSetIterator(int batchSize, ImageTransform imageTransform, String dataSetType) throws Exception { + private MultiRecordDataSetIterator(int batchSize, ImageTransform imageTransform, String dataSetType) throws Exception { this.batchSize = batchSize; load = new MulRecordDataLoader(imageTransform, dataSetType); numExample = load.totalExamples(); @@ -80,11 +80,7 @@ public void reset() { @Override public boolean hasNext() { - if(batchNum < numExample){ - return true; - } else { - return false; - } + return batchNum < numExample; } @Override