Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add a opaque operator base class #414

Merged
merged 2 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
122 changes: 122 additions & 0 deletions python/hidet/graph/ops/opaque.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
Opaque operator is an operator that does not provide the computation definition, but use an unique name to
identify the computation. Opaque operator is used to represent the operators that are hard to represent its
computation definition, or it is too tedious to represent its computation definition.
"""
from typing import List, Dict, Any, Optional, Union, Sequence
from hidet.graph.tensor import symbol
from .utils import Tensor, Task, Operator, IRModule, Expr, input_like


class OpaqueTask(Task):
def __init__(self, name: str, inputs, outputs, op):
super().__init__(name=name, inputs=inputs, outputs=outputs, attributes={'is_opaque': True})
self.op: OpaqueOperator = op

def allow_prologue(self) -> bool:
return self.op.allow_prologue()

def allow_epilogue(self) -> bool:
return self.op.allow_prologue()

def implement_cuda(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
return self.op.implement_cuda(self.op.inputs, self.op.outputs)

def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
return self.op.implement_cpu(self.op.inputs, self.op.outputs)


class OpaqueOperator(Operator):
def __init__(self, name: str, inputs: Dict[str, Tensor], attributes: Optional[Dict[str, Any]] = None):
symbol_outputs: Dict[str, Tensor] = self.symbolic_forward(**inputs)
super().__init__(
inputs=list(inputs.values()),
attributes=attributes if attributes is not None else {},
task=OpaqueTask(
name=name,
inputs=[input_like(tensor, name) for name, tensor in inputs.items()],
outputs=[input_like(tensor, name) for name, tensor in symbol_outputs.items()],
op=self,
),
)

def symbol(self, shape: Sequence[Union[int, str, Expr]], dtype='float32', device='cpu'):
return symbol(shape, dtype, device)

def allow_prologue(self):
"""
Whether to allow prologue for this operator for prologue_epilogue_fusion pass.

Returns
-------
ret: bool
True if allow prologue, False otherwise.
"""
return False

def allow_epilogue(self):
"""
Whether to allow epilogue for this operator for prologue_epilogue_fusion pass.


Returns
-------
ret: bool
True if allow epilogue, False otherwise.
"""
return False

def symbolic_forward(self, **args: Tensor) -> Dict[str, Tensor]:
"""
Infer the dtype and shape of the output tensors given the input tensors.

Parameters
----------
args: Dict[str, Tensor]
The input tensors.

Returns
-------
ret: Dict[str, Tensor]
The output tensors.
"""
raise NotImplementedError()

def implement_cuda(self, inputs: List[Tensor], outputs: List[Tensor]) -> Union[IRModule, List[IRModule]]:
"""
Implement this operator on CUDA.

Parameters
----------
inputs: List[Tensor]
The input tensors.
outputs: List[Tensor]
The output tensors.

Returns
-------
ret: Union[IRModule, List[IRModule]]
The IRModule or a list of IRModules that implement this operator. When multiple IRModules are returned,
they must have the same functionality and hidet will pick the most performant one to use.
"""
raise NotImplementedError('Opaque operator {} does not have CUDA implementation'.format(self.name))

def implement_cpu(self, inputs: List[Tensor], outputs: List[Tensor]) -> Union[IRModule, List[IRModule]]:
"""
Implement this operator on CPU.

Parameters
----------
inputs: List[Tensor]
The input tensors.

outputs: List[Tensor]
The output tensors.

Returns
-------
ret: Union[IRModule, List[IRModule]]
The IRModule or a list of IRModules that implement this operator. When multiple IRModules are returned,
they must have the same functionality and hidet will pick the most performant one to use.
"""
raise NotImplementedError('Opaque operator {} does not have CPU implementation'.format(self.name))
2 changes: 1 addition & 1 deletion python/hidet/ir/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _sanity_check(self):

# check all TensorInput used in outputs are placed in inputs
used_inputs = collect(self.outputs, TensorInput)
if any(x not in self.inputs for x in used_inputs):
if any(x not in self.inputs + self.outputs for x in used_inputs):
raise ValueError('Some TensorInput used in outputs are not placed in inputs: {}'.format(used_inputs))

# check assertions for correctness
Expand Down
63 changes: 63 additions & 0 deletions tests/operators/test_opaque.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import List, Union
import pytest
import hidet
from hidet import Tensor
from hidet.graph.ops.opaque import OpaqueOperator
from hidet.ir.dtypes import float32
from hidet.ir import IRModule

hidet.option.cache_dir('./outs/cache')


class OpaqueMatmul(OpaqueOperator):
def __init__(self, x: Tensor, y: Tensor):
super().__init__(name='matmul', inputs={'x': x, 'y': y})

def symbolic_forward(self, x: Tensor, y: Tensor):
assert x.dtype == y.dtype == float32
assert x.device.is_cuda()
m, k = x.shape
k, n = y.shape
return {'z': self.symbol(shape=[m, n], dtype=x.dtype, device=x.device)}

def implement_cuda(self, inputs: List[Tensor], outputs: List[Tensor]) -> Union[IRModule, List[IRModule]]:
import hidet
from hidet.lang import attrs
from hidet.lang.types import f32
from hidet.lang.cuda import threadIdx, blockIdx

m_size, k_size = inputs[0].shape
k_size, n_size = inputs[1].shape

with hidet.script_module() as script_module:

@hidet.script
def matmul(x: f32[m_size, k_size], y: f32[k_size, n_size], z: f32[m_size, n_size]):
attrs.func_kind = 'cuda_kernel'
attrs.cuda.block_dim = (32, 32)
attrs.cuda.grid_dim = ((n_size + 31) // 32, (m_size + 31) // 32)

i = threadIdx.x + blockIdx.x * 32
j = threadIdx.y + blockIdx.y * 32
if i < n_size and j < m_size:
z[j, i] = 0.0
for k in range(k_size):
z[j, i] += x[j, k] * y[k, i]

return script_module.ir_module()


def opaque_matmul(x: Tensor, y: Tensor) -> Tensor:
return OpaqueMatmul(x, y).outputs[0]


def test_opaque_operator():
a = hidet.randn([128, 128], dtype='float32', device='cuda')
b = hidet.randn([128, 128], dtype='float32', device='cuda')
c1 = opaque_matmul(a, b)
c2 = a @ b
hidet.utils.assert_close(c1, c2)


if __name__ == '__main__':
pytest.main([__file__])