/
LeNetMNIST.java
133 lines (119 loc) · 6.03 KB
/
LeNetMNIST.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
///usr/bin/env jbang "$0" "$@" ; exit $?
//DEPS org.deeplearning4j:deeplearning4j-core:1.0.0-M2
//DEPS org.nd4j:nd4j-native:1.0.0-M2
/* *****************************************************************************
*
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.examples.sample;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.InvocationType;
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
/**
* Created by agibsonccc on 9/16/15.
*/
public class LeNetMNIST {
private static final Logger log = LoggerFactory.getLogger(LeNetMNIST.class);
public static void main(String[] args) throws Exception {
int nChannels = 1; // Number of input channels
int outputNum = 10; // The number of possible outcomes
int batchSize = 64; // Test batch size
int nEpochs = 1; // Number of training epochs
int seed = 123; //
/*
Create an iterator using the batch size for one iteration
*/
log.info("Load data....");
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize,true,12345);
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize,false,12345);
/*
Construct the neural network
*/
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.l2(0.0005)
.weightInit(WeightInit.XAVIER)
.updater(new Adam(1e-3))
.list()
.layer(new ConvolutionLayer.Builder(5, 5)
//nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
.nIn(nChannels)
.stride(1,1)
.nOut(20)
.activation(Activation.IDENTITY)
.build())
.layer(new SubsamplingLayer.Builder(PoolingType.MAX)
.kernelSize(2,2)
.stride(2,2)
.build())
.layer(new ConvolutionLayer.Builder(5, 5)
//Note that nIn need not be specified in later layers
.stride(1,1)
.nOut(50)
.activation(Activation.IDENTITY)
.build())
.layer(new SubsamplingLayer.Builder(PoolingType.MAX)
.kernelSize(2,2)
.stride(2,2)
.build())
.layer(new DenseLayer.Builder().activation(Activation.RELU)
.nOut(500).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(28,28,1)) //See note below
.build();
/*
Regarding the .setInputType(InputType.convolutionalFlat(28,28,1)) line: This does a few things.
(a) It adds preprocessors, which handle things like the transition between the convolutional/subsampling layers
and the dense layer
(b) Does some additional configuration validation
(c) Where necessary, sets the nIn (number of input neurons, or input depth in the case of CNNs) values for each
layer based on the size of the previous layer (but it won't override values manually set by the user)
InputTypes can be used with other layer types too (RNNs, MLPs etc) not just CNNs.
For normal images (when using ImageRecordReader) use InputType.convolutional(height,width,depth).
MNIST record reader is a special case, that outputs 28x28 pixel grayscale (nChannels=1) images, in a "flattened"
row vector format (i.e., 1x784 vectors), hence the "convolutionalFlat" input type used here.
*/
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
log.info("Train model...");
model.setListeners(new ScoreIterationListener(10), new EvaluativeListener(mnistTest, 1, InvocationType.EPOCH_END)); //Print score every 10 iterations and evaluate on test set every epoch
model.fit(mnistTrain, nEpochs);
String path = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "lenetmnist.zip");
log.info("Saving model to tmp folder: "+path);
model.save(new File(path), true);
log.info("****************Example finished********************");
}
}