Skip to content

Commit

Permalink
[TFLite] Support depthwise convolution multiplier greater than 1
Browse files Browse the repository at this point in the history
  • Loading branch information
FrozenGene committed Sep 9, 2019
1 parent 2f5b155 commit 06d037d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
25 changes: 16 additions & 9 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,6 @@ def convert_conv(self, op, conv_type):
conv_options = DepthwiseConv2DOptions()
conv_options.Init(op_options.Bytes, op_options.Pos)
depth_multiplier = conv_options.DepthMultiplier()
assert depth_multiplier == 1, "TF frontend transforms it to be 1 regardless of what " \
"original value is set to 0.25, 0.5 or anything else"
else:
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend TFLite.'.format(conv_type))
Expand All @@ -636,11 +634,13 @@ def convert_conv(self, op, conv_type):
padding = conv_options.Padding()
fused_activation_fn = conv_options.FusedActivationFunction()

_, input_h, input_w, _ = input_tensor.tensor.ShapeAsNumpy()
_, input_h, input_w, input_c = input_tensor.tensor.ShapeAsNumpy()

if is_depthwise_conv:
multiplier, kernel_h, kernel_w, in_channels = weight_tensor.tensor.ShapeAsNumpy()
assert multiplier == depth_multiplier
# TFLite depthwise convolution kernel layout is:
# 1 KH KW C(input_c * depth_multiplier)
_, kernel_h, kernel_w, in_channels = weight_tensor.tensor.ShapeAsNumpy()
assert in_channels == input_c * depth_multiplier
else:
output_channels, kernel_h, kernel_w, _ = weight_tensor.tensor.ShapeAsNumpy()

Expand All @@ -654,7 +654,7 @@ def convert_conv(self, op, conv_type):
'data_layout': 'NHWC'}

if is_depthwise_conv:
params['channels'] = int(in_channels * multiplier)
params['channels'] = int(in_channels)
params['groups'] = int(in_channels)
params['kernel_layout'] = 'HWOI'
else:
Expand All @@ -669,9 +669,16 @@ def convert_conv(self, op, conv_type):
in_expr = self.get_expr(input_tensor_idx)
weight_value = self.get_tensor_value(weight_tensor)

# TFLite is OC/M KH KW IC, we require KH KW IC OC/M
# M means multiplier in depthwise convolution
weight_value = weight_value.transpose((1, 2, 3, 0))
# TFLite kernel layout:
# convolution:
# OC KH KW IC, we require KH KW IC OC (HWIO)
# depthwise convolution:
# 1 KH KW C(input_c * depth_multiplier), we require
# KH KW IC M (depth_multiplier)
if is_depthwise_conv:
weight_value = weight_value.reshape(kernel_h, kernel_w, input_c, depth_multiplier)
else:
weight_value = weight_value.transpose((1, 2, 3, 0))

weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)

Expand Down
1 change: 1 addition & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def test_forward_convolution():
_test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True)
_test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True)
_test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True)
_test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True)


#######################################################################
Expand Down

0 comments on commit 06d037d

Please sign in to comment.