diff --git a/topi/python/topi/arm_cpu/conv2d_gemm.py b/topi/python/topi/arm_cpu/conv2d_gemm.py index 2b6122919d85..c1587ba4e745 100644 --- a/topi/python/topi/arm_cpu/conv2d_gemm.py +++ b/topi/python/topi/arm_cpu/conv2d_gemm.py @@ -24,6 +24,11 @@ from ..nn.util import get_pad_tuple from .tensor_intrin import gemv_quantized, gemv_quantized_impl +def is_aarch64_arm(): + """ Checks whether we are compiling for an AArch64 target. """ + target = tvm.target.Target.current(allow_none=False) + return 'aarch64' in ' '.join(target.options) + # Compute function def compute_conv2d_gemm_without_weight_transform(cfg, @@ -70,8 +75,8 @@ def compute_conv2d_gemm_without_weight_transform(cfg, else: A = te.compute(A_shape, lambda n, x, y: data_pad[n, - HSTR * (x // OW) + dilation_h * (y // IC) // KW, - WSTR * (x % OW) + dilation_w * (y // IC) % KW, y % IC], + HSTR * (x // OW) + dilation_h * ((y // IC) // KW), + WSTR * (x % OW) + dilation_w * ((y // IC) % KW), y % IC], name='data_im2col') N_transformed = B_interleaved_t.shape[0] @@ -157,7 +162,7 @@ def schedule_conv2d_gemm(cfg, s, out): in_type = A_interleaved.dtype out_type = C.dtype - if out_type == 'int32': + if is_aarch64_arm() and out_type == 'int32': K = A_interleaved_input.shape[2] _, M, N = C.shape assert in_type in ['int8', 'uint8'], "Only int8 and uint8 gemm are supported" diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 59288892ebaa..51de4546663a 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -622,7 +622,7 @@ def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols): pad_N = tile_rows - (N % tile_rows) if K % tile_cols != 0: - pad_k = tile_cols - (K % tile_cols) + pad_K = tile_cols - (K % tile_cols) N_padded = N + pad_N K_padded = K + pad_K diff --git a/topi/tests/python/test_topi_conv2d_int8.py b/topi/tests/python/test_topi_conv2d_int8.py index 06f930ccdbbd..edf4267ddaee 100644 --- a/topi/tests/python/test_topi_conv2d_int8.py +++ b/topi/tests/python/test_topi_conv2d_int8.py @@ -29,9 +29,97 @@ from common import get_all_backend, Int8Fallback -oc_block_factor = 4 +def verify_conv2d_NHWC_gemm_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, + dilation=1, add_bias=False, add_relu=False): + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) + padding_sum = pad_top + pad_left + pad_bottom + pad_right + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, + kernel, stride, padding_sum, dilation)) + + in_height = in_width = in_size + + A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='int8') + W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype='int8') + bias = te.placeholder((num_filter,), name='bias', dtype='int8') + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + bias_shape = get_const_tuple(bias.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_conv2d_int8.verify_conv2d_nchw") + def get_ref_data(): + a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype) + w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) + c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding).astype(dtype) + + if add_bias: + b_np = np.random.uniform(size=bias_shape).astype(dtype) + c_np += b_np + if add_relu: + c_np = np.maximum(c_np, 0) + + return a_np, w_np, b_np, c_np + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + C = topi.arm_cpu.compute_conv2d_NHWC_quantized(A, W, (stride, stride), padding, + (dilation, dilation), dtype) + if add_bias: + C = topi.add(C, bias) + if add_relu: + C = topi.nn.relu(C) + s = topi.arm_cpu.schedule_conv2d_NHWC_quantized([C]) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + if add_bias: + tvm.build(s, [A, W, bias, C], device, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding_sum, + dilation)) + func = tvm.build(s, [A, W, bias, C], device, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding_sum, + dilation)) + func(a, w, b, c) + else: + func = tvm.build(s, [A, W, C], device, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding_sum, + dilation)) + func(a, w, c) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + + check_device("llvm") +oc_block_factor = 4 def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False): pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) padding_sum = pad_top + pad_left + pad_bottom + pad_right @@ -285,6 +373,43 @@ def test_conv2d_nchw(): verify_conv2d_nchw_int8(7, 32, 149, 32, 3, 1, 0) verify_conv2d_nchw_int8(1, 32, 35, 64, 7, 2, (0, 0, 1, 1)) +def test_conv2d_nhwc(): + with Int8Fallback(): + # Subset of inception v3 expanded (dilation > 1, batch > 1, 'VALID' padding) + verify_conv2d_NHWC_gemm_int8(1, 3, 299, 32, 3, 2, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 32, 149, 32, 3, 1, 'SAME', dilation=2) + verify_conv2d_NHWC_gemm_int8(4, 32, 147, 64, 3, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 64, 73, 80, 1, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 80, 73, 192, 3, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 192, 35, 48, 1, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 192, 35, 64, 1, 1, 'VALID') + verify_conv2d_NHWC_gemm_int8(1, 192, 35, 32, 1, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 48, 35, 64, 5, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 256, 35, 48, 1, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 256, 35, 64, 1, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 288, 35, 64, 1, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 288, 35, 48, 1, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 2, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 128, 17, 192, 7, 1, 'SAME', dilation=2) + verify_conv2d_NHWC_gemm_int8(1, 160, 17, 160, 7, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 160, 17, 192, 1, 1, 'VALID') + verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 1, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 768, 5, 128, 1, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 192, 17, 320, 3, 2, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 3, 2, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 192, 1, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 384, 1, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 320, 1, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 448, 1, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 1, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 3, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 448, 8, 384, 3, 1, 'VALID') + verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 320, 1, 1, 'SAME') + verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 448, 1, 1, 'SAME', add_bias=True, add_relu=True) + verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 192, 1, 1, 'SAME', add_bias=True) + if __name__ == "__main__": test_conv2d_nchw() + test_conv2d_nhwc()