-
Notifications
You must be signed in to change notification settings - Fork 15k
[mlir][python] Add pythonic interface for GPUFuncOp #163596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][python] Add pythonic interface for GPUFuncOp #163596
Conversation
The func dialect provides a more pythonic interface for constructing operations, but the gpu dialect does not; this is the first PR to provide the same conveniences for the gpu dialect, starting with the gpu.func op.
|
@llvm/pr-subscribers-mlir Author: Asher Mancinelli (ashermancinelli) ChangesThe func dialect provides a more pythonic interface for constructing operations, but the gpu dialect does not; this is the first PR to provide the same conveniences for the gpu dialect, starting with the gpu.func op. Full diff: https://github.com/llvm/llvm-project/pull/163596.diff 2 Files Affected:
diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py
index 4cd80aa8b7ca8..14b965927e280 100644
--- a/mlir/python/mlir/dialects/gpu/__init__.py
+++ b/mlir/python/mlir/dialects/gpu/__init__.py
@@ -3,5 +3,121 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._gpu_ops_gen import *
+from .._gpu_ops_gen import _Dialect
from .._gpu_enum_gen import *
from ..._mlir_libs._mlirDialectsGPU import *
+from typing import Callable, Sequence, Union, Optional
+
+try:
+ from ...ir import (
+ FunctionType,
+ TypeAttr,
+ StringAttr,
+ UnitAttr,
+ Block,
+ InsertionPoint,
+ ArrayAttr,
+ Type,
+ DictAttr,
+ Attribute,
+ )
+ from .._ods_common import (
+ get_default_loc_context as _get_default_loc_context,
+ _cext as _ods_cext,
+ )
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+
+FUNCTION_TYPE_ATTRIBUTE_NAME = "function_type"
+KERNEL_ATTRIBUTE_NAME = "gpu.kernel"
+SYM_NAME_ATTRIBUTE_NAME = "sym_name"
+ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
+RESULT_ATTRIBUTE_NAME = "res_attrs"
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class GPUFuncOp(GPUFuncOp):
+ def __init__(
+ self,
+ function_type: Union[FunctionType, TypeAttr],
+ sym_name: Optional[Union[str, StringAttr]] = None,
+ kernel: Optional[bool] = None,
+ body_builder: Optional[Callable[[GPUFuncOp], None]] = None,
+ *args,
+ loc=None,
+ ip=None,
+ **kwargs,
+ ):
+ function_type = (
+ TypeAttr.get(function_type)
+ if not isinstance(function_type, TypeAttr)
+ else function_type
+ )
+ super().__init__(function_type, *args, loc=loc, ip=ip, **kwargs)
+ if sym_name is not None:
+ self.attributes[SYM_NAME_ATTRIBUTE_NAME] = StringAttr.get(str(sym_name))
+ if kernel:
+ self.attributes[KERNEL_ATTRIBUTE_NAME] = UnitAttr.get()
+ if body_builder is not None:
+ with InsertionPoint(self.add_entry_block()):
+ body_builder(self)
+
+ @property
+ def type(self) -> FunctionType:
+ return FunctionType(
+ TypeAttr(self.attributes[FUNCTION_TYPE_ATTRIBUTE_NAME]).value
+ )
+
+ @property
+ def name(self) -> StringAttr:
+ return StringAttr(self.attributes[SYM_NAME_ATTRIBUTE_NAME])
+
+ @property
+ def is_kernel(self) -> bool:
+ return KERNEL_ATTRIBUTE_NAME in self.attributes
+
+ def add_entry_block(self) -> Block:
+ function_type = self.type
+ return self.body.blocks.append(
+ *function_type.inputs,
+ arg_locs=[self.location for _ in function_type.inputs],
+ )
+
+ @property
+ def entry_block(self) -> Block:
+ return self.body.blocks[0]
+
+ @property
+ def arguments(self) -> Sequence[Type]:
+ return self.type.inputs
+
+ @property
+ def arg_attrs(self):
+ if ARGUMENT_ATTRIBUTE_NAME not in self.attributes:
+ return ArrayAttr.get([DictAttr.get({}) for _ in self.type.inputs])
+ return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
+
+ @arg_attrs.setter
+ def arg_attrs(self, attribute: Union[ArrayAttr, list[Attribute]]):
+ if isinstance(attribute, ArrayAttr):
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
+ else:
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+ attribute, context=self.context
+ )
+
+ @property
+ def result_attrs(self) -> Optional[ArrayAttr]:
+ if RESULT_ATTRIBUTE_NAME not in self.attributes:
+ return ArrayAttr.get([DictAttr.get({}) for _ in self.type.results])
+ return self.attributes[RESULT_ATTRIBUTE_NAME]
+
+ @result_attrs.setter
+ def result_attrs(self, attribute: Union[ArrayAttr, list[Attribute]]):
+ if isinstance(attribute, ArrayAttr):
+ self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
+ else:
+ self.attributes[RESULT_ATTRIBUTE_NAME] = ArrayAttr.get(
+ attribute, context=self.context
+ )
diff --git a/mlir/test/python/dialects/gpu/dialect.py b/mlir/test/python/dialects/gpu/dialect.py
index 26ee9f34cb332..ce6e3df634e90 100644
--- a/mlir/test/python/dialects/gpu/dialect.py
+++ b/mlir/test/python/dialects/gpu/dialect.py
@@ -4,6 +4,7 @@
import mlir.dialects.gpu as gpu
import mlir.dialects.gpu.passes
from mlir.passmanager import *
+import mlir.ir as ir
def run(f):
@@ -64,3 +65,65 @@ def testObjectAttr():
# CHECK: #gpu.object<#nvvm.target, kernels = <[#gpu.kernel_metadata<"kernel", () -> ()>]>, "BC\C0\DE5\14\00\00\05\00\00\00b\0C0$MY\BEf">
print(o)
assert o.kernels == kernelTable
+
+
+# CHECK-LABEL: testGPUFuncOp
+@run
+def testGPUFuncOp():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ gpu_module_name = StringAttr.get("gpu_module")
+ gpumodule = gpu.GPUModuleOp(gpu_module_name)
+ block = gpumodule.bodyRegion.blocks.append()
+
+ def builder(func: gpu.GPUFuncOp) -> None:
+ _ = gpu.GlobalIdOp(gpu.Dimension.x)
+ _ = gpu.ReturnOp([])
+
+ with InsertionPoint(block):
+ name = StringAttr.get("kernel0")
+ func_type = ir.FunctionType.get(inputs=[], results=[])
+ type_attr = TypeAttr.get(func_type)
+ func = gpu.GPUFuncOp(type_attr, name)
+ func.attributes[gpu.SYM_NAME_ATTRIBUTE_NAME] = name
+ func.attributes[gpu.KERNEL_ATTRIBUTE_NAME] = UnitAttr.get()
+ block = func.body.blocks.append()
+ with InsertionPoint(block):
+ builder(func)
+
+ func = gpu.GPUFuncOp(
+ func_type,
+ sym_name="kernel1",
+ kernel=True,
+ body_builder=builder,
+ )
+
+ assert func.name.value == "kernel1"
+ assert func.arg_attrs == ArrayAttr.get([])
+ assert func.result_attrs == ArrayAttr.get([])
+ assert func.arguments == []
+ assert func.entry_block == func.body.blocks[0]
+ assert func.is_kernel
+
+ non_kernel_func = gpu.GPUFuncOp(
+ func_type,
+ sym_name="non_kernel_func",
+ body_builder=builder,
+ )
+ assert not non_kernel_func.is_kernel
+
+ print(module)
+
+ # CHECK: gpu.module @gpu_module
+ # CHECK: gpu.func @kernel0() kernel {
+ # CHECK: %[[VAL_0:.*]] = gpu.global_id x
+ # CHECK: gpu.return
+ # CHECK: }
+ # CHECK: gpu.func @kernel1() kernel {
+ # CHECK: %[[VAL_0:.*]] = gpu.global_id x
+ # CHECK: gpu.return
+ # CHECK: }
+ # CHECK: gpu.func @non_kernel_func() {
+ # CHECK: %[[VAL_0:.*]] = gpu.global_id x
+ # CHECK: gpu.return
+ # CHECK: }
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Maybe let's also wait @makslevental
a308dfc to
380a694
Compare
|
✅ With the latest revision this PR passed the Python code formatter. |
|
So in python-extras I have both I would have liked to upstream
|
I'm quite happy to help with that, though your FuncBase seems to do quite a lot that doesn't necessarily help with the use cases that motivate this PR; would you like to chat about upstreaming some of that stuff in a future PR elsewhere (discord)? |
yea you're 100% right which also why i hadn't/haven't upstreamed yet - yea hit me up on discord under this same username if you wanna chat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for the PR!
The func dialect provides a more pythonic interface for constructing operations, but the gpu dialect does not; this is the first PR to provide the same conveniences for the gpu dialect, starting with the gpu.func op.