Skip to content

Commit

Permalink
[TensorFlow] Fix limitation that depth_mult can only be 1 for Depthwi…
Browse files Browse the repository at this point in the history
…seConv2dNative (apache#3676)

* [TensorFlow] Fix limitation that depth_mult can only be 1 for DepthwiseConv2dNative

* Improve code readability
  • Loading branch information
lixiaoquan authored and wweic committed Sep 16, 2019
1 parent 0747cfd commit ac8e7a2
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 16 deletions.
3 changes: 0 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,6 @@ def _impl(inputs, attr, params):
raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))

if opname == 'depthwise':
if depth_mult > 1:
raise tvm.error.OpNotImplemented('depth_mult > 1 of operator DepthwiseConv2dNative'
' is not supported.')
attr['groups'] = attr['channels']

# Fix padding
Expand Down
17 changes: 9 additions & 8 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from __future__ import absolute_import

import topi
from topi.util import get_const_int, get_const_tuple
from topi.util import get_const_tuple
from .. import op as reg
from ..op import OpPattern, schedule_injective

Expand Down Expand Up @@ -144,19 +144,20 @@ def compute_conv2d(attrs, inputs, out_type, target):
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")

def _get_out_depth():
weight_shape = get_const_tuple(inputs[1].shape)
if kernel_layout == "HWOI":
return weight_shape[2] * weight_shape[3]
return weight_shape[0] * weight_shape[1]

if groups == 1:
out = topi.nn.conv2d(
inputs[0], inputs[1], strides, padding,
dilation, layout, out_dtype)
elif layout == "NCHW" and \
get_const_int(inputs[1].shape[0]) == groups and \
get_const_int(inputs[1].shape[1]) == 1:
elif layout == "NCHW" and _get_out_depth() == groups:
out = topi.nn.depthwise_conv2d_nchw(
inputs[0], inputs[1], strides, padding, dilation, out_dtype)
elif layout == "NHWC" and \
kernel_layout == "HWOI" and\
get_const_int(inputs[1].shape[2]) == groups and \
get_const_int(inputs[1].shape[3]) == 1:
elif layout == "NHWC" and kernel_layout == "HWOI" and _get_out_depth() == groups:
out = topi.nn.depthwise_conv2d_nhwc(
inputs[0], inputs[1], strides, padding, dilation, out_dtype)
elif layout in ['NCHW', 'NCHW4c']:
Expand Down
23 changes: 18 additions & 5 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
* \brief Convolution operators
*/
#include <tvm/data_layout.h>
#include <tvm/ir_pass.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <vector>
Expand Down Expand Up @@ -74,11 +75,23 @@ bool Conv2DRel(const Array<Type>& types,
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2);
Array<IndexExpr> wshape(
{param->channels,
dshape_nchw[1] / param->groups,
param->kernel_size[0],
param->kernel_size[1]});
Array<IndexExpr> wshape;

if (tvm::ir::Equal(param->channels, param->groups)) {
// infer weight's shape for depthwise convolution
wshape = {
{dshape_nchw[1],
param->groups / dshape_nchw[1],
param->kernel_size[0],
param->kernel_size[1]}};
} else {
wshape = {
{param->channels,
dshape_nchw[1] / param->groups,
param->kernel_size[0],
param->kernel_size[1]}};
}

wshape = trans_kernel_layout.BackwardShape(wshape);
channels = param->channels;
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
Expand Down
2 changes: 2 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def test_forward_convolution():
_test_convolution('depthwise', [4, 19, 17, 17], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution('depthwise', [4, 124, 17, 17], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NCHW')
_test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NCHW')

_test_convolution('conv', [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution('conv', [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
Expand All @@ -284,6 +285,7 @@ def test_forward_convolution():
_test_convolution('depthwise', [4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution('depthwise', [4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC')

#######################################################################
# BiasAdd
Expand Down

0 comments on commit ac8e7a2

Please sign in to comment.