Skip to content

Commit

Permalink
#4223 LayerVertex memory report: apply preprocessor first
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Nov 24, 2017
1 parent 6f5c4d2 commit 2eebb37
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
@@ -1,5 +1,6 @@
package org.deeplearning4j.nn.misc;

import org.apache.commons.io.FileUtils;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
Expand All @@ -18,8 +19,11 @@
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.primitives.Pair;

import java.io.File;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;

Expand Down Expand Up @@ -235,4 +239,15 @@ public void validateSimple() {
assertEquals(0, mr.getMemoryBytes(MemoryType.WORKING_MEMORY_VARIABLE, 1, MemoryUseMode.INFERENCE,
CacheMode.NONE, DataBuffer.Type.FLOAT));
}

@Test
public void testPreprocessors() throws Exception {
//https://github.com/deeplearning4j/deeplearning4j/issues/4223
File f = new ClassPathResource("4223/CompGraphConfig.json").getTempFileFromArchive();
String s = FileUtils.readFileToString(f, Charset.defaultCharset());

ComputationGraphConfiguration conf = ComputationGraphConfiguration.fromJson(s);

conf.getMemoryReport(InputType.convolutional(17,19,19));
}
}
Expand Up @@ -491,6 +491,8 @@ public NetworkMemoryReport getMemoryReport(InputType... inputTypes) {
}
}



InputType outputFromVertex =
gv.getOutputType(currLayerIdx, inputTypeList.toArray(new InputType[inputTypeList.size()]));
vertexOutputs.put(s, outputFromVertex);
Expand Down
Expand Up @@ -129,7 +129,17 @@ public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws

@Override
public MemoryReport getMemoryReport(InputType... inputTypes) {
if(inputTypes.length != 1){
throw new IllegalArgumentException("Only one input supported for layer vertices: got "
+ Arrays.toString(inputTypes));
}
InputType it;
if(preProcessor != null){
it = preProcessor.getOutputType(inputTypes[0]);
} else {
it = inputTypes[0];
}
//TODO preprocessor memory
return layerConf.getLayer().getMemoryReport(inputTypes[0]);
return layerConf.getLayer().getMemoryReport(it);
}
}

0 comments on commit 2eebb37

Please sign in to comment.