Skip to content

Commit

Permalink
MKLDNNConv OpContext
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Feb 25, 2019
1 parent 72a0251 commit 8399e87
Showing 1 changed file with 45 additions and 29 deletions.
Expand Up @@ -28,11 +28,13 @@
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;

import java.util.Collections;
import java.util.Map;
Expand All @@ -43,6 +45,9 @@
* @author Alex Black
*/
public class MKLDNNConvHelper implements ConvolutionHelper {

protected OpContext context;

@Override
public boolean checkSupported() {
return BaseMKLDNNHelper.mklDnnEnabled();
Expand All @@ -62,25 +67,32 @@ public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray weight
kernel, strides, dilation);
}

Conv2DConfig conf = Conv2DConfig.builder()
.dataFormat(Conv2DConfig.NCHW)
.kH(kernel[0]).kW(kernel[1])
.sH(strides[0]).sW(strides[1])
.pH(pad[0]).pW(pad[1])
.dH(dilation[0]).dH(dilation[1])
.isSameMode(convolutionMode == ConvolutionMode.Same)
.build();
if(context == null){
context = Nd4j.getExecutioner().buildContext();
context.setIArguments(kernel[0], kernel[1],
strides[0], strides[1],
pad[0], pad[1],
dilation[0], dilation[1],
ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same),
0 //0=NCHW
);
};

INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape());

INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weightsPermute, delta} : new INDArray[]{input, weightsPermute, bias, delta};
INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradViewPermute} : new INDArray[]{gradAtInput, weightGradViewPermute, biasGradView};
DynamicCustomOp op = Conv2DDerivative.derivativeBuilder()
.config(conf)
.inputArrays(inputsArr)
.outputs(outputArr)
.build();
Nd4j.exec(op);
context.getInputArrays().clear();
context.getOutputArrays().clear();
for( int i=0; i<inputsArr.length; i++ ){
context.setInputArray(i, inputsArr[i]);
}
for( int i=0; i<outputArr.length; i++ ){
context.setOutputArray(i, outputArr[i]);
}

Conv2DDerivative op = new Conv2DDerivative();
Nd4j.exec(op, context);

Gradient g = new DefaultGradient();
if(biasGradView != null) {
Expand All @@ -104,28 +116,32 @@ public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[]
outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation
}

if(context == null){
context = Nd4j.getExecutioner().buildContext();
context.setIArguments(kernel[0], kernel[1],
strides[0], strides[1],
pad[0], pad[1],
dilation[0], dilation[1],
ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same),
0 //0=NCHW
);
};

int outDepth = (int) weights.size(0);
INDArray out = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), input.size(0), outDepth, outSize[0], outSize[1]);

//Note: conv2d op expects [kH, kW, iC, oC] weights... DL4J conv uses [oC, iC, kH, kW]
weights = weights.permute(2,3,1,0);

Conv2DConfig conf = Conv2DConfig.builder()
.dataFormat(Conv2DConfig.NCHW)
.kH(kernel[0]).kW(kernel[1])
.sH(strides[0]).sW(strides[1])
.pH(pad[0]).pW(pad[1])
.dH(dilation[0]).dH(dilation[1])
.isSameMode(convolutionMode == ConvolutionMode.Same)
.build();

INDArray[] inputsArr = bias == null ? new INDArray[]{input, weights} : new INDArray[]{input, weights, bias};
DynamicCustomOp op = Conv2D.builder()
.config(conf)
.inputArrays(inputsArr)
.outputs(new INDArray[]{out})
.build();
Nd4j.exec(op);
context.getInputArrays().clear();
for( int i=0; i<inputsArr.length; i++ ){
context.setInputArray(i, inputsArr[i]);
}
context.getOutputArrays().clear();
context.setOutputArray(0, out);
Conv2D op = new Conv2D();
Nd4j.exec(op, context);

return out;
}
Expand Down

0 comments on commit 8399e87

Please sign in to comment.