Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBuilder] Use last added layer/vertex if no inputs specified #7403

Merged
merged 11 commits into from
Apr 2, 2019
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ public void testJSONBasic() {
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.dist(new NormalDistribution(0, 1)).updater(new NoOp())
.graphBuilder().addInputs("input")
.addLayer("firstLayer",
new DenseLayer.Builder().nIn(4).nOut(5).activation(Activation.TANH).build(),
"input")
.appendLayer("firstLayer",
new DenseLayer.Builder().nIn(4).nOut(5).activation(Activation.TANH).build())
.addLayer("outputLayer",
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(5).nOut(3).build(),
Expand Down Expand Up @@ -153,6 +152,18 @@ public void testInvalidConfigurations() {
//e.printStackTrace();
}

// Use appendLayer on first layer
try {
new NeuralNetConfiguration.Builder().graphBuilder()
.appendLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build())
.addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build()).setOutputs("out")
.build();
fail("No exception thrown for invalid configuration");
} catch (IllegalStateException e) {
//OK - exception is good
//e.printStackTrace();
}

//Test no network inputs
try {
new NeuralNetConfiguration.Builder().graphBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,8 @@ public static class GraphBuilder {
protected boolean validateOutputConfig = true;
protected boolean validateTbpttConfig = true;

protected String lastAdded = null;

public GraphBuilder(NeuralNetConfiguration.Builder globalConfiguration) {
this.globalConfiguration = globalConfiguration;
}
Expand Down Expand Up @@ -754,20 +756,32 @@ public GraphBuilder tBPTTLength(int tbpttLength){
*
* @param layerName Name/label of the layer to add
* @param layer The layer configuration
* @param layerInputs Inputs to this layer (must be 1 or more). Inputs may be other layers, GraphVertex objects,
* @param layerInputs Inputs to this layer. Inputs may be other layers, GraphVertex objects,
* on a combination of the two.
* @see #addLayer(String, Layer, InputPreProcessor, String...)
*/
public GraphBuilder addLayer(String layerName, Layer layer, String... layerInputs) {
return addLayer(layerName, layer, null, layerInputs);
}

/**
* Add a layer, with no {@link InputPreProcessor}, with the specified name
* and input from the last added layer/vertex.
*
* @param layerName Name/label of the layer to add
* @param layer The layer configuration
* @see #addLayer(String, Layer, InputPreProcessor, String...)
*/
public GraphBuilder appendLayer(String layerName, Layer layer) {
return appendLayer(layerName, layer, null);
}

/**
* Add a layer, with no {@link InputPreProcessor}, with the specified name and specified inputs.
*
* @param layerName Name/label of the layer to add
* @param layer The layer configuration
* @param layerInputs Inputs to this layer (must be 1 or more). Inputs may be other layers, GraphVertex objects,
* @param layerInputs Inputs to this layer. Inputs may be other layers, GraphVertex objects,
* on a combination of the two.
* @see #addLayer(String, Layer, InputPreProcessor, String...)
*/
Expand All @@ -780,7 +794,7 @@ public GraphBuilder layer(int layerName, Layer layer, String... layerInputs) {
*
* @param layerName Name/label of the layer to add
* @param layer The layer configuration
* @param layerInputs Inputs to this layer (must be 1 or more). Inputs may be other layers, GraphVertex objects,
* @param layerInputs Inputs to this layer. Inputs may be other layers, GraphVertex objects,
* on a combination of the two.
* @see #addLayer(String, Layer, InputPreProcessor, String...)
*/
Expand All @@ -794,7 +808,7 @@ public GraphBuilder layer(String layerName, Layer layer, String... layerInputs)
* @param layerName Name/label of the layer to add
* @param layer The layer configuration
* @param preProcessor The InputPreProcessor to use with this layer.
* @param layerInputs Inputs to this layer (must be 1 or more). Inputs may be other layers, GraphVertex objects,
* @param layerInputs Inputs to this layer. Inputs may be other layers, GraphVertex objects,
* on a combination of the two.
*/
public GraphBuilder addLayer(String layerName, Layer layer, InputPreProcessor preProcessor,
Expand All @@ -806,13 +820,31 @@ public GraphBuilder addLayer(String layerName, Layer layer, InputPreProcessor pr
return this;
}

/**
* Add a layer and an {@link InputPreProcessor}, with the specified name
* and input from the last added layer/vertex.
*
* @param layerName Name/label of the layer to add
* @param layer The layer configuration
* @param preProcessor The InputPreProcessor to use with this layer.
*/
public GraphBuilder appendLayer(String layerName, Layer layer, InputPreProcessor preProcessor) {

if(lastAdded == null){
throw new IllegalStateException("Can not use appendLayer with no previous layers");
}

addLayer(layerName, layer, preProcessor, lastAdded);
return this;
}

/**
* Add a layer and an {@link InputPreProcessor}, with the specified name and specified inputs.
*
* @param layerName Name/label of the layer to add
* @param layer The layer configuration
* @param preProcessor The InputPreProcessor to use with this layer.
* @param layerInputs Inputs to this layer (must be 1 or more). Inputs may be other layers, GraphVertex objects,
* @param layerInputs Inputs to this layer. Inputs may be other layers, GraphVertex objects,
* on a combination of the two.
*/
public GraphBuilder layer(String layerName, Layer layer, InputPreProcessor preProcessor,
Expand Down Expand Up @@ -881,6 +913,7 @@ public GraphBuilder removeVertex(String vertexName, boolean removeConnections) {
*/
public GraphBuilder addInputs(String... inputNames) {
Collections.addAll(networkInputs, inputNames);
lastAdded = networkInputs.get(networkInputs.size() - 1);
return this;
}

Expand All @@ -891,6 +924,7 @@ public GraphBuilder addInputs(String... inputNames) {
*/
public GraphBuilder addInputs(Collection<String> inputNames) {
networkInputs.addAll(inputNames);
lastAdded = networkInputs.get(networkInputs.size() - 1);
return this;
}

Expand Down Expand Up @@ -934,9 +968,10 @@ public GraphBuilder setOutputs(String... outputNames) {
*
* @param vertexName The name of the GraphVertex to add
* @param vertex The GraphVertex to add
* @param vertexInputs The inputs/activations to this GraphVertex
* @param vertexInputs The inputs/activations to this GraphVertex.
*/
public GraphBuilder addVertex(String vertexName, GraphVertex vertex, String... vertexInputs) {

Preconditions.checkState(!vertices.containsKey(vertexName), "Cannot add vertex: a vertex with name \"%s\" already exists", vertexName);
vertices.put(vertexName, vertex);

Expand All @@ -948,6 +983,29 @@ public GraphBuilder addVertex(String vertexName, GraphVertex vertex, String... v
} else if (vertexInputs != null) {
this.vertexInputs.put(vertexName, Arrays.asList(vertexInputs));
}

this.lastAdded = vertexName;

return this;
}

/**
* Add a {@link GraphVertex} to the network configuration, with input from the last added vertex/layer. A GraphVertex defines forward and backward pass methods,
* and can contain a {@link LayerVertex}, a {@link org.deeplearning4j.nn.conf.graph.ElementWiseVertex} to do element-wise
* addition/subtraction, a {@link MergeVertex} to combine/concatenate the activations out of multiple layers or vertices,
* a {@link org.deeplearning4j.nn.conf.graph.SubsetVertex} to select a subset of the activations out of another layer/GraphVertex.<br>
* Custom GraphVertex objects (that extend the abstract {@link GraphVertex} class) may also be used.
*
* @param vertexName The name of the GraphVertex to add
* @param vertex The GraphVertex to add
*/
public GraphBuilder appendVertex(String vertexName, GraphVertex vertex) {

if(lastAdded == null){
throw new IllegalStateException("Can not use appendLayer with no previous layers");
}

addVertex(vertexName, vertex, lastAdded);
return this;
}

Expand Down