Skip to content

Conversation

@ashermancinelli
Copy link
Contributor

The func dialect provides a more pythonic interface for constructing operations, but the gpu dialect does not; this is the first PR to provide the same conveniences for the gpu dialect, starting with the gpu.func op.

The func dialect provides a more pythonic interface for constructing
operations, but the gpu dialect does not; this is the first PR to
provide the same conveniences for the gpu dialect, starting with the
gpu.func op.
@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2025

@llvm/pr-subscribers-mlir

Author: Asher Mancinelli (ashermancinelli)

Changes

The func dialect provides a more pythonic interface for constructing operations, but the gpu dialect does not; this is the first PR to provide the same conveniences for the gpu dialect, starting with the gpu.func op.


Full diff: https://github.com/llvm/llvm-project/pull/163596.diff

2 Files Affected:

  • (modified) mlir/python/mlir/dialects/gpu/init.py (+116)
  • (modified) mlir/test/python/dialects/gpu/dialect.py (+63)
diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py
index 4cd80aa8b7ca8..14b965927e280 100644
--- a/mlir/python/mlir/dialects/gpu/__init__.py
+++ b/mlir/python/mlir/dialects/gpu/__init__.py
@@ -3,5 +3,121 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from .._gpu_ops_gen import *
+from .._gpu_ops_gen import _Dialect
 from .._gpu_enum_gen import *
 from ..._mlir_libs._mlirDialectsGPU import *
+from typing import Callable, Sequence, Union, Optional
+
+try:
+    from ...ir import (
+        FunctionType,
+        TypeAttr,
+        StringAttr,
+        UnitAttr,
+        Block,
+        InsertionPoint,
+        ArrayAttr,
+        Type,
+        DictAttr,
+        Attribute,
+    )
+    from .._ods_common import (
+        get_default_loc_context as _get_default_loc_context,
+        _cext as _ods_cext,
+    )
+except ImportError as e:
+    raise RuntimeError("Error loading imports from extension module") from e
+
+
+FUNCTION_TYPE_ATTRIBUTE_NAME = "function_type"
+KERNEL_ATTRIBUTE_NAME = "gpu.kernel"
+SYM_NAME_ATTRIBUTE_NAME = "sym_name"
+ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
+RESULT_ATTRIBUTE_NAME = "res_attrs"
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class GPUFuncOp(GPUFuncOp):
+    def __init__(
+        self,
+        function_type: Union[FunctionType, TypeAttr],
+        sym_name: Optional[Union[str, StringAttr]] = None,
+        kernel: Optional[bool] = None,
+        body_builder: Optional[Callable[[GPUFuncOp], None]] = None,
+        *args,
+        loc=None,
+        ip=None,
+        **kwargs,
+    ):
+        function_type = (
+            TypeAttr.get(function_type)
+            if not isinstance(function_type, TypeAttr)
+            else function_type
+        )
+        super().__init__(function_type, *args, loc=loc, ip=ip, **kwargs)
+        if sym_name is not None:
+            self.attributes[SYM_NAME_ATTRIBUTE_NAME] = StringAttr.get(str(sym_name))
+        if kernel:
+            self.attributes[KERNEL_ATTRIBUTE_NAME] = UnitAttr.get()
+        if body_builder is not None:
+            with InsertionPoint(self.add_entry_block()):
+                body_builder(self)
+
+    @property
+    def type(self) -> FunctionType:
+        return FunctionType(
+            TypeAttr(self.attributes[FUNCTION_TYPE_ATTRIBUTE_NAME]).value
+        )
+
+    @property
+    def name(self) -> StringAttr:
+        return StringAttr(self.attributes[SYM_NAME_ATTRIBUTE_NAME])
+
+    @property
+    def is_kernel(self) -> bool:
+        return KERNEL_ATTRIBUTE_NAME in self.attributes
+
+    def add_entry_block(self) -> Block:
+        function_type = self.type
+        return self.body.blocks.append(
+            *function_type.inputs,
+            arg_locs=[self.location for _ in function_type.inputs],
+        )
+
+    @property
+    def entry_block(self) -> Block:
+        return self.body.blocks[0]
+
+    @property
+    def arguments(self) -> Sequence[Type]:
+        return self.type.inputs
+
+    @property
+    def arg_attrs(self):
+        if ARGUMENT_ATTRIBUTE_NAME not in self.attributes:
+            return ArrayAttr.get([DictAttr.get({}) for _ in self.type.inputs])
+        return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
+
+    @arg_attrs.setter
+    def arg_attrs(self, attribute: Union[ArrayAttr, list[Attribute]]):
+        if isinstance(attribute, ArrayAttr):
+            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
+        else:
+            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+                attribute, context=self.context
+            )
+
+    @property
+    def result_attrs(self) -> Optional[ArrayAttr]:
+        if RESULT_ATTRIBUTE_NAME not in self.attributes:
+            return ArrayAttr.get([DictAttr.get({}) for _ in self.type.results])
+        return self.attributes[RESULT_ATTRIBUTE_NAME]
+
+    @result_attrs.setter
+    def result_attrs(self, attribute: Union[ArrayAttr, list[Attribute]]):
+        if isinstance(attribute, ArrayAttr):
+            self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
+        else:
+            self.attributes[RESULT_ATTRIBUTE_NAME] = ArrayAttr.get(
+                attribute, context=self.context
+            )
diff --git a/mlir/test/python/dialects/gpu/dialect.py b/mlir/test/python/dialects/gpu/dialect.py
index 26ee9f34cb332..ce6e3df634e90 100644
--- a/mlir/test/python/dialects/gpu/dialect.py
+++ b/mlir/test/python/dialects/gpu/dialect.py
@@ -4,6 +4,7 @@
 import mlir.dialects.gpu as gpu
 import mlir.dialects.gpu.passes
 from mlir.passmanager import *
