/
BasicHyperparameterOptimizationExample.java
220 lines (179 loc) · 10.1 KB
/
BasicHyperparameterOptimizationExample.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
/* *****************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.examples.arbiter;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
import org.deeplearning4j.arbiter.layers.OutputLayerSpace;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition;
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner;
import org.deeplearning4j.arbiter.saver.local.FileModelSaver;
import org.deeplearning4j.arbiter.scoring.impl.EvaluationScoreFunction;
import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator;
import org.deeplearning4j.arbiter.ui.listener.ArbiterStatusListener;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.storage.FileStatsStorage;
import org.nd4j.evaluation.classification.Evaluation.Metric;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.TimeUnit;
/**
* This is a basic hyperparameter optimization example using Arbiter to conduct random search on two network hyperparameters.
* The two hyperparameters are learning rate and layer size, and the search is conducted for a simple multi-layer perceptron
* on MNIST data.
* <p>
* Note that this example is set up to use Arbiter's UI: http://localhost:9000/arbiter
*
* @author Alex Black
*/
public class BasicHyperparameterOptimizationExample {
public static void main(String[] args) throws Exception {
//First: Set up the hyperparameter configuration space. This is like a MultiLayerConfiguration, but can have either
// fixed values or values to optimize, for each hyperparameter
ParameterSpace<Double> learningRateHyperparam = new ContinuousParameterSpace(0.0001, 0.1); //Values will be generated uniformly at random between 0.0001 and 0.1 (inclusive)
ParameterSpace<Integer> layerSizeHyperparam = new IntegerParameterSpace(16, 256); //Integer values will be generated uniformly at random between 16 and 256 (inclusive)
MultiLayerSpace hyperparameterSpace = new MultiLayerSpace.Builder()
//These next few options: fixed values for all models
.weightInit(WeightInit.XAVIER)
.l2(0.0001)
//Learning rate hyperparameter: search over different values, applied to all models
.updater(new SgdSpace(learningRateHyperparam))
.addLayer(new DenseLayerSpace.Builder()
//Fixed values for this layer:
.nIn(784) //Fixed input: 28x28=784 pixels for MNIST
.activation(Activation.LEAKYRELU)
//One hyperparameter to infer: layer size
.nOut(layerSizeHyperparam)
.build())
.addLayer(new OutputLayerSpace.Builder()
.nOut(10)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT)
.build())
.numEpochs(2)
.build();
//Now: We need to define a few configuration options
// (a) How are we going to generate candidates? (random search or grid search)
CandidateGenerator candidateGenerator = new RandomSearchGenerator(hyperparameterSpace, null); //Alternatively: new GridSearchCandidateGenerator<>(hyperparameterSpace, 5, GridSearchCandidateGenerator.Mode.RandomOrder);
// (b) How are going to provide data? We'll use a simple data source that returns MNIST data
// Note that we set teh number of epochs in MultiLayerSpace above
Class<? extends DataSource> dataSourceClass = ExampleDataSource.class;
Properties dataSourceProperties = new Properties();
dataSourceProperties.setProperty("minibatchSize", "64");
// (c) How we are going to save the models that are generated and tested?
// In this example, let's save them to disk the working directory
// This will result in examples being saved to arbiterExample/0/, arbiterExample/1/, arbiterExample/2/, ...
String baseSaveDirectory = "arbiterExample/";
File f = new File(baseSaveDirectory);
if (f.exists()) //noinspection ResultOfMethodCallIgnored
f.delete();
//noinspection ResultOfMethodCallIgnored
f.mkdir();
ResultSaver modelSaver = new FileModelSaver(baseSaveDirectory);
// (d) What are we actually trying to optimize?
// In this example, let's use classification accuracy on the test set
// See also ScoreFunctions.testSetF1(), ScoreFunctions.testSetRegression(regressionValue) etc
ScoreFunction scoreFunction = new EvaluationScoreFunction(Metric.ACCURACY);
// (e) When should we stop searching? Specify this with termination conditions
// For this example, we are stopping the search at 15 minutes or 10 candidates - whichever comes first
TerminationCondition[] terminationConditions = {
new MaxTimeCondition(15, TimeUnit.MINUTES),
new MaxCandidatesCondition(10)};
//Given these configuration options, let's put them all together:
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator)
.dataSource(dataSourceClass,dataSourceProperties)
.modelSaver(modelSaver)
.scoreFunction(scoreFunction)
.terminationConditions(terminationConditions)
.build();
//And set up execution locally on this machine:
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
//Start the UI. Arbiter uses the same storage and persistence approach as DL4J's UI
//Access at http://localhost:9000/arbiter
StatsStorage ss = new FileStatsStorage(new File(System.getProperty("java.io.tmpdir"), "arbiterExampleUiStats.dl4j"));
runner.addListeners(new ArbiterStatusListener(ss));
UIServer.getInstance().attach(ss);
//Start the hyperparameter optimization
runner.execute();
//Print out some basic stats regarding the optimization procedure
String s = "Best score: " + runner.bestScore() + "\n" +
"Index of model with best score: " + runner.bestScoreCandidateIndex() + "\n" +
"Number of configurations evaluated: " + runner.numCandidatesCompleted() + "\n";
System.out.println(s);
//Get all results, and print out details of the best result:
int indexOfBestResult = runner.bestScoreCandidateIndex();
List<ResultReference> allResults = runner.getResults();
OptimizationResult bestResult = allResults.get(indexOfBestResult).getResult();
MultiLayerNetwork bestModel = (MultiLayerNetwork) bestResult.getResultReference().getResultModel();
System.out.println("\n\nConfiguration of best model:\n");
System.out.println(bestModel.getLayerWiseConfigurations().toJson());
//Wait a while before exiting
Thread.sleep(60000);
UIServer.getInstance().stop();
}
public static class ExampleDataSource implements DataSource {
private int minibatchSize;
public ExampleDataSource() {
}
@Override
public void configure(Properties properties) {
this.minibatchSize = Integer.parseInt(properties.getProperty("minibatchSize", "16"));
}
@Override
public Object trainData() {
try {
return new MnistDataSetIterator(minibatchSize, true, 12345);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public Object testData() {
try {
return new MnistDataSetIterator(minibatchSize, false, 12345);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public Class<?> getDataType() {
return DataSetIterator.class;
}
}
}