Skip to content

Commit

Permalink
Merge pull request #4349 from deeplearning4j/ab_4347
Browse files Browse the repository at this point in the history
Workspace + preprocessor fixes
  • Loading branch information
AlexDBlack committed Nov 30, 2017
2 parents 62b5b88 + bff92b3 commit 71cc4a8
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 50 deletions.
@@ -1,10 +1,13 @@
package org.deeplearning4j.nn.misc;

import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
Expand All @@ -16,6 +19,7 @@
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.primitives.Pair;

@Slf4j
public class WorkspaceTests {
Expand Down Expand Up @@ -102,4 +106,102 @@ public static ComputationGraph createNet() throws Exception {
return model;
}


@Test
public void testWithPreprocessorsCG(){
//https://github.com/deeplearning4j/deeplearning4j/issues/4347
//Cause for the above issue was layerVertex.setInput() applying the preprocessor, with the result
// not being detached properly from the workspace...

for(WorkspaceMode wm : WorkspaceMode.values()) {
System.out.println(wm);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.trainingWorkspaceMode(wm)
.inferenceWorkspaceMode(wm)
.graphBuilder()
.addInputs("in")
.addLayer("e", new GravesLSTM.Builder().nIn(10).nOut(5).build(), new DupPreProcessor(), "in")
// .addLayer("e", new GravesLSTM.Builder().nIn(10).nOut(5).build(), "in") //Note that no preprocessor is OK
.addLayer("rnn", new GravesLSTM.Builder().nIn(5).nOut(8).build(), "e")
.addLayer("out", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation(Activation.SIGMOID).nOut(3).build(), "rnn")
.setInputTypes(InputType.recurrent(10))
.setOutputs("out")
.build();

ComputationGraph cg = new ComputationGraph(conf);
cg.init();


INDArray[] input = new INDArray[]{Nd4j.zeros(1, 10, 5)};

for( boolean train : new boolean[]{false, true}){
cg.clear();
cg.feedForward(input, train);
}

cg.setInputs(input);
cg.setLabels(Nd4j.rand(1, 3, 5));
cg.computeGradientAndScore();
}
}

@Test
public void testWithPreprocessorsMLN(){
for(WorkspaceMode wm : WorkspaceMode.values()) {
System.out.println(wm);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.trainingWorkspaceMode(wm)
.inferenceWorkspaceMode(wm)
.list()
.layer(new GravesLSTM.Builder().nIn(10).nOut(5).build())
.layer(new GravesLSTM.Builder().nIn(5).nOut(8).build())
.layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.SIGMOID).nOut(3).build())
.inputPreProcessor(0, new DupPreProcessor())
.setInputType(InputType.recurrent(10))
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();


INDArray input = Nd4j.zeros(1, 10, 5);

for( boolean train : new boolean[]{false, true}){
net.clear();
net.feedForward(input, train);
}

net.setInput(input);
net.setLabels(Nd4j.rand(1, 3, 5));
net.computeGradientAndScore();
}
}

public static class DupPreProcessor implements InputPreProcessor {
@Override
public INDArray preProcess(INDArray input, int miniBatchSize) {
return input.dup();
}

@Override
public INDArray backprop(INDArray output, int miniBatchSize) {
return output.dup();
}

@Override
public InputPreProcessor clone() {
return new DupPreProcessor();
}

@Override
public InputType getOutputType(InputType inputType) {
return inputType;
}

@Override
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
return new Pair<>(maskArray, currentMaskState);
}
}
}
Expand Up @@ -766,17 +766,24 @@ public void pretrainLayer(String layerName, MultiDataSetIterator iter) {
ComputationGraph.workspaceConfigurationCache,
ComputationGraph.workspaceCache);

MemoryWorkspace wsFF = configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace()
: configuration.getTrainingWorkspaceMode() == WorkspaceMode.SINGLE
? Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceExternal)
: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
workspaceConfigurationFeedForward, workspaceFeedForward);

MemoryWorkspace wsPTR = configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace()
: configuration.getTrainingWorkspaceMode() == WorkspaceMode.SINGLE
? Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceExternal)
: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
workspaceConfigurationFeedForward, workspacePretrain);
MemoryWorkspace wsFF;
MemoryWorkspace wsPTR;
switch (configuration.getTrainingWorkspaceMode()){
case NONE:
wsFF = new DummyWorkspace();
wsPTR = new DummyWorkspace();
break;
case SINGLE:
wsFF = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceExternal);
wsPTR = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceExternal);
break;
case SEPARATE:
wsFF = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationFeedForward, workspaceFeedForward);
wsPTR = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationFeedForward, workspacePretrain);
break;
default:
throw new RuntimeException();
}

while (iter.hasNext()) {
MultiDataSet multiDataSet = iter.next();
Expand Down Expand Up @@ -1459,12 +1466,20 @@ protected Map<String, INDArray> feedForward(boolean train, boolean excludeOutput
boolean includeNonLayerVertexActivations, boolean publicApi) {
Map<String, INDArray> layerActivations = new HashMap<>();

MemoryWorkspace workspace = configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE
? new DummyWorkspace()
: configuration.getTrainingWorkspaceMode() == WorkspaceMode.SINGLE
? Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceExternal)
: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
workspaceConfigurationFeedForward, workspaceFeedForward);
MemoryWorkspace workspace;
switch(configuration.getTrainingWorkspaceMode()){
case NONE:
workspace = new DummyWorkspace();
break;
case SINGLE:
workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceExternal);
break;
case SEPARATE:
workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationFeedForward, workspaceFeedForward);
break;
default:
throw new RuntimeException();
}

//Do forward pass according to the topological ordering of the network
for (int i = 0; i < topologicalOrder.length; i++) {
Expand Down Expand Up @@ -1666,16 +1681,20 @@ protected void calcBackpropGradients(boolean truncatedBPTT, INDArray... external
initGradientsView();
}


MemoryWorkspace workspace =
configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace()
: configuration.getTrainingWorkspaceMode() == WorkspaceMode.SINGLE
? Nd4j.getWorkspaceManager()
.getWorkspaceForCurrentThread(workspaceExternal)
//: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(wsConf, workspaceBackProp);
: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
workspaceConfigurationFeedForward,
workspaceFeedForward);
MemoryWorkspace workspace;
switch (configuration.getTrainingWorkspaceMode()){
case NONE:
workspace = new DummyWorkspace();
break;
case SINGLE:
workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceExternal);
break;
case SEPARATE:
workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(workspaceConfigurationFeedForward,workspaceFeedForward);
break;
default:
throw new RuntimeException();
}