+import mlir.ir as ir
 
 
 def run(f):
@@ -64,3 +65,65 @@ def testObjectAttr():
     # CHECK: #gpu.object<#nvvm.target, kernels = <[#gpu.kernel_metadata<"kernel", () -> ()>]>, "BC\C0\DE5\14\00\00\05\00\00\00b\0C0$MY\BEf">
     print(o)
     assert o.kernels == kernelTable
+
+
+# CHECK-LABEL: testGPUFuncOp
+@run
+def testGPUFuncOp():
+    module = Module.create()
+    with InsertionPoint(module.body):
+        gpu_module_name = StringAttr.get("gpu_module")
+        gpumodule = gpu.GPUModuleOp(gpu_module_name)
+        block = gpumodule.bodyRegion.blocks.append()
+
+        def builder(func: gpu.GPUFuncOp) -> None:
+            _ = gpu.GlobalIdOp(gpu.Dimension.x)
+            _ = gpu.ReturnOp([])
+
+        with InsertionPoint(block):
+            name = StringAttr.get("kernel0")
+            func_type = ir.FunctionType.get(inputs=[], results=[])
+            type_attr = TypeAttr.get(func_type)
+            func = gpu.GPUFuncOp(type_attr, name)
+            func.attributes[gpu.SYM_NAME_ATTRIBUTE_NAME] = name
+            func.attributes[gpu.KERNEL_ATTRIBUTE_NAME] = UnitAttr.get()
+            block = func.body.blocks.append()
+            with InsertionPoint(block):
+                builder(func)
+
+            func = gpu.GPUFuncOp(
+                func_type,
+                sym_name="kernel1",
+                kernel=True,
+                body_builder=builder,
+            )
+
+            assert func.name.value == "kernel1"
+            assert func.arg_attrs == ArrayAttr.get([])
+            assert func.result_attrs == ArrayAttr.get([])
+            assert func.arguments == []
+            assert func.entry_block == func.body.blocks[0]
+            assert func.is_kernel
+
+            non_kernel_func = gpu.GPUFuncOp(
+                func_type,
+                sym_name="non_kernel_func",
+                body_builder=builder,
+            )
+            assert not non_kernel_func.is_kernel
+
+    print(module)
+
+    # CHECK: gpu.module @gpu_module
+    # CHECK: gpu.func @kernel0() kernel {
+    # CHECK:   %[[VAL_0:.*]] = gpu.global_id  x
+    # CHECK:   gpu.return
+    # CHECK: }
+    # CHECK: gpu.func @kernel1() kernel {
+    # CHECK:   %[[VAL_0:.*]] = gpu.global_id  x
+    # CHECK:   gpu.return
+    # CHECK: }
+    # CHECK: gpu.func @non_kernel_func() {
+    # CHECK:   %[[VAL_0:.*]] = gpu.global_id  x
+    # CHECK:   gpu.return
+    # CHECK: }

Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Maybe let's also wait @makslevental

@ashermancinelli ashermancinelli force-pushed the ajm/gpu-func-py-wrapper branch from a308dfc to 380a694 Compare October 15, 2025 17:41
@github-actions
Copy link

github-actions bot commented Oct 15, 2025

✅ With the latest revision this PR passed the Python code formatter.

@makslevental
Copy link
Contributor

So in python-extras I have both gpu.func and func.func supported using the same base class FuncBase.

I would have liked to upstream FuncBase because lots of people end up defining their own dialect.func ops (for whatever reasons...) but I didn't get around to it and currently I'm waiting on legal at $JOB to approve me upstreaming that code here........... But I can say that

  1. This PR doesn't actually contravene such a layering as I currently have in python-extras (because it's simply providing a custom builder);
  2. I'd be more than happy if you took my code in python-extras and copy-pasted it into a PR here where I could help adjust it for upstreaming 😉 (but then again, no pressure to use my formulation).

@ashermancinelli
Copy link
Contributor Author

@makslevental I'd be more than happy if you took my code in python-extras

I'm quite happy to help with that, though your FuncBase seems to do quite a lot that doesn't necessarily help with the use cases that motivate this PR; would you like to chat about upstreaming some of that stuff in a future PR elsewhere (discord)?

@makslevental
Copy link
Contributor

I'm quite happy to help with that, though your FuncBase seems to do quite a lot that doesn't necessarily help with the use cases that motivate this PR; would you like to chat about upstreaming some of that stuff in a future PR elsewhere (discord)?

yea you're 100% right which also why i hadn't/haven't upstreamed yet - yea hit me up on discord under this same username if you wanna chat

@ashermancinelli ashermancinelli self-assigned this Oct 15, 2025
Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the PR!

@ashermancinelli ashermancinelli merged commit fbdd98f into llvm:main Oct 16, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:python MLIR Python bindings mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants