From 8399e876f8e91ec3f9431786533bf73e42e9cc19 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sat, 23 Feb 2019 23:36:06 +1100 Subject: [PATCH] MKLDNNConv OpContext --- .../nn/layers/mkldnn/MKLDNNConvHelper.java | 74 +++++++++++-------- 1 file changed, 45 insertions(+), 29 deletions(-) diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java index 5aee91f22adf..106d9c52574f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java @@ -28,11 +28,13 @@ import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2DDerivative; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; +import org.nd4j.linalg.util.ArrayUtil; import java.util.Collections; import java.util.Map; @@ -43,6 +45,9 @@ * @author Alex Black */ public class MKLDNNConvHelper implements ConvolutionHelper { + + protected OpContext context; + @Override public boolean checkSupported() { return BaseMKLDNNHelper.mklDnnEnabled(); @@ -62,25 +67,32 @@ public Pair backpropGradient(INDArray input, INDArray weight kernel, strides, dilation); } - Conv2DConfig conf = Conv2DConfig.builder() - .dataFormat(Conv2DConfig.NCHW) - .kH(kernel[0]).kW(kernel[1]) - .sH(strides[0]).sW(strides[1]) - .pH(pad[0]).pW(pad[1]) - .dH(dilation[0]).dH(dilation[1]) - .isSameMode(convolutionMode == ConvolutionMode.Same) - .build(); + if(context == null){ + context = Nd4j.getExecutioner().buildContext(); + context.setIArguments(kernel[0], kernel[1], + strides[0], strides[1], + pad[0], pad[1], + dilation[0], dilation[1], + ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same), + 0 //0=NCHW + ); + }; INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape()); INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weightsPermute, delta} : new INDArray[]{input, weightsPermute, bias, delta}; INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradViewPermute} : new INDArray[]{gradAtInput, weightGradViewPermute, biasGradView}; - DynamicCustomOp op = Conv2DDerivative.derivativeBuilder() - .config(conf) - .inputArrays(inputsArr) - .outputs(outputArr) - .build(); - Nd4j.exec(op); + context.getInputArrays().clear(); + context.getOutputArrays().clear(); + for( int i=0; i