Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

add NNPACK support for high convolution inference perf #3666

Merged
merged 7 commits into from
Nov 10, 2016

Conversation

clcarwin
Copy link
Contributor

This PR use NNPACK's nnp_convolution_inference to speedup single batch convolution inference. It can run 2x to 10x faster.

@piiswrong
Copy link
Contributor

So it only supports forward? Are you comparing to mxnet default conv or mkldnn's conv?

@clcarwin
Copy link
Contributor Author

clcarwin commented Nov 1, 2016

@piiswrong Backward inherit from ConvolutionOp, Forward use fast algorithm of NNPACK.
NNPACK support ARMv7/ARM64, embedded devices may benefit from this code.

tests/python/predict/mxnet_predict_example.py

OPENMP NNPACK thread=1 NNPACK thread=2 NNPACK thread=4 MXNET conv
OFF 0.395s 0.323s 0.290s 0.402s

conv test result:

input f k s NNPACK t=1 t=2 t=4 MXNET conv speedup
3x1000x1000 28 9x9 1x1 0.739s 0.521s 0.441s 4.01s 9.1x
20x600x600 20 5x5 1x1 0.234s 0.176s 0.151s 2.76s 18x
3x600x600 32 5x5 1x1 0.188s 0.146s 0.129s 0.569s 4.4x
128x60x60 128 5x5 1x1 0.0533s 0.0340s 0.0254s 0.168s 6.6x
3x60x60 3 5x5 1x1 0.00123s 0.00134s 0.00210s 0.00469s 2.2x
3x60x60 3 3x3 1x1 0.00118s 0.00114s 0.00145s 0.00242s 1.6x

conv test code:

import mxnet as mx
import os, sys, time
import numpy as np

data = mx.sym.Variable("data")
net = mx.symbol.Convolution(data=data, num_filter=28, kernel=(9, 9), stride=(1,1), pad=(0,0))

mod = mx.mod.Module(symbol=net)
mod.bind(data_shapes=[("data" , (1,3,1000,1000))], for_training=False, inputs_need_grad=False)
mod.init_params()

d = np.ones((1,3,1000,1000))
start = time.time()
dataiter = mx.io.NDArrayIter(d)
r = mod.predict(dataiter).asnumpy()
print 'time: ' , (time.time() - start)

@piiswrong
Copy link
Contributor

Could you fix tests?

@clcarwin
Copy link
Contributor Author

clcarwin commented Nov 5, 2016

@piiswrong All tests have been fixed.

@piiswrong piiswrong merged commit 251634f into apache:master Nov 10, 2016
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants