From f8e54143feda348f6910cee026d166dc5926e694 Mon Sep 17 00:00:00 2001 From: makslevental Date: Thu, 25 Sep 2025 09:40:53 -0700 Subject: [PATCH] [MLIR][Python] enable precise registration --- mlir/examples/standalone/pyproject.toml | 5 ++ .../examples/standalone/python/CMakeLists.txt | 5 +- .../python/StandaloneExtensionNanobind.cpp | 2 +- .../python/StandaloneExtensionPybind11.cpp | 2 +- .../dialects/standalone_nanobind.py | 7 +++ .../dialects/standalone_pybind11.py | 7 +++ .../standalone/test/python/smoketest.py | 27 +++++++--- mlir/include/mlir-c/Bindings/Python/Interop.h | 9 ++++ mlir/include/mlir-c/Dialect/Affine.h | 36 +++++++++++++ mlir/include/mlir-c/Dialect/Bufferization.h | 36 +++++++++++++ mlir/include/mlir-c/Dialect/Complex.h | 33 ++++++++++++ mlir/include/mlir-c/Dialect/Tosa.h | 33 ++++++++++++ mlir/include/mlir-c/Dialect/UB.h | 33 ++++++++++++ mlir/include/mlir-c/IR.h | 18 +++++-- mlir/lib/Bindings/Python/IRCore.cpp | 42 ++++++++++++++- mlir/lib/CAPI/Dialect/Affine.cpp | 14 +++++ mlir/lib/CAPI/Dialect/Bufferization.cpp | 14 +++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 45 ++++++++++++++++ mlir/lib/CAPI/Dialect/Complex.cpp | 14 +++++ mlir/lib/CAPI/Dialect/Tosa.cpp | 13 +++++ .../lib/CAPI/Dialect/TransformInterpreter.cpp | 3 -- mlir/lib/CAPI/Dialect/UB.cpp | 13 +++++ mlir/lib/CAPI/IR/IR.cpp | 12 +++++ .../lib/Dialect/Transform/IR/TransformOps.cpp | 3 +- mlir/python/CMakeLists.txt | 13 +++-- mlir/python/mlir/_mlir_libs/_capi.py | 51 +++++++++++++++++++ mlir/python/mlir/dialects/QuantOps.td | 14 +++++ mlir/python/mlir/dialects/quant.py | 1 + mlir/test/python/ir/capi.py | 45 ++++++++++++++++ 29 files changed, 526 insertions(+), 24 deletions(-) create mode 100644 mlir/include/mlir-c/Dialect/Affine.h create mode 100644 mlir/include/mlir-c/Dialect/Bufferization.h create mode 100644 mlir/include/mlir-c/Dialect/Complex.h create mode 100644 mlir/include/mlir-c/Dialect/Tosa.h create mode 100644 mlir/include/mlir-c/Dialect/UB.h create mode 100644 mlir/lib/CAPI/Dialect/Affine.cpp create mode 100644 mlir/lib/CAPI/Dialect/Bufferization.cpp create mode 100644 mlir/lib/CAPI/Dialect/Complex.cpp create mode 100644 mlir/lib/CAPI/Dialect/Tosa.cpp create mode 100644 mlir/lib/CAPI/Dialect/UB.cpp create mode 100644 mlir/python/mlir/_mlir_libs/_capi.py create mode 100644 mlir/python/mlir/dialects/QuantOps.td create mode 100644 mlir/test/python/ir/capi.py diff --git a/mlir/examples/standalone/pyproject.toml b/mlir/examples/standalone/pyproject.toml index 5a1e6e86513c3..9a14611cfb150 100644 --- a/mlir/examples/standalone/pyproject.toml +++ b/mlir/examples/standalone/pyproject.toml @@ -37,6 +37,8 @@ cmake.source-dir = "." # This is for installing/distributing the python bindings target and only the python bindings target. build.targets = ["StandalonePythonModules"] install.components = ["StandalonePythonModules"] +# This is the default but make it explicit to highlight that this option exists (turn off for debug symbols). +install.strip = true [tool.scikit-build.cmake.define] # Optional @@ -51,6 +53,9 @@ LLVM_USE_LINKER = { env = "LLVM_USE_LINKER", default = "" } CMAKE_VISIBILITY_INLINES_HIDDEN = "ON" CMAKE_C_VISIBILITY_PRESET = "hidden" CMAKE_CXX_VISIBILITY_PRESET = "hidden" +# Disables generation of "version soname" (i.e. libFoo.so.), +# which causes pure duplication of various shlibs for Python wheels. +CMAKE_PLATFORM_NO_VERSIONED_SONAME = "ON" # Non-optional (alternatively you could use CMAKE_PREFIX_PATH here). MLIR_DIR = { env = "MLIR_DIR", default = "" } diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt index 905c944939756..6d86d0e4ce3a8 100644 --- a/mlir/examples/standalone/python/CMakeLists.txt +++ b/mlir/examples/standalone/python/CMakeLists.txt @@ -65,7 +65,8 @@ add_mlir_python_common_capi_library(StandalonePythonCAPI DECLARED_SOURCES StandalonePythonSources MLIRPythonSources.Core - MLIRPythonSources.Dialects.builtin + EMBED_LIBS + MLIRCAPIQuant ) ################################################################################ @@ -138,6 +139,8 @@ set(_declared_sources StandalonePythonSources MLIRPythonSources.Core MLIRPythonSources.Dialects.builtin + MLIRPythonSources.Dialects.arith + MLIRPythonSources.Dialects.quant ) # For an external projects build, the MLIRPythonExtension.Core.type_stub_gen # target already exists and can just be added to DECLARED_SOURCES. diff --git a/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp b/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp index 0ec6cdfa7994b..e28744b32ccfe 100644 --- a/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp +++ b/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp @@ -32,7 +32,7 @@ NB_MODULE(_standaloneDialectsNanobind, m) { mlirDialectHandleRegisterDialect(standaloneHandle, context); if (load) { mlirDialectHandleLoadDialect(arithHandle, context); - mlirDialectHandleRegisterDialect(standaloneHandle, context); + mlirDialectHandleLoadDialect(standaloneHandle, context); } }, nb::arg("context").none() = nb::none(), nb::arg("load") = true, diff --git a/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp b/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp index da8c2167dc36b..db7645bcf5331 100644 --- a/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp +++ b/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp @@ -31,7 +31,7 @@ PYBIND11_MODULE(_standaloneDialectsPybind11, m) { mlirDialectHandleRegisterDialect(standaloneHandle, context); if (load) { mlirDialectHandleLoadDialect(arithHandle, context); - mlirDialectHandleRegisterDialect(standaloneHandle, context); + mlirDialectHandleLoadDialect(standaloneHandle, context); } }, py::arg("context") = py::none(), py::arg("load") = true); diff --git a/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_nanobind.py b/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_nanobind.py index 6218720951c82..3ba94916eba18 100644 --- a/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_nanobind.py +++ b/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_nanobind.py @@ -4,3 +4,10 @@ from ._standalone_ops_gen import * from .._mlir_libs._standaloneDialectsNanobind.standalone import * + +from .._mlir_libs import get_dialect_registry as _get_dialect_registry +from .._mlir_libs._capi import register_dialect as _register_dialect + +_dialect_registry = _get_dialect_registry() +if "quant" not in _dialect_registry.dialect_names: + _register_dialect("quant", _dialect_registry) diff --git a/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_pybind11.py b/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_pybind11.py index bfb98e404e13f..1c209e151ae5f 100644 --- a/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_pybind11.py +++ b/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_pybind11.py @@ -4,3 +4,10 @@ from ._standalone_ops_gen import * from .._mlir_libs._standaloneDialectsPybind11.standalone import * + +from .._mlir_libs import get_dialect_registry as _get_dialect_registry +from .._mlir_libs._capi import register_dialect as _register_dialect + +_dialect_registry = _get_dialect_registry() +if "quant" not in _dialect_registry.dialect_names: + _register_dialect("quant", _dialect_registry) diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py index 26d84fd63e947..3628522360c74 100644 --- a/mlir/examples/standalone/test/python/smoketest.py +++ b/mlir/examples/standalone/test/python/smoketest.py @@ -11,15 +11,26 @@ else: raise ValueError("Expected either pybind11 or nanobind as arguments") +from mlir_standalone.dialects import arith, quant -with Context(): + +with Context(), Location.unknown(): standalone_d.register_dialects() - module = Module.parse( - """ - %0 = arith.constant 2 : i32 - %1 = standalone.foo %0 : i32 - """ + f32 = F32Type.get() + i8 = IntegerType.get_signless(8) + i32 = IntegerType.get_signless(32) + uniform = quant.UniformQuantizedType.get( + quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7 ) - # CHECK: %[[C:.*]] = arith.constant 2 : i32 - # CHECK: standalone.foo %[[C]] : i32 + + module = Module.create() + with InsertionPoint(module.body): + two_i32 = arith.constant(i32, 2) + standalone_d.foo(two_i32) + two_f32 = arith.constant(f32, 2.0) + quant.qcast(uniform, two_f32) + # CHECK: %[[TWOI32:.*]] = arith.constant 2 : i32 + # CHECK: standalone.foo %[[TWOI32]] : i32 + # CHECK: %[[TWOF32:.*]] = arith.constant 2.000000e+00 : f32 + # CHECK: quant.qcast %[[TWOF32]] : f32 to !quant.uniform:f32, 9.987200e-01:127> print(str(module)) diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h index a33190c380d37..89559da689017 100644 --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -84,6 +84,8 @@ #define MLIR_PYTHON_CAPSULE_VALUE MAKE_MLIR_PYTHON_QUALNAME("ir.Value._CAPIPtr") #define MLIR_PYTHON_CAPSULE_TYPEID \ MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_DIALECT_HANDLE \ + MAKE_MLIR_PYTHON_QUALNAME("ir.DialectHandle._CAPIPtr") /** Attribute on MLIR Python objects that expose their C-API pointer. * This will be a type-specific capsule created as per one of the helpers @@ -457,6 +459,13 @@ static inline MlirValue mlirPythonCapsuleToValue(PyObject *capsule) { return value; } +static inline MlirDialectHandle +mlirPythonCapsuleToDialectHandle(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_DIALECT_HANDLE); + MlirDialectHandle handle = {ptr}; + return handle; +} + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir-c/Dialect/Affine.h b/mlir/include/mlir-c/Dialect/Affine.h new file mode 100644 index 0000000000000..b2bf5aad44de9 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Affine.h @@ -0,0 +1,36 @@ +//===-- mlir-c/Dialect/Affine.h - C API for Affine dialect --------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface for registering and accessing the +// Affine dialect. A dialect should be registered with a context to make it +// available to users of the context. These users must load the dialect +// before using any of its attributes, operations or types. Parser and pass +// manager can load registered dialects automatically. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_AFFINE_H +#define MLIR_C_DIALECT_AFFINE_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Affine, affine); + +MLIR_CAPI_EXPORTED void +mlirAffineRegisterTransformDialectExtension(MlirDialectRegistry registry); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_AFFINE_H diff --git a/mlir/include/mlir-c/Dialect/Bufferization.h b/mlir/include/mlir-c/Dialect/Bufferization.h new file mode 100644 index 0000000000000..41af7a294eb5c --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Bufferization.h @@ -0,0 +1,36 @@ +//===-- mlir-c/Dialect/Bufferization.h - C API for Bufferization dialect --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface for registering and accessing the +// Bufferization dialect. A dialect should be registered with a context to make +// it available to users of the context. These users must load the dialect +// before using any of its attributes, operations or types. Parser and pass +// manager can load registered dialects automatically. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_BUFFERIZATION_H +#define MLIR_C_DIALECT_BUFFERIZATION_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Bufferization, bufferization); + +MLIR_CAPI_EXPORTED void mlirBufferizationRegisterTransformDialectExtension( + MlirDialectRegistry registry); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_BUFFERIZATION_H diff --git a/mlir/include/mlir-c/Dialect/Complex.h b/mlir/include/mlir-c/Dialect/Complex.h new file mode 100644 index 0000000000000..e51a67346a6ee --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Complex.h @@ -0,0 +1,33 @@ +//===-- mlir-c/Dialect/Complex.h - C API for Complex dialect ------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface for registering and accessing the +// Complex dialect. A dialect should be registered with a context to make it +// available to users of the context. These users must load the dialect +// before using any of its attributes, operations or types. Parser and pass +// manager can load registered dialects automatically. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_COMPLEX_H +#define MLIR_C_DIALECT_COMPLEX_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Complex, complex); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_COMPLEX_H diff --git a/mlir/include/mlir-c/Dialect/Tosa.h b/mlir/include/mlir-c/Dialect/Tosa.h new file mode 100644 index 0000000000000..ed55577996604 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Tosa.h @@ -0,0 +1,33 @@ +//===-- mlir-c/Dialect/Tosa.h - C API for Tosa dialect ----------*- C ---*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface for registering and accessing the +// Tosa dialect. A dialect should be registered with a context to make it +// available to users of the context. These users must load the dialect +// before using any of its attributes, operations or types. Parser and pass +// manager can load registered dialects automatically. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_TOSA_H +#define MLIR_C_DIALECT_TOSA_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Tosa, tosa); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_TOSA_H diff --git a/mlir/include/mlir-c/Dialect/UB.h b/mlir/include/mlir-c/Dialect/UB.h new file mode 100644 index 0000000000000..74159f0c705de --- /dev/null +++ b/mlir/include/mlir-c/Dialect/UB.h @@ -0,0 +1,33 @@ +//===-- mlir-c/Dialect/UB.h - C API for UB dialect ----------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface for registering and accessing the +// UB dialect. A dialect should be registered with a context to make it +// available to users of the context. These users must load the dialect +// before using any of its attributes, operations or types. Parser and pass +// manager can load registered dialects automatically. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_UB_H +#define MLIR_C_DIALECT_UB_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(UB, ub); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_UB_H diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 061d7620ba077..55cc86accb8a0 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -66,6 +66,7 @@ DEFINE_C_API_STRUCT(MlirLocation, const void); DEFINE_C_API_STRUCT(MlirModule, const void); DEFINE_C_API_STRUCT(MlirType, const void); DEFINE_C_API_STRUCT(MlirValue, const void); +DEFINE_C_API_STRUCT(MlirDialectHandle, const void); #undef DEFINE_C_API_STRUCT @@ -207,11 +208,6 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect); // registration schemes. //===----------------------------------------------------------------------===// -struct MlirDialectHandle { - const void *ptr; -}; -typedef struct MlirDialectHandle MlirDialectHandle; - #define MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Name, Namespace) \ MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##Namespace##__( \ void) @@ -233,6 +229,11 @@ MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle, MLIR_CAPI_EXPORTED MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle, MlirContext); +/// Checks if the dialect handle is null. +static inline bool mlirDialectHandleIsNull(MlirDialectHandle handle) { + return !handle.ptr; +} + //===----------------------------------------------------------------------===// // DialectRegistry API. //===----------------------------------------------------------------------===// @@ -249,6 +250,13 @@ static inline bool mlirDialectRegistryIsNull(MlirDialectRegistry registry) { MLIR_CAPI_EXPORTED void mlirDialectRegistryDestroy(MlirDialectRegistry registry); +MLIR_CAPI_EXPORTED int64_t +mlirDialectRegistryGetNumDialectNames(MlirDialectRegistry registry); + +MLIR_CAPI_EXPORTED void +mlirDialectRegistryGetDialectNames(MlirDialectRegistry registry, + MlirStringRef *dialectNames); + //===----------------------------------------------------------------------===// // Location API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 83a8757bb72c7..21e60d5550a51 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2897,6 +2897,14 @@ maybeGetTracebackLocation(const std::optional &location) { // Populates the core exports of the 'ir' submodule. //------------------------------------------------------------------------------ +MlirDialectHandle createMlirDialectHandleFromCapsule(nb::object capsule) { + MlirDialectHandle rawRegistry = + mlirPythonCapsuleToDialectHandle(capsule.ptr()); + if (mlirDialectHandleIsNull(rawRegistry)) + throw nb::python_error(); + return rawRegistry; +} + void mlir::python::populateIRCore(nb::module_ &m) { // disable leak warnings which tend to be false positives. nb::set_leak_warnings(false); @@ -3126,14 +3134,46 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::sig("def __repr__(self) -> str")); + //---------------------------------------------------------------------------- + // Mapping of MlirDialectHandle + //---------------------------------------------------------------------------- + + nb::class_(m, "DialectHandle") + .def_prop_ro_static( + "_capsule_name", + [](nb::handle &) { return MLIR_PYTHON_CAPSULE_DIALECT_HANDLE; }, + nb::sig("def _capsule_name(/) -> str")) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, + &createMlirDialectHandleFromCapsule); + //---------------------------------------------------------------------------- // Mapping of PyDialectRegistry //---------------------------------------------------------------------------- nb::class_(m, "DialectRegistry") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule) + .def_prop_ro_static( + "_capsule_name", + [](nb::handle &) { return MLIR_PYTHON_CAPSULE_DIALECT_REGISTRY; }, + nb::sig("def _capsule_name(/) -> str")) .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule) - .def(nb::init<>()); + .def(nb::init<>()) + .def("insert_dialect", + [](PyDialectRegistry &self, MlirDialectHandle handle) { + mlirDialectHandleInsertDialect(handle, self.get()); + }) + .def("insert_dialect", + [](PyDialectRegistry &self, intptr_t ptr) { + mlirDialectHandleInsertDialect( + {reinterpret_cast(ptr)}, self.get()); + }) + .def_prop_ro("dialect_names", [](PyDialectRegistry &self) { + int64_t numDialectNames = + mlirDialectRegistryGetNumDialectNames(self.get()); + std::vector dialectNames(numDialectNames); + mlirDialectRegistryGetDialectNames(self.get(), dialectNames.data()); + return dialectNames; + }); //---------------------------------------------------------------------------- // Mapping of Location diff --git a/mlir/lib/CAPI/Dialect/Affine.cpp b/mlir/lib/CAPI/Dialect/Affine.cpp new file mode 100644 index 0000000000000..b796523390a29 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Affine.cpp @@ -0,0 +1,14 @@ +//===- Affine.cpp - C Interface for Affine dialect ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Affine.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Affine, affine, + mlir::affine::AffineDialect) diff --git a/mlir/lib/CAPI/Dialect/Bufferization.cpp b/mlir/lib/CAPI/Dialect/Bufferization.cpp new file mode 100644 index 0000000000000..da50dfaab44b6 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Bufferization.cpp @@ -0,0 +1,14 @@ +//===- Bufferization.cpp - C Interface for Bufferization dialect ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Bufferization.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Bufferization, bufferization, + mlir::bufferization::BufferizationDialect) diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index bb1fdf8be3c8f..c667ca354b266 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -7,6 +7,15 @@ add_mlir_upstream_c_api_library(MLIRCAPIAMDGPU MLIRAMDGPUDialect ) +add_mlir_upstream_c_api_library(MLIRCAPIAffine + Affine.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRAffineDialect +) + add_mlir_upstream_c_api_library(MLIRCAPIArith Arith.cpp @@ -16,6 +25,15 @@ add_mlir_upstream_c_api_library(MLIRCAPIArith MLIRArithDialect ) +add_mlir_upstream_c_api_library(MLIRCAPIBufferization + Bufferization.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRBufferizationDialect +) + add_mlir_upstream_c_api_library(MLIRCAPIAsync Async.cpp AsyncPasses.cpp @@ -31,6 +49,15 @@ add_mlir_upstream_c_api_library(MLIRCAPIAsync MLIRPass ) +add_mlir_upstream_c_api_library(MLIRCAPIComplex + Complex.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRComplexDialect +) + add_mlir_upstream_c_api_library(MLIRCAPIControlFlow ControlFlow.cpp @@ -278,3 +305,21 @@ add_mlir_upstream_c_api_library(MLIRCAPISMT MLIRCAPIIR MLIRSMT ) + +add_mlir_upstream_c_api_library(MLIRCAPITosa + Tosa.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRTosaDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPIUB + UB.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRUBDialect +) diff --git a/mlir/lib/CAPI/Dialect/Complex.cpp b/mlir/lib/CAPI/Dialect/Complex.cpp new file mode 100644 index 0000000000000..7063028e5d640 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Complex.cpp @@ -0,0 +1,14 @@ +//===- Complex.cpp - C Interface for Complex dialect ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Complex.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Complex/IR/Complex.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Complex, complex, + mlir::complex::ComplexDialect) diff --git a/mlir/lib/CAPI/Dialect/Tosa.cpp b/mlir/lib/CAPI/Dialect/Tosa.cpp new file mode 100644 index 0000000000000..357717a2ed5a5 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Tosa.cpp @@ -0,0 +1,13 @@ +//===- Tosa.cpp - C Interface for Tosa dialect ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Tosa.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Tosa, tosa, mlir::tosa::TosaDialect) diff --git a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp index 145455e1c1b3d..05648de959e7a 100644 --- a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp +++ b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp @@ -23,8 +23,6 @@ using namespace mlir; DEFINE_C_API_PTR_METHODS(MlirTransformOptions, transform::TransformOptions) -extern "C" { - MlirTransformOptions mlirTransformOptionsCreate() { return wrap(new transform::TransformOptions); } @@ -80,4 +78,3 @@ MlirLogicalResult mlirMergeSymbolsIntoFromClone(MlirOperation target, unwrap(target), std::move(otherOwning)); return wrap(result); } -} diff --git a/mlir/lib/CAPI/Dialect/UB.cpp b/mlir/lib/CAPI/Dialect/UB.cpp new file mode 100644 index 0000000000000..de989237159c4 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/UB.cpp @@ -0,0 +1,13 @@ +//===- Ub.cpp - C Interface for UB dialect --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/UB.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/UB/IR/UBOps.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(UB, ub, mlir::ub::UBDialect) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index e9844a7cc1909..a81e2a14e5255 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -150,6 +150,18 @@ void mlirDialectRegistryDestroy(MlirDialectRegistry registry) { delete unwrap(registry); } +int64_t mlirDialectRegistryGetNumDialectNames(MlirDialectRegistry registry) { + auto dialectNames = unwrap(registry)->getDialectNames(); + return std::distance(dialectNames.begin(), dialectNames.end()); +} + +void mlirDialectRegistryGetDialectNames(MlirDialectRegistry registry, + MlirStringRef *dialectNames) { + for (auto [i, location] : + llvm::enumerate(unwrap(registry)->getDialectNames())) + dialectNames[i] = wrap(location); +} + //===----------------------------------------------------------------------===// // AsmState API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 132ed815c354e..d8e54bf060bd6 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -2928,7 +2928,8 @@ LogicalResult transform::SequenceOp::verify() { InFlightDiagnostic diag = emitOpError() << "expected children ops to implement TransformOpInterface"; - diag.attachNote(child.getLoc()) << "op without interface"; + diag.attachNote(child.getLoc()) + << "op without interface: " << child.getName(); return diag; } diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 9f5246de6bda0..41974a1f97634 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -23,6 +23,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python ADD_TO_PARENT MLIRPythonSources.Core SOURCES _mlir_libs/__init__.py + _mlir_libs/_capi.py _mlir_libs/_mlir/py.typed ir.py passmanager.py @@ -370,14 +371,16 @@ declare_mlir_dialect_python_bindings( dialects/rocdl.py DIALECT_NAME rocdl) -declare_mlir_python_sources( - MLIRPythonSources.Dialects.quant +declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/QuantOps.td GEN_ENUM_BINDINGS SOURCES dialects/quant.py - _mlir_libs/_mlir/dialects/quant.pyi) + _mlir_libs/_mlir/dialects/quant.pyi + DIALECT_NAME quant +) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -858,6 +861,10 @@ add_mlir_python_common_capi_library(MLIRPythonCAPI MLIRPythonSources MLIRPythonExtension.RegisterEverything ${_ADDL_TEST_SOURCES} + EMBED_LIBS + MLIRCAPIArith + MLIRCAPIQuant + MLIRCAPIMath ) ################################################################################ diff --git a/mlir/python/mlir/_mlir_libs/_capi.py b/mlir/python/mlir/_mlir_libs/_capi.py new file mode 100644 index 0000000000000..108b1c030ef29 --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_capi.py @@ -0,0 +1,51 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import ctypes +import platform +from pathlib import Path + +from . import _mlir, get_dialect_registry as _get_dialect_registry + +_get_dialect_registry() + +MLIR_PYTHON_CAPSULE_DIALECT_HANDLE = _mlir.ir.DialectHandle._capsule_name.encode() + +MLIR_PYTHON_CAPSULE_DIALECT_REGISTRY = _mlir.ir.DialectRegistry._capsule_name.encode() + +if platform.system() == "Windows": + _ext_suffix = "dll" +elif platform.system() == "Darwin": + _ext_suffix = "dylib" +else: + _ext_suffix = "so" + +for fp in Path(__file__).parent.glob(f"*.{_ext_suffix}"): + if "CAPI" in fp.name: + _capi_dylib = fp + break +else: + raise ValueError("Couldn't find CAPI dylib") + + +_capi = ctypes.CDLL(str(Path(__file__).parent / _capi_dylib)) + +PyCapsule_New = ctypes.pythonapi.PyCapsule_New +PyCapsule_New.restype = ctypes.py_object +PyCapsule_New.argtypes = ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p + +PyCapsule_GetPointer = ctypes.pythonapi.PyCapsule_GetPointer +PyCapsule_GetPointer.argtypes = [ctypes.py_object, ctypes.c_char_p] +PyCapsule_GetPointer.restype = ctypes.c_void_p + + +def register_dialect(dialect_name, dialect_registry): + dialect_handle_capi = f"mlirGetDialectHandle__{dialect_name}__" + if not hasattr(_capi, dialect_handle_capi): + raise RuntimeError(f"missing {dialect_handle_capi} API") + dialect_handle_capi = getattr(_capi, dialect_handle_capi) + dialect_handle_capi.argtypes = [] + dialect_handle_capi.restype = ctypes.c_void_p + handle = dialect_handle_capi() + dialect_registry.insert_dialect(handle) diff --git a/mlir/python/mlir/dialects/QuantOps.td b/mlir/python/mlir/dialects/QuantOps.td new file mode 100644 index 0000000000000..46385fd00ac0c --- /dev/null +++ b/mlir/python/mlir/dialects/QuantOps.td @@ -0,0 +1,14 @@ +//===-- QuantOps.td - Entry point for QuantOps bind --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_QUANT_OPS +#define PYTHON_BINDINGS_QUANT_OPS + +include "mlir/Dialect/Quant/IR/QuantOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/quant.py b/mlir/python/mlir/dialects/quant.py index bf1fc5f2de378..7a7273d8e26be 100644 --- a/mlir/python/mlir/dialects/quant.py +++ b/mlir/python/mlir/dialects/quant.py @@ -2,4 +2,5 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from ._quant_ops_gen import * from .._mlir_libs._mlirDialectsQuant import * diff --git a/mlir/test/python/ir/capi.py b/mlir/test/python/ir/capi.py new file mode 100644 index 0000000000000..8ef26e4767952 --- /dev/null +++ b/mlir/test/python/ir/capi.py @@ -0,0 +1,45 @@ +# RUN: %PYTHON %s | FileCheck %s + +import ctypes + +from mlir._mlir_libs._capi import ( + _capi, + PyCapsule_New, + MLIR_PYTHON_CAPSULE_DIALECT_HANDLE, + register_dialect, +) +from mlir.ir import DialectHandle, DialectRegistry + +print("success") +# CHECK: success + + +if not hasattr(_capi, "mlirGetDialectHandle__arith__"): + raise Exception("missing API") +_capi.mlirGetDialectHandle__arith__.argtypes = [] +_capi.mlirGetDialectHandle__arith__.restype = ctypes.c_void_p + +if not hasattr(_capi, "mlirGetDialectHandle__quant__"): + raise Exception("missing API") +_capi.mlirGetDialectHandle__quant__.argtypes = [] +_capi.mlirGetDialectHandle__quant__.restype = ctypes.c_void_p + +dialect_registry = DialectRegistry() +# CHECK: ['builtin'] +print(dialect_registry.dialect_names) + +arith_handle = _capi.mlirGetDialectHandle__arith__() +dialect_registry.insert_dialect(arith_handle) +# CHECK: ['arith', 'builtin'] +print(sorted(dialect_registry.dialect_names)) + +quant_handle = _capi.mlirGetDialectHandle__quant__() +capsule = PyCapsule_New(quant_handle, MLIR_PYTHON_CAPSULE_DIALECT_HANDLE, None) +dialect_handle = DialectHandle._CAPICreate(capsule) +dialect_registry.insert_dialect(dialect_handle) +# CHECK: ['arith', 'builtin', 'quant'] +print(sorted(dialect_registry.dialect_names)) + +register_dialect("math", dialect_registry) +# CHECK: ['arith', 'builtin', 'math', 'quant'] +print(sorted(dialect_registry.dialect_names))