/
UCISequenceClassification.java
227 lines (196 loc) · 11.3 KB
/
UCISequenceClassification.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
/* *****************************************************************************
*
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.examples.quickstart.modeling.recurrent;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.InvocationType;
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.learning.config.Nadam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.common.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
/**
* Sequence Classification Example Using a LSTM Recurrent Neural Network
*
* This example learns how to classify univariate time series as belonging to one of six categories.
* Categories are: Normal, Cyclic, Increasing trend, Decreasing trend, Upward shift, Downward shift
*
* Data is the UCI Synthetic Control Chart Time Series Data Set
* Details: https://archive.ics.uci.edu/ml/datasets/Synthetic+Control+Chart+Time+Series
* Data: https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/synthetic_control.data
* Image: https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/data.jpeg
*
* This example proceeds as follows:
* 1. Download and prepare the data (in downloadUCIData() method)
* (a) Split the 600 sequences into train set of size 450, and test set of size 150
* (b) Write the data into a format suitable for loading using the CSVSequenceRecordReader for sequence classification
* This format: one time series per file, and a separate file for the labels.
* For example, train/features/0.csv is the features using with the labels file train/labels/0.csv
* Because the data is a univariate time series, we only have one column in the CSV files. Normally, each column
* would contain multiple values - one time step per row.
* Furthermore, because we have only one label for each time series, the labels CSV files contain only a single value
*
* 2. Load the training data using CSVSequenceRecordReader (to load/parse the CSV files) and SequenceRecordReaderDataSetIterator
* (to convert it to DataSet objects, ready to train)
* For more details on this step, see: https://deeplearning4j.konduit.ai/models/recurrent#data-for-rnns
*
* 3. Normalize the data. The raw data contain values that are too large for effective training, and need to be normalized.
* Normalization is conducted using NormalizerStandardize, based on statistics (mean, st.dev) collected on the training
* data only. Note that both the training data and test data are normalized in the same way.
*
* 4. Configure the network
* The data set here is very small, so we can't afford to use a large network with many parameters.
* We are using one small LSTM layer and one RNN output layer
*
* 5. Train the network for 40 epochs
* At each epoch, evaluate and print the accuracy and f1 on the test set
*
* @author Alex Black
*/
@SuppressWarnings("ResultOfMethodCallIgnored")
public class UCISequenceClassification {
private static final Logger log = LoggerFactory.getLogger(UCISequenceClassification.class);
//'baseDir': Base directory for the data. Change this if you want to save the data somewhere else
private static File baseDir = new File("src/main/resources/uci/");
private static File baseTrainDir = new File(baseDir, "train");
private static File featuresDirTrain = new File(baseTrainDir, "features");
private static File labelsDirTrain = new File(baseTrainDir, "labels");
private static File baseTestDir = new File(baseDir, "test");
private static File featuresDirTest = new File(baseTestDir, "features");
private static File labelsDirTest = new File(baseTestDir, "labels");
public static void main(String[] args) throws Exception {
downloadUCIData();
// ----- Load the training data -----
//Note that we have 450 training files for features: train/features/0.csv through train/features/449.csv
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, 449));
int miniBatchSize = 10;
int numLabelClasses = 6;
DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
//Normalize the training data
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainData); //Collect training data statistics
trainData.reset();
//Use previously collected statistics to normalize on-the-fly. Each DataSet returned by 'trainData' iterator will be normalized
trainData.setPreProcessor(normalizer);
// ----- Load the test data -----
//Same process as for the training data.
SequenceRecordReader testFeatures = new CSVSequenceRecordReader();
testFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
SequenceRecordReader testLabels = new CSVSequenceRecordReader();
testLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, 149));
DataSetIterator testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels, miniBatchSize, numLabelClasses,
false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
testData.setPreProcessor(normalizer); //Note that we are using the exact same normalization process as the training data
// ----- Configure the network -----
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123) //Random number generator seed for improved repeatability. Optional.
.weightInit(WeightInit.XAVIER)
.updater(new Nadam())
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) //Not always required, but helps with this data set
.gradientNormalizationThreshold(0.5)
.list()
.layer(new LSTM.Builder().activation(Activation.TANH).nIn(1).nOut(10).build())
.layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(10).nOut(numLabelClasses).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
log.info("Starting training...");
net.setListeners(new ScoreIterationListener(20), new EvaluativeListener(testData, 1, InvocationType.EPOCH_END)); //Print the score (loss function value) every 20 iterations
int nEpochs = 40;
net.fit(trainData, nEpochs);
log.info("Evaluating...");
Evaluation eval = net.evaluate(testData);
log.info(eval.stats());
log.info("----- Example Complete -----");
}
//This method downloads the data, and converts the "one time series per line" format into a suitable
//CSV sequence format that DataVec (CsvSequenceRecordReader) and DL4J can read.
private static void downloadUCIData() throws Exception {
if (baseDir.exists()) return; //Data already exists, don't download it again
String url = "https://archive.ics.uci.edu/ml/machine-learning-databases/synthetic_control-mld/synthetic_control.data";
String data = IOUtils.toString(new URL(url), (Charset) null);
String[] lines = data.split("\n");
//Create directories
baseDir.mkdir();
baseTrainDir.mkdir();
featuresDirTrain.mkdir();
labelsDirTrain.mkdir();
baseTestDir.mkdir();
featuresDirTest.mkdir();
labelsDirTest.mkdir();
int lineCount = 0;
List<Pair<String, Integer>> contentAndLabels = new ArrayList<>();
for (String line : lines) {
String transposed = line.replaceAll(" +", "\n");
//Labels: first 100 quickstartexamples (lines) are label 0, second 100 quickstartexamples are label 1, and so on
contentAndLabels.add(new Pair<>(transposed, lineCount++ / 100));
}
//Randomize and do a train/test split:
Collections.shuffle(contentAndLabels, new Random(12345));
int nTrain = 450; //75% train, 25% test
int trainCount = 0;
int testCount = 0;
for (Pair<String, Integer> p : contentAndLabels) {
//Write output in a format we can read, in the appropriate locations
File outPathFeatures;
File outPathLabels;
if (trainCount < nTrain) {
outPathFeatures = new File(featuresDirTrain, trainCount + ".csv");
outPathLabels = new File(labelsDirTrain, trainCount + ".csv");
trainCount++;
} else {
outPathFeatures = new File(featuresDirTest, testCount + ".csv");
outPathLabels = new File(labelsDirTest, testCount + ".csv");
testCount++;
}
FileUtils.writeStringToFile(outPathFeatures, p.getFirst(), (Charset) null);
FileUtils.writeStringToFile(outPathLabels, p.getSecond().toString(), (Charset) null);
}
}
}