Skip to content

Commit

Permalink
[Operator] Add cublas to matmul tune space (#422)
Browse files Browse the repository at this point in the history
Also fix bug in call graph generation.
  • Loading branch information
hjjq committed Jan 24, 2024
1 parent 1fa14a5 commit 072a606
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 88 deletions.
80 changes: 73 additions & 7 deletions python/hidet/graph/ops/matmul/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def implement_cuda(self, working_dir: str) -> List[IRModule]:
spatial(2, 1) * spatial(1, 8) * spatial(2, 1),
],
warp_inner=[(4, 4)],
use_cublas=[True, False],
)
@tune.space(
1,
Expand All @@ -84,12 +85,41 @@ def implement_cuda(self, working_dir: str) -> List[IRModule]:
warp_outer=[(1, 1), (1, 2), (2, 1), (2, 2)],
warp_mid=[spatial(4, 8)],
warp_inner=[(4, 4), (4, 8), (8, 4)],
use_cublas=[True, False],
)
def schedule_simt(
self, block_warps_k=8, block_warps=(4, 2), warp_outer=(2, 2), warp_mid=spatial(4, 8), warp_inner=(4, 4)
self,
block_warps_k=8,
block_warps=(4, 2),
warp_outer=(2, 2),
warp_mid=spatial(4, 8),
warp_inner=(4, 4),
use_cublas=False,
) -> IRModule:
task = self
dtype = task.inputs[0].type.dtype

if use_cublas:
from hidet.graph.ops.utils.schedule_utils import get_cublas_matmul_schedule

a_shape = task.inputs[0].type.shape
b_shape = task.inputs[1].type.shape
c_shape = task.outputs[0].type.shape
# Hack to reduce redundant schedules. When use_cublas == False, other tuning params are irrelevant
# and we only need one copy of the schedule.
from hidet.ir.mapping import SpatialTaskMapping

schedule_filter = (
block_warps_k == 8
and block_warps == (1, 1)
and warp_outer == (1, 1)
and isinstance(warp_mid, SpatialTaskMapping)
and warp_mid.task_shape == (4, 8)
and warp_inner == (4, 4)
)
tune.check(schedule_filter)
return get_cublas_matmul_schedule(a_shape, b_shape, c_shape, dtype, dtype, dtype)

warp_k = 1

