-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[mlir][amdgpu] Add Python bindings for TDM types #172309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
@llvm/pr-subscribers-mlir Author: Tim Gymnich (tgymnich) ChangesAdd bindings for:
Full diff: https://github.com/llvm/llvm-project/pull/172309.diff 6 Files Affected:
diff --git a/mlir/include/mlir-c/Dialect/AMDGPU.h b/mlir/include/mlir-c/Dialect/AMDGPU.h
index 142044f7f3afe..950dca3f2fa1c 100644
--- a/mlir/include/mlir-c/Dialect/AMDGPU.h
+++ b/mlir/include/mlir-c/Dialect/AMDGPU.h
@@ -18,6 +18,33 @@ extern "C" {
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(AMDGPU, amdgpu);
+//===---------------------------------------------------------------------===//
+// TDMBaseType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsAAMDGPUTDMBaseType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMBaseTypeGet(MlirContext ctx,
+ MlirType elementType);
+
+//===---------------------------------------------------------------------===//
+// TDMDescriptorType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsAAMDGPUTDMDescriptorType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMDescriptorTypeGet(MlirContext ctx);
+
+//===---------------------------------------------------------------------===//
+// TDMGatherBaseType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsAAMDGPUTDMGatherBaseType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMGatherBaseTypeGet(MlirContext ctx,
+ MlirType elementType,
+ MlirType indexType);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
new file mode 100644
index 0000000000000..d593513427b3e
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
@@ -0,0 +1,62 @@
+//===--- DialectAMDGPU.cpp - Pybind module for AMDGPU dialect API support -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/AMDGPU.h"
+#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "nanobind/nanobind.h"
+
+namespace nb = nanobind;
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::python;
+using namespace mlir::python::nanobind_adaptors;
+
+static void populateDialectAMDGPUSubmodule(const nb::module_ &m) {
+ auto amdgpuTDMBaseType =
+ mlir_type_subclass(m, "TDMBaseType", mlirTypeIsAAMDGPUTDMBaseType);
+
+ amdgpuTDMBaseType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirType elementType, MlirContext ctx) {
+ return cls(mlirAMDGPUTDMBaseTypeGet(ctx, elementType));
+ },
+ "Gets an instance of TDMBaseType in the same context", nb::arg("cls"),
+ nb::arg("element_type"), nb::arg("ctx") = nb::none());
+
+ auto amdgpuTDMDescriptorType = mlir_type_subclass(
+ m, "TDMDescriptorType", mlirTypeIsAAMDGPUTDMDescriptorType);
+
+ amdgpuTDMDescriptorType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirAMDGPUTDMDescriptorTypeGet(ctx));
+ },
+ "Gets an instance of TDMDescriptorType in the same context",
+ nb::arg("cls"), nb::arg("ctx") = nb::none());
+
+ auto amdgpuTDMGatherBaseType = mlir_type_subclass(
+ m, "TDMGatherBaseType", mlirTypeIsAAMDGPUTDMGatherBaseType);
+
+ amdgpuTDMGatherBaseType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirType elementType, MlirType indexType,
+ MlirContext ctx) {
+ return cls(mlirAMDGPUTDMGatherBaseTypeGet(ctx, elementType, indexType));
+ },
+ "Gets an instance of TDMGatherBaseType in the same context",
+ nb::arg("cls"), nb::arg("element_type"), nb::arg("index_type"),
+ nb::arg("ctx") = nb::none());
+};
+
+NB_MODULE(_mlirDialectsAMDGPU, m) {
+ m.doc() = "MLIR AMDGPU dialect.";
+
+ populateDialectAMDGPUSubmodule(m);
+}
diff --git a/mlir/lib/CAPI/Dialect/AMDGPU.cpp b/mlir/lib/CAPI/Dialect/AMDGPU.cpp
index d877ca2dff375..77536e822c0ac 100644
--- a/mlir/lib/CAPI/Dialect/AMDGPU.cpp
+++ b/mlir/lib/CAPI/Dialect/AMDGPU.cpp
@@ -12,3 +12,44 @@
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(AMDGPU, amdgpu,
mlir::amdgpu::AMDGPUDialect)
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+//===---------------------------------------------------------------------===//
+// TDMBaseType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsAAMDGPUTDMBaseType(MlirType type) {
+ return isa<amdgpu::TDMBaseType>(unwrap(type));
+}
+
+MlirType mlirAMDGPUTDMBaseTypeGet(MlirContext ctx, MlirType elementType) {
+ return wrap(amdgpu::TDMBaseType::get(unwrap(ctx), unwrap(elementType)));
+}
+
+//===---------------------------------------------------------------------===//
+// TDMDescriptorType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsAAMDGPUTDMDescriptorType(MlirType type) {
+ return isa<amdgpu::TDMDescriptorType>(unwrap(type));
+}
+
+MlirType mlirAMDGPUTDMDescriptorTypeGet(MlirContext ctx) {
+ return wrap(amdgpu::TDMDescriptorType::get(unwrap(ctx)));
+}
+
+//===---------------------------------------------------------------------===//
+// TDMGatherBaseType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsAAMDGPUTDMGatherBaseType(MlirType type) {
+ return isa<amdgpu::TDMGatherBaseType>(unwrap(type));
+}
+
+MlirType mlirAMDGPUTDMGatherBaseTypeGet(MlirContext ctx, MlirType elementType,
+ MlirType indexType) {
+ return wrap(amdgpu::TDMGatherBaseType::get(unwrap(ctx), unwrap(elementType),
+ unwrap(indexType)));
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 2acb6ee6cfda5..6e449e275f782 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -804,6 +804,21 @@ declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter
MLIRCAPITransformDialectTransforms
)
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.AMDGPU.Pybind
+ MODULE_NAME _mlirDialectsAMDGPU
+ ADD_TO_PARENT MLIRPythonSources.Dialects.amdgpu
+ ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
+ SOURCES
+ DialectAMDGPU.cpp
+ PRIVATE_LINK_LIBS
+ LLVMSupport
+ EMBED_CAPI_LINK_LIBS
+ MLIRCAPIIR
+ MLIRCAPIAMDGPU
+)
+
+
# TODO: Figure out how to put this in the test tree.
# This should not be included in the main Python extension. However,
# putting it into MLIRPythonTestSources along with the dialect declaration
diff --git a/mlir/python/mlir/dialects/amdgpu.py b/mlir/python/mlir/dialects/amdgpu.py
index 43d905d0c481c..1c4d274bc31af 100644
--- a/mlir/python/mlir/dialects/amdgpu.py
+++ b/mlir/python/mlir/dialects/amdgpu.py
@@ -4,3 +4,4 @@
from ._amdgpu_ops_gen import *
from ._amdgpu_enum_gen import *
+from .._mlir_libs._mlirDialectsAMDGPU import *
diff --git a/mlir/test/python/dialects/amdgpu.py b/mlir/test/python/dialects/amdgpu.py
index b479576dac093..c126a6d201eb0 100644
--- a/mlir/test/python/dialects/amdgpu.py
+++ b/mlir/test/python/dialects/amdgpu.py
@@ -2,7 +2,7 @@
# This is just a smoke test that the dialect is functional.
from mlir.ir import *
-from mlir.dialects import amdgpu, func
+from mlir.dialects import amdgpu, func, memref
def constructAndPrintInModule(f):
@@ -43,3 +43,22 @@ def testFatRawBufferCastOpParams():
# CHECK-NEXT: amdgpu.fat_raw_buffer_cast %[[ARG0]] resetOffset : memref<?x?xf32> to memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
# CHECK-NEXT: amdgpu.fat_raw_buffer_cast %[[ARG0]] boundsCheck(false) : memref<?x?xf32> to memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
# CHECK-NEXT: amdgpu.fat_raw_buffer_cast %[[ARG0]] boundsCheck(false) resetOffset : memref<?x?xf32> to memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
+
+
+# CHECK-LABEL: testTDMTypes
+@constructAndPrintInModule
+def testTDMTypes():
+ f32 = F32Type.get()
+ i32 = IntegerType.get_signless(32)
+
+ # CHECK: !amdgpu.tdm_base<f32>
+ tdm_base = amdgpu.TDMBaseType.get(f32)
+ print(tdm_base)
+
+ # CHECK: !amdgpu.tdm_descriptor
+ tdm_descriptor = amdgpu.TDMDescriptorType.get()
+ print(tdm_descriptor)
+
+ # CHECK: !amdgpu.tdm_gather_base<f32, i32>`
+ tdm_gather_base = amdgpu.TDMGatherBaseType.get(f32, i32)
+ print(tdm_gather_base)
|
|
@llvm/pr-subscribers-backend-amdgpu Author: Tim Gymnich (tgymnich) ChangesAdd bindings for:
Full diff: https://github.com/llvm/llvm-project/pull/172309.diff 6 Files Affected:
diff --git a/mlir/include/mlir-c/Dialect/AMDGPU.h b/mlir/include/mlir-c/Dialect/AMDGPU.h
index 142044f7f3afe..950dca3f2fa1c 100644
--- a/mlir/include/mlir-c/Dialect/AMDGPU.h
+++ b/mlir/include/mlir-c/Dialect/AMDGPU.h
@@ -18,6 +18,33 @@ extern "C" {
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(AMDGPU, amdgpu);
+//===---------------------------------------------------------------------===//
+// TDMBaseType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsAAMDGPUTDMBaseType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMBaseTypeGet(MlirContext ctx,
+ MlirType elementType);
+
+//===---------------------------------------------------------------------===//
+// TDMDescriptorType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsAAMDGPUTDMDescriptorType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMDescriptorTypeGet(MlirContext ctx);
+
+//===---------------------------------------------------------------------===//
+// TDMGatherBaseType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsAAMDGPUTDMGatherBaseType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMGatherBaseTypeGet(MlirContext ctx,
+ MlirType elementType,
+ MlirType indexType);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
new file mode 100644
index 0000000000000..d593513427b3e
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
@@ -0,0 +1,62 @@
+//===--- DialectAMDGPU.cpp - Pybind module for AMDGPU dialect API support -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/AMDGPU.h"
+#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "nanobind/nanobind.h"
+
+namespace nb = nanobind;
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::python;
+using namespace mlir::python::nanobind_adaptors;
+
+static void populateDialectAMDGPUSubmodule(const nb::module_ &m) {
+ auto amdgpuTDMBaseType =
+ mlir_type_subclass(m, "TDMBaseType", mlirTypeIsAAMDGPUTDMBaseType);
+
+ amdgpuTDMBaseType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirType elementType, MlirContext ctx) {
+ return cls(mlirAMDGPUTDMBaseTypeGet(ctx, elementType));
+ },
+ "Gets an instance of TDMBaseType in the same context", nb::arg("cls"),
+ nb::arg("element_type"), nb::arg("ctx") = nb::none());
+
+ auto amdgpuTDMDescriptorType = mlir_type_subclass(
+ m, "TDMDescriptorType", mlirTypeIsAAMDGPUTDMDescriptorType);
+
+ amdgpuTDMDescriptorType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirAMDGPUTDMDescriptorTypeGet(ctx));
+ },
+ "Gets an instance of TDMDescriptorType in the same context",
+ nb::arg("cls"), nb::arg("ctx") = nb::none());
+
+ auto amdgpuTDMGatherBaseType = mlir_type_subclass(
+ m, "TDMGatherBaseType", mlirTypeIsAAMDGPUTDMGatherBaseType);
+
+ amdgpuTDMGatherBaseType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirType elementType, MlirType indexType,
+ MlirContext ctx) {
+ return cls(mlirAMDGPUTDMGatherBaseTypeGet(ctx, elementType, indexType));
+ },
+ "Gets an instance of TDMGatherBaseType in the same context",
+ nb::arg("cls"), nb::arg("element_type"), nb::arg("index_type"),
+ nb::arg("ctx") = nb::none());
+};
+
+NB_MODULE(_mlirDialectsAMDGPU, m) {
+ m.doc() = "MLIR AMDGPU dialect.";
+
+ populateDialectAMDGPUSubmodule(m);
+}
diff --git a/mlir/lib/CAPI/Dialect/AMDGPU.cpp b/mlir/lib/CAPI/Dialect/AMDGPU.cpp
index d877ca2dff375..77536e822c0ac 100644
--- a/mlir/lib/CAPI/Dialect/AMDGPU.cpp
+++ b/mlir/lib/CAPI/Dialect/AMDGPU.cpp
@@ -12,3 +12,44 @@
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(AMDGPU, amdgpu,
mlir::amdgpu::AMDGPUDialect)
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+//===---------------------------------------------------------------------===//
+// TDMBaseType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsAAMDGPUTDMBaseType(MlirType type) {
+ return isa<amdgpu::TDMBaseType>(unwrap(type));
+}
+
+MlirType mlirAMDGPUTDMBaseTypeGet(MlirContext ctx, MlirType elementType) {
+ return wrap(amdgpu::TDMBaseType::get(unwrap(ctx), unwrap(elementType)));
+}
+
+//===---------------------------------------------------------------------===//
+// TDMDescriptorType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsAAMDGPUTDMDescriptorType(MlirType type) {
+ return isa<amdgpu::TDMDescriptorType>(unwrap(type));
+}
+
+MlirType mlirAMDGPUTDMDescriptorTypeGet(MlirContext ctx) {
+ return wrap(amdgpu::TDMDescriptorType::get(unwrap(ctx)));
+}
+
+//===---------------------------------------------------------------------===//
+// TDMGatherBaseType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsAAMDGPUTDMGatherBaseType(MlirType type) {
+ return isa<amdgpu::TDMGatherBaseType>(unwrap(type));
+}
+
+MlirType mlirAMDGPUTDMGatherBaseTypeGet(MlirContext ctx, MlirType elementType,
+ MlirType indexType) {
+ return wrap(amdgpu::TDMGatherBaseType::get(unwrap(ctx), unwrap(elementType),
+ unwrap(indexType)));
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 2acb6ee6cfda5..6e449e275f782 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -804,6 +804,21 @@ declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter
MLIRCAPITransformDialectTransforms
)
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.AMDGPU.Pybind
+ MODULE_NAME _mlirDialectsAMDGPU
+ ADD_TO_PARENT MLIRPythonSources.Dialects.amdgpu
+ ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
+ SOURCES
+ DialectAMDGPU.cpp
+ PRIVATE_LINK_LIBS
+ LLVMSupport
+ EMBED_CAPI_LINK_LIBS
+ MLIRCAPIIR
+ MLIRCAPIAMDGPU
+)
+
+
# TODO: Figure out how to put this in the test tree.
# This should not be included in the main Python extension. However,
# putting it into MLIRPythonTestSources along with the dialect declaration
diff --git a/mlir/python/mlir/dialects/amdgpu.py b/mlir/python/mlir/dialects/amdgpu.py
index 43d905d0c481c..1c4d274bc31af 100644
--- a/mlir/python/mlir/dialects/amdgpu.py
+++ b/mlir/python/mlir/dialects/amdgpu.py
@@ -4,3 +4,4 @@
from ._amdgpu_ops_gen import *
from ._amdgpu_enum_gen import *
+from .._mlir_libs._mlirDialectsAMDGPU import *
diff --git a/mlir/test/python/dialects/amdgpu.py b/mlir/test/python/dialects/amdgpu.py
index b479576dac093..c126a6d201eb0 100644
--- a/mlir/test/python/dialects/amdgpu.py
+++ b/mlir/test/python/dialects/amdgpu.py
@@ -2,7 +2,7 @@
# This is just a smoke test that the dialect is functional.
from mlir.ir import *
-from mlir.dialects import amdgpu, func
+from mlir.dialects import amdgpu, func, memref
def constructAndPrintInModule(f):
@@ -43,3 +43,22 @@ def testFatRawBufferCastOpParams():
# CHECK-NEXT: amdgpu.fat_raw_buffer_cast %[[ARG0]] resetOffset : memref<?x?xf32> to memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
# CHECK-NEXT: amdgpu.fat_raw_buffer_cast %[[ARG0]] boundsCheck(false) : memref<?x?xf32> to memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
# CHECK-NEXT: amdgpu.fat_raw_buffer_cast %[[ARG0]] boundsCheck(false) resetOffset : memref<?x?xf32> to memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
+
+
+# CHECK-LABEL: testTDMTypes
+@constructAndPrintInModule
+def testTDMTypes():
+ f32 = F32Type.get()
+ i32 = IntegerType.get_signless(32)
+
+ # CHECK: !amdgpu.tdm_base<f32>
+ tdm_base = amdgpu.TDMBaseType.get(f32)
+ print(tdm_base)
+
+ # CHECK: !amdgpu.tdm_descriptor
+ tdm_descriptor = amdgpu.TDMDescriptorType.get()
+ print(tdm_descriptor)
+
+ # CHECK: !amdgpu.tdm_gather_base<f32, i32>`
+ tdm_gather_base = amdgpu.TDMGatherBaseType.get(f32, i32)
+ print(tdm_gather_base)
|
amd-eochoalo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I think it looks good. I just had one question. I think that memref import may be removed.
🐧 Linux x64 Test Results
✅ The build succeeded and all tests passed. |
🪟 Windows x64 Test Results
✅ The build succeeded and all tests passed. |
amd-eochoalo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
1c36fc6 to
9e638c9
Compare
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/169/builds/18075 Here is the relevant piece of the build log for the reference |
Add bindings for:
TDMBaseTypeTDMDescriptorTypeTDMGatherBaseType