/
ModelConfiguration.java
289 lines (266 loc) · 13.8 KB
/
ModelConfiguration.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
package org.deeplearning4j.nn.modelimport.keras;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.nd4j.shade.jackson.core.type.TypeReference;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Routines for importing saved Keras model configurations.
*
* @author davekale
*/
public class ModelConfiguration {
private static Logger log = LoggerFactory.getLogger(Model.class);
private ModelConfiguration() {}
/**
* Imports a Keras Sequential model configuration saved using call to model.to_json().
*
* @param configJsonFilename Path to text file storing Keras configuration as valid JSON.
* @return DL4J MultiLayerConfiguration
* @throws IOException
*/
public static MultiLayerConfiguration importSequentialModelConfigFromFile(String configJsonFilename)
throws IOException {
String configJson = new String(Files.readAllBytes(Paths.get(configJsonFilename)));
return importSequentialModelConfig(configJson);
}
/**
* Imports a Keras Functional API model configuration saved using call to model.to_json().
*
* @param configJsonFilename Path to text file storing Keras configuration as valid JSON.
* @return DL4J ComputationGraphConfiguration
* @throws IOException
*/
public static ComputationGraphConfiguration importFunctionalApiConfigFromFile(String configJsonFilename)
throws IOException {
String configJson = new String(Files.readAllBytes(Paths.get(configJsonFilename)));
return importFunctionalApiConfig(configJson);
}
/**
* Imports a Keras Sequential model configuration saved using call to model.to_json().
*
* @param configJson String storing Keras configuration as valid JSON.
* @return DL4J MultiLayerConfiguration
* @throws IOException
*/
public static MultiLayerConfiguration importSequentialModelConfig(String configJson)
throws IOException {
Map<String,Object> kerasConfig = parseJsonString(configJson);
MultiLayerConfiguration modelConfig = importSequentialModelConfig(kerasConfig);
return modelConfig;
}
/**
* Imports a Keras Functional API model configuration saved using call to model.to_json().
*
* @param configJson String storing Keras configuration as valid JSON.
* @return DL4J ComputationGraphConfiguration
* @throws IOException
*/
public static ComputationGraphConfiguration importFunctionalApiConfig(String configJson)
throws IOException {
Map<String,Object> kerasConfig = parseJsonString(configJson);
ComputationGraphConfiguration modelConfig = importFunctionalApiConfig(kerasConfig);
return modelConfig;
}
/**
* Imports a Keras Sequential model configuration saved using call to model.to_json().
*
* @param kerasConfig Nested Map storing Keras configuration read from valid JSON.
* @return DL4J MultiLayerConfiguration
* @throws IOException
* @throws NotImplementedException
* @throws IncompatibleKerasConfigurationException
*/
private static MultiLayerConfiguration importSequentialModelConfig(Map<String,Object> kerasConfig)
throws IOException, IncompatibleKerasConfigurationException {
String arch = (String)kerasConfig.get("class_name");
if (!arch.equals("Sequential"))
throw new IncompatibleKerasConfigurationException("Expected \"Sequential\" model config, found " + arch);
/* Make first pass through layer configs to
* - merge dropout layers into subsequent layers
* - merge activation layers into previous layers
* TODO: remove this once Dropout layer added to DL4J
*/
double prevDropout = 0.0;
List<Map<String,Object>> layerConfigs = new ArrayList<>();
for (Object o : (List<Object>)kerasConfig.get("config")) {
String kerasLayerName = (String)((Map<String,Object>)o).get("class_name");
Map<String,Object> layerConfig = (Map<String,Object>)((Map<String,Object>)o).get("config");
switch (kerasLayerName) {
case "Dropout":
/* Store dropout layer so we can merge into subsequent layer.
* TODO: remove once Dropout layer added to DL4J.
*/
prevDropout = (double)layerConfig.get("p");
continue;
case "Activation":
/* Merge activation function into previous layer.
* TODO: we have an Activation layer in DL4J so maybe remove this.
*/
if (layerConfigs.size() == 0)
throw new IncompatibleKerasConfigurationException("Plain activation layer applied to input not supported.");
String activation = LayerConfiguration.mapActivation((String)layerConfig.get("activation"));
layerConfigs.get(layerConfigs.size()-1).put("activation", activation);
continue;
}
layerConfig.put("keras_class", kerasLayerName);
/* Merge dropout from previous layer.
* TODO: remove once Dropout layer added to DL4J.
*/
if (prevDropout > 0) {
double oldDropout = layerConfig.containsKey("dropout") ? (double)layerConfig.get("dropout") : 0.0;
double newDropout = 1.0 - (1.0 - prevDropout) * (1.0 - oldDropout);
layerConfig.put("dropout", newDropout);
if (oldDropout != newDropout)
log.warn("Changed layer-defined dropout " + oldDropout + " to " + newDropout +
" because of previous Dropout=" + newDropout + " layer");
prevDropout = 0.0;
}
layerConfigs.add(layerConfig);
}
/* Make pass through layer configs, building each in turn. In addition:
* - get input shape from "batch_input_shape" field of input layer config
* - get dim ordering (based on Keras backend)
* - determine whether model includes recurrent or convolutional layers
*/
List<Integer> batchInputShape = null;
String dimOrdering = null;
boolean isRecurrent = false;
boolean isConvolutional = false;
NeuralNetConfiguration.Builder modelBuilder = new NeuralNetConfiguration.Builder();
NeuralNetConfiguration.ListBuilder listBuilder = modelBuilder.list();
int layerIndex = 0;
for (Map<String,Object> layerConfig : layerConfigs) {
String kerasLayerName = (String)layerConfig.get("keras_class");
/* Look for "batch_input_shape" field, which should be set
* for input layer and ONLY for input layer.
*/
if (layerConfig.containsKey("batch_input_shape")) {
if (layerIndex > 0)
throw new IncompatibleKerasConfigurationException("Non-input layer should not specify \"batch_input_shape\" field");
else
batchInputShape = (List<Integer>) layerConfig.get("batch_input_shape");
} else if (layerIndex == 0)
throw new IncompatibleKerasConfigurationException("Input layer must specify \"batch_input_shape\" field");
/* Look for "dim_ordering" field, which will generally
* show up only in convolutional and max pooling layers.
*/
if (layerConfig.containsKey("dim_ordering")) {
String layerDimOrdering = (String)layerConfig.get("dim_ordering");
if (!layerDimOrdering.equals("th") && !layerDimOrdering.equals("tf"))
throw new IncompatibleKerasConfigurationException("Unknown Keras backend: " + layerDimOrdering);
if (dimOrdering != null && !layerDimOrdering.equals(dimOrdering))
throw new IncompatibleKerasConfigurationException("Found layers with conflicting Keras backends.");
dimOrdering = layerDimOrdering;
}
/* Build layer based on name, config, order. */
Layer layer = LayerConfiguration.buildLayer(kerasLayerName, layerConfig, (layerIndex == layerConfigs.size()-1));
if (layer == null)
continue;
/* Detect whether layer is recurrent or convolutional. */
if (layer instanceof BaseRecurrentLayer)
isRecurrent = true;
else if (layer instanceof ConvolutionLayer)
isConvolutional = true;
if (layer.getL1() > 0 || layer.getL2() > 0)
modelBuilder.regularization(true);
/* Add layer to list builder. */
listBuilder.layer(layerIndex, layer);
layerIndex++;
}
/* If layer is recurrent or convolutional, set input type to appropriate
* InputType with shape based on "batch_input_shape" field.
*/
if (isRecurrent && isConvolutional) {
throw new IncompatibleKerasConfigurationException("Recurrent convolutional architecture not supported.");
} else if (isRecurrent) {
listBuilder.setInputType(InputType.recurrent(batchInputShape.get(2)));
if (batchInputShape.get(1) == null)
log.warn("Input sequence length must be specified manually for truncated BPTT!");
else {
int sequenceLength = batchInputShape.get(1);
listBuilder.tBPTTForwardLength(sequenceLength).tBPTTBackwardLength(sequenceLength);
}
} else if (isConvolutional) {
int[] imageSize = new int[3];
if (dimOrdering.equals("tf")) {
/* TensorFlow convolutional input: # examples, # rows, # cols, # channels */
imageSize[0] = batchInputShape.get(1);
imageSize[1] = batchInputShape.get(2);
imageSize[2] = batchInputShape.get(3);
} else if (dimOrdering.equals("th")) {
/* Theano convolutional input: # examples, # channels, # rows, # cols */
imageSize[0] = batchInputShape.get(2);
imageSize[1] = batchInputShape.get(3);
imageSize[2] = batchInputShape.get(1);
} else {
throw new IncompatibleKerasConfigurationException("Unknown keras backend " + dimOrdering);
}
listBuilder.setInputType(InputType.convolutional(imageSize[0], imageSize[1], imageSize[2]));
} else {
listBuilder.setInputType(InputType.feedForward(batchInputShape.get(1)));
}
return listBuilder.build();
}
/**
* Imports a Keras Functional API model configuration saved using call to model.to_json().
*
* @param kerasConfig Nested Map storing Keras configuration read from valid JSON.
* @return DL4J ComputationGraph
* @throws IOException
* @throws NotImplementedException
* @throws IncompatibleKerasConfigurationException
*/
private static ComputationGraphConfiguration importFunctionalApiConfig(Map<String,Object> kerasConfig)
throws IOException, NotImplementedException, IncompatibleKerasConfigurationException {
throw new NotImplementedException("Import of Keras Functional API model configs not supported.");
}
/**
* Extract Keras configuration properties that may are not relevant for configuring DL4J layers
* or models but may be important when importing stored model weights. Only relevant property
* at this time is the Keras backend (stored as "dim_ordering" in convolutional and pooling layers).
*
* @param configJson String storing Keras configuration as valid JSON
* @return Map from metadata fields to relevant values
* @throws IOException
*/
public static Map<String, Object> extractWeightsMetadataFromConfig(String configJson) throws IOException {
Map<String,Object> weightsMetadata = new HashMap<>();
ObjectMapper mapper = new ObjectMapper();
TypeReference<HashMap<String,Object>> typeRef = new TypeReference<HashMap<String,Object>>() {};
Map<String,Object> kerasConfig = mapper.readValue(configJson, typeRef);
List<Map<String,Object>> layers = (List<Map<String,Object>>)kerasConfig.get("config");
for (Map<String,Object> layer : layers) {
Map<String,Object> layerConfig = (Map<String,Object>)layer.get("config");
if (layerConfig.containsKey("dim_ordering") && !weightsMetadata.containsKey("keras_backend"))
weightsMetadata.put("keras_backend", layerConfig.get("dim_ordering"));
}
return weightsMetadata;
}
/**
* Convenience function for parsing JSON strings.
*
* @param json String containing valid JSON
* @return Nested Map with arbitrary depth
* @throws IOException
*/
private static Map<String,Object> parseJsonString(String json) throws IOException {
ObjectMapper mapper = new ObjectMapper();
TypeReference<HashMap<String,Object>> typeRef = new TypeReference<HashMap<String,Object>>() {};
return mapper.readValue(json, typeRef);
}
}