# Task Layouts
Expand Down Expand Up @@ -369,6 +399,7 @@ def batch_matmul_kernel(
warp_n=[16, 32, 64],
warp_k=[8, 16, 32],
mma_config=MmaConfig.all(),
use_cublas=[True, False],
)
@tune.space(
1,
Expand All @@ -379,9 +410,18 @@ def batch_matmul_kernel(
warp_n=[32, 64],
warp_k=[8, 16, 32],
mma_config=MmaConfig.all(),
use_cublas=[True, False],
)
def schedule_mma(
self, block_m=64, block_n=64, block_k=16, warp_m=32, warp_n=32, warp_k=16, mma_config: MmaConfig = None
self,
block_m=64,
block_n=64,
block_k=16,
warp_m=32,
warp_n=32,
warp_k=16,
mma_config: MmaConfig = None,
use_cublas=False,
) -> IRModule:
def resolve_mma_type(a_dtype: DataType, b_dtype: DataType, c_dtype: DataType):
dtype_rank = {'float16': 0, 'bfloat16': 1, 'tfloat32': 2, 'float32': 4}
Expand All @@ -398,9 +438,35 @@ def resolve_mma_type(a_dtype: DataType, b_dtype: DataType, c_dtype: DataType):

task = self

input_a, input_b, input_c = task.inputs[0], task.inputs[1], task.outputs[0]
input_a_dtype, input_b_dtype, input_c_dtype = [t.type.dtype for t in [input_a, input_b, input_c]]
mma_type = resolve_mma_type(input_a_dtype, input_b_dtype, input_c_dtype)
input_a, input_b, output_c = task.inputs[0], task.inputs[1], task.outputs[0]
input_a_dtype, input_b_dtype, output_c_dtype = [t.type.dtype for t in [input_a, input_b, output_c]]
input_a_shape, input_b_shape, output_c_shape = [t.type.shape for t in [input_a, input_b, output_c]]

if use_cublas:
from hidet.graph.ops.utils.schedule_utils import get_cublas_matmul_schedule

# Hack to reduce redundant schedules. When use_cublas == False, other tuning params are irrelevant
# and we only need one copy of the schedule.
schedule_filter = (
block_m == 64
and block_n == 64
and block_k == 8
and warp_m == 32
and warp_n == 32
and warp_k == 8
and mma_config
and mma_config.m == 16
and mma_config.n == 8
and mma_config.k == 8
and mma_config.input_dtype == 'f16'
and mma_config.output_dtype == 'f16'
)
tune.check(schedule_filter)
return get_cublas_matmul_schedule(
input_a_shape, input_b_shape, output_c_shape, input_a_dtype, input_b_dtype, output_c_dtype
)

mma_type = resolve_mma_type(input_a_dtype, input_b_dtype, output_c_dtype)

# Resolve parameters when space level is 0
if mma_config is None:
Expand Down Expand Up @@ -549,7 +615,7 @@ def copy_b_s2r(
@hidet.script
def copy_c_r2g(
regs_c: TensorType(dtype=c_dtype, layout=regs_c_layout),
c: input_c_dtype[bs, m_size, n_size],
c: output_c_dtype[bs, m_size, n_size],
offset_m: i32,
offset_n: i32,
smem: void_p,
Expand Down Expand Up @@ -610,7 +676,7 @@ def mma(
def batch_matmul_kernel(
a: input_a_dtype[bs, m_size, k_size],
b: input_b_dtype[bs, k_size, n_size],
c: input_c_dtype[bs, m_size, n_size],
c: output_c_dtype[bs, m_size, n_size],
):
attrs.cuda.grid_dim = (m_tiles * n_tiles, bs)
attrs.cuda.block_dim = block_size
Expand Down
82 changes: 5 additions & 77 deletions python/hidet/graph/ops/matmul/matmul_cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, List, Optional
from typing import Union, Optional

import hidet
from hidet.ir.module import IRModule
from hidet.ir.type import DataType
from hidet.ir.expr import Expr, is_true
from hidet.ir.dtypes import f16, f32
from hidet.utils import prod
from hidet.ir.expr import Expr
from hidet.cuda.cublas import cublasComputeType
from ..utils import Task, Operator, Tensor, input_like
from ..utils import TensorInput
from ..utils.schedule_utils import convert_to_cublas_strided_gemm, resolve_cublas_compute_type


class CublasMatmulTask(Task):
Expand All @@ -30,83 +28,13 @@ def __init__(self, a: TensorInput, b: TensorInput, compute_type: Optional[Union[
if a.type.dtype != b.type.dtype:
raise ValueError('dtype of a and b must be the same, got {} and {}'.format(a.type.dtype, b.type.dtype))

self.compute_type: cublasComputeType = self.resolve_compute_type(a.type.dtype, a.type.dtype, compute_type)
self.compute_type: cublasComputeType = resolve_cublas_compute_type(a.type.dtype, a.type.dtype, compute_type)

c = cops.matmul(a, b, allow_1d=True)
super().__init__(
name='cublas_matmul', inputs=[a, b], outputs=[c], attributes={'compute_type': self.compute_type}
)

def resolve_compute_type(
self, in_dtype: DataType, out_dtype: DataType, compute_type: Optional[Union[int, cublasComputeType]]
) -> cublasComputeType:
if compute_type is not None:
return cublasComputeType(compute_type)
if in_dtype == out_dtype == f16:
# use tensor core whenever possible
return cublasComputeType.CUBLAS_COMPUTE_16F
elif in_dtype == out_dtype == f32:
# use tensor core whenever possible
return cublasComputeType.CUBLAS_COMPUTE_32F
else:
raise NotImplementedError(
'not implemented resolve rules for compute_type with in_dtype={}, out_dtype={}'.format(
in_dtype, out_dtype
)
)

def convert_to_strided_gemm(self, a_shape: List[Expr], b_shape: List[Expr], c_shape: List[Expr]):
a_rank: int = len(a_shape)
b_rank: int = len(b_shape)

assert a_rank >= 1 and b_rank >= 1 and (a_rank >= 2 or b_rank >= 2)
if a_rank == 1:
bs = prod(b_shape[:-2])
m = 1
n = b_shape[-1]
k = a_shape[0]
stride_a = 0
stride_b = b_shape[-2] * b_shape[-1]
stride_c = c_shape[-2] * c_shape[-1]
elif b_rank == 1:
bs = prod(a_shape[:-2])
m = a_shape[-2]
n = 1
k = b_shape[0]
stride_a = a_shape[-2] * a_shape[-1]
stride_b = 0
stride_c = c_shape[-1]
else:
if is_true(prod(a_shape[:-2]) == 1):
bs = prod(b_shape[:-2])
m = a_shape[-2]
n = b_shape[-1]
k = a_shape[-1]
stride_a = 0
stride_b = b_shape[-2] * b_shape[-1]
stride_c = c_shape[-2] * c_shape[-1]
elif is_true(prod(b_shape[:-2]) == 1):
bs = prod(a_shape[:-2])
m = a_shape[-2]
n = b_shape[-1]
k = a_shape[-1]
stride_a = a_shape[-2] * a_shape[-1]
stride_b = 0
stride_c = c_shape[-2] * c_shape[-1]
elif all(is_true(a == b) for a, b in zip(a_shape[:-2], b_shape[:-2])):
bs = prod(a_shape[:-2])
m = a_shape[-2]
n = b_shape[-1]
k = a_shape[-1]
stride_a = a_shape[-2] * a_shape[-1]
stride_b = b_shape[-2] * b_shape[-1]
stride_c = c_shape[-2] * c_shape[-1]
else:
# todo: add cublasGemmBatchedEx to support this case
# https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmbatchedex
raise NotImplementedError('Can not convert matmul {} @ {} to strided_gemm'.format(a_shape, b_shape))
return bs, m, n, k, stride_a, stride_b, stride_c

def implement_cuda(self, working_dir: str) -> IRModule:
from hidet.lang import attrs
from hidet.lang.cuda import cublas
Expand All @@ -120,7 +48,7 @@ def implement_cuda(self, working_dir: str) -> IRModule:
with hidet.script_module() as script_module:

def generate(a: Expr, b: Expr, c: Expr) -> Expr:
bs, m, n, k, stride_a, stride_b, stride_c = self.convert_to_strided_gemm(a_shape, b_shape, c_shape)
bs, m, n, k, stride_a, stride_b, stride_c = convert_to_cublas_strided_gemm(a_shape, b_shape, c_shape)
return cublas.strided_gemm(
bs,
m,
Expand Down
41 changes: 39 additions & 2 deletions python/hidet/graph/ops/matmul/matmul_f16.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,29 @@ def implement_cuda(self, working_dir: str) -> List[IRModule]:
warp_n=[16, 32, 48, 64],
warp_k=[8, 16, 32, 64],
mma=['m16n8k16'],
use_cublas=[True, False],
)
@tune.space(
1,
block_m=[128],
block_n=[128],
block_k=[16],
warp_m=[64],
warp_n=[64],
warp_k=[16],
mma=['m16n8k16'],
use_cublas=[True, False],
)
@tune.space(1, block_m=[128], block_n=[128], block_k=[16], warp_m=[64], warp_n=[64], warp_k=[16], mma=['m16n8k16'])
def schedule(
self, block_m=64, block_n=128, block_k=16, warp_m=32, warp_n=64, warp_k=16, mma: str = 'm16n8k16'
self,
block_m=64,
block_n=128,
block_k=16,
warp_m=32,
warp_n=64,
warp_k=16,
mma: str = 'm16n8k16',
use_cublas=False,
) -> IRModule:
# pylint: disable=unused-variable
import hidet
Expand All @@ -117,6 +136,24 @@ def schedule(
k_parts = self.attrs['parallel_k_parts']
k_part_extent = cdiv(cdiv(k_size, k_parts), 8) * 8

if use_cublas:
from hidet.graph.ops.utils.schedule_utils import get_cublas_matmul_schedule

dtype = self.inputs[0].type.dtype
# Hack to reduce redundant schedules. When use_cublas == False, other tuning params are irrelevant
# and we only need one copy of the schedule.
schedule_filter = (
block_m == 128
and block_n == 128
and block_k == 16
and warp_m == 64
and warp_n == 64
and warp_k == 16
and mma == 'm16n8k16'
)
tune.check(schedule_filter)
return get_cublas_matmul_schedule(a_shape, b_shape, c_shape, dtype, dtype, dtype)

# schedule parameters
mma_configs = {'m16n8k8': MmaConfig.m16n8k8_f16_f16(), 'm16n8k16': MmaConfig.m16n8k16_f16_f16()}
tune.check(mma in mma_configs)
Expand Down

0 comments on commit 072a606

Please sign in to comment.