Skip to content

Commit

Permalink
Fix x86 depthwise conv2d alter_op_layout (#3264)
Browse files Browse the repository at this point in the history
* Fix x86 depthwise conv2d alter_op_layout

* Small fix

* Add test case

* Fix test

* Assert kernel layout

* Minor fix

* Add get_shape function

* Minor change
  • Loading branch information
kevinthesun authored and yzhliu committed Jun 6, 2019
1 parent 770ac84 commit d7bc4fd
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 9 deletions.
43 changes: 42 additions & 1 deletion tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Test alter op layout pass"""
import tvm

from tvm import relay
from tvm.relay.op import register_alter_op_layout
Expand Down Expand Up @@ -513,6 +514,45 @@ def expected():

assert alpha_equal(a, b), "Actual = \n" + str(a)

def test_alter_layout_depthwise_conv2d():
"""Test depthwise_conv2d operator"""
def before():
x = relay.var("x", shape=(1, 32, 56, 56))
w = relay.var("w", shape=(32, 1, 3, 3))
y = relay.nn.conv2d(x, w, padding=(1, 1), channels=32, kernel_size=(3, 3), groups=32)
y = relay.Function(free_vars(y), y)
return y

import topi
@register_alter_op_layout("nn.conv2d", level=110)
def alter_conv2d(attrs, inputs, tinfos):
with tvm.target.create("llvm"):
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, relay)

def expected():
x = relay.var("x", shape=(1, 32, 56, 56))
w = relay.var("w", shape=(32, 1, 3, 3))
x = relay.layout_transform(x, "NCHW", "NCHW8c")
w = relay.layout_transform(w, "OIHW", "OIHW1i8o")
y = relay.nn.contrib_depthwise_conv2d_nchwc(x, w, padding=(1, 1), channels=32, kernel_size=(3, 3),
groups=32, data_layout="NCHW8c", kernel_layout="OIHW1i8o",
out_layout="NCHW8c")
y = relay.layout_transform(y, "NCHW8c", "NCHW")
y = relay.Function(free_vars(y), y)
return y

a = before()
a = infer_type(a)
a = canonicalize_ops(a)
a = infer_type(a)
a = alter_op_layout(a)
a = infer_type(a)

b = expected()
b = infer_type(b)

assert(alpha_equal(a, b))

def test_alter_layout_prelu():
"""Test PRelu operator"""
def before():
Expand All @@ -524,7 +564,7 @@ def before():
y = relay.Function(free_vars(y), y)
return y

@register_alter_op_layout("nn.conv2d", level=110)
@register_alter_op_layout("nn.conv2d", level=111)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
Expand Down Expand Up @@ -571,4 +611,5 @@ def expected():
test_alter_layout_concatenate()
test_alter_layout_nchw_upsamping_op()
test_alter_layout_strided_slice()
test_alter_layout_depthwise_conv2d()
test_alter_layout_prelu()
6 changes: 3 additions & 3 deletions topi/python/topi/arm_cpu/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from ..nn.util import get_pad_tuple

# register original implementation of depthwise_conv2d_nchw since we don't need to change this part
autotvm.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct',
autotvm.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', 'direct',
depthwise_conv2d_nchw.fdefault)

# register customized schedule for arm cpu.
@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, ['arm_cpu', 'cpu'],
@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'arm_cpu',
['direct', 'contrib_spatial_pack'])
def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
"""Schedule depthwise conv2d
Expand Down Expand Up @@ -151,7 +151,7 @@ def _callback(op):
traverse_inline(s, outs[0].op, _callback)
return s

@autotvm.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], ['contrib_spatial_pack'])
@autotvm.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', ['contrib_spatial_pack'])
def depthwise_conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, out_dtype):
"""TOPI compute callback for depthwise_conv2d nchw
Expand Down
39 changes: 39 additions & 0 deletions topi/python/topi/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from numbers import Integral

import tvm
from tvm.api import layout, bijective_layout
from . import tag

def traverse_inline(s, final_op, callback):
Expand Down Expand Up @@ -289,3 +290,41 @@ def get_max_power2_factor(n, max_value=None):
x *= 2
n /= 2
return x


def get_shape(src_shape, src_layout, dst_layout):
"""Given a source shape, a source layout and a destination layout, infer
the destination shape.
Parameter
---------
src_shape : tuple of int or IntImm
Source shape
src_layout : str or Layout
Source layout
dst_layout : str or Layout
Destination layout
Returns
-------
dst_shape : tuple of int
Destination shape
"""
if src_layout == dst_layout:
return get_const_tuple(src_shape)

if isinstance(src_layout, str):
src_layout = layout(src_layout)
if isinstance(dst_layout, str):
dst_layout = layout(dst_layout)

assert len(src_layout) == len(dst_layout), \
"Incompatible layout %s vs %s" % (src_layout, dst_layout)

layout_mapping = bijective_layout(src_layout, dst_layout)
dst_indices = layout_mapping.forward_index(
tvm.convert([i for i in range(len(src_layout))]))

return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices]))
9 changes: 6 additions & 3 deletions topi/python/topi/x86/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tvm.autotvm.task import get_config
from .. import generic, tag
from .. import nn
from ..util import get_const_tuple
from ..util import get_const_tuple, get_shape
from ..nn.conv2d import conv2d, conv2d_NCHWc, \
conv2d_alter_layout, conv2d_infer_layout, _get_workload as _get_conv2d_workload
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
Expand Down Expand Up @@ -415,11 +415,14 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):

dtype = data.dtype
out_dtype = dtype if out_dtype in ("same", "") else out_dtype
is_depthwise = groups == in_channel and groups == out_channel

kshape = get_shape(kernel.shape, attrs["kernel_layout"], "OIHW")
is_depthwise = groups == kshape[0] and kshape[1] == 1

# only optimize for NCHW
if layout != 'NCHW':
if layout != 'NCHW' or attrs["kernel_layout"] != "OIHW":
return None

if groups != 1 and not is_depthwise:
return None

Expand Down
11 changes: 9 additions & 2 deletions topi/python/topi/x86/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
from tvm.autotvm.task.space import SplitEntity
from tvm.autotvm.task.topi_integration import deserialize_args
from .. import generic, tag
from ..generic import schedule_depthwise_conv2d_nchw
from ..nn.pad import pad
from ..util import get_const_tuple
from ..nn.util import get_pad_tuple
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, _get_workload, \
depthwise_conv2d_infer_layout
from ..nn.depthwise_conv2d import depthwise_conv2d_nchw, depthwise_conv2d_NCHWc, \
_get_workload, depthwise_conv2d_infer_layout

from .util import get_fp32_len

Expand Down Expand Up @@ -70,6 +71,12 @@ def _fallback_schedule(cfg, wkl):
cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])


autotvm.register_topi_compute(depthwise_conv2d_nchw, 'cpu', 'direct',
depthwise_conv2d_nchw.fdefault)
autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'cpu', 'direct',
schedule_depthwise_conv2d_nchw.fdefault)


@autotvm.register_topi_compute(depthwise_conv2d_NCHWc, 'cpu', 'direct')
def _depthwise_conv2d_NCHWc_cpu(cfg, data, kernel, strides, padding, dilation,
layout, out_layout, out_dtype=None):
Expand Down
35 changes: 35 additions & 0 deletions topi/tests/python/test_topi_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for util"""

import topi


def verify_get_shape(src_shape, src_layout, dst_layout, expect_shape):
dst_shape = topi.util.get_shape(src_shape, src_layout, dst_layout)
assert dst_shape == expect_shape, \
"Shape mismatch: expecting %s but got %s" % (expect_shape, dst_shape)


def test_get_shape():
verify_get_shape((1, 3, 224, 224), "NCHW", "NCHW", (1, 3, 224, 224))
verify_get_shape((1, 3, 224, 224), "NCHW", "NHWC", (1, 224, 224, 3))
verify_get_shape((3, 2, 32, 48, 16), "NCHW16c", "NC16cWH", (3, 2, 16, 48, 32))
verify_get_shape((2, 3, 32, 32, 16, 8), "OIHW16i8o", "HWO8oI16i", (32, 32, 2, 8, 3, 16))

if __name__ == "__main__":
test_get_shape()

0 comments on commit d7bc4fd

Please sign in to comment.