Skip to content

Commit

Permalink
[CUDA]batch_matmul tensorcore schedule (apache#7146)
Browse files Browse the repository at this point in the history
* add batch_matmul_tensorcore

* add bmm cublas autotune

* add bmm tests

* out_shape for bmm_tensorcore

* fix comments

* code format

* add todos for tensorcore datatype checking

* fix lint

* fix have_tensorcore

* add dtype check for batch_matmul_tensorcore
  • Loading branch information
Meteorix authored and trevor-m committed Jan 21, 2021
1 parent a78d88d commit 40e2e90
Show file tree
Hide file tree
Showing 8 changed files with 422 additions and 2 deletions.
16 changes: 16 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,22 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
name="batch_matmul_cublas.cuda",
plevel=15,
)
if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target):
x, y = inputs
_, M, K = get_const_tuple(x.shape)
_, N, K = get_const_tuple(y.shape)
if x.dtype in ["float16", "int8", "uint8"] and (
(M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
):
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul_tensorcore),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul_tensorcore),
name="batch_matmul_tensorcore.cuda",
plevel=20,
)

return strategy


Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .pooling import *
from .nn import schedule_lrn
from .batch_matmul import *
from .batch_matmul_tensorcore import *
from .vision import *
from .ssd import *
from .nms import get_valid_counts, non_max_suppression
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/topi/cuda/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm import te
from tvm.contrib import cublas
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from .. import nn
from .. import nn, generic
from ..utils import traverse_inline, get_const_tuple, get_max_power2_factor


Expand Down Expand Up @@ -138,7 +138,8 @@ def _callback(op):
return s


def batch_matmul_cublas(x, y, out_shape=None):
@autotvm.register_topi_compute("batch_matmul_cublas.cuda")
def batch_matmul_cublas(cfg, x, y, out_shape=None):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
Expand All @@ -158,4 +159,13 @@ def batch_matmul_cublas(x, y, out_shape=None):
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
b, m, k = x.shape
b, n, k = y.shape
cfg.add_flop(b * m * k * n * 2)
return cublas.batch_matmul(x, y, False, True)


@autotvm.register_topi_schedule("batch_matmul_cublas.cuda")
def schedule_batch_matmul_cublas(_, outs):
"""Schedule batch_matmul operator using CUBLAS"""
return generic.schedule_extern(outs)
315 changes: 315 additions & 0 deletions python/tvm/topi/cuda/batch_matmul_tensorcore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument
"""cuda batch_matmul operators"""
import tvm
from tvm import autotvm
from tvm import te
from ..utils import traverse_inline, get_const_tuple
from .tensor_intrin import (
intrin_wmma_load_matrix_A,
intrin_wmma_load_matrix_W,
intrin_wmma_store_matrix,
intrin_wmma_gemm,
)


@autotvm.register_topi_compute("batch_matmul_tensorcore.cuda")
def batch_matmul_tensorcore(cfg, x, y, out_shape=None):
"""batch matmul tensorcore operator on cuda"""
# todo: deal with out_shape for broadcast, liuxin.ai
return batch_matmul_tensorcore_cuda(x, y)