LinkedList<Triple<String, INDArray, Character>> gradients = new LinkedList<>();
Expand Down Expand Up @@ -2030,7 +2049,8 @@ public double score(MultiDataSet dataSet, boolean training) {

int i = 0;
for (String s : configuration.getNetworkOutputs()) {
Layer outLayer = verticesMap.get(s).getLayer();
GraphVertex gv = verticesMap.get(s);
Layer outLayer = gv.getLayer();
if (outLayer == null || !(outLayer instanceof IOutputLayer)) {
log.warn("Cannot calculate score: vertex \"" + s + "\" is not an output layer");
return 0.0;
Expand All @@ -2039,7 +2059,7 @@ public double score(MultiDataSet dataSet, boolean training) {
IOutputLayer ol = (IOutputLayer) outLayer;
ol.setLabels(labels[i++]);

score += ol.computeScore(l1, l2, training);
score += ((LayerVertex)gv).computeScore(l1, l2, training);

//Only want to add l1/l2 once...
l1 = 0.0;
Expand Down Expand Up @@ -2095,7 +2115,8 @@ public INDArray scoreExamples(MultiDataSet data, boolean addRegularizationTerms)
double l2 = (addRegularizationTerms ? calcL2() : 0.0);
int i = 0;
for (String s : configuration.getNetworkOutputs()) {
Layer outLayer = verticesMap.get(s).getLayer();
GraphVertex gv = verticesMap.get(s);
Layer outLayer = gv.getLayer();
if (outLayer == null || !(outLayer instanceof IOutputLayer)) {
throw new UnsupportedOperationException(
"Cannot calculate score: vertex \"" + s + "\" is not an output layer");
Expand All @@ -2104,7 +2125,7 @@ public INDArray scoreExamples(MultiDataSet data, boolean addRegularizationTerms)
IOutputLayer ol = (IOutputLayer) outLayer;
ol.setLabels(labels[i++]);

INDArray scoreCurrLayer = ol.computeScoreForExamples(l1, l2);
INDArray scoreCurrLayer = ((LayerVertex)gv).computeScoreForExamples(l1, l2);
if (out == null)
out = scoreCurrLayer;
else
Expand Down
Expand Up @@ -31,7 +31,10 @@
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.layers.OutputLayer;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

import java.util.Arrays;
Expand All @@ -47,6 +50,7 @@ public class LayerVertex extends BaseGraphVertex {

private Layer layer;
private final InputPreProcessor layerPreProcessor;
private boolean setLayerInput;

/**
* Create a network input vertex:
Expand Down Expand Up @@ -100,9 +104,31 @@ public INDArray doForward(boolean training) {
if (!canDoForward())
throw new IllegalStateException("Cannot do forward pass: all inputs not set");

applyPreprocessorAndSetInput();
return layer.activate(training);
}

protected void applyPreprocessorAndSetInput(){
//Apply preprocessor
INDArray currInput = inputs[0];
if (layerPreProcessor != null) {
if (Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(ComputationGraph.workspaceExternal)
&& Nd4j.getMemoryManager().getCurrentWorkspace() != Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(ComputationGraph.workspaceExternal)) {
//WS single, or FF as part of backprop
//NOTE: we *could* leverage instead (less memory, worse performance), but most preprocessors will only
//allocate 1 array (i.e., the new output), so this is usually preferable in practice
try (MemoryWorkspace wsB = Nd4j.getWorkspaceManager()
.getWorkspaceForCurrentThread(ComputationGraph.workspaceExternal).notifyScopeBorrowed()) {
currInput = layerPreProcessor.preProcess(currInput, graph.batchSize());
}
} else {
currInput = layerPreProcessor.preProcess(currInput, graph.batchSize());
}
}
layer.setInput(currInput);
setLayerInput = true;
}

@Override
public Pair<Gradient, INDArray[]> doBackward(boolean tbptt) {
if (!canDoBackward()) {
Expand All @@ -116,6 +142,11 @@ public Pair<Gradient, INDArray[]> doBackward(boolean tbptt) {
}
}

//Edge case: output layer - never did forward pass hence layer.setInput was never called...
if(!setLayerInput){
applyPreprocessorAndSetInput();
}

Pair<Gradient, INDArray> pair;
if (tbptt && layer instanceof RecurrentLayer) {
//Truncated BPTT for recurrent layers
Expand Down Expand Up @@ -143,12 +174,7 @@ public void setInput(int inputNumber, INDArray input) {
"Invalid input number: LayerVertex instances have only 1 input (got inputNumber = "
+ inputNumber + ")");
inputs[inputNumber] = input;

INDArray currInput = inputs[0];
if (layerPreProcessor != null) {
currInput = layerPreProcessor.preProcess(currInput, graph.batchSize());
}
layer.setInput(currInput);
setLayerInput = false;
}

@Override
Expand Down Expand Up @@ -213,4 +239,32 @@ public boolean canDoBackward() {

return true;
}

public double computeScore(double l1, double l2, boolean training){
if(!(layer instanceof IOutputLayer)){
throw new UnsupportedOperationException("Cannot compute score: layer is not an output layer (layer class: "
+ layer.getClass().getSimpleName());
}
//Edge case: output layer - never did forward pass hence layer.setInput was never called...
if(!setLayerInput){
applyPreprocessorAndSetInput();
}

IOutputLayer ol = (IOutputLayer)layer;
return ol.computeScore(l1, l2, training);
}

public INDArray computeScoreForExamples(double l1, double l2){
if(!(layer instanceof IOutputLayer)){
throw new UnsupportedOperationException("Cannot compute score: layer is not an output layer (layer class: "
+ layer.getClass().getSimpleName());
}
//Edge case: output layer - never did forward pass hence layer.setInput was never called...
if(!setLayerInput){
applyPreprocessorAndSetInput();
}

IOutputLayer ol = (IOutputLayer)layer;
return ol.computeScoreForExamples(l1, l2);
}
}

0 comments on commit 71cc4a8

Please sign in to comment.