Skip to content

Commit

Permalink
[Operator] Add hidet.ops.matmul_cublas operator (#405)
Browse files Browse the repository at this point in the history
Add `hidet.ops.matmul_cublas` operator, which uses cuBLAS library.

```python
m, n, k = 1024, 1024, 1024
a = hidet.randn([m, k], dtype='float16', device='cuda') / 32.0
b = hidet.randn([k, n], dtype='float16', device='cuda') / 32.0
c = hidet.ops.matmul_cublas(a, b)
d = hidet.ops.matmul(a, b)

hidet.utils.assert_close(actual=c, expected=d, rtol=1e-2, atol=1e-2)
```

The generated code `source.cu` looks like
```c++
#include <hidet/runtime/cuda/cublas.h>

// ...

DLL void hidet_launch_0(half * __restrict__ a, half * __restrict__ b, half * __restrict__ c) {
  hidet_cublas_strided_gemm(1, 1024, 1024, 1024, 2, 2, 2, a, b, c, 0, 1048576, 1048576, false, false, 64);
}

```
  • Loading branch information
yaoyaoding committed Jan 4, 2024
1 parent f70e5e6 commit f4ab1e4
Show file tree
Hide file tree
Showing 10 changed files with 283 additions and 27 deletions.
18 changes: 9 additions & 9 deletions python/hidet/cuda/cublas/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def gemm(
a,
b,
c,
trans_a: bool,
trans_b: bool,
compute_type: Union[int, cublasComputeType],
trans_a: bool = False,
trans_b: bool = False,
):
"""
Matrix multiplication of two matrices using cublas in row major by default.
Expand All @@ -54,18 +54,18 @@ def gemm(
Type of elements in matrix B.
type_c: Union[int, cudaDataType, DataType]
Type of elements in matrix C.
a: Tensor or int
a: hidet.Tensor or int
Matrix A, can be either a Tensor or an integer (the address of the matrix).
b: Tensor or int
b: hidet.Tensor or int
Matrix B, can be either a Tensor or an integer (the address of the matrix).
c: Tensor or int
c: hidet.Tensor or int
Matrix C, can be either a Tensor or an integer (the address of the matrix).
compute_type: Union[int, cublasComputeType]
The compute type of the operation.
trans_a: bool
Whether matrix A is transposed.
trans_b: bool
Whether matrix B is transposed.
compute_type: Union[int, cublasComputeType]
The compute type of the operation.
"""
ffi.gemm(
m,
Expand Down Expand Up @@ -97,9 +97,9 @@ def strided_gemm(
stride_a: int,
stride_b: int,
stride_c: int,
trans_a: bool,
trans_b: bool,
compute_type: Union[int, cublasComputeType],
trans_a: bool = False,
trans_b: bool = False,
):
"""
Batch matrix multiplication of two matrices using cublas in row major order by default.
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=redefined-builtin
from .matmul import batch_matmul, matmul, matmul_x86
from .matmul import batch_matmul, matmul, matmul_x86, matmul_cublas
from .conv1d import conv1d, conv1d_gemm
from .conv1d_transpose import conv1d_transpose
from .conv2d import conv2d, conv2d_channel_last, conv2d_winograd, conv2d_gemm, conv2d_gemm_fp16
Expand Down
1 change: 1 addition & 0 deletions python/hidet/graph/ops/matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# limitations under the License.
from .matmul import matmul, MatmulOp, MatmulTask
from .batch_matmul import batch_matmul, BatchMatmulOp, BatchMatmulTask
from .matmul_cublas import matmul_cublas
from . import resolve


Expand Down
159 changes: 159 additions & 0 deletions python/hidet/graph/ops/matmul/matmul_cublas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, List, Optional

import hidet
from hidet.ir.module import IRModule
from hidet.ir.type import DataType
from hidet.ir.expr import Expr, is_true
from hidet.ir.dtypes import f16, f32
from hidet.utils import prod
from hidet.cuda.cublas import cublasComputeType
from ..utils import Task, Operator, Tensor, input_like
from ..utils import TensorInput


class CublasMatmulTask(Task):
def __init__(self, a: TensorInput, b: TensorInput, compute_type: Optional[Union[int, cublasComputeType]] = None):
from hidet.ir.compute import cops

# check
if a.type.dtype != b.type.dtype:
raise ValueError('dtype of a and b must be the same, got {} and {}'.format(a.type.dtype, b.type.dtype))

self.compute_type: cublasComputeType = self.resolve_compute_type(a.type.dtype, a.type.dtype, compute_type)

c = cops.matmul(a, b, allow_1d=True)
super().__init__(
name='cublas_matmul', inputs=[a, b], outputs=[c], attributes={'compute_type': self.compute_type}
)

def resolve_compute_type(
self, in_dtype: DataType, out_dtype: DataType, compute_type: Optional[Union[int, cublasComputeType]]
) -> cublasComputeType:
if compute_type is not None:
return cublasComputeType(compute_type)
if in_dtype == out_dtype == f16:
# use tensor core whenever possible
return cublasComputeType.CUBLAS_COMPUTE_16F
elif in_dtype == out_dtype == f32:
# use tensor core whenever possible
return cublasComputeType.CUBLAS_COMPUTE_32F
else:
raise NotImplementedError(
'not implemented resolve rules for compute_type with in_dtype={}, out_dtype={}'.format(
in_dtype, out_dtype
)
)

def convert_to_strided_gemm(self, a_shape: List[Expr], b_shape: List[Expr], c_shape: List[Expr]):
a_rank: int = len(a_shape)
b_rank: int = len(b_shape)

assert a_rank >= 1 and b_rank >= 1 and (a_rank >= 2 or b_rank >= 2)
if a_rank == 1:
bs = prod(b_shape[:-2])
m = 1
n = b_shape[-1]
k = a_shape[0]
stride_a = 0
stride_b = b_shape[-2] * b_shape[-1]
stride_c = c_shape[-2] * c_shape[-1]
elif b_rank == 1:
bs = prod(a_shape[:-2])
m = a_shape[-2]
n = 1
k = b_shape[0]
stride_a = a_shape[-2] * a_shape[-1]
stride_b = 0
stride_c = c_shape[-1]
else:
if is_true(prod(a_shape[:-2]) == 1):
bs = prod(b_shape[:-2])
m = a_shape[-2]
n = b_shape[-1]
k = a_shape[-1]
stride_a = 0
stride_b = b_shape[-2] * b_shape[-1]
stride_c = c_shape[-2] * c_shape[-1]
elif is_true(prod(b_shape[:-2]) == 1):
bs = prod(a_shape[:-2])
m = a_shape[-2]
n = b_shape[-1]
k = a_shape[-1]
stride_a = a_shape[-2] * a_shape[-1]
stride_b = 0
stride_c = c_shape[-2] * c_shape[-1]
elif all(is_true(a == b) for a, b in zip(a_shape[:-2], b_shape[:-2])):
bs = prod(a_shape[:-2])
m = a_shape[-2]
n = b_shape[-1]
k = a_shape[-1]
stride_a = a_shape[-2] * a_shape[-1]
stride_b = b_shape[-2] * b_shape[-1]
stride_c = c_shape[-2] * c_shape[-1]
else:
# todo: add cublasGemmBatchedEx to support this case
# https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmbatchedex
raise NotImplementedError('Can not convert matmul {} @ {} to strided_gemm'.format(a_shape, b_shape))
return bs, m, n, k, stride_a, stride_b, stride_c

def implement_cuda(self, working_dir: str) -> IRModule:
from hidet.lang import attrs
from hidet.lang.cuda import cublas

dtype = self.inputs[0].type.dtype
c_dtype = self.outputs[0].type.dtype
a_shape = list(self.inputs[0].type.shape)
b_shape = list(self.inputs[1].type.shape)
c_shape = list(self.outputs[0].type.shape)

with hidet.script_module() as script_module:

def generate(a: Expr, b: Expr, c: Expr) -> Expr:
bs, m, n, k, stride_a, stride_b, stride_c = self.convert_to_strided_gemm(a_shape, b_shape, c_shape)
return cublas.strided_gemm(
bs,
m,
n,
k,
dtype,
dtype,
c_dtype,
a,
b,
c,
stride_a,
stride_b,
stride_c,
False,
False,
self.compute_type,
)

@hidet.script
def launch(a: dtype[a_shape], b: dtype[b_shape], c: c_dtype[c_shape]):
attrs.func_kind = 'public'

generate(a, b, c)

return script_module.ir_module()


class CublasMatmulOp(Operator):
def __init__(self, a: Tensor, b: Tensor):
task = CublasMatmulTask(input_like(a, 'a'), input_like(b, 'b'))
super().__init__(inputs=[a, b], attributes={}, task=task)


def matmul_cublas(a: Tensor, b: Tensor) -> Tensor:
return CublasMatmulOp(a, b).outputs[0]
2 changes: 1 addition & 1 deletion python/hidet/ir/library/cuda/cublas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
# limitations under the License.
from hidet.cuda.cublas.utils import as_type_code
from hidet.cuda.cublas.kernels import cublasComputeType, cudaDataType
from .kernels import gemm, bgemm
from .kernels import gemm, strided_gemm
from . import regs as _regs # register functions
41 changes: 33 additions & 8 deletions python/hidet/ir/library/cuda/cublas/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,44 +10,69 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from hidet.ir.type import DataType
from hidet.ir.expr import Expr
from hidet.ir.primitives.func import call_primitive_func
from hidet.cuda.cublas.utils import as_type_code


def gemm(
m: Union[Expr, int],
n: Union[Expr, int],
k: Union[Expr, int],
type_a: Union[Expr, int],
type_b: Union[Expr, int],
type_c: Union[Expr, int],
type_a: Union[Expr, DataType, int],
type_b: Union[Expr, DataType, int],
type_c: Union[Expr, DataType, int],
a: Expr,
b: Expr,
c: Expr,
trans_a: Union[Expr, bool],
trans_b: Union[Expr, bool],
compute_type: Union[Expr, int],
):
type_a, type_b, type_c = [as_type_code(t) if isinstance(t, DataType) else t for t in [type_a, type_b, type_c]]
return call_primitive_func(
func_name='cublas.gemm', args=[m, n, k, type_a, type_b, type_c, a, b, c, trans_a, trans_b, compute_type]
)


def bgemm(
def strided_gemm(
bs: Union[Expr, int],
m: Union[Expr, int],
n: Union[Expr, int],
k: Union[Expr, int],
type_a: Union[Expr, int],
type_b: Union[Expr, int],
type_c: Union[Expr, int],
type_a: Union[Expr, DataType, int],
type_b: Union[Expr, DataType, int],
type_c: Union[Expr, DataType, int],
a: Expr,
b: Expr,
c: Expr,
stride_a: Union[Expr, int],
stride_b: Union[Expr, int],
stride_c: Union[Expr, int],
trans_a: Union[Expr, bool],
trans_b: Union[Expr, bool],
compute_type: Union[Expr, int],
):
type_a, type_b, type_c = [as_type_code(t) if isinstance(t, DataType) else t for t in [type_a, type_b, type_c]]
return call_primitive_func(
func_name='cublas.bgemm', args=[bs, m, n, k, type_a, type_b, type_c, a, b, c, trans_a, trans_b, compute_type]
func_name='cublas.strided_gemm',
args=[
bs,
m,
n,
k,
type_a,
type_b,
type_c,
a,
b,
c,
stride_a,
stride_b,
stride_c,
trans_a,
trans_b,
compute_type,
],
)
9 changes: 6 additions & 3 deletions python/hidet/ir/library/cuda/cublas/regs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.ir.dtypes import int32, boolean
from hidet.ir.dtypes import int32, int64, boolean
from hidet.ir.type import FuncType, void_p, void
from hidet.ir.primitives.func import register_primitive_function
from hidet.utils import initialize
Expand Down Expand Up @@ -39,7 +39,7 @@ def register_cublas_kernels():
codegen_name='hidet_cublas_gemm',
)
register_primitive_function(
name='cublas.bgemm',
name='cublas.strided_gemm',
func_or_type=FuncType(
param_types=[
int32, # bs
Expand All @@ -52,11 +52,14 @@ def register_cublas_kernels():
void_p, # a
void_p, # b
void_p, # c
int64, # stride_a
int64, # stride_b
int64, # stride_c
boolean, # trans_a
boolean, # trans_b
int32, # compute_type (cublasComputeType)
],
ret_type=void,
),
codegen_name='hidet_cublas_bgemm',
codegen_name='hidet_cublas_strided_gemm',
)
7 changes: 6 additions & 1 deletion python/hidet/ir/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from __future__ import annotations
from typing import Any, Dict, List, Union, Callable, Optional, Tuple
import os
import enum
import pickle
from hidet.ir.node import Node
from hidet.ir.type import FuncType, VoidType
Expand Down Expand Up @@ -147,7 +148,11 @@ def signature(self) -> str:
dtype = tensor.type.dtype.name
params.append('{}={}{}'.format(name, dtype, tensor.type.shape))
for name, value in self.attrs.items():
params.append('{}={}'.format(name, repr(value)))
if isinstance(value, enum.Enum):
value_str = value.name
else:
value_str = repr(value)
params.append('{}={}'.format(name, value_str))
param_doc = ', '.join(params)
fuse_doc = ''
return ''.join([self.name, '(', param_doc, ')', fuse_doc])
Expand Down

0 comments on commit f4ab1e4

Please sign in to comment.