@autotvm.register_topi_schedule("batch_matmul_tensorcore.cuda")
def schedule_batch_matmul_tensorcore(cfg, outs):
"""Schedule for batch_matmul operator using Tensorcore
Parameters
----------
outs: Array of Tensor
The computation graph description of batch_matmul
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _schedule(cfg, s, C):
A, B = s[C].op.input_tensors
batch, m_dim, k_dim = get_const_tuple(A.shape)
batch, n_dim, k_dim = get_const_tuple(B.shape)
out_dtype = C.dtype
# inline astype fp16
s[A].compute_inline()
s[B].compute_inline()

# Explicit memory access
AS = s.cache_read(A, "shared", [C])
BS = s.cache_read(B, "shared", [C])
AF = s.cache_read(AS, "wmma.matrix_a", [C])
BF = s.cache_read(BS, "wmma.matrix_b", [C])
CF = s.cache_write(C, "wmma.accumulator")
CS = s.cache_read(CF, "shared", [C])

# fallback support
target = tvm.target.Target.current()
if cfg.is_fallback:
ref_log = autotvm.tophub.load_reference_log(
target.kind.name, target.model, "batch_matmul_tensorcore.cuda"
)
cfg.fallback_with_reference_log(ref_log)

# Deal with op fusion, such as bias/relu and slice after padding
if C.op not in s.outputs and "injective" in s.outputs[0].tag:
s[C].compute_inline()
C = s.outputs[0].output(0)

# create tuning space
cfg.define_knob("block_row_warps", [1, 2, 4])
cfg.define_knob("block_col_warps", [1, 2, 4])
cfg.define_knob("warp_row_tiles", [1, 2, 4])
cfg.define_knob("warp_col_tiles", [1, 2, 4])
cfg.define_knob("chunk", [1, 2, 4, 8])
cfg.define_knob("offset", [0, 8])
cfg.define_knob("offsetCS", [0, 8])
cfg.define_knob("vec", [1, 2, 4, 8])

# Ensure that the default parameters are applicable when autotvm is not in use
if m_dim % 32 == 0 and n_dim % 8 == 0:
cfg.define_knob("wmma_m", [32, 16, 8])
elif m_dim % 16 == 0 and n_dim % 16 == 0:
cfg.define_knob("wmma_m", [16, 8, 32])
elif m_dim % 8 == 0 and n_dim % 32 == 0:
cfg.define_knob("wmma_m", [8, 16, 32])

warp_size = 32
wmma_k = 16
block_row_warps = cfg["block_row_warps"].val
block_col_warps = cfg["block_col_warps"].val
warp_row_tiles = cfg["warp_row_tiles"].val
warp_col_tiles = cfg["warp_col_tiles"].val
chunk = cfg["chunk"].val
offset = cfg["offset"].val
offsetCS = cfg["offsetCS"].val
wmma_m = cfg["wmma_m"].val
vec = cfg["vec"].val

if wmma_m == 16:
wmma_n = 16
elif wmma_m == 8:
wmma_n = 32
elif wmma_m == 32:
wmma_n = 8

# Define the stride of intrin functions
AS_align = chunk * wmma_k + offset
BS_align = chunk * wmma_k + offset
CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS
AS_stride = [AS_align, 1]
BS_stride = [BS_align, 1]
AF_stride = [wmma_k, 1]
BF_stride = [wmma_k, 1]
CF_stride = [warp_col_tiles * wmma_n, 1]
CS_stride = [CS_align, 1]

block_x = te.thread_axis("blockIdx.x")
block_y = te.thread_axis("blockIdx.y")
block_z = te.thread_axis("blockIdx.z")
thread_x = te.thread_axis("threadIdx.x")
thread_y = te.thread_axis("threadIdx.y")
thread_z = te.thread_axis("threadIdx.z")

# Schedule for dense computation
block_factor_m = wmma_m * warp_row_tiles * block_row_warps
block_factor_n = wmma_n * warp_col_tiles * block_col_warps
b, m, n = C.op.axis
block_i, bc = s[C].split(m, factor=block_factor_m)
block_j, oc = s[C].split(n, factor=block_factor_n)
s[C].reorder(b, block_i, block_j, bc, oc)
t = s[C].fuse(bc, oc)
t, vi = s[C].split(t, factor=vec)
t, tx = s[C].split(t, factor=warp_size)
t, ty = s[C].split(t, factor=block_row_warps)
t, tz = s[C].split(t, factor=block_col_warps)
s[C].bind(block_i, block_x)
s[C].bind(block_j, block_y)
s[C].bind(b, block_z)
s[C].bind(tz, thread_z)
s[C].bind(ty, thread_y)
s[C].bind(tx, thread_x)
s[C].vectorize(vi)

# Schedule for wmma store
s[CS].compute_at(s[C], block_j)
bs, bb, oo = CS.op.axis
s[CS].storage_align(bb, CS_align - 1, CS_align)
bb, bbi = s[CS].split(bb, factor=wmma_m)
oo, ooi = s[CS].split(oo, factor=wmma_n)
bb, bbii = s[CS].split(bb, factor=warp_row_tiles)
oo, ooii = s[CS].split(oo, factor=warp_col_tiles)
s[CS].reorder(bs, bb, oo, bbii, ooii, bbi, ooi)

# Schedule for wmma computation
s[CF].compute_at(s[CS], oo)
bs, warp_i, warp_j = CF.op.axis
warp_i, _ii = s[CF].split(warp_i, factor=wmma_m)
warp_j, _jj = s[CF].split(warp_j, factor=wmma_n)
(k,) = CF.op.reduce_axis
k, _k = s[CF].split(k, factor=wmma_k)
ko, ki = s[CF].split(k, factor=chunk)
s[CF].reorder(bs, ko, ki, warp_i, warp_j, _ii, _jj, _k)

# Schedule for wmma_matrix_a load
s[AF].compute_at(s[CF], ki)
bs, b, i = AF.op.axis
b, b_ii = s[AF].split(b, factor=wmma_m)
i, i_jj = s[AF].split(i, factor=wmma_k)
s[AF].reorder(bs, b, i, b_ii, i_jj)

# Schedule for wmma_matrix_b load
s[BF].compute_at(s[CF], ki)
bs, o, i = BF.op.axis
o, o_ii = s[BF].split(o, factor=wmma_n)
i, i_ii = s[BF].split(i, factor=wmma_k)
s[BF].reorder(bs, o, i, o_ii, i_ii)

# Schedule for A's(B's) shared memory load
def shared_shedule(stage, strides):
s[stage].compute_at(s[CF], ko)
bs, xo, yo = stage.op.axis
s[stage].storage_align(xo, strides - 1, strides)
t = s[stage].fuse(xo, yo)
t, vi = s[stage].split(t, factor=vec)
t, tx = s[stage].split(t, factor=warp_size)
t, ty = s[stage].split(t, factor=block_row_warps)
_, tz = s[stage].split(t, factor=block_col_warps)
s[stage].bind(ty, thread_y)
s[stage].bind(tz, thread_z)
s[stage].bind(tx, thread_x)
s[stage].vectorize(vi)

shared_shedule(AS, AS_align)
shared_shedule(BS, BS_align)

shape = (wmma_m, wmma_n, wmma_k)
# TODO: add checking here, datatype casting may cause precision loss
in_dtype = "float16"
AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype)
BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype)
k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm")
CL_compute = te.compute(
(wmma_m, wmma_n),
lambda ii, jj: te.sum(
AL_gemm[ii, k_gemm].astype(out_dtype) * BL_gemm[jj, k_gemm].astype(out_dtype),
axis=k_gemm,
),
name="CL_compute",
)

# lower the computation loops down to TensorCore hardware intrinsics
# by mapping the dense tensorcore to tensor intrinsics
s[AF].tensorize(
b_ii,
intrin_wmma_load_matrix_A(
AF_stride,
AS_stride,
shape,
"row_major",
(wmma_m, wmma_k),
(wmma_m, wmma_k),
"float16",
),
)
s[BF].tensorize(
o_ii,
intrin_wmma_load_matrix_W(
BF_stride,
BS_stride,
shape,
"col_major",
(wmma_n, wmma_k),
(wmma_n, wmma_k),
"float16",
),
)
s[CF].tensorize(
_ii,
intrin_wmma_gemm(AL_gemm, BL_gemm, CL_compute, AF_stride, BF_stride, CF_stride, shape),
)
s[CS].tensorize(
bbi,
intrin_wmma_store_matrix(
CS_stride, CF_stride, shape, out_dtype, (wmma_m, wmma_n), (wmma_m, wmma_n)
),
)

def _callback(op):
if "batch_matmul_tensorcore" in op.tag:
_schedule(cfg, s, op.output(0))

traverse_inline(s, outs[0].op, _callback)
return s


def batch_matmul_tensorcore_cuda(x, y):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
Parameters
----------
x : tvm.te.Tensor
3-D with shape [batch, M, K]
y : tvm.te.Tensor
3-D with shape [batch, N, K]
Returns
-------
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
x_shape = get_const_tuple(x.shape)
y_shape = get_const_tuple(y.shape)
assert x_shape[0] == y_shape[0], "batch dimension doesn't match"
assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant"
batch, M, K = x.shape
N = y.shape[1]
out_dtype = x.dtype

assert (
(M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
), "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)"

x_16 = te.compute((batch, M, K), lambda b, i, k: x[b, i, k].astype("float16"))
y_16 = te.compute((batch, N, K), lambda b, j, k: y[b, j, k].astype("float16"))

k = te.reduce_axis((0, K), name="k")
return te.compute(
(batch, M, N),
lambda b, i, j: te.sum(
x_16[b, i, k].astype(out_dtype) * y_16[b, j, k].astype(out_dtype), axis=k
),
tag="batch_matmul_tensorcore",
)
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/conv2d_nhwc_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtyp
ry = te.reduce_axis((0, kernel_h), name="ry")
rx = te.reduce_axis((0, kernel_w), name="rx")
# convert data type of input feature maps and weights
# TODO: add checking here, datatype casting may cause precision loss
TransPaddedInput = te.compute(
PaddedInput.shape, lambda n, h, w, c: PaddedInput[n, h, w, c].astype("float16")
)
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def ndhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dty
ry = te.reduce_axis((0, kernel_h), name="ry")
rx = te.reduce_axis((0, kernel_w), name="rx")
# convert data type of input feature maps and weights
# TODO: add checking here, datatype casting may cause precision loss
TransPaddedInput = te.compute(
PaddedInput.shape, lambda n, d, h, w, c: PaddedInput[n, d, h, w, c].astype("float16")
)
Expand Down

0 comments on commit 40e2e90

Please sign in to comment.