Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Operator] Add a opaque operator base class (#414)
When we do not want to gives the computation for some operator because its too tedious or can not expressed using our computation defintion DSL, we can define an opaque operator that only gives 1. the dtype and shape inference function that infer the output dtype and shape given the inputs' 2. the implement function that implements the operator given the input/output dtype and shape An example to define an opaque operator to perform matrix multiplication. ```python from typing import List, Union import hidet from hidet import Tensor from hidet.graph.ops.opaque import OpaqueOperator from hidet.ir.dtypes import float32 from hidet.ir import IRModule hidet.option.cache_dir('./outs/cache') class OpaqueMatmul(OpaqueOperator): def __init__(self, x: Tensor, y: Tensor): super().__init__( name='matmul', inputs={ 'x': x, 'y': y }, ) def symbolic_forward(self, x: Tensor, y: Tensor): assert x.dtype == y.dtype == float32 assert x.device.is_cuda() m, k = x.shape k, n = y.shape return { 'z': self.symbol( shape=[m, n], dtype=x.dtype, device=x.device ) } def implement_cuda(self, inputs: List[Tensor], outputs: List[Tensor]) -> Union[IRModule, List[IRModule]]: import hidet from hidet.lang import attrs from hidet.lang.types import f32 from hidet.lang.cuda import threadIdx, blockIdx m_size, k_size = inputs[0].shape k_size, n_size = inputs[1].shape with hidet.script_module() as script_module: @hidet.script def matmul(x: f32[m_size, k_size], y: f32[k_size, n_size], z: f32[m_size, n_size]): attrs.func_kind = 'cuda_kernel' attrs.cuda.block_dim = (32, 32) attrs.cuda.grid_dim = ((n_size + 31) // 32, (m_size + 31) // 32) i = threadIdx.x + blockIdx.x * 32 j = threadIdx.y + blockIdx.y * 32 if i < n_size and j < m_size: z[j, i] = 0.0 for k in range(k_size): z[j, i] += x[j, k] * y[k, i] return script_module.ir_module() def opaque_matmul(x: Tensor, y: Tensor) -> Tensor: return OpaqueMatmul(x, y).outputs[0] def test_opaque_operator(): a = hidet.randn([128, 128], dtype='float32', device='cuda') b = hidet.randn([128, 128], dtype='float32', device='cuda') c1 = opaque_matmul(a, b) c2 = a @ b print(hidet.ops.max(hidet.ops.abs(c1 - c2), dims=[0, 1])) ```
- Loading branch information