Skip to content

Commit

Permalink
Add dot product support for quantized convolution. (apache#6445)
Browse files Browse the repository at this point in the history
* Add dot product support for quantized convolution.

We added two new intrinsics in: topi/arm_cpu/tensor_intrin.py, namely
- mmla4x4: compute a matrix multiplication between tile A(4,4) and tile
  B(4,4)
- mmla16x4: compute a matrix multiplication between tile A(rows,4) and tile
  B(4,16)
Then we used those intrinsics in two separate strategies. We added the
strategies in topi/arm_cpu/conv2d_int8.py and implemented the schedules
in topi/arm_cpu/conv2d_gemm.py. In particular:
- schedule_conv2d_gemm, when accelerated, packs matrix A, compute GEMM,
  and unpack the resulting matrix. This uses the mmla4x4 intrinsic
- schedule_conv2d_gemm_hybrid doesn't do any packing on A and C which
  are in native form.  This uses the mmla16x4 intrinsic

Please note that for the limitations of `tensorize` we need to pad
matrix A in both cases (when dimensions are not multiple of the tiling
shape)

Change-Id: Id0d818d84ffc458c6dad7983fd350a0f3d5db395

* Add back nhwc_spatial_pack strategy as default

Change-Id: I8b1826a7ae1d742956296e8d157da19955a4942c

* Fix linting through Black

Change-Id: Ic74ef5461a90bca9f4d4980a214137e384d5f923

* Fix python linting

Change-Id: I5fb8a2ae4467a87bd3470f6b3753c074f9b7cc78

* Addressing review comments

Change-Id: I284b1f2c121051e672f548d6c6ee2a3267854e31

* Fix black linting issues

Change-Id: I1813b0226b536aedee0dce9eeeba27aa2d95518b

* Fixing failing test and adding tests for dot-product compilation

Change-Id: Ic040722abd5538fccb85af4de922394c939e7000

* Fixing linting and review comments

Change-Id: If09e3baa514c85dc78d3c27c2ac2fa2e01773d89

* Fixing black linting and address comments

Change-Id: I857b28b6f9b23307d8c1eebc509de6ad2783c756

* Address review comments

Change-Id: I63d1a639d4a72abeb33148fd2868cd356ef84122
  • Loading branch information
Giuseppe Rossini authored and trevor-m committed Oct 19, 2020
1 parent b4db42d commit 7e3bfbd
Show file tree
Hide file tree
Showing 8 changed files with 874 additions and 183 deletions.
44 changes: 30 additions & 14 deletions python/tvm/relay/op/strategy/arm_cpu.py
Expand Up @@ -135,20 +135,29 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
name="conv2d_direct_simd.micro_dev",
)
elif kernel_layout == "HWIO":
is_aarch64 = "aarch64" in str(isa.target)

is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm()
has_dot_prod = topi.arm_cpu.arm_utils.is_dotprod_available()
if has_dot_prod and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
name="conv2d_NHWC_quantized_native.arm_cpu",
)
if is_aarch64 and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
name="conv2d_NHWC_quantized.arm_cpu",
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved),
name="conv2d_NHWC_quantized_interleaved.arm_cpu",
)
if (not is_aarch64) or (data.dtype not in ["int8", "uint8"]):
# TODO(@giuseros)
# This strategy errors out for quantized data types when tuning.
# Let's use this only for non-aarch64 or non-quantized cases
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack),
name="conv2d_nhwc_spatial_pack.arm_cpu",
)

strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack),
name="conv2d_nhwc_spatial_pack.arm_cpu",
)
else:
raise RuntimeError(
"Unsupported kernel layout {} for conv2d NHWC".format(kernel_layout)
Expand Down Expand Up @@ -328,11 +337,18 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
data = inputs[0]
strategy = _op.OpStrategy()

interleaved_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved_without_transform
native_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_native_without_transform
if layout == "NHWC" and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_quantized_without_transform),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
name="conv2d_NHWC_quantized_without_transform.arm_cpu",
wrap_compute_conv2d_gemm(native_compute),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
name="conv2d_NHWC_quantized_native_without_transform.arm_cpu",
)
strategy.add_implementation(
wrap_compute_conv2d_gemm(interleaved_compute),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved),
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
)
else:
raise RuntimeError(
Expand Down
100 changes: 100 additions & 0 deletions python/tvm/topi/arm_cpu/arm_utils.py
@@ -0,0 +1,100 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Arm target utility functions"""

import re
import tvm


def get_arch_version(target_mattr):
"""Parse the LLVM target -mattr, and return
the architecture version in a decimal representation
(e.g., if -mattr=v8.4a, return 8.4)
"""

arch_version = 8.0
m = re.compile(r"\+v(.*)\.(.*)a")
for attr in target_mattr:
match_obj = m.match(attr)
if match_obj:
major = int(match_obj.group(1))
minor = int(match_obj.group(2))
decimal = 10
if minor >= 10:
decimal = 100
arch_version = major + float(minor) / decimal

return arch_version


def is_dotprod_available():
""" Checks whether the hardware has support for fast Int8 arithmetic operations. """
target = tvm.target.Target.current(allow_none=False)
arch_version = get_arch_version(target.mattr)
return arch_version >= 8.4 or ((arch_version in (8.2, 8.3)) and "+dotprod" in target.mattr)


def is_aarch64_arm():
""" Checks whether we are compiling for an AArch64 target. """
target = tvm.target.Target.current(allow_none=False)
return "aarch64" in target.attrs.get("mtriple", "")


def get_tiling_B_interleaved_t(interleave_A):
"""Compute the tiling information for matrix B', where B'
is the transposed and interleaved version of matrix B in C=A*B.
The tiling information is chosen to maximize register usage during the
tile computation.
Please refer to:
- https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product
- Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
In order to have more information
Parameters
----------
interleave_A: bool
determines if A is expected to be interleaved
Returns
----------
tile_rows_B: the output tile rows of B'
tile_cols_B: the output tile columns of B'
"""
if is_dotprod_available():
# The number of tile rows of B' vary depending on the
# strategy:
# * If we are interleaving A, then we select 12 columns from B'(i.e.,
# 12 rows from B).
# * If we are not interleaving A, then we select 16 columns from B'(i.e.,
# 16 rows from B).
tile_rows_B = 12 if interleave_A else 16

# Dot product instruction groups 2 (u)int16x8 vectors in
# groups of 4 and compute the dot product among those groups
# This means that the number of columns in a tile of B' (i.e., the
# rows of the original matrix B) need to be 4.
tile_cols_B = 4
else:
# If dot product is not available, A must be interleaved. In this case
# we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements
tile_rows_B = 4
tile_cols_B = 16

return tile_rows_B, tile_cols_B
105 changes: 75 additions & 30 deletions python/tvm/topi/arm_cpu/conv2d_alter_op.py
Expand Up @@ -27,10 +27,64 @@
from ..nn import conv2d_alter_layout
from ..util import get_const_tuple
from ..x86.conv2d import _get_default_config as _get_x86_default_config
from .arm_utils import get_tiling_B_interleaved_t

logger = logging.getLogger("topi")


def interleave_transpose_weights(inputs, data, kernel, interleave_A):
"""Transform the weight matrix by reshaping, interleaving and transposing it
Parameters
----------
inputs : tvm.relay.Expr
Grouped input symbols
data :
Input shape and dtype
kernel :
Input shape and dtype
interleave_A: indicates if we expect matrix A to be interleaved
Returns
----------
new_kernel : tvm.te.placeholder
A placeholder with the new shape
new_kernel_expr : tvm.relay.Expr
The relay expression of the weights
"""
assert (
data.dtype == "int8"
and kernel.dtype == "int8"
or data.dtype == "uint8"
and kernel.dtype == "uint8"
)

KH, KW, IC, OC = get_const_tuple(kernel.shape)
K = KH * KW * IC
N = OC

# Get tiling information for the interleaved transposed version of B
tile_rows_B, tile_cols_B = get_tiling_B_interleaved_t(interleave_A)

pad_K = 0
pad_N = 0

if N % tile_rows_B != 0:
pad_N = tile_rows_B - (N % tile_rows_B)
if K % tile_cols_B != 0:
pad_K = tile_cols_B - (K % tile_cols_B)

N_padded = N + pad_N
K_padded = K + pad_K
new_kernel_expr = relay.nn.contrib_conv2d_gemm_weight_transform(
inputs[1], tile_rows_B, tile_cols_B
)
new_kernel = te.placeholder(
(N_padded // tile_rows_B, K_padded // tile_cols_B, tile_rows_B, tile_cols_B), kernel.dtype
)
return new_kernel, new_kernel_expr


@conv2d_alter_layout.register(["arm_cpu"])
def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
target = tvm.target.Target.current(allow_none=False)
Expand Down Expand Up @@ -279,44 +333,35 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
)
dispatch_ctx.update(target, new_workload, cfg)
return relay.nn.contrib_depthwise_conv2d_nchwc(*inputs, **new_attrs)
if topi_tmpl == "conv2d_NHWC_quantized.arm_cpu":
assert (
data.dtype == "int8"
and kernel.dtype == "int8"
or data.dtype == "uint8"
and kernel.dtype == "uint8"
)
if topi_tmpl == "conv2d_NHWC_quantized_interleaved.arm_cpu":
assert data_layout == "NHWC" and kernel_layout == "HWIO"
KH, KW, IC, OC = get_const_tuple(kernel.shape)
K = KH * KW * IC
N = OC

tile_rows = 4
tile_cols = 16
pad_K = 0
pad_N = 0

if N % tile_rows != 0:
pad_N = tile_rows - (N % tile_rows)
if K % tile_cols != 0:
pad_K = tile_cols - (K % tile_cols)

N_padded = N + pad_N
K_padded = K + pad_K
kernel_expr = relay.nn.contrib_conv2d_gemm_weight_transform(inputs[1], tile_rows, tile_cols)
new_kernel = te.placeholder(
(N_padded // tile_rows, K_padded // tile_cols, tile_rows, tile_cols), kernel.dtype
KH, KW, _, OC = get_const_tuple(kernel.shape)
new_workload_name = "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu"
new_kernel, new_kernel_expr = interleave_transpose_weights(
inputs, data, kernel, interleave_A=True
)

new_workload_name = "conv2d_NHWC_quantized_without_transform.arm_cpu"
new_workload = autotvm.task.args_to_workload(
[data, new_kernel, strides, padding, dilation, out_dtype, (KH, KW), OC],
new_workload_name,
)
dispatch_ctx.update(target, new_workload, cfg)

return relay.nn.contrib_conv2d_gemm_without_weight_transform(
inputs[0], kernel_expr, **new_attrs
inputs[0], new_kernel_expr, **new_attrs
)
if topi_tmpl == "conv2d_NHWC_quantized_native.arm_cpu":
assert data_layout == "NHWC" and kernel_layout == "HWIO"
KH, KW, _, OC = get_const_tuple(kernel.shape)
new_workload_name = "conv2d_NHWC_quantized_native_without_transform.arm_cpu"
new_kernel, new_kernel_expr = interleave_transpose_weights(
inputs, data, kernel, interleave_A=False
)
new_workload = autotvm.task.args_to_workload(
[data, new_kernel, strides, padding, dilation, out_dtype, (KH, KW), OC],
new_workload_name,
)
dispatch_ctx.update(target, new_workload, cfg)
return relay.nn.contrib_conv2d_gemm_without_weight_transform(
inputs[0], new_kernel_expr, **new_attrs
)

return None

0 comments on commit 7e3bfbd

Please sign in to comment.