Skip to content

Commit

Permalink
Handle legacy dropconnect from JSON
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Sep 7, 2017
1 parent 8bf7fc0 commit aa5149e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 11 deletions.
Expand Up @@ -7,14 +7,13 @@
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.weightnoise.DropConnect;
import org.nd4j.shade.jackson.core.JsonLocation;
import org.nd4j.shade.jackson.core.JsonParser;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationContext;
import org.nd4j.shade.jackson.databind.JsonDeserializer;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.node.ArrayNode;
import org.nd4j.shade.jackson.databind.node.ObjectNode;

import java.io.IOException;
Expand Down Expand Up @@ -67,10 +66,12 @@ public ComputationGraphConfiguration deserialize(JsonParser jp, DeserializationC
int layerIdx = 0;
while(iter.hasNext()){
JsonNode next = iter.next();
ObjectNode confNode = null;
if(next.has("LayerVertex")){
next = next.get("LayerVertex");
if(next.has("layerConf")){
next = next.get("layerConf").get("layer").elements().next();
confNode = (ObjectNode) next.get("layerConf");
next = confNode.get("layer").elements().next();
} else {
continue;
}
Expand All @@ -84,7 +85,13 @@ public ComputationGraphConfiguration deserialize(JsonParser jp, DeserializationC
if(next.has("dropOut")){
double d = next.get("dropOut").asDouble();
if(!Double.isNaN(d)){
layers[layerIdx].setIDropout(new Dropout(d));
//Might be dropout or dropconnect...
if(layers[layerIdx] instanceof BaseLayer && confNode.has("useDropConnect")
&& confNode.get("useDropConnect").asBoolean(false)){
((BaseLayer)layers[layerIdx]).setWeightNoise(new DropConnect(d));
} else {
layers[layerIdx].setIDropout(new Dropout(d));
}
}
}
}
Expand Down
Expand Up @@ -2,24 +2,20 @@

import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.dropout.Dropout;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.nd4j.linalg.learning.config.*;
import org.deeplearning4j.nn.conf.weightnoise.DropConnect;
import org.nd4j.shade.jackson.core.JsonLocation;
import org.nd4j.shade.jackson.core.JsonParser;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationContext;
import org.nd4j.shade.jackson.databind.JsonDeserializer;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.node.ArrayNode;
import org.nd4j.shade.jackson.databind.node.ObjectNode;
import org.nd4j.shade.jackson.databind.util.TokenBuffer;

import java.io.IOException;
import java.util.Iterator;

public class MultiLayerConfigurationDeserializer extends BaseNetConfigDeserializer<MultiLayerConfiguration> {

Expand Down Expand Up @@ -53,9 +49,11 @@ public MultiLayerConfiguration deserialize(JsonParser jp, DeserializationContext

for( int i=0; i<layers.length; i++ ){
ObjectNode on = (ObjectNode) confsNode.get(i);
ObjectNode confNode = null;
if(layers[i] instanceof BaseLayer && ((BaseLayer)layers[i]).getIUpdater() == null){
//layer -> (first/only child) -> updater
if(on.has("layer")){
confNode = on;
on = (ObjectNode) on.get("layer");
} else {
continue;
Expand All @@ -66,11 +64,17 @@ public MultiLayerConfiguration deserialize(JsonParser jp, DeserializationContext
}

if(layers[i].getIDropout() == null){
//Check for legacy dropout
//Check for legacy dropout/dropconnect
if(on.has("dropOut")){
double d = on.get("dropOut").asDouble();
if(!Double.isNaN(d)){
layers[i].setIDropout(new Dropout(d));
//Might be dropout or dropconnect...
if(confNode != null && layers[i] instanceof BaseLayer && confNode.has("useDropConnect")
&& confNode.get("useDropConnect").asBoolean(false)){
((BaseLayer)layers[i]).setWeightNoise(new DropConnect(d));
} else {
layers[i].setIDropout(new Dropout(d));
}
}
}
}
Expand Down

0 comments on commit aa5149e

Please sign in to comment.