Skip to content

Commit

Permalink
Fix small typo in nn.conv2d_gemm_weight_transform (apache#5925)
Browse files Browse the repository at this point in the history
* Fix small typo in nn.conv2d_gemm_weight_transform

Change-Id: I7844d898ebf82592f78f478982262ef95f83cc3e

* Add TOPI conv2d_gemm unit tests

Change-Id: I9ed82a68acffcf0dd9720781f8be4aada9d8e6e4
  • Loading branch information
Giuseppe Rossini authored and zhiics committed Jul 2, 2020
1 parent a1a8c3c commit 460c0ec
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 5 deletions.
11 changes: 8 additions & 3 deletions topi/python/topi/arm_cpu/conv2d_gemm.py
Expand Up @@ -24,6 +24,11 @@
from ..nn.util import get_pad_tuple
from .tensor_intrin import gemv_quantized, gemv_quantized_impl

def is_aarch64_arm():
""" Checks whether we are compiling for an AArch64 target. """
target = tvm.target.Target.current(allow_none=False)
return 'aarch64' in ' '.join(target.options)


# Compute function
def compute_conv2d_gemm_without_weight_transform(cfg,
Expand Down Expand Up @@ -70,8 +75,8 @@ def compute_conv2d_gemm_without_weight_transform(cfg,
else:
A = te.compute(A_shape, lambda n, x, y:
data_pad[n,
HSTR * (x // OW) + dilation_h * (y // IC) // KW,
WSTR * (x % OW) + dilation_w * (y // IC) % KW, y % IC],
HSTR * (x // OW) + dilation_h * ((y // IC) // KW),
WSTR * (x % OW) + dilation_w * ((y // IC) % KW), y % IC],
name='data_im2col')
N_transformed = B_interleaved_t.shape[0]

Expand Down Expand Up @@ -157,7 +162,7 @@ def schedule_conv2d_gemm(cfg, s, out):

in_type = A_interleaved.dtype
out_type = C.dtype
if out_type == 'int32':
if is_aarch64_arm() and out_type == 'int32':
K = A_interleaved_input.shape[2]
_, M, N = C.shape
assert in_type in ['int8', 'uint8'], "Only int8 and uint8 gemm are supported"
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/nn/conv2d.py
Expand Up @@ -622,7 +622,7 @@ def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols):
pad_N = tile_rows - (N % tile_rows)

if K % tile_cols != 0:
pad_k = tile_cols - (K % tile_cols)
pad_K = tile_cols - (K % tile_cols)

N_padded = N + pad_N
K_padded = K + pad_K
Expand Down
127 changes: 126 additions & 1 deletion topi/tests/python/test_topi_conv2d_int8.py
Expand Up @@ -29,9 +29,97 @@

from common import get_all_backend, Int8Fallback

oc_block_factor = 4
def verify_conv2d_NHWC_gemm_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding,
dilation=1, add_bias=False, add_relu=False):
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
padding_sum = pad_top + pad_left + pad_bottom + pad_right
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter,
kernel, stride, padding_sum, dilation))

in_height = in_width = in_size

A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='int8')
W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype='int8')
bias = te.placeholder((num_filter,), name='bias', dtype='int8')

a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype

@memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw")
def get_ref_data():
a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype)
w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding).astype(dtype)

if add_bias:
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)

return a_np, w_np, b_np, c_np

a_np, w_np, b_np, c_np = get_ref_data()

def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
C = topi.arm_cpu.compute_conv2d_NHWC_quantized(A, W, (stride, stride), padding,
(dilation, dilation), dtype)
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = topi.arm_cpu.schedule_conv2d_NHWC_quantized([C])

a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
tvm.build(s, [A, W, bias, C], device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
in_channel,
in_size,
num_filter,
kernel,
stride,
padding_sum,
dilation))
func = tvm.build(s, [A, W, bias, C], device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
in_channel,
in_size,
num_filter,
kernel,
stride,
padding_sum,
dilation))
func(a, w, b, c)
else:
func = tvm.build(s, [A, W, C], device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
in_channel,
in_size,
num_filter,
kernel,
stride,
padding_sum,
dilation))
func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)

check_device("llvm")

oc_block_factor = 4
def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
padding_sum = pad_top + pad_left + pad_bottom + pad_right
Expand Down Expand Up @@ -285,6 +373,43 @@ def test_conv2d_nchw():
verify_conv2d_nchw_int8(7, 32, 149, 32, 3, 1, 0)
verify_conv2d_nchw_int8(1, 32, 35, 64, 7, 2, (0, 0, 1, 1))

def test_conv2d_nhwc():
with Int8Fallback():
# Subset of inception v3 expanded (dilation > 1, batch > 1, 'VALID' padding)
verify_conv2d_NHWC_gemm_int8(1, 3, 299, 32, 3, 2, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 32, 149, 32, 3, 1, 'SAME', dilation=2)
verify_conv2d_NHWC_gemm_int8(4, 32, 147, 64, 3, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 64, 73, 80, 1, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 80, 73, 192, 3, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 192, 35, 48, 1, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 192, 35, 64, 1, 1, 'VALID')
verify_conv2d_NHWC_gemm_int8(1, 192, 35, 32, 1, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 48, 35, 64, 5, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 256, 35, 48, 1, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 256, 35, 64, 1, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 288, 35, 64, 1, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 288, 35, 48, 1, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 2, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 128, 17, 192, 7, 1, 'SAME', dilation=2)
verify_conv2d_NHWC_gemm_int8(1, 160, 17, 160, 7, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 160, 17, 192, 1, 1, 'VALID')
verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 1, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 768, 5, 128, 1, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 192, 17, 320, 3, 2, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 3, 2, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 192, 1, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 384, 1, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 320, 1, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 448, 1, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 1, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 3, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 448, 8, 384, 3, 1, 'VALID')
verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 320, 1, 1, 'SAME')
verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 448, 1, 1, 'SAME', add_bias=True, add_relu=True)
verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 192, 1, 1, 'SAME', add_bias=True)


if __name__ == "__main__":
test_conv2d_nchw()
test_conv2d_nhwc()

0 comments on commit 460c0ec

Please sign in to comment.