Skip to content

Commit

Permalink
wip updating examples.
Browse files Browse the repository at this point in the history
Signed-off-by: Robert Altena <Rob@Ra-ai.com>
  • Loading branch information
RobAltena committed Sep 23, 2019
1 parent 391ec1b commit f2d19a4
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 102 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -61,3 +61,4 @@ Word2vec-index/
!tutorials/*.json
*end.model

arbiterExample/
@@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
Expand Down Expand Up @@ -42,11 +42,11 @@
import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator;
import org.deeplearning4j.arbiter.ui.listener.ArbiterStatusListener;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.storage.FileStatsStorage;
import org.nd4j.evaluation.classification.Evaluation.Metric;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
Expand Down Expand Up @@ -114,14 +114,16 @@ public static void main(String[] args) throws Exception {
// This will result in examples being saved to arbiterExample/0/, arbiterExample/1/, arbiterExample/2/, ...
String baseSaveDirectory = "arbiterExample/";
File f = new File(baseSaveDirectory);
if (f.exists()) f.delete();
if (f.exists()) //noinspection ResultOfMethodCallIgnored
f.delete();
//noinspection ResultOfMethodCallIgnored
f.mkdir();
ResultSaver modelSaver = new FileModelSaver(baseSaveDirectory);

// (d) What are we actually trying to optimize?
// In this example, let's use classification accuracy on the test set
// See also ScoreFunctions.testSetF1(), ScoreFunctions.testSetRegression(regressionValue) etc
ScoreFunction scoreFunction = new EvaluationScoreFunction(Evaluation.Metric.ACCURACY);
ScoreFunction scoreFunction = new EvaluationScoreFunction(Metric.ACCURACY);


// (e) When should we stop searching? Specify this with termination conditions
Expand Down
@@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
Expand All @@ -25,9 +25,10 @@
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.deeplearning4j.arbiter.scoring.impl.EvaluationScoreFunction;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.evaluation.classification.Evaluation.Metric;

import java.io.IOException;
import java.util.List;

/**
Expand All @@ -38,23 +39,13 @@
*
* @author Alexandre Boulanger
*/

public class BaseGeneticHyperparameterOptimizationExample {

public static void main(String[] args) throws Exception {

ComputationGraphSpace cgs = GeneticSearchExampleConfiguration.GetGraphConfiguration();

EvaluationScoreFunction scoreFunction = new EvaluationScoreFunction(Evaluation.Metric.F1);

// This is where we create the GeneticSearchCandidateGenerator with its default behavior:
// - a population that fits 30 candidates and is culled back to 20 when it overflows
// - new candidates are generated with a probability of 85% of being the result of breeding (a k-point crossover with 1 to 4 points)
// - the new candidate have a probability of 0.5% of sustaining a random mutation on one of its genes.
GeneticSearchCandidateGenerator candidateGenerator = new GeneticSearchCandidateGenerator.Builder(cgs, scoreFunction).build();

// Let's have a listener to print the population size after each evaluation.
PopulationModel populationModel = candidateGenerator.getPopulationModel();
/**
* Common code used by two Arbiter examples.
*/
public static void run(PopulationModel populationModel, GeneticSearchCandidateGenerator candidateGenerator,
EvaluationScoreFunction scoreFunction) throws IOException {
populationModel.addListener(new ExamplePopulationListener());

IOptimizationRunner runner = GeneticSearchExampleConfiguration.BuildRunner(candidateGenerator, scoreFunction);
Expand All @@ -80,13 +71,32 @@ public static void main(String[] args) throws Exception {
System.out.println(bestModel.getConfiguration().toJson());
}

public static void main(String[] args) throws Exception {

ComputationGraphSpace cgs = GeneticSearchExampleConfiguration.GetGraphConfiguration();

EvaluationScoreFunction scoreFunction = new EvaluationScoreFunction(Metric.F1);

// This is where we create the GeneticSearchCandidateGenerator with its default behavior:
// - a population that fits 30 candidates and is culled back to 20 when it overflows
// - new candidates are generated with a probability of 85% of being the result of breeding (a k-point crossover with 1 to 4 points)
// - the new candidate have a probability of 0.5% of sustaining a random mutation on one of its genes.
GeneticSearchCandidateGenerator candidateGenerator = new GeneticSearchCandidateGenerator.Builder(cgs, scoreFunction).build();

// Let's have a listener to print the population size after each evaluation.
PopulationModel populationModel = candidateGenerator.getPopulationModel();
populationModel.addListener(new ExamplePopulationListener());
run(populationModel, candidateGenerator, scoreFunction);
}

public static class ExamplePopulationListener implements PopulationListener {

@SuppressWarnings("OptionalGetWithoutIsPresent")
@Override
public void onChanged(List<Chromosome> population) {
double best = population.get(0).getFitness();
double average = population.stream()
.mapToDouble(c -> c.getFitness())
.mapToDouble(Chromosome::getFitness)
.average()
.getAsDouble();
System.out.println(String.format("\nPopulation size is %1$s, best score is %2$s, average score is %3$s", population.size(), best, average));
Expand Down
@@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
Expand All @@ -20,25 +20,19 @@
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
import org.deeplearning4j.arbiter.optimize.generator.GeneticSearchCandidateGenerator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover;
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationListener;
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.GeneticSelectionOperator;
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.SelectionOperator;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.deeplearning4j.arbiter.scoring.impl.EvaluationScoreFunction;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.evaluation.classification.Evaluation.Metric;

import java.util.List;
import static org.deeplearning4j.examples.arbiter.genetic.BaseGeneticHyperparameterOptimizationExample.run;

/**
* In this hyperparameter optimization example, we change the default behavior of the genetic candidate generator.
Expand All @@ -55,7 +49,7 @@ public static void main(String[] args) throws Exception {

ComputationGraphSpace cgs = GeneticSearchExampleConfiguration.GetGraphConfiguration();

EvaluationScoreFunction scoreFunction = new EvaluationScoreFunction(Evaluation.Metric.F1);
EvaluationScoreFunction scoreFunction = new EvaluationScoreFunction(Metric.F1);

// The ExampleCullOperator extends the default cull operator (least fit) to include an artificial predator.
CullOperator cullOperator = new ExampleCullOperator();
Expand Down Expand Up @@ -85,30 +79,8 @@ public static void main(String[] args) throws Exception {
.build();

// Let's have a listener to print the population size after each evaluation.
populationModel.addListener(new ExamplePopulationListener());

IOptimizationRunner runner = GeneticSearchExampleConfiguration.BuildRunner(candidateGenerator, scoreFunction);

//Start the hyperparameter optimization
runner.execute();

//Print out some basic stats regarding the optimization procedure
String s = "Best score: " + runner.bestScore() + "\n" +
"Index of model with best score: " + runner.bestScoreCandidateIndex() + "\n" +
"Number of configurations evaluated: " + runner.numCandidatesCompleted() + "\n";
System.out.println(s);


//Get all results, and print out details of the best result:
int indexOfBestResult = runner.bestScoreCandidateIndex();
List<ResultReference> allResults = runner.getResults();

OptimizationResult bestResult = allResults.get(indexOfBestResult).getResult();
ComputationGraph bestModel = (ComputationGraph) bestResult.getResultReference().getResultModel();

System.out.println("\n\nConfiguration of best model:\n");
System.out.println(bestModel.getConfiguration().toJson());

populationModel.addListener(new BaseGeneticHyperparameterOptimizationExample.ExamplePopulationListener());
run(populationModel, candidateGenerator, scoreFunction);
}

// This is an example of a custom behavior for the genetic algorithm. We force one of the parent to be one of the
Expand Down Expand Up @@ -158,17 +130,4 @@ public void cullPopulation() {
System.out.println(String.format("Randomly removed %1$s candidate(s).", preyCount));
}
}

public static class ExamplePopulationListener implements PopulationListener {

@Override
public void onChanged(List<Chromosome> population) {
double best = population.get(0).getFitness();
double average = population.stream()
.mapToDouble(Chromosome::getFitness)
.average()
.getAsDouble();
System.out.println(String.format("\nPopulation size is %1$s, best score is %2$s, average score is %3$s", population.size(), best, average));
}
}
}
@@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
Expand Down Expand Up @@ -49,22 +49,20 @@
import java.io.File;
import java.util.Properties;

public class GeneticSearchExampleConfiguration {
class GeneticSearchExampleConfiguration {

public static ComputationGraphSpace GetGraphConfiguration() {
static ComputationGraphSpace GetGraphConfiguration() {
int inputSize = 784;
int outputSize = 47;

// First, we setup the hyperspace parameters. These are the values which will change, breed and mutate
// while attempting to find the best candidate.
DiscreteParameterSpace<Activation> activationSpace = new DiscreteParameterSpace(new Activation[] {
Activation.ELU,
DiscreteParameterSpace<Activation> activationSpace = new DiscreteParameterSpace<>(Activation.ELU,
Activation.RELU,
Activation.LEAKYRELU,
Activation.TANH,
Activation.SELU,
Activation.HARDSIGMOID
});
Activation.HARDSIGMOID);
IntegerParameterSpace[] layersParametersSpace = new IntegerParameterSpace[] {
new IntegerParameterSpace(outputSize, inputSize),
new IntegerParameterSpace(outputSize, inputSize),
Expand Down
@@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
Expand Down Expand Up @@ -44,7 +44,7 @@
import java.io.File;

/**
* @Description This is a demo that multi-digit number recognition. The maximum length is 6 digits.
* Description This is a demo that multi-digit number recognition. The maximum length is 6 digits.
* If it is less than 6 digits, then zero is added to last
* Training set: There were 14108 images, and they were used to train a model.
* Testing set: in total 108 images,they copied from the training set,mainly to determine whether it's good that the model fited training data
Expand All @@ -57,8 +57,6 @@ public class MultiDigitNumberRecognition {

private static final Logger log = LoggerFactory.getLogger(MultiDigitNumberRecognition.class);

private static long seed = 123;
private static int epochs = 4;
private static int batchSize = 15;
private static String rootPath = System.getProperty("user.dir");

Expand All @@ -72,7 +70,8 @@ public static void main(String[] args) throws Exception {
File modelDir = new File(modelDirPath);

// create directory
boolean hasDir = modelDir.exists() || modelDir.mkdirs();
if (!modelDir.exists()) { //noinspection ResultOfMethodCallIgnored
modelDir.mkdirs(); }
log.info( modelPath );
//create model
ComputationGraph model = createModel();
Expand All @@ -88,6 +87,7 @@ public static void main(String[] args) throws Exception {

//fit
model.setListeners(new ScoreIterationListener(10), new StatsListener( statsStorage), new EvaluativeListener(testMulIterator, 1, InvocationType.EPOCH_END));
int epochs = 4;
model.fit(trainMulIterator, epochs);

//save
Expand All @@ -104,8 +104,9 @@ public static void main(String[] args) throws Exception {

}

public static ComputationGraph createModel() {
private static ComputationGraph createModel() {

long seed = 123;
ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
.seed(seed)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
Expand Down Expand Up @@ -157,31 +158,32 @@ public static ComputationGraph createModel() {
return model;
}

public static void modelPredict(ComputationGraph model, MultiDataSetIterator iterator) {
private static void modelPredict(ComputationGraph model, MultiDataSetIterator iterator) {
int sumCount = 0;
int correctCount = 0;

while (iterator.hasNext()) {
MultiDataSet mds = iterator.next();
INDArray[] output = model.output(mds.getFeatures());
INDArray[] labels = mds.getLabels();
int dataNum = batchSize > output[0].rows() ? output[0].rows() : batchSize;
int dataNum = Math.min(batchSize, output[0].rows());
for (int dataIndex = 0; dataIndex < dataNum; dataIndex ++) {
String reLabel = "";
String peLabel = "";
INDArray preOutput = null;
INDArray realLabel = null;
StringBuilder reLabel = new StringBuilder();
StringBuilder peLabel = new StringBuilder();
INDArray preOutput;
INDArray realLabel;
for (int digit = 0; digit < 6; digit ++) {
preOutput = output[digit].getRow(dataIndex);
peLabel += Nd4j.argMax(preOutput, 1).getInt(0);
peLabel.append(Nd4j.argMax(preOutput, 1).getInt(0));
realLabel = labels[digit].getRow(dataIndex);
reLabel += Nd4j.argMax(realLabel, 1).getInt(0);
reLabel.append(Nd4j.argMax(realLabel, 1).getInt(0));
}
if (peLabel.equals(reLabel)) {
boolean equals = peLabel.toString().equals(reLabel.toString());
if (equals) {
correctCount ++;
}
sumCount ++;
log.info("real image {} prediction {} status {}", reLabel,peLabel, peLabel.equals(reLabel));
log.info("real image {} prediction {} status {}", reLabel.toString(), peLabel.toString(), equals);
}
}
iterator.reset();
Expand Down
@@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
Expand Down Expand Up @@ -26,16 +26,16 @@
*/

public class MultiRecordDataSetIterator implements MultiDataSetIterator {
private int batchSize = 0;
private int batchSize;
private int batchNum = 0;
private int numExample = 0;
private int numExample;
private MulRecordDataLoader load;
private MultiDataSetPreProcessor preProcessor;

public MultiRecordDataSetIterator(int batchSize, String dataSetType) throws Exception {
MultiRecordDataSetIterator(int batchSize, String dataSetType) throws Exception {
this(batchSize, null, dataSetType);
}
public MultiRecordDataSetIterator(int batchSize, ImageTransform imageTransform, String dataSetType) throws Exception {
private MultiRecordDataSetIterator(int batchSize, ImageTransform imageTransform, String dataSetType) throws Exception {
this.batchSize = batchSize;
load = new MulRecordDataLoader(imageTransform, dataSetType);
numExample = load.totalExamples();
Expand Down Expand Up @@ -80,11 +80,7 @@ public void reset() {

@Override
public boolean hasNext() {
if(batchNum < numExample){
return true;
} else {
return false;
}
return batchNum < numExample;
}

@Override
Expand Down

0 comments on commit f2d19a4

Please sign in to comment.