diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake index 8c301faf0941a..f4d078dfe7118 100644 --- a/mlir/cmake/modules/AddMLIRPython.cmake +++ b/mlir/cmake/modules/AddMLIRPython.cmake @@ -228,7 +228,7 @@ endfunction() # aggregate dylib that is linked against. function(declare_mlir_python_extension name) cmake_parse_arguments(ARG - "" + "_PRIVATE_SUPPORT_LIB" "ROOT_DIR;MODULE_NAME;ADD_TO_PARENT" "SOURCES;PRIVATE_LINK_LIBS;EMBED_CAPI_LINK_LIBS" ${ARGN}) @@ -236,6 +236,11 @@ function(declare_mlir_python_extension name) if(NOT ARG_ROOT_DIR) set(ARG_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}") endif() + if(ARG__PRIVATE_SUPPORT_LIB) + set(SOURCES_TYPE "support") + else() + set(SOURCES_TYPE "extension") + endif() set(_install_destination "src/python/${name}") add_library(${name} INTERFACE) @@ -243,7 +248,7 @@ function(declare_mlir_python_extension name) # Yes: Leading-lowercase property names are load bearing and the recommended # way to do this: https://gitlab.kitware.com/cmake/cmake/-/issues/19261 EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_EXTENSION_MODULE_NAME;mlir_python_EMBED_CAPI_LINK_LIBS;mlir_python_DEPENDS" - mlir_python_SOURCES_TYPE extension + mlir_python_SOURCES_TYPE "${SOURCES_TYPE}" mlir_python_EXTENSION_MODULE_NAME "${ARG_MODULE_NAME}" mlir_python_EMBED_CAPI_LINK_LIBS "${ARG_EMBED_CAPI_LINK_LIBS}" mlir_python_DEPENDS "" @@ -297,6 +302,58 @@ function(_mlir_python_install_sources name source_root_dir destination) ) endfunction() +function(build_nanobind_lib) + cmake_parse_arguments(ARG + "" + "INSTALL_COMPONENT;INSTALL_DESTINATION;OUTPUT_DIRECTORY;MLIR_BINDINGS_PYTHON_NB_DOMAIN" + "" + ${ARGN}) + + # Only build in free-threaded mode if the Python ABI supports it. + # See https://github.com/wjakob/nanobind/blob/4ba51fcf795971c5d603d875ae4184bc0c9bd8e6/cmake/nanobind-config.cmake#L363-L371. + if (NB_ABI MATCHES "[0-9]t") + set(_ft "-ft") + endif() + # nanobind does a string match on the suffix to figure out whether to build + # the lib with free threading... + set(NB_LIBRARY_TARGET_NAME "nanobind${_ft}-${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}") + set(NB_LIBRARY_TARGET_NAME "${NB_LIBRARY_TARGET_NAME}" PARENT_SCOPE) + nanobind_build_library(${NB_LIBRARY_TARGET_NAME} AS_SYSINCLUDE) + target_compile_definitions(${NB_LIBRARY_TARGET_NAME} + PRIVATE + NB_DOMAIN=${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN} + ) + if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + # nanobind handles this correctly for MacOS by explicitly setting -U for all the necessary Python symbols + # (see https://github.com/wjakob/nanobind/blob/master/cmake/darwin-ld-cpython.sym) + # but since we set -z,defs in llvm/cmake/modules/HandleLLVMOptions.cmake:340 for all Linux shlibs + # we need to negate it here (we could have our own linux-ld-cpython.sym but that would be too much + # maintenance). + target_link_options(${NB_LIBRARY_TARGET_NAME} PRIVATE "LINKER:-z,undefs") + endif() + # nanobind configures with LTO for shared build which doesn't work everywhere + # (see https://github.com/llvm/llvm-project/issues/139602). + if(NOT LLVM_ENABLE_LTO) + set_target_properties(${NB_LIBRARY_TARGET_NAME} PROPERTIES + INTERPROCEDURAL_OPTIMIZATION_RELEASE OFF + INTERPROCEDURAL_OPTIMIZATION_MINSIZEREL OFF + ) + endif() + set_target_properties(${NB_LIBRARY_TARGET_NAME} PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}" + BINARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}" + # Needed for windows (and doesn't hurt others). + RUNTIME_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}" + ARCHIVE_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}" + ) + mlir_python_setup_extension_rpath(${NB_LIBRARY_TARGET_NAME}) + install(TARGETS ${NB_LIBRARY_TARGET_NAME} + COMPONENT ${ARG_INSTALL_COMPONENT} + LIBRARY DESTINATION "${ARG_INSTALL_DESTINATION}" + RUNTIME DESTINATION "${ARG_INSTALL_DESTINATION}" + ) +endfunction() + # Function: add_mlir_python_modules # Adds python modules to a project, building them from a list of declared # source groupings (see declare_mlir_python_sources and @@ -308,6 +365,11 @@ endfunction() # for non-relocatable modules or a deeper directory tree for relocatable. # INSTALL_PREFIX: Prefix into the install tree for installing the package. # Typically mirrors the path above but without an absolute path. +# MLIR_BINDINGS_PYTHON_NB_DOMAIN: nanobind (and MLIR) domain within which +# extensions will be compiled. This determines whether this package +# will share nanobind types with other bindings packages. Expected to be unique +# per project (and per specific set of bindings, for projects with multiple +# bindings packages). # DECLARED_SOURCES: List of declared source groups to include. The entire # DAG of source modules is included. # COMMON_CAPI_LINK_LIBS: List of dylibs (typically one) to make every @@ -315,11 +377,32 @@ endfunction() function(add_mlir_python_modules name) cmake_parse_arguments(ARG "" - "ROOT_PREFIX;INSTALL_PREFIX" + "ROOT_PREFIX;INSTALL_PREFIX;MLIR_BINDINGS_PYTHON_NB_DOMAIN" "COMMON_CAPI_LINK_LIBS;DECLARED_SOURCES" ${ARGN}) + + # TODO(max): do the same for MLIR_PYTHON_PACKAGE_PREFIX? + if((NOT ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN) AND MLIR_BINDINGS_PYTHON_NB_DOMAIN) + set(ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN}) + endif() + if((NOT ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN) OR ("${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}" STREQUAL "")) + message(WARNING "MLIR_BINDINGS_PYTHON_NB_DOMAIN CMake var is not set - setting to a default `mlir`.\ + It is highly recommend to set this to something unique so that your project's Python bindings do not collide with\ + others'. You also pass explicitly to `add_mlir_python_modules`.\ + See https://github.com/llvm/llvm-project/pull/171775 for more information.") + set(ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir") + endif() + + # This call sets NB_LIBRARY_TARGET_NAME. + build_nanobind_lib( + INSTALL_COMPONENT ${name} + INSTALL_DESTINATION "${ARG_INSTALL_PREFIX}/_mlir_libs" + OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs" + MLIR_BINDINGS_PYTHON_NB_DOMAIN ${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN} + ) + # Helper to process an individual target. - function(_process_target modules_target sources_target) + function(_process_target modules_target sources_target support_libs) get_target_property(_source_type ${sources_target} mlir_python_SOURCES_TYPE) if(_source_type STREQUAL "pure") @@ -337,16 +420,20 @@ function(add_mlir_python_modules name) get_target_property(_module_name ${sources_target} mlir_python_EXTENSION_MODULE_NAME) # Transform relative source to based on root dir. set(_extension_target "${modules_target}.extension.${_module_name}.dso") - add_mlir_python_extension(${_extension_target} "${_module_name}" + add_mlir_python_extension(${_extension_target} "${_module_name}" ${NB_LIBRARY_TARGET_NAME} INSTALL_COMPONENT ${modules_target} INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs" OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs" + MLIR_BINDINGS_PYTHON_NB_DOMAIN ${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN} LINK_LIBS PRIVATE ${sources_target} ${ARG_COMMON_CAPI_LINK_LIBS} + ${support_libs} ) add_dependencies(${modules_target} ${_extension_target}) mlir_python_setup_extension_rpath(${_extension_target}) + elseif(_source_type STREQUAL "support") + # do nothing because already built else() message(SEND_ERROR "Unrecognized source type '${_source_type}' for python source target ${sources_target}") return() @@ -356,8 +443,36 @@ function(add_mlir_python_modules name) # Build the modules target. add_custom_target(${name} ALL) _flatten_mlir_python_targets(_flat_targets ${ARG_DECLARED_SOURCES}) + + # Build all support libs first. + set(_mlir_python_support_libs) + foreach(sources_target ${_flat_targets}) + get_target_property(_source_type ${sources_target} mlir_python_SOURCES_TYPE) + if(_source_type STREQUAL "support") + get_target_property(_module_name ${sources_target} mlir_python_EXTENSION_MODULE_NAME) + # Use a similar mechanism as nanobind to help the runtime loader pick the correct lib. + set(_module_name "${_module_name}-${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN}") + set(_extension_target "${name}.extension.${_module_name}.so") + add_mlir_python_extension(${_extension_target} "${_module_name}" ${NB_LIBRARY_TARGET_NAME} + INSTALL_COMPONENT ${name} + INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs" + OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs" + MLIR_BINDINGS_PYTHON_NB_DOMAIN ${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN} + _PRIVATE_SUPPORT_LIB + LINK_LIBS PRIVATE + LLVMSupport + ${sources_target} + ${ARG_COMMON_CAPI_LINK_LIBS} + ) + add_dependencies(${name} ${_extension_target}) + mlir_python_setup_extension_rpath(${_extension_target}) + list(APPEND _mlir_python_support_libs "${_extension_target}") + endif() + endforeach() + + # Build extensions. foreach(sources_target ${_flat_targets}) - _process_target(${name} ${sources_target}) + _process_target(${name} ${sources_target} "${_mlir_python_support_libs}") endforeach() # Create an install target. @@ -622,7 +737,7 @@ function(add_mlir_python_common_capi_library name) set_target_properties(${name} PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}" BINARY_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}" - # Needed for windows (and don't hurt others). + # Needed for windows (and doesn't hurt others). RUNTIME_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}" ARCHIVE_OUTPUT_DIRECTORY "${ARG_OUTPUT_DIRECTORY}" ) @@ -742,10 +857,10 @@ endfunction() ################################################################################ # Build python extension ################################################################################ -function(add_mlir_python_extension libname extname) +function(add_mlir_python_extension libname extname nb_library_target_name) cmake_parse_arguments(ARG - "" - "INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY" + "_PRIVATE_SUPPORT_LIB" + "INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY;MLIR_BINDINGS_PYTHON_NB_DOMAIN" "SOURCES;LINK_LIBS" ${ARGN}) if(ARG_UNPARSED_ARGUMENTS) @@ -761,10 +876,41 @@ function(add_mlir_python_extension libname extname) set(eh_rtti_enable -frtti -fexceptions) endif () - nanobind_add_module(${libname} - NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN} - FREE_THREADED - ${ARG_SOURCES} + if(ARG__PRIVATE_SUPPORT_LIB) + add_library(${libname} SHARED ${ARG_SOURCES}) + if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + # nanobind handles this correctly for MacOS by explicitly setting -U for all the necessary Python symbols + # (see https://github.com/wjakob/nanobind/blob/master/cmake/darwin-ld-cpython.sym) + # but since we set -z,defs in llvm/cmake/modules/HandleLLVMOptions.cmake:340 for all Linux shlibs + # we need to negate it here (we could have our own linux-ld-cpython.sym but that would be too much + # maintenance). + target_link_options(${libname} PRIVATE "LINKER:-z,undefs") + endif() + nanobind_link_options(${libname}) + target_compile_definitions(${libname} + PRIVATE + NB_DOMAIN=${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN} + MLIR_CAPI_BUILDING_LIBRARY=1 + ) + if(MSVC) + set_property(TARGET ${libname} PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON) + endif() + else() + nanobind_add_module(${libname} + NB_DOMAIN ${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN} + FREE_THREADED + NB_SHARED + ${ARG_SOURCES} + ) + target_compile_definitions(${libname} + PRIVATE + MLIR_BINDINGS_PYTHON_DOMAIN=${MLIR_BINDINGS_PYTHON_NB_DOMAIN} + ) + endif() + target_link_libraries(${libname} PRIVATE ${nb_library_target_name}) + target_compile_definitions(${libname} + PRIVATE + MLIR_BINDINGS_PYTHON_DOMAIN=${ARG_MLIR_BINDINGS_PYTHON_NB_DOMAIN} ) if(APPLE) # In llvm/cmake/modules/HandleLLVMOptions.cmake:268 we set -Wl,-flat_namespace which breaks @@ -778,29 +924,28 @@ function(add_mlir_python_extension libname extname) # Avoid some warnings from upstream nanobind. # If a superproject set MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES, let # the super project handle compile options as it wishes. - get_property(NB_LIBRARY_TARGET_NAME TARGET ${libname} PROPERTY LINK_LIBRARIES) - target_compile_options(${NB_LIBRARY_TARGET_NAME} + target_compile_options(${nb_library_target_name} PRIVATE -Wno-c++98-compat-extra-semi - -Wno-cast-qual - -Wno-covered-switch-default - -Wno-deprecated-literal-operator - -Wno-nested-anon-types - -Wno-unused-parameter - -Wno-zero-length-array - -Wno-missing-field-initializers + -Wno-cast-qual + -Wno-covered-switch-default + -Wno-deprecated-literal-operator + -Wno-nested-anon-types + -Wno-unused-parameter + -Wno-zero-length-array + -Wno-missing-field-initializers ${eh_rtti_enable}) target_compile_options(${libname} PRIVATE -Wno-c++98-compat-extra-semi - -Wno-cast-qual - -Wno-covered-switch-default - -Wno-deprecated-literal-operator - -Wno-nested-anon-types - -Wno-unused-parameter - -Wno-zero-length-array - -Wno-missing-field-initializers + -Wno-cast-qual + -Wno-covered-switch-default + -Wno-deprecated-literal-operator + -Wno-nested-anon-types + -Wno-unused-parameter + -Wno-zero-length-array + -Wno-missing-field-initializers ${eh_rtti_enable}) endif() @@ -818,12 +963,26 @@ function(add_mlir_python_extension libname extname) target_compile_options(${libname} PRIVATE ${eh_rtti_enable}) + # Quoting CMake: + # + # "If you use it on normal shared libraries which other targets link against, on some platforms a + # linker will insert a full path to the library (as specified at link time) into the dynamic section of the + # dependent binary. Therefore, once installed, dynamic loader may eventually fail to locate the library + # for the binary." + # + # So for support libs we do need an SO name but for extensions we do not (they're MODULEs anyway - + # i.e., can't be linked against, only loaded). + if (ARG__PRIVATE_SUPPORT_LIB) + set(_no_soname OFF) + else () + set(_no_soname ON) + endif () # Configure the output to match python expectations. set_target_properties( ${libname} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${ARG_OUTPUT_DIRECTORY} OUTPUT_NAME "${extname}" - NO_SONAME ON + NO_SONAME ${_no_soname} ) if(WIN32) diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md index 4f4f531f7723c..4278774933a4a 100644 --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -37,6 +37,13 @@ LLVM ERROR: ... unregistered/uninitialized dialect/type/pass ...` ``` +* **`MLIR_BINDINGS_PYTHON_NB_DOMAIN`**: `STRING` + + nanobind (and MLIR) domain within which extensions will be compiled. + This determines whether this package will share nanobind types with other bindings packages. + Expected to be unique per project (and per specific set of bindings, for projects with multiple bindings packages). + Can also be passed explicitly to `add_mlir_python_modules`. + ### Recommended development practices It is recommended to use a Python virtual environment. Many ways exist for this, diff --git a/mlir/examples/standalone/include/Standalone-c/Dialects.h b/mlir/examples/standalone/include/Standalone-c/Dialects.h index b3e47752ccc69..5aa9e004cb9fe 100644 --- a/mlir/examples/standalone/include/Standalone-c/Dialects.h +++ b/mlir/examples/standalone/include/Standalone-c/Dialects.h @@ -17,6 +17,13 @@ extern "C" { MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Standalone, standalone); +MLIR_CAPI_EXPORTED MlirType mlirStandaloneCustomTypeGet(MlirContext ctx, + MlirStringRef value); + +MLIR_CAPI_EXPORTED bool mlirStandaloneTypeIsACustomType(MlirType t); + +MLIR_CAPI_EXPORTED MlirTypeID mlirStandaloneCustomTypeGetTypeID(void); + #ifdef __cplusplus } #endif diff --git a/mlir/examples/standalone/lib/CAPI/Dialects.cpp b/mlir/examples/standalone/lib/CAPI/Dialects.cpp index 98006e81a3d26..4de55ba485490 100644 --- a/mlir/examples/standalone/lib/CAPI/Dialects.cpp +++ b/mlir/examples/standalone/lib/CAPI/Dialects.cpp @@ -9,7 +9,20 @@ #include "Standalone-c/Dialects.h" #include "Standalone/StandaloneDialect.h" +#include "Standalone/StandaloneTypes.h" #include "mlir/CAPI/Registration.h" MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Standalone, standalone, mlir::standalone::StandaloneDialect) + +MlirType mlirStandaloneCustomTypeGet(MlirContext ctx, MlirStringRef value) { + return wrap(mlir::standalone::CustomType::get(unwrap(ctx), unwrap(value))); +} + +bool mlirStandaloneTypeIsACustomType(MlirType t) { + return llvm::isa(unwrap(t)); +} + +MlirTypeID mlirStandaloneCustomTypeGetTypeID() { + return wrap(mlir::standalone::CustomType::getTypeID()); +} diff --git a/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp b/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp index 0ec6cdfa7994b..c568369913595 100644 --- a/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp +++ b/mlir/examples/standalone/python/StandaloneExtensionNanobind.cpp @@ -11,17 +11,44 @@ #include "Standalone-c/Dialects.h" #include "mlir-c/Dialect/Arith.h" +#include "mlir/Bindings/Python/IRCore.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; +struct PyCustomType + : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirStandaloneTypeIsACustomType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirStandaloneCustomTypeGetTypeID; + static constexpr const char *pyClassName = "CustomType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const std::string &value, + mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext + context) { + return PyCustomType( + context->getRef(), + mlirStandaloneCustomTypeGet( + context.get()->get(), + mlirStringRefCreateFromCString(value.c_str()))); + }, + nb::arg("value"), nb::arg("context").none() = nb::none()); + } +}; + NB_MODULE(_standaloneDialectsNanobind, m) { //===--------------------------------------------------------------------===// // standalone dialect //===--------------------------------------------------------------------===// auto standaloneM = m.def_submodule("standalone"); + PyCustomType::bind(standaloneM); + standaloneM.def( "register_dialects", [](MlirContext context, bool load) { diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py index 09040eb2f45dc..fe4e40e6e8a99 100644 --- a/mlir/examples/standalone/test/python/smoketest.py +++ b/mlir/examples/standalone/test/python/smoketest.py @@ -19,6 +19,10 @@ # CHECK: standalone.foo %[[C2]] : i32 print(str(standalone_module), file=sys.stderr) + custom_type = standalone_d.CustomType.get("foo") + # CHECK: !standalone.custom<"foo"> + print(custom_type, file=sys.stderr) + # CHECK: Testing mlir package print("Testing mlir package", file=sys.stderr) diff --git a/mlir/include/mlir-c/Support.h b/mlir/include/mlir-c/Support.h index 78fc94f93439e..6abd8894227c3 100644 --- a/mlir/include/mlir-c/Support.h +++ b/mlir/include/mlir-c/Support.h @@ -46,6 +46,8 @@ #define MLIR_CAPI_EXPORTED __attribute__((visibility("default"))) #endif +#define MLIR_PYTHON_API_EXPORTED MLIR_CAPI_EXPORTED + #ifdef __cplusplus extern "C" { #endif diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h similarity index 96% rename from mlir/lib/Bindings/Python/Globals.h rename to mlir/include/mlir/Bindings/Python/Globals.h index 1e81f53e465ac..5548a716cbe21 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/include/mlir/Bindings/Python/Globals.h @@ -15,10 +15,11 @@ #include #include -#include "NanobindUtils.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/NanobindUtils.h" #include "mlir/CAPI/Support.h" + #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" @@ -27,19 +28,16 @@ namespace mlir { namespace python { - +namespace MLIR_BINDINGS_PYTHON_DOMAIN { /// Globals that are always accessible once the extension has been initialized. /// Methods of this class are thread-safe. -class PyGlobals { +class MLIR_PYTHON_API_EXPORTED PyGlobals { public: PyGlobals(); ~PyGlobals(); /// Most code should get the globals via this static accessor. - static PyGlobals &get() { - assert(instance && "PyGlobals is null"); - return *instance; - } + static PyGlobals &get(); /// Get and set the list of parent modules to search for dialect /// implementation classes. @@ -119,7 +117,7 @@ class PyGlobals { std::optional lookupOperationClass(llvm::StringRef operationName); - class TracebackLoc { + class MLIR_PYTHON_API_EXPORTED TracebackLoc { public: bool locTracebacksEnabled(); @@ -199,7 +197,7 @@ class PyGlobals { TracebackLoc tracebackLoc; TypeIDAllocator typeIDAllocator; }; - +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/include/mlir/Bindings/Python/IRCore.h similarity index 69% rename from mlir/lib/Bindings/Python/IRModule.h rename to mlir/include/mlir/Bindings/Python/IRCore.h index e706be3b4d32a..d8662137b60e7 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/include/mlir/Bindings/Python/IRCore.h @@ -1,4 +1,4 @@ -//===- IRModules.h - IR Submodules of pybind module -----------------------===// +//===- IRCore.h - IR helpers of python bindings ---------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,8 +7,8 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception //===----------------------------------------------------------------------===// -#ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H -#define MLIR_BINDINGS_PYTHON_IRMODULES_H +#ifndef MLIR_BINDINGS_PYTHON_IRCORE_H +#define MLIR_BINDINGS_PYTHON_IRCORE_H #include #include @@ -20,17 +20,21 @@ #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/Debug.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "mlir-c/Transforms.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" + #include "llvm/ADT/DenseMap.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ThreadPool.h" namespace mlir { namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { class PyBlock; class PyDiagnostic; @@ -47,10 +51,20 @@ class PyType; class PySymbolTable; class PyValue; +/// Wrapper for the global LLVM debugging flag. +struct MLIR_PYTHON_API_EXPORTED PyGlobalDebugFlag { + static void set(nanobind::object &o, bool enable); + static bool get(const nanobind::object &); + static void bind(nanobind::module_ &m); + +private: + static nanobind::ft_mutex mutex; +}; + /// Template for a reference to a concrete type which captures a python /// reference to its underlying python object. template -class PyObjectRef { +class MLIR_PYTHON_API_EXPORTED PyObjectRef { public: PyObjectRef(T *referrent, nanobind::object object) : referrent(referrent), object(std::move(object)) { @@ -109,7 +123,7 @@ class PyObjectRef { /// Context. Pushing a Context will not modify the Location or InsertionPoint /// unless if they are from a different context, in which case, they are /// cleared. -class PyThreadContextEntry { +class MLIR_PYTHON_API_EXPORTED PyThreadContextEntry { public: enum class FrameKind { Context, @@ -165,22 +179,16 @@ class PyThreadContextEntry { /// Wrapper around MlirLlvmThreadPool /// Python object owns the C++ thread pool -class PyThreadPool { +class MLIR_PYTHON_API_EXPORTED PyThreadPool { public: - PyThreadPool() { - ownedThreadPool = std::make_unique(); - } + PyThreadPool(); PyThreadPool(const PyThreadPool &) = delete; PyThreadPool(PyThreadPool &&) = delete; int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); } MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); } - std::string _mlir_thread_pool_ptr() const { - std::stringstream ss; - ss << ownedThreadPool.get(); - return ss.str(); - } + std::string _mlir_thread_pool_ptr() const; private: std::unique_ptr ownedThreadPool; @@ -188,7 +196,7 @@ class PyThreadPool { /// Wrapper around MlirContext. using PyMlirContextRef = PyObjectRef; -class PyMlirContext { +class MLIR_PYTHON_API_EXPORTED PyMlirContext { public: PyMlirContext() = delete; PyMlirContext(MlirContext context); @@ -205,9 +213,7 @@ class PyMlirContext { /// Gets a strong reference to this context, which will ensure it is kept /// alive for the life of the reference. - PyMlirContextRef getRef() { - return PyMlirContextRef(this, nanobind::cast(this)); - } + PyMlirContextRef getRef(); /// Gets a capsule wrapping the void* within the MlirContext. nanobind::object getCapsule(); @@ -269,7 +275,7 @@ class PyMlirContext { /// Used in function arguments when None should resolve to the current context /// manager set instance. -class DefaultingPyMlirContext +class MLIR_PYTHON_API_EXPORTED DefaultingPyMlirContext : public Defaulting { public: using Defaulting::Defaulting; @@ -281,7 +287,7 @@ class DefaultingPyMlirContext /// MlirContext. The lifetime of the context will extend at least to the /// lifetime of these instances. /// Immutable objects that depend on a context extend this directly. -class BaseContextObject { +class MLIR_PYTHON_API_EXPORTED BaseContextObject { public: BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { assert(this->contextRef && @@ -296,7 +302,7 @@ class BaseContextObject { }; /// Wrapper around an MlirLocation. -class PyLocation : public BaseContextObject { +class MLIR_PYTHON_API_EXPORTED PyLocation : public BaseContextObject { public: PyLocation(PyMlirContextRef contextRef, MlirLocation loc) : BaseContextObject(std::move(contextRef)), loc(loc) {} @@ -323,16 +329,35 @@ class PyLocation : public BaseContextObject { MlirLocation loc; }; +enum PyDiagnosticSeverity : std::underlying_type_t { + MlirDiagnosticError = MlirDiagnosticError, + MlirDiagnosticWarning = MlirDiagnosticWarning, + MlirDiagnosticNote = MlirDiagnosticNote, + MlirDiagnosticRemark = MlirDiagnosticRemark +}; + +enum PyWalkResult : std::underlying_type_t { + MlirWalkResultAdvance = MlirWalkResultAdvance, + MlirWalkResultInterrupt = MlirWalkResultInterrupt, + MlirWalkResultSkip = MlirWalkResultSkip +}; + +/// Traversal order for operation walk. +enum PyWalkOrder : std::underlying_type_t { + MlirWalkPreOrder = MlirWalkPreOrder, + MlirWalkPostOrder = MlirWalkPostOrder +}; + /// Python class mirroring the C MlirDiagnostic struct. Note that these structs /// are only valid for the duration of a diagnostic callback and attempting /// to access them outside of that will raise an exception. This applies to /// nested diagnostics (in the notes) as well. -class PyDiagnostic { +class MLIR_PYTHON_API_EXPORTED PyDiagnostic { public: PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {} void invalidate(); bool isValid() { return valid; } - MlirDiagnosticSeverity getSeverity(); + PyDiagnosticSeverity getSeverity(); PyLocation getLocation(); nanobind::str getMessage(); nanobind::tuple getNotes(); @@ -340,7 +365,7 @@ class PyDiagnostic { /// Materialized diagnostic information. This is safe to access outside the /// diagnostic callback. struct DiagnosticInfo { - MlirDiagnosticSeverity severity; + PyDiagnosticSeverity severity; PyLocation location; std::string message; std::vector notes; @@ -377,7 +402,7 @@ class PyDiagnostic { /// The object may remain live from a Python perspective for an arbitrary time /// after detachment, but there is nothing the user can do with it (since there /// is no way to attach an existing handler object). -class PyDiagnosticHandler { +class MLIR_PYTHON_API_EXPORTED PyDiagnosticHandler { public: PyDiagnosticHandler(MlirContext context, nanobind::object callback); ~PyDiagnosticHandler(); @@ -405,7 +430,7 @@ class PyDiagnosticHandler { /// RAII object that captures any error diagnostics emitted to the provided /// context. -struct PyMlirContext::ErrorCapture { +struct MLIR_PYTHON_API_EXPORTED PyMlirContext::ErrorCapture { ErrorCapture(PyMlirContextRef ctx) : ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler( ctx->get(), handler, /*userData=*/this, @@ -432,7 +457,7 @@ struct PyMlirContext::ErrorCapture { /// plugins which extend dialect functionality through extension python code. /// This should be seen as the "low-level" object and `Dialect` as the /// high-level, user facing object. -class PyDialectDescriptor : public BaseContextObject { +class MLIR_PYTHON_API_EXPORTED PyDialectDescriptor : public BaseContextObject { public: PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) : BaseContextObject(std::move(contextRef)), dialect(dialect) {} @@ -445,7 +470,7 @@ class PyDialectDescriptor : public BaseContextObject { /// User-level object for accessing dialects with dotted syntax such as: /// ctx.dialect.std -class PyDialects : public BaseContextObject { +class MLIR_PYTHON_API_EXPORTED PyDialects : public BaseContextObject { public: PyDialects(PyMlirContextRef contextRef) : BaseContextObject(std::move(contextRef)) {} @@ -456,7 +481,7 @@ class PyDialects : public BaseContextObject { /// User-level dialect object. For dialects that have a registered extension, /// this will be the base class of the extension dialect type. For un-extended, /// objects of this type will be returned directly. -class PyDialect { +class MLIR_PYTHON_API_EXPORTED PyDialect { public: PyDialect(nanobind::object descriptor) : descriptor(std::move(descriptor)) {} @@ -469,7 +494,7 @@ class PyDialect { /// Wrapper around an MlirDialectRegistry. /// Upon construction, the Python wrapper takes ownership of the /// underlying MlirDialectRegistry. -class PyDialectRegistry { +class MLIR_PYTHON_API_EXPORTED PyDialectRegistry { public: PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {} PyDialectRegistry(MlirDialectRegistry registry) : registry(registry) {} @@ -495,7 +520,7 @@ class PyDialectRegistry { /// Used in function arguments when None should resolve to the current context /// manager set instance. -class DefaultingPyLocation +class MLIR_PYTHON_API_EXPORTED DefaultingPyLocation : public Defaulting { public: using Defaulting::Defaulting; @@ -509,7 +534,7 @@ class DefaultingPyLocation /// This is the top-level, user-owned object that contains regions/ops/blocks. class PyModule; using PyModuleRef = PyObjectRef; -class PyModule : public BaseContextObject { +class MLIR_PYTHON_API_EXPORTED PyModule : public BaseContextObject { public: /// Returns a PyModule reference for the given MlirModule. This always returns /// a new object. @@ -549,7 +574,7 @@ class PyAsmState; /// Base class for PyOperation and PyOpView which exposes the primary, user /// visible methods for manipulating it. -class PyOperationBase { +class MLIR_PYTHON_API_EXPORTED PyOperationBase { public: virtual ~PyOperationBase() = default; /// Implements the bound 'print' method and helps with others. @@ -571,8 +596,8 @@ class PyOperationBase { std::optional bytecodeVersion); // Implement the walk method. - void walk(std::function callback, - MlirWalkOrder walkOrder); + void walk(std::function callback, + PyWalkOrder walkOrder); /// Moves the operation before or after the other operation. void moveAfter(PyOperationBase &other); @@ -602,7 +627,8 @@ class PyOperationBase { class PyOperation; class PyOpView; using PyOperationRef = PyObjectRef; -class PyOperation : public PyOperationBase, public BaseContextObject { +class MLIR_PYTHON_API_EXPORTED PyOperation : public PyOperationBase, + public BaseContextObject { public: ~PyOperation() override; PyOperation &getOperation() override { return *this; } @@ -627,32 +653,17 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// Detaches the operation from its parent block and updates its state /// accordingly. - void detachFromParent() { - mlirOperationRemoveFromParent(getOperation()); - setDetached(); - parentKeepAlive = nanobind::object(); - } + void detachFromParent(); /// Gets the backing operation. operator MlirOperation() const { return get(); } - MlirOperation get() const { - checkValid(); - return operation; - } + MlirOperation get() const; - PyOperationRef getRef() { - return PyOperationRef(this, nanobind::borrow(handle)); - } + PyOperationRef getRef(); bool isAttached() { return attached; } - void setAttached(const nanobind::object &parent = nanobind::object()) { - assert(!attached && "operation already attached"); - attached = true; - } - void setDetached() { - assert(attached && "operation already detached"); - attached = false; - } + void setAttached(const nanobind::object &parent = nanobind::object()); + void setDetached(); void checkValid() const; /// Gets the owning block or raises an exception if the operation has no @@ -720,7 +731,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// custom ODS-style operation classes. Since this class is subclass on the /// python side, it must present an __init__ method that operates in pure /// python types. -class PyOpView : public PyOperationBase { +class MLIR_PYTHON_API_EXPORTED PyOpView : public PyOperationBase { public: PyOpView(const nanobind::object &operationObject); PyOperation &getOperation() override { return operation; } @@ -756,7 +767,7 @@ class PyOpView : public PyOperationBase { /// Wrapper around an MlirRegion. /// Regions are managed completely by their containing operation. Unlike the /// C++ API, the python API does not support detached regions. -class PyRegion { +class MLIR_PYTHON_API_EXPORTED PyRegion { public: PyRegion(PyOperationRef parentOperation, MlirRegion region) : parentOperation(std::move(parentOperation)), region(region) { @@ -775,26 +786,10 @@ class PyRegion { }; /// Wrapper around an MlirAsmState. -class PyAsmState { +class MLIR_PYTHON_API_EXPORTED PyAsmState { public: - PyAsmState(MlirValue value, bool useLocalScope) { - flags = mlirOpPrintingFlagsCreate(); - // The OpPrintingFlags are not exposed Python side, create locally and - // associate lifetime with the state. - if (useLocalScope) - mlirOpPrintingFlagsUseLocalScope(flags); - state = mlirAsmStateCreateForValue(value, flags); - } - - PyAsmState(PyOperationBase &operation, bool useLocalScope) { - flags = mlirOpPrintingFlagsCreate(); - // The OpPrintingFlags are not exposed Python side, create locally and - // associate lifetime with the state. - if (useLocalScope) - mlirOpPrintingFlagsUseLocalScope(flags); - state = - mlirAsmStateCreateForOperation(operation.getOperation().get(), flags); - } + PyAsmState(MlirValue value, bool useLocalScope); + PyAsmState(PyOperationBase &operation, bool useLocalScope); ~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); } // Delete copy constructors. PyAsmState(PyAsmState &other) = delete; @@ -810,7 +805,7 @@ class PyAsmState { /// Wrapper around an MlirBlock. /// Blocks are managed completely by their containing operation. Unlike the /// C++ API, the python API does not support detached blocks. -class PyBlock { +class MLIR_PYTHON_API_EXPORTED PyBlock { public: PyBlock(PyOperationRef parentOperation, MlirBlock block) : parentOperation(std::move(parentOperation)), block(block) { @@ -834,7 +829,7 @@ class PyBlock { /// Calls to insert() will insert a new operation before the /// reference operation. If the reference operation is null, then appends to /// the end of the block. -class PyInsertionPoint { +class MLIR_PYTHON_API_EXPORTED PyInsertionPoint { public: /// Creates an insertion point positioned after the last operation in the /// block, but still inside the block. @@ -873,9 +868,10 @@ class PyInsertionPoint { std::optional refOperation; PyBlock block; }; + /// Wrapper around the generic MlirType. /// The lifetime of a type is bound by the PyContext that created it. -class PyType : public BaseContextObject { +class MLIR_PYTHON_API_EXPORTED PyType : public BaseContextObject { public: PyType(PyMlirContextRef contextRef, MlirType type) : BaseContextObject(std::move(contextRef)), type(type) {} @@ -901,7 +897,7 @@ class PyType : public BaseContextObject { /// A TypeID provides an efficient and unique identifier for a specific C++ /// type. This allows for a C++ type to be compared, hashed, and stored in an /// opaque context. This class wraps around the generic MlirTypeID. -class PyTypeID { +class MLIR_PYTHON_API_EXPORTED PyTypeID { public: PyTypeID(MlirTypeID typeID) : typeID(typeID) {} // Note, this tests whether the underlying TypeIDs are the same, @@ -927,7 +923,7 @@ class PyTypeID { /// concrete type class extends PyType); however, intermediate python-visible /// base classes can be modeled by specifying a BaseTy. template -class PyConcreteType : public BaseTy { +class MLIR_PYTHON_API_EXPORTED PyConcreteType : public BaseTy { public: // Derived classes must define statics for: // IsAFunctionTy isaFunction @@ -1005,7 +1001,7 @@ class PyConcreteType : public BaseTy { /// Wrapper around the generic MlirAttribute. /// The lifetime of a type is bound by the PyContext that created it. -class PyAttribute : public BaseContextObject { +class MLIR_PYTHON_API_EXPORTED PyAttribute : public BaseContextObject { public: PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) : BaseContextObject(std::move(contextRef)), attr(attr) {} @@ -1031,7 +1027,7 @@ class PyAttribute : public BaseContextObject { /// Represents a Python MlirNamedAttr, carrying an optional owned name. /// TODO: Refactor this and the C-API to be based on an Identifier owned /// by the context so as to avoid ownership issues here. -class PyNamedAttribute { +class MLIR_PYTHON_API_EXPORTED PyNamedAttribute { public: /// Constructs a PyNamedAttr that retains an owned name. This should be /// used in any code that originates an MlirNamedAttribute from a python @@ -1057,7 +1053,7 @@ class PyNamedAttribute { /// concrete attribute class extends PyAttribute); however, intermediate /// python-visible base classes can be modeled by specifying a BaseTy. template -class PyConcreteAttribute : public BaseTy { +class MLIR_PYTHON_API_EXPORTED PyConcreteAttribute : public BaseTy { public: // Derived classes must define statics for: // IsAFunctionTy isaFunction @@ -1147,7 +1143,8 @@ class PyConcreteAttribute : public BaseTy { static void bindDerived(ClassTy &m) {} }; -class PyStringAttribute : public PyConcreteAttribute { +class MLIR_PYTHON_API_EXPORTED PyStringAttribute + : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; static constexpr const char *pyClassName = "StringAttr"; @@ -1164,7 +1161,7 @@ class PyStringAttribute : public PyConcreteAttribute { /// value. For block argument values, this is the operation that contains the /// block to which the value is an argument (blocks cannot be detached in Python /// bindings so such operation always exists). -class PyValue { +class MLIR_PYTHON_API_EXPORTED PyValue { public: // The virtual here is "load bearing" in that it enables RTTI // for PyConcreteValue CRTP classes that support maybeDownCast. @@ -1194,7 +1191,7 @@ class PyValue { }; /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context. -class PyAffineExpr : public BaseContextObject { +class MLIR_PYTHON_API_EXPORTED PyAffineExpr : public BaseContextObject { public: PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {} @@ -1221,7 +1218,7 @@ class PyAffineExpr : public BaseContextObject { MlirAffineExpr affineExpr; }; -class PyAffineMap : public BaseContextObject { +class MLIR_PYTHON_API_EXPORTED PyAffineMap : public BaseContextObject { public: PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {} @@ -1242,7 +1239,7 @@ class PyAffineMap : public BaseContextObject { MlirAffineMap affineMap; }; -class PyIntegerSet : public BaseContextObject { +class MLIR_PYTHON_API_EXPORTED PyIntegerSet : public BaseContextObject { public: PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} @@ -1263,7 +1260,7 @@ class PyIntegerSet : public BaseContextObject { }; /// Bindings for MLIR symbol tables. -class PySymbolTable { +class MLIR_PYTHON_API_EXPORTED PySymbolTable { public: /// Constructs a symbol table for the given operation. explicit PySymbolTable(PyOperationBase &operation); @@ -1315,7 +1312,7 @@ class PySymbolTable { /// Custom exception that allows access to error diagnostic information. This is /// converted to the `ir.MLIRError` python exception when thrown. -struct MLIRError { +struct MLIR_PYTHON_API_EXPORTED MLIRError { MLIRError(llvm::Twine message, std::vector &&errorDiagnostics = {}) : message(message.str()), errorDiagnostics(std::move(errorDiagnostics)) {} @@ -1323,12 +1320,492 @@ struct MLIRError { std::vector errorDiagnostics; }; -void populateIRAffine(nanobind::module_ &m); -void populateIRAttributes(nanobind::module_ &m); -void populateIRCore(nanobind::module_ &m); -void populateIRInterfaces(nanobind::module_ &m); -void populateIRTypes(nanobind::module_ &m); +//------------------------------------------------------------------------------ +// Utilities. +//------------------------------------------------------------------------------ + +inline MlirStringRef toMlirStringRef(const std::string &s) { + return mlirStringRefCreate(s.data(), s.size()); +} + +inline MlirStringRef toMlirStringRef(std::string_view s) { + return mlirStringRefCreate(s.data(), s.size()); +} + +inline MlirStringRef toMlirStringRef(const nanobind::bytes &s) { + return mlirStringRefCreate(static_cast(s.data()), s.size()); +} + +/// Create a block, using the current location context if no locations are +/// specified. +MlirBlock MLIR_PYTHON_API_EXPORTED +createBlock(const nanobind::sequence &pyArgTypes, + const std::optional &pyArgLocs); + +struct MLIR_PYTHON_API_EXPORTED PyAttrBuilderMap { + static bool dunderContains(const std::string &attributeKind); + static nanobind::callable + dunderGetItemNamed(const std::string &attributeKind); + static void dunderSetItemNamed(const std::string &attributeKind, + nanobind::callable func, bool replace); + + static void bind(nanobind::module_ &m); +}; + +//------------------------------------------------------------------------------ +// Collections. +//------------------------------------------------------------------------------ + +class MLIR_PYTHON_API_EXPORTED PyRegionIterator { +public: + PyRegionIterator(PyOperationRef operation, int nextIndex) + : operation(std::move(operation)), nextIndex(nextIndex) {} + + PyRegionIterator &dunderIter() { return *this; } + + PyRegion dunderNext(); + + static void bind(nanobind::module_ &m); + +private: + PyOperationRef operation; + intptr_t nextIndex = 0; +}; + +/// Regions of an op are fixed length and indexed numerically so are represented +/// with a sequence-like container. +class MLIR_PYTHON_API_EXPORTED PyRegionList + : public Sliceable { +public: + static constexpr const char *pyClassName = "RegionSequence"; + + PyRegionList(PyOperationRef operation, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1); + + PyRegionIterator dunderIter(); + + static void bindDerived(ClassTy &c); + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements(); + + PyRegion getRawElement(intptr_t pos); + + PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) const; + + PyOperationRef operation; +}; + +class MLIR_PYTHON_API_EXPORTED PyBlockIterator { +public: + PyBlockIterator(PyOperationRef operation, MlirBlock next) + : operation(std::move(operation)), next(next) {} + + PyBlockIterator &dunderIter() { return *this; } + + PyBlock dunderNext(); + + static void bind(nanobind::module_ &m); + +private: + PyOperationRef operation; + MlirBlock next; +}; + +/// Blocks are exposed by the C-API as a forward-only linked list. In Python, +/// we present them as a more full-featured list-like container but optimize +/// it for forward iteration. Blocks are always owned by a region. +class MLIR_PYTHON_API_EXPORTED PyBlockList { +public: + PyBlockList(PyOperationRef operation, MlirRegion region) + : operation(std::move(operation)), region(region) {} + + PyBlockIterator dunderIter(); + + intptr_t dunderLen(); + + PyBlock dunderGetItem(intptr_t index); + + PyBlock appendBlock(const nanobind::args &pyArgTypes, + const std::optional &pyArgLocs); + + static void bind(nanobind::module_ &m); + +private: + PyOperationRef operation; + MlirRegion region; +}; + +class MLIR_PYTHON_API_EXPORTED PyOperationIterator { +public: + PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) + : parentOperation(std::move(parentOperation)), next(next) {} + + PyOperationIterator &dunderIter() { return *this; } + + nanobind::typed dunderNext(); + + static void bind(nanobind::module_ &m); + +private: + PyOperationRef parentOperation; + MlirOperation next; +}; + +/// Operations are exposed by the C-API as a forward-only linked list. In +/// Python, we present them as a more full-featured list-like container but +/// optimize it for forward iteration. Iterable operations are always owned +/// by a block. +class MLIR_PYTHON_API_EXPORTED PyOperationList { +public: + PyOperationList(PyOperationRef parentOperation, MlirBlock block) + : parentOperation(std::move(parentOperation)), block(block) {} + + PyOperationIterator dunderIter(); + + intptr_t dunderLen(); + + nanobind::typed dunderGetItem(intptr_t index); + + static void bind(nanobind::module_ &m); + +private: + PyOperationRef parentOperation; + MlirBlock block; +}; + +class MLIR_PYTHON_API_EXPORTED PyOpOperand { +public: + PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {} + + nanobind::typed getOwner() const; + + size_t getOperandNumber() const; + + static void bind(nanobind::module_ &m); + +private: + MlirOpOperand opOperand; +}; + +class MLIR_PYTHON_API_EXPORTED PyOpOperandIterator { +public: + PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {} + + PyOpOperandIterator &dunderIter() { return *this; } + + PyOpOperand dunderNext(); + + static void bind(nanobind::module_ &m); + +private: + MlirOpOperand opOperand; +}; + +/// CRTP base class for Python MLIR values that subclass Value and should be +/// castable from it. The value hierarchy is one level deep and is not supposed +/// to accommodate other levels unless core MLIR changes. +template +class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + // and redefine bindDerived. + using ClassTy = nanobind::class_; + using IsAFunctionTy = bool (*)(MlirValue); + + PyConcreteValue() = default; + PyConcreteValue(PyOperationRef operationRef, MlirValue value) + : PyValue(operationRef, value) {} + PyConcreteValue(PyValue &orig) + : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} + + /// Attempts to cast the original value to the derived type and throws on + /// type mismatches. + static MlirValue castFrom(PyValue &orig) { + if (!DerivedTy::isaFunction(orig.get())) { + auto origRepr = + nanobind::cast(nanobind::repr(nanobind::cast(orig))); + throw nanobind::value_error((Twine("Cannot cast value to ") + + DerivedTy::pyClassName + " (from " + + origRepr + ")") + .str() + .c_str()); + } + return orig.get(); + } + + /// Binds the Python module objects to functions of this class. + static void bind(nanobind::module_ &m) { + auto cls = ClassTy( + m, DerivedTy::pyClassName, nanobind::is_generic(), + nanobind::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])") + .str() + .c_str())); + cls.def(nanobind::init(), nanobind::keep_alive<0, 1>(), + nanobind::arg("value")); + cls.def_static( + "isinstance", + [](PyValue &otherValue) -> bool { + return DerivedTy::isaFunction(otherValue); + }, + nanobind::arg("other_value")); + cls.def( + MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](DerivedTy &self) -> nanobind::typed { + return self.maybeDownCast(); + }); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + +/// Python wrapper for MlirOpResult. +class MLIR_PYTHON_API_EXPORTED PyOpResult : public PyConcreteValue { +public: + static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; + static constexpr const char *pyClassName = "OpResult"; + using PyConcreteValue::PyConcreteValue; + + static void bindDerived(ClassTy &c); +}; + +/// A list of operation results. Internally, these are stored as consecutive +/// elements, random access is cheap. The (returned) result list is associated +/// with the operation whose results these are, and thus extends the lifetime of +/// this operation. +class MLIR_PYTHON_API_EXPORTED PyOpResultList + : public Sliceable { +public: + static constexpr const char *pyClassName = "OpResultList"; + using SliceableT = Sliceable; + + PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1); + + static void bindDerived(ClassTy &c); + + PyOperationRef &getOperation() { return operation; } + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements(); + + PyOpResult getRawElement(intptr_t index); + + PyOpResultList slice(intptr_t startIndex, intptr_t length, + intptr_t step) const; + + PyOperationRef operation; +}; + +/// Python wrapper for MlirBlockArgument. +class MLIR_PYTHON_API_EXPORTED PyBlockArgument + : public PyConcreteValue { +public: + static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; + static constexpr const char *pyClassName = "BlockArgument"; + using PyConcreteValue::PyConcreteValue; + + static void bindDerived(ClassTy &c); +}; + +/// A list of block arguments. Internally, these are stored as consecutive +/// elements, random access is cheap. The argument list is associated with the +/// operation that contains the block (detached blocks are not allowed in +/// Python bindings) and extends its lifetime. +class MLIR_PYTHON_API_EXPORTED PyBlockArgumentList + : public Sliceable { +public: + static constexpr const char *pyClassName = "BlockArgumentList"; + using SliceableT = Sliceable; + + PyBlockArgumentList(PyOperationRef operation, MlirBlock block, + intptr_t startIndex = 0, intptr_t length = -1, + intptr_t step = 1); + + static void bindDerived(ClassTy &c); + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + /// Returns the number of arguments in the list. + intptr_t getRawNumElements(); + + /// Returns `pos`-the element in the list. + PyBlockArgument getRawElement(intptr_t pos) const; + + /// Returns a sublist of this list. + PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, + intptr_t step) const; + + PyOperationRef operation; + MlirBlock block; +}; + +/// A list of operation operands. Internally, these are stored as consecutive +/// elements, random access is cheap. The (returned) operand list is associated +/// with the operation whose operands these are, and thus extends the lifetime +/// of this operation. +class MLIR_PYTHON_API_EXPORTED PyOpOperandList + : public Sliceable { +public: + static constexpr const char *pyClassName = "OpOperandList"; + using SliceableT = Sliceable; + + PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1); + + void dunderSetItem(intptr_t index, PyValue value); + + static void bindDerived(ClassTy &c); + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + intptr_t getRawNumElements(); + + PyValue getRawElement(intptr_t pos); + + PyOpOperandList slice(intptr_t startIndex, intptr_t length, + intptr_t step) const; + + PyOperationRef operation; +}; + +/// A list of operation successors. Internally, these are stored as consecutive +/// elements, random access is cheap. The (returned) successor list is +/// associated with the operation whose successors these are, and thus extends +/// the lifetime of this operation. +class MLIR_PYTHON_API_EXPORTED PyOpSuccessors + : public Sliceable { +public: + static constexpr const char *pyClassName = "OpSuccessors"; + + PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1); + + void dunderSetItem(intptr_t index, PyBlock block); + + static void bindDerived(ClassTy &c); + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements(); + + PyBlock getRawElement(intptr_t pos); + + PyOpSuccessors slice(intptr_t startIndex, intptr_t length, + intptr_t step) const; + + PyOperationRef operation; +}; + +/// A list of block successors. Internally, these are stored as consecutive +/// elements, random access is cheap. The (returned) successor list is +/// associated with the operation and block whose successors these are, and thus +/// extends the lifetime of this operation and block. +class MLIR_PYTHON_API_EXPORTED PyBlockSuccessors + : public Sliceable { +public: + static constexpr const char *pyClassName = "BlockSuccessors"; + + PyBlockSuccessors(PyBlock block, PyOperationRef operation, + intptr_t startIndex = 0, intptr_t length = -1, + intptr_t step = 1); + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements(); + + PyBlock getRawElement(intptr_t pos); + + PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, + intptr_t step) const; + + PyOperationRef operation; + PyBlock block; +}; + +/// A list of block predecessors. The (returned) predecessor list is +/// associated with the operation and block whose predecessors these are, and +/// thus extends the lifetime of this operation and block. +/// +/// WARNING: This Sliceable is more expensive than the others here because +/// mlirBlockGetPredecessor actually iterates the use-def chain (of block +/// operands) anew for each indexed access. +class MLIR_PYTHON_API_EXPORTED PyBlockPredecessors + : public Sliceable { +public: + static constexpr const char *pyClassName = "BlockPredecessors"; + + PyBlockPredecessors(PyBlock block, PyOperationRef operation, + intptr_t startIndex = 0, intptr_t length = -1, + intptr_t step = 1); + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements(); + + PyBlock getRawElement(intptr_t pos); + + PyBlockPredecessors slice(intptr_t startIndex, intptr_t length, + intptr_t step) const; + + PyOperationRef operation; + PyBlock block; +}; + +/// A list of operation attributes. Can be indexed by name, producing +/// attributes, or by index, producing named attributes. +class MLIR_PYTHON_API_EXPORTED PyOpAttributeMap { +public: + PyOpAttributeMap(PyOperationRef operation) + : operation(std::move(operation)) {} + + nanobind::typed + dunderGetItemNamed(const std::string &name); + + PyNamedAttribute dunderGetItemIndexed(intptr_t index); + + void dunderSetItem(const std::string &name, const PyAttribute &attr); + + void dunderDelItem(const std::string &name); + + intptr_t dunderLen(); + + bool dunderContains(const std::string &name); + + static void + forEachAttr(MlirOperation op, + llvm::function_ref fn); + + static void bind(nanobind::module_ &m); + +private: + PyOperationRef operation; +}; + +MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation); +MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m); +MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m); +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN } // namespace python } // namespace mlir @@ -1336,13 +1813,18 @@ namespace nanobind { namespace detail { template <> -struct type_caster - : MlirDefaultingCaster {}; +struct type_caster< + mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext> + : MlirDefaultingCaster< + mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext> { +}; template <> -struct type_caster - : MlirDefaultingCaster {}; +struct type_caster< + mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyLocation> + : MlirDefaultingCaster< + mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyLocation> {}; } // namespace detail } // namespace nanobind -#endif // MLIR_BINDINGS_PYTHON_IRMODULES_H +#endif // MLIR_BINDINGS_PYTHON_IRCORE_H diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h index ba9642cf2c6a2..87e0e10764bd8 100644 --- a/mlir/include/mlir/Bindings/Python/IRTypes.h +++ b/mlir/include/mlir/Bindings/Python/IRTypes.h @@ -12,9 +12,11 @@ #include "mlir/Bindings/Python/NanobindAdaptors.h" namespace mlir { - +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { /// Shaped Type Interface - ShapedType -class PyShapedType : public python::PyConcreteType { +class MLIR_PYTHON_API_EXPORTED PyShapedType + : public PyConcreteType { public: static const IsAFunctionTy isaFunction; static constexpr const char *pyClassName = "ShapedType"; @@ -25,7 +27,8 @@ class PyShapedType : public python::PyConcreteType { private: void requireHasRank(); }; - +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python } // namespace mlir #endif // MLIR_BINDINGS_PYTHON_IRTYPES_H diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/include/mlir/Bindings/Python/NanobindUtils.h similarity index 100% rename from mlir/lib/Bindings/Python/NanobindUtils.h rename to mlir/include/mlir/Bindings/Python/NanobindUtils.h diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp index 0d1d9e89f92f6..a87918a05b126 100644 --- a/mlir/lib/Bindings/Python/DialectSMT.cpp +++ b/mlir/lib/Bindings/Python/DialectSMT.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "NanobindUtils.h" +#include "mlir/Bindings/Python/NanobindUtils.h" #include "mlir-c/Dialect/SMT.h" #include "mlir-c/IR.h" diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/Globals.cpp similarity index 95% rename from mlir/lib/Bindings/Python/IRModule.cpp rename to mlir/lib/Bindings/Python/Globals.cpp index 0de2f1711829b..e2e8693ba45f3 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/Globals.cpp @@ -6,25 +6,29 @@ // //===----------------------------------------------------------------------===// -#include "IRModule.h" +#include "mlir/Bindings/Python/IRCore.h" #include #include -#include "Globals.h" -#include "NanobindUtils.h" +#include "mlir/Bindings/Python/Globals.h" +// clang-format off +#include "mlir/Bindings/Python/NanobindUtils.h" #include "mlir-c/Bindings/Python/Interop.h" +// clang-format on #include "mlir-c/Support.h" #include "mlir/Bindings/Python/Nanobind.h" namespace nb = nanobind; using namespace mlir; -using namespace mlir::python; // ----------------------------------------------------------------------------- // PyGlobals // ----------------------------------------------------------------------------- +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { PyGlobals *PyGlobals::instance = nullptr; PyGlobals::PyGlobals() { @@ -37,6 +41,11 @@ PyGlobals::PyGlobals() { PyGlobals::~PyGlobals() { instance = nullptr; } +PyGlobals &PyGlobals::get() { + assert(instance && "PyGlobals is null"); + return *instance; +} + bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { { nb::ft_lock_guard lock(mutex); @@ -265,3 +274,6 @@ bool PyGlobals::TracebackLoc::isUserTracebackFilename( } return isUserTracebackFilenameCache[file]; } +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 7147f2cbad149..ce235470bbdc7 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -13,11 +13,13 @@ #include #include -#include "IRModule.h" -#include "NanobindUtils.h" #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" +#include "mlir/Bindings/Python/IRCore.h" +// clang-format off +#include "mlir/Bindings/Python/NanobindUtils.h" #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. +// clang-format on #include "mlir-c/IntegerSet.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Support/LLVM.h" @@ -28,7 +30,7 @@ namespace nb = nanobind; using namespace mlir; -using namespace mlir::python; +using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN; using llvm::SmallVector; using llvm::StringRef; @@ -78,7 +80,9 @@ static bool isPermutation(const std::vector &permutation) { return true; } -namespace { +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr /// and should be castable from it. Intermediate hierarchy classes can be @@ -356,7 +360,9 @@ class PyAffineCeilDivExpr } }; -} // namespace +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir bool PyAffineExpr::operator==(const PyAffineExpr &other) const { return mlirAffineExprEqual(affineExpr, other.affineExpr); @@ -378,7 +384,9 @@ PyAffineExpr PyAffineExpr::createFromCapsule(const nb::object &capsule) { //------------------------------------------------------------------------------ // PyAffineMap and utilities. //------------------------------------------------------------------------------ -namespace { +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { /// A list of expressions contained in an affine map. Internally these are /// stored as a consecutive array leading to inexpensive random access. Both @@ -414,7 +422,9 @@ class PyAffineMapExprList PyAffineMap affineMap; }; -} // namespace +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir bool PyAffineMap::operator==(const PyAffineMap &other) const { return mlirAffineMapEqual(affineMap, other.affineMap); @@ -436,7 +446,9 @@ PyAffineMap PyAffineMap::createFromCapsule(const nb::object &capsule) { //------------------------------------------------------------------------------ // PyIntegerSet and utilities. //------------------------------------------------------------------------------ -namespace { +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { class PyIntegerSetConstraint { public: @@ -490,7 +502,9 @@ class PyIntegerSetConstraintList PyIntegerSet set; }; -} // namespace +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir bool PyIntegerSet::operator==(const PyIntegerSet &other) const { return mlirIntegerSetEqual(integerSet, other.integerSet); @@ -509,7 +523,10 @@ PyIntegerSet PyIntegerSet::createFromCapsule(const nb::object &capsule) { rawIntegerSet); } -void mlir::python::populateIRAffine(nb::module_ &m) { +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { +void populateIRAffine(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of PyAffineExpr and derived classes. //---------------------------------------------------------------------------- @@ -995,3 +1012,6 @@ void mlir::python::populateIRAffine(nb::module_ &m) { PyIntegerSetConstraint::bind(m); PyIntegerSetConstraintList::bind(m); } +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index c0a945e3f4f3b..f0f0ae9ba741e 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -12,19 +12,19 @@ #include #include -#include "IRModule.h" -#include "NanobindUtils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" +#include "mlir/Bindings/Python/IRCore.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/NanobindUtils.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/raw_ostream.h" namespace nb = nanobind; using namespace nanobind::literals; using namespace mlir; -using namespace mlir::python; +using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN; using llvm::SmallVector; @@ -121,7 +121,9 @@ subsequent processing. type or if the buffer does not meet expectations. )"; -namespace { +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { struct nb_buffer_info { void *ptr = nullptr; @@ -228,14 +230,6 @@ struct nb_format_descriptor { static const char *format() { return "d"; } }; -static MlirStringRef toMlirStringRef(const std::string &s) { - return mlirStringRefCreate(s.data(), s.size()); -} - -static MlirStringRef toMlirStringRef(const nb::bytes &s) { - return mlirStringRefCreate(static_cast(s.data()), s.size()); -} - class PyAffineMapAttribute : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; @@ -1753,7 +1747,9 @@ nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { throw nb::type_error(msg.c_str()); } -} // namespace +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir void PyStringAttribute::bindDerived(ClassTy &c) { c.def_static( @@ -1799,7 +1795,10 @@ void PyStringAttribute::bindDerived(ClassTy &c) { "Returns the value of the string attribute as `bytes`"); } -void mlir::python::populateIRAttributes(nb::module_ &m) { +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { +void populateIRAttributes(nb::module_ &m) { PyAffineMapAttribute::bind(m); PyDenseBoolArrayAttribute::bind(m); PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); @@ -1852,3 +1851,6 @@ void mlir::python::populateIRAttributes(nb::module_ &m) { PyStridedLayoutAttribute::bind(m); } +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 168c57955af07..45e2cda4c91e2 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -6,10 +6,12 @@ // //===----------------------------------------------------------------------===// -#include "Globals.h" -#include "IRModule.h" -#include "NanobindUtils.h" +// clang-format off +#include "mlir/Bindings/Python/Globals.h" +#include "mlir/Bindings/Python/IRCore.h" +#include "mlir/Bindings/Python/NanobindUtils.h" #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. +// clang-format on #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Debug.h" #include "mlir-c/Diagnostics.h" @@ -27,7 +29,6 @@ namespace nb = nanobind; using namespace nb::literals; using namespace mlir; -using namespace mlir::python; using llvm::SmallVector; using llvm::StringRef; @@ -64,44 +65,41 @@ static nb::object classmethod(Func f, Args... args) { static nb::object createCustomDialectWrapper(const std::string &dialectNamespace, nb::object dialectDescriptor) { - auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); + auto dialectClass = + python::MLIR_BINDINGS_PYTHON_DOMAIN::PyGlobals::get().lookupDialectClass( + dialectNamespace); if (!dialectClass) { // Use the base class. - return nb::cast(PyDialect(std::move(dialectDescriptor))); + return nb::cast(python::MLIR_BINDINGS_PYTHON_DOMAIN::PyDialect( + std::move(dialectDescriptor))); } // Create the custom implementation. return (*dialectClass)(std::move(dialectDescriptor)); } -static MlirStringRef toMlirStringRef(const std::string &s) { - return mlirStringRefCreate(s.data(), s.size()); -} - -static MlirStringRef toMlirStringRef(std::string_view s) { - return mlirStringRefCreate(s.data(), s.size()); -} - -static MlirStringRef toMlirStringRef(const nb::bytes &s) { - return mlirStringRefCreate(static_cast(s.data()), s.size()); -} +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { -/// Create a block, using the current location context if no locations are -/// specified. -static MlirBlock createBlock(const nb::sequence &pyArgTypes, - const std::optional &pyArgLocs) { +MlirBlock createBlock(const nb::sequence &pyArgTypes, + const std::optional &pyArgLocs) { SmallVector argTypes; argTypes.reserve(nb::len(pyArgTypes)); for (const auto &pyType : pyArgTypes) - argTypes.push_back(nb::cast(pyType)); + argTypes.push_back( + nb::cast(pyType)); SmallVector argLocs; if (pyArgLocs) { argLocs.reserve(nb::len(*pyArgLocs)); for (const auto &pyLoc : *pyArgLocs) - argLocs.push_back(nb::cast(pyLoc)); + argLocs.push_back( + nb::cast(pyLoc)); } else if (!argTypes.empty()) { - argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve()); + argLocs.assign( + argTypes.size(), + python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyLocation::resolve()); } if (argTypes.size() != argLocs.size()) @@ -112,82 +110,77 @@ static MlirBlock createBlock(const nb::sequence &pyArgTypes, return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); } -/// Wrapper for the global LLVM debugging flag. -struct PyGlobalDebugFlag { - static void set(nb::object &o, bool enable) { - nb::ft_lock_guard lock(mutex); - mlirEnableGlobalDebug(enable); - } - - static bool get(const nb::object &) { - nb::ft_lock_guard lock(mutex); - return mlirIsGlobalDebugEnabled(); - } +void PyGlobalDebugFlag::set(nb::object &o, bool enable) { + nb::ft_lock_guard lock(mutex); + mlirEnableGlobalDebug(enable); +} - static void bind(nb::module_ &m) { - // Debug flags. - nb::class_(m, "_GlobalDebug") - .def_prop_rw_static("flag", &PyGlobalDebugFlag::get, - &PyGlobalDebugFlag::set, "LLVM-wide debug flag.") - .def_static( - "set_types", - [](const std::string &type) { - nb::ft_lock_guard lock(mutex); - mlirSetGlobalDebugType(type.c_str()); - }, - "types"_a, "Sets specific debug types to be produced by LLVM.") - .def_static( - "set_types", - [](const std::vector &types) { - std::vector pointers; - pointers.reserve(types.size()); - for (const std::string &str : types) - pointers.push_back(str.c_str()); - nb::ft_lock_guard lock(mutex); - mlirSetGlobalDebugTypes(pointers.data(), pointers.size()); - }, - "types"_a, - "Sets multiple specific debug types to be produced by LLVM."); - } +bool PyGlobalDebugFlag::get(const nb::object &) { + nb::ft_lock_guard lock(mutex); + return mlirIsGlobalDebugEnabled(); +} -private: - static nb::ft_mutex mutex; -}; +void PyGlobalDebugFlag::bind(nb::module_ &m) { + // Debug flags. + nb::class_(m, "_GlobalDebug") + .def_prop_rw_static("flag", &PyGlobalDebugFlag::get, + &PyGlobalDebugFlag::set, "LLVM-wide debug flag.") + .def_static( + "set_types", + [](const std::string &type) { + nb::ft_lock_guard lock(mutex); + mlirSetGlobalDebugType(type.c_str()); + }, + "types"_a, "Sets specific debug types to be produced by LLVM.") + .def_static( + "set_types", + [](const std::vector &types) { + std::vector pointers; + pointers.reserve(types.size()); + for (const std::string &str : types) + pointers.push_back(str.c_str()); + nb::ft_lock_guard lock(mutex); + mlirSetGlobalDebugTypes(pointers.data(), pointers.size()); + }, + "types"_a, + "Sets multiple specific debug types to be produced by LLVM."); +} nb::ft_mutex PyGlobalDebugFlag::mutex; -struct PyAttrBuilderMap { - static bool dunderContains(const std::string &attributeKind) { - return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); - } - static nb::callable dunderGetItemNamed(const std::string &attributeKind) { - auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); - if (!builder) - throw nb::key_error(attributeKind.c_str()); - return *builder; - } - static void dunderSetItemNamed(const std::string &attributeKind, - nb::callable func, bool replace) { - PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), - replace); - } +bool PyAttrBuilderMap::dunderContains(const std::string &attributeKind) { + return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); +} - static void bind(nb::module_ &m) { - nb::class_(m, "AttrBuilder") - .def_static("contains", &PyAttrBuilderMap::dunderContains, - "attribute_kind"_a, - "Checks whether an attribute builder is registered for the " - "given attribute kind.") - .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed, - "attribute_kind"_a, - "Gets the registered attribute builder for the given " - "attribute kind.") - .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed, - "attribute_kind"_a, "attr_builder"_a, "replace"_a = false, - "Register an attribute builder for building MLIR " - "attributes from Python values."); - } -}; +nb::callable +PyAttrBuilderMap::dunderGetItemNamed(const std::string &attributeKind) { + auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); + if (!builder) + throw nb::key_error(attributeKind.c_str()); + return *builder; +} + +void PyAttrBuilderMap::dunderSetItemNamed(const std::string &attributeKind, + nb::callable func, bool replace) { + PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), + replace); +} + +void PyAttrBuilderMap::bind(nb::module_ &m) { + nb::class_(m, "AttrBuilder") + .def_static("contains", &PyAttrBuilderMap::dunderContains, + "attribute_kind"_a, + "Checks whether an attribute builder is registered for the " + "given attribute kind.") + .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed, + "attribute_kind"_a, + "Gets the registered attribute builder for the given " + "attribute kind.") + .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed, + "attribute_kind"_a, "attr_builder"_a, "replace"_a = false, + "Register an attribute builder for building MLIR " + "attributes from Python values."); +} //------------------------------------------------------------------------------ // PyBlock @@ -201,335 +194,252 @@ nb::object PyBlock::getCapsule() { // Collections. //------------------------------------------------------------------------------ -namespace { - -class PyRegionIterator { -public: - PyRegionIterator(PyOperationRef operation, int nextIndex) - : operation(std::move(operation)), nextIndex(nextIndex) {} - - PyRegionIterator &dunderIter() { return *this; } - - PyRegion dunderNext() { - operation->checkValid(); - if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { - throw nb::stop_iteration(); - } - MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); - return PyRegion(operation, region); +PyRegion PyRegionIterator::dunderNext() { + operation->checkValid(); + if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { + throw nb::stop_iteration(); } + MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); + return PyRegion(operation, region); +} - static void bind(nb::module_ &m) { - nb::class_(m, "RegionIterator") - .def("__iter__", &PyRegionIterator::dunderIter, - "Returns an iterator over the regions in the operation.") - .def("__next__", &PyRegionIterator::dunderNext, - "Returns the next region in the iteration."); - } +void PyRegionIterator::bind(nb::module_ &m) { + nb::class_(m, "RegionIterator") + .def("__iter__", &PyRegionIterator::dunderIter, + "Returns an iterator over the regions in the operation.") + .def("__next__", &PyRegionIterator::dunderNext, + "Returns the next region in the iteration."); +} -private: - PyOperationRef operation; - intptr_t nextIndex = 0; -}; - -/// Regions of an op are fixed length and indexed numerically so are represented -/// with a sequence-like container. -class PyRegionList : public Sliceable { -public: - static constexpr const char *pyClassName = "RegionSequence"; - - PyRegionList(PyOperationRef operation, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirOperationGetNumRegions(operation->get()) - : length, - step), - operation(std::move(operation)) {} - - PyRegionIterator dunderIter() { - operation->checkValid(); - return PyRegionIterator(operation, startIndex); - } +PyRegionList::PyRegionList(PyOperationRef operation, intptr_t startIndex, + intptr_t length, intptr_t step) + : Sliceable(startIndex, + length == -1 ? mlirOperationGetNumRegions(operation->get()) + : length, + step), + operation(std::move(operation)) {} - static void bindDerived(ClassTy &c) { - c.def("__iter__", &PyRegionList::dunderIter, - "Returns an iterator over the regions in the sequence."); - } +PyRegionIterator PyRegionList::dunderIter() { + operation->checkValid(); + return PyRegionIterator(operation, startIndex); +} -private: - /// Give the parent CRTP class access to hook implementations below. - friend class Sliceable; +void PyRegionList::bindDerived(ClassTy &c) { + c.def("__iter__", &PyRegionList::dunderIter, + "Returns an iterator over the regions in the sequence."); +} - intptr_t getRawNumElements() { - operation->checkValid(); - return mlirOperationGetNumRegions(operation->get()); - } +intptr_t PyRegionList::getRawNumElements() { + operation->checkValid(); + return mlirOperationGetNumRegions(operation->get()); +} - PyRegion getRawElement(intptr_t pos) { - operation->checkValid(); - return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos)); - } +PyRegion PyRegionList::getRawElement(intptr_t pos) { + operation->checkValid(); + return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos)); +} - PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) { - return PyRegionList(operation, startIndex, length, step); - } +PyRegionList PyRegionList::slice(intptr_t startIndex, intptr_t length, + intptr_t step) const { + return PyRegionList(operation, startIndex, length, step); +} - PyOperationRef operation; -}; +PyBlock PyBlockIterator::dunderNext() { + operation->checkValid(); + if (mlirBlockIsNull(next)) { + throw nb::stop_iteration(); + } -class PyBlockIterator { -public: - PyBlockIterator(PyOperationRef operation, MlirBlock next) - : operation(std::move(operation)), next(next) {} + PyBlock returnBlock(operation, next); + next = mlirBlockGetNextInRegion(next); + return returnBlock; +} - PyBlockIterator &dunderIter() { return *this; } +void PyBlockIterator::bind(nb::module_ &m) { + nb::class_(m, "BlockIterator") + .def("__iter__", &PyBlockIterator::dunderIter, + "Returns an iterator over the blocks in the operation's region.") + .def("__next__", &PyBlockIterator::dunderNext, + "Returns the next block in the iteration."); +} - PyBlock dunderNext() { - operation->checkValid(); - if (mlirBlockIsNull(next)) { - throw nb::stop_iteration(); - } +PyBlockIterator PyBlockList::dunderIter() { + operation->checkValid(); + return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); +} - PyBlock returnBlock(operation, next); - next = mlirBlockGetNextInRegion(next); - return returnBlock; +intptr_t PyBlockList::dunderLen() { + operation->checkValid(); + intptr_t count = 0; + MlirBlock block = mlirRegionGetFirstBlock(region); + while (!mlirBlockIsNull(block)) { + count += 1; + block = mlirBlockGetNextInRegion(block); } + return count; +} - static void bind(nb::module_ &m) { - nb::class_(m, "BlockIterator") - .def("__iter__", &PyBlockIterator::dunderIter, - "Returns an iterator over the blocks in the operation's region.") - .def("__next__", &PyBlockIterator::dunderNext, - "Returns the next block in the iteration."); +PyBlock PyBlockList::dunderGetItem(intptr_t index) { + operation->checkValid(); + if (index < 0) { + index += dunderLen(); } - -private: - PyOperationRef operation; - MlirBlock next; -}; - -/// Blocks are exposed by the C-API as a forward-only linked list. In Python, -/// we present them as a more full-featured list-like container but optimize -/// it for forward iteration. Blocks are always owned by a region. -class PyBlockList { -public: - PyBlockList(PyOperationRef operation, MlirRegion region) - : operation(std::move(operation)), region(region) {} - - PyBlockIterator dunderIter() { - operation->checkValid(); - return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); + if (index < 0) { + throw nb::index_error("attempt to access out of bounds block"); } - - intptr_t dunderLen() { - operation->checkValid(); - intptr_t count = 0; - MlirBlock block = mlirRegionGetFirstBlock(region); - while (!mlirBlockIsNull(block)) { - count += 1; - block = mlirBlockGetNextInRegion(block); + MlirBlock block = mlirRegionGetFirstBlock(region); + while (!mlirBlockIsNull(block)) { + if (index == 0) { + return PyBlock(operation, block); } - return count; - } - - PyBlock dunderGetItem(intptr_t index) { - operation->checkValid(); - if (index < 0) { - index += dunderLen(); - } - if (index < 0) { - throw nb::index_error("attempt to access out of bounds block"); - } - MlirBlock block = mlirRegionGetFirstBlock(region); - while (!mlirBlockIsNull(block)) { - if (index == 0) { - return PyBlock(operation, block); - } - block = mlirBlockGetNextInRegion(block); - index -= 1; - } - throw nb::index_error("attempt to access out of bounds block"); + block = mlirBlockGetNextInRegion(block); + index -= 1; } + throw nb::index_error("attempt to access out of bounds block"); +} - PyBlock appendBlock(const nb::args &pyArgTypes, - const std::optional &pyArgLocs) { - operation->checkValid(); - MlirBlock block = - createBlock(nb::cast(pyArgTypes), pyArgLocs); - mlirRegionAppendOwnedBlock(region, block); - return PyBlock(operation, block); - } +PyBlock PyBlockList::appendBlock(const nb::args &pyArgTypes, + const std::optional &pyArgLocs) { + operation->checkValid(); + MlirBlock block = createBlock(nb::cast(pyArgTypes), pyArgLocs); + mlirRegionAppendOwnedBlock(region, block); + return PyBlock(operation, block); +} - static void bind(nb::module_ &m) { - nb::class_(m, "BlockList") - .def("__getitem__", &PyBlockList::dunderGetItem, - "Returns the block at the specified index.") - .def("__iter__", &PyBlockList::dunderIter, - "Returns an iterator over blocks in the operation's region.") - .def("__len__", &PyBlockList::dunderLen, - "Returns the number of blocks in the operation's region.") - .def("append", &PyBlockList::appendBlock, - R"( +void PyBlockList::bind(nb::module_ &m) { + nb::class_(m, "BlockList") + .def("__getitem__", &PyBlockList::dunderGetItem, + "Returns the block at the specified index.") + .def("__iter__", &PyBlockList::dunderIter, + "Returns an iterator over blocks in the operation's region.") + .def("__len__", &PyBlockList::dunderLen, + "Returns the number of blocks in the operation's region.") + .def("append", &PyBlockList::appendBlock, + R"( Appends a new block, with argument types as positional args. Returns: The created block. )", - nb::arg("args"), nb::kw_only(), - nb::arg("arg_locs") = std::nullopt); - } - -private: - PyOperationRef operation; - MlirRegion region; -}; - -class PyOperationIterator { -public: - PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) - : parentOperation(std::move(parentOperation)), next(next) {} + "args"_a, nb::kw_only(), "arg_locs"_a = std::nullopt); +} - PyOperationIterator &dunderIter() { return *this; } +nb::typed PyOperationIterator::dunderNext() { + parentOperation->checkValid(); + if (mlirOperationIsNull(next)) { + throw nb::stop_iteration(); + } - nb::typed dunderNext() { - parentOperation->checkValid(); - if (mlirOperationIsNull(next)) { - throw nb::stop_iteration(); - } + PyOperationRef returnOperation = + PyOperation::forOperation(parentOperation->getContext(), next); + next = mlirOperationGetNextInBlock(next); + return returnOperation->createOpView(); +} - PyOperationRef returnOperation = - PyOperation::forOperation(parentOperation->getContext(), next); - next = mlirOperationGetNextInBlock(next); - return returnOperation->createOpView(); - } +void PyOperationIterator::bind(nb::module_ &m) { + nb::class_(m, "OperationIterator") + .def("__iter__", &PyOperationIterator::dunderIter, + "Returns an iterator over the operations in an operation's block.") + .def("__next__", &PyOperationIterator::dunderNext, + "Returns the next operation in the iteration."); +} - static void bind(nb::module_ &m) { - nb::class_(m, "OperationIterator") - .def("__iter__", &PyOperationIterator::dunderIter, - "Returns an iterator over the operations in an operation's block.") - .def("__next__", &PyOperationIterator::dunderNext, - "Returns the next operation in the iteration."); - } +PyOperationIterator PyOperationList::dunderIter() { + parentOperation->checkValid(); + return PyOperationIterator(parentOperation, + mlirBlockGetFirstOperation(block)); +} -private: - PyOperationRef parentOperation; - MlirOperation next; -}; - -/// Operations are exposed by the C-API as a forward-only linked list. In -/// Python, we present them as a more full-featured list-like container but -/// optimize it for forward iteration. Iterable operations are always owned -/// by a block. -class PyOperationList { -public: - PyOperationList(PyOperationRef parentOperation, MlirBlock block) - : parentOperation(std::move(parentOperation)), block(block) {} - - PyOperationIterator dunderIter() { - parentOperation->checkValid(); - return PyOperationIterator(parentOperation, - mlirBlockGetFirstOperation(block)); +intptr_t PyOperationList::dunderLen() { + parentOperation->checkValid(); + intptr_t count = 0; + MlirOperation childOp = mlirBlockGetFirstOperation(block); + while (!mlirOperationIsNull(childOp)) { + count += 1; + childOp = mlirOperationGetNextInBlock(childOp); } + return count; +} - intptr_t dunderLen() { - parentOperation->checkValid(); - intptr_t count = 0; - MlirOperation childOp = mlirBlockGetFirstOperation(block); - while (!mlirOperationIsNull(childOp)) { - count += 1; - childOp = mlirOperationGetNextInBlock(childOp); - } - return count; +nb::typed PyOperationList::dunderGetItem(intptr_t index) { + parentOperation->checkValid(); + if (index < 0) { + index += dunderLen(); } - - nb::typed dunderGetItem(intptr_t index) { - parentOperation->checkValid(); - if (index < 0) { - index += dunderLen(); - } - if (index < 0) { - throw nb::index_error("attempt to access out of bounds operation"); - } - MlirOperation childOp = mlirBlockGetFirstOperation(block); - while (!mlirOperationIsNull(childOp)) { - if (index == 0) { - return PyOperation::forOperation(parentOperation->getContext(), childOp) - ->createOpView(); - } - childOp = mlirOperationGetNextInBlock(childOp); - index -= 1; - } + if (index < 0) { throw nb::index_error("attempt to access out of bounds operation"); } - - static void bind(nb::module_ &m) { - nb::class_(m, "OperationList") - .def("__getitem__", &PyOperationList::dunderGetItem, - "Returns the operation at the specified index.") - .def("__iter__", &PyOperationList::dunderIter, - "Returns an iterator over operations in the list.") - .def("__len__", &PyOperationList::dunderLen, - "Returns the number of operations in the list."); - } - -private: - PyOperationRef parentOperation; - MlirBlock block; -}; - -class PyOpOperand { -public: - PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {} - - nb::typed getOwner() { - MlirOperation owner = mlirOpOperandGetOwner(opOperand); - PyMlirContextRef context = - PyMlirContext::forContext(mlirOperationGetContext(owner)); - return PyOperation::forOperation(context, owner)->createOpView(); + MlirOperation childOp = mlirBlockGetFirstOperation(block); + while (!mlirOperationIsNull(childOp)) { + if (index == 0) { + return PyOperation::forOperation(parentOperation->getContext(), childOp) + ->createOpView(); + } + childOp = mlirOperationGetNextInBlock(childOp); + index -= 1; } + throw nb::index_error("attempt to access out of bounds operation"); +} - size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); } +void PyOperationList::bind(nb::module_ &m) { + nb::class_(m, "OperationList") + .def("__getitem__", &PyOperationList::dunderGetItem, + "Returns the operation at the specified index.") + .def("__iter__", &PyOperationList::dunderIter, + "Returns an iterator over operations in the list.") + .def("__len__", &PyOperationList::dunderLen, + "Returns the number of operations in the list."); +} - static void bind(nb::module_ &m) { - nb::class_(m, "OpOperand") - .def_prop_ro("owner", &PyOpOperand::getOwner, - "Returns the operation that owns this operand.") - .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber, - "Returns the operand number in the owning operation."); - } +nb::typed PyOpOperand::getOwner() const { + MlirOperation owner = mlirOpOperandGetOwner(opOperand); + PyMlirContextRef context = + PyMlirContext::forContext(mlirOperationGetContext(owner)); + return PyOperation::forOperation(context, owner)->createOpView(); +} -private: - MlirOpOperand opOperand; -}; +size_t PyOpOperand::getOperandNumber() const { + return mlirOpOperandGetOperandNumber(opOperand); +} -class PyOpOperandIterator { -public: - PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {} +void PyOpOperand::bind(nb::module_ &m) { + nb::class_(m, "OpOperand") + .def_prop_ro("owner", &PyOpOperand::getOwner, + "Returns the operation that owns this operand.") + .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber, + "Returns the operand number in the owning operation."); +} - PyOpOperandIterator &dunderIter() { return *this; } +PyOpOperand PyOpOperandIterator::dunderNext() { + if (mlirOpOperandIsNull(opOperand)) + throw nb::stop_iteration(); - PyOpOperand dunderNext() { - if (mlirOpOperandIsNull(opOperand)) - throw nb::stop_iteration(); + PyOpOperand returnOpOperand(opOperand); + opOperand = mlirOpOperandGetNextUse(opOperand); + return returnOpOperand; +} - PyOpOperand returnOpOperand(opOperand); - opOperand = mlirOpOperandGetNextUse(opOperand); - return returnOpOperand; - } +void PyOpOperandIterator::bind(nb::module_ &m) { + nb::class_(m, "OpOperandIterator") + .def("__iter__", &PyOpOperandIterator::dunderIter, + "Returns an iterator over operands.") + .def("__next__", &PyOpOperandIterator::dunderNext, + "Returns the next operand in the iteration."); +} - static void bind(nb::module_ &m) { - nb::class_(m, "OpOperandIterator") - .def("__iter__", &PyOpOperandIterator::dunderIter, - "Returns an iterator over operands.") - .def("__next__", &PyOpOperandIterator::dunderNext, - "Returns the next operand in the iteration."); - } +//------------------------------------------------------------------------------ +// PyThreadPool +//------------------------------------------------------------------------------ -private: - MlirOpOperand opOperand; -}; +PyThreadPool::PyThreadPool() { + ownedThreadPool = std::make_unique(); +} -} // namespace +std::string PyThreadPool::_mlir_thread_pool_ptr() const { + std::stringstream ss; + ss << ownedThreadPool.get(); + return ss.str(); +} //------------------------------------------------------------------------------ // PyMlirContext @@ -554,6 +464,10 @@ PyMlirContext::~PyMlirContext() { mlirContextDestroy(context); } +PyMlirContextRef PyMlirContext::getRef() { + return PyMlirContextRef(this, nb::cast(this)); +} + nb::object PyMlirContext::getCapsule() { return nb::steal(mlirPythonContextToCapsule(get())); } @@ -662,7 +576,8 @@ MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag, if (self->ctx->emitErrorDiagnostics) return mlirLogicalResultFailure(); - if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError) + if (mlirDiagnosticGetSeverity(diag) != + MlirDiagnosticSeverity::MlirDiagnosticError) return mlirLogicalResultFailure(); self->errors.emplace_back(PyDiagnostic(diag).getInfo()); @@ -849,9 +764,10 @@ void PyDiagnostic::checkValid() { } } -MlirDiagnosticSeverity PyDiagnostic::getSeverity() { +PyDiagnosticSeverity PyDiagnostic::getSeverity() { checkValid(); - return mlirDiagnosticGetSeverity(diagnostic); + return static_cast( + mlirDiagnosticGetSeverity(diagnostic)); } PyLocation PyDiagnostic::getLocation() { @@ -1088,6 +1004,31 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef, return PyOperation::createDetached(std::move(contextRef), op); } +void PyOperation::detachFromParent() { + mlirOperationRemoveFromParent(getOperation()); + setDetached(); + parentKeepAlive = nb::object(); +} + +MlirOperation PyOperation::get() const { + checkValid(); + return operation; +} + +PyOperationRef PyOperation::getRef() { + return PyOperationRef(this, nb::borrow(handle)); +} + +void PyOperation::setAttached(const nb::object &parent) { + assert(!attached && "operation already attached"); + attached = true; +} + +void PyOperation::setDetached() { + assert(attached && "operation already detached"); + attached = false; +} + void PyOperation::checkValid() const { if (!valid) { throw std::runtime_error("the operation has been invalidated"); @@ -1164,13 +1105,12 @@ void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject, .c_str()); } -void PyOperationBase::walk( - std::function callback, - MlirWalkOrder walkOrder) { +void PyOperationBase::walk(std::function callback, + PyWalkOrder walkOrder) { PyOperation &operation = getOperation(); operation.checkValid(); struct UserData { - std::function callback; + std::function callback; bool gotException; std::string exceptionWhat; nb::object exceptionType; @@ -1180,7 +1120,7 @@ void PyOperationBase::walk( void *userData) { UserData *calleeUserData = static_cast(userData); try { - return (calleeUserData->callback)(op); + return static_cast((calleeUserData->callback)(op)); } catch (nb::python_error &e) { calleeUserData->gotException = true; calleeUserData->exceptionWhat = std::string(e.what()); @@ -1188,7 +1128,8 @@ void PyOperationBase::walk( return MlirWalkResult::MlirWalkResultInterrupt; } }; - mlirOperationWalk(operation, walkCallback, &userData, walkOrder); + mlirOperationWalk(operation, walkCallback, &userData, + static_cast(walkOrder)); if (userData.gotException) { std::string message("Exception raised in callback: "); message.append(userData.exceptionWhat); @@ -1448,93 +1389,22 @@ void PyOperation::erase() { mlirOperationDestroy(operation); } -namespace { -/// CRTP base class for Python MLIR values that subclass Value and should be -/// castable from it. The value hierarchy is one level deep and is not supposed -/// to accommodate other levels unless core MLIR changes. -template -class PyConcreteValue : public PyValue { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - // and redefine bindDerived. - using ClassTy = nb::class_; - using IsAFunctionTy = bool (*)(MlirValue); - - PyConcreteValue() = default; - PyConcreteValue(PyOperationRef operationRef, MlirValue value) - : PyValue(operationRef, value) {} - PyConcreteValue(PyValue &orig) - : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} - - /// Attempts to cast the original value to the derived type and throws on - /// type mismatches. - static MlirValue castFrom(PyValue &orig) { - if (!DerivedTy::isaFunction(orig.get())) { - auto origRepr = nb::cast(nb::repr(nb::cast(orig))); - throw nb::value_error((Twine("Cannot cast value to ") + - DerivedTy::pyClassName + " (from " + origRepr + - ")") - .str() - .c_str()); - } - return orig.get(); - } - - /// Binds the Python module objects to functions of this class. - static void bind(nb::module_ &m) { - auto cls = ClassTy( - m, DerivedTy::pyClassName, nb::is_generic(), - nb::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])") - .str() - .c_str())); - cls.def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")); - cls.def_static( - "isinstance", - [](PyValue &otherValue) -> bool { - return DerivedTy::isaFunction(otherValue); - }, - nb::arg("other_value")); - cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, - [](DerivedTy &self) -> nb::typed { - return self.maybeDownCast(); - }); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - -} // namespace - -/// Python wrapper for MlirOpResult. -class PyOpResult : public PyConcreteValue { -public: - static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; - static constexpr const char *pyClassName = "OpResult"; - using PyConcreteValue::PyConcreteValue; - - static void bindDerived(ClassTy &c) { - c.def_prop_ro( - "owner", - [](PyOpResult &self) -> nb::typed { - assert(mlirOperationEqual(self.getParentOperation()->get(), - mlirOpResultGetOwner(self.get())) && - "expected the owner of the value in Python to match that in " - "the IR"); - return self.getParentOperation()->createOpView(); - }, - "Returns the operation that produces this result."); - c.def_prop_ro( - "result_number", - [](PyOpResult &self) { - return mlirOpResultGetResultNumber(self.get()); - }, - "Returns the position of this result in the operation's result list."); - } -}; +void PyOpResult::bindDerived(ClassTy &c) { + c.def_prop_ro( + "owner", + [](PyOpResult &self) -> nb::typed { + assert(mlirOperationEqual(self.getParentOperation()->get(), + mlirOpResultGetOwner(self.get())) && + "expected the owner of the value in Python to match that in " + "the IR"); + return self.getParentOperation()->createOpView(); + }, + "Returns the operation that produces this result."); + c.def_prop_ro( + "result_number", + [](PyOpResult &self) { return mlirOpResultGetResultNumber(self.get()); }, + "Returns the position of this result in the operation's result list."); +} /// Returns the list of types of the values held by container. template @@ -1550,60 +1420,43 @@ getValueTypes(Container &container, PyMlirContextRef &context) { return result; } -/// A list of operation results. Internally, these are stored as consecutive -/// elements, random access is cheap. The (returned) result list is associated -/// with the operation whose results these are, and thus extends the lifetime of -/// this operation. -class PyOpResultList : public Sliceable { -public: - static constexpr const char *pyClassName = "OpResultList"; - using SliceableT = Sliceable; - - PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirOperationGetNumResults(operation->get()) - : length, - step), - operation(std::move(operation)) {} - - static void bindDerived(ClassTy &c) { - c.def_prop_ro( - "types", - [](PyOpResultList &self) { - return getValueTypes(self, self.operation->getContext()); - }, - "Returns a list of types for all results in this result list."); - c.def_prop_ro( - "owner", - [](PyOpResultList &self) -> nb::typed { - return self.operation->createOpView(); - }, - "Returns the operation that owns this result list."); - } - - PyOperationRef &getOperation() { return operation; } - -private: - /// Give the parent CRTP class access to hook implementations below. - friend class Sliceable; - - intptr_t getRawNumElements() { - operation->checkValid(); - return mlirOperationGetNumResults(operation->get()); - } +PyOpResultList::PyOpResultList(PyOperationRef operation, intptr_t startIndex, + intptr_t length, intptr_t step) + : Sliceable(startIndex, + length == -1 ? mlirOperationGetNumResults(operation->get()) + : length, + step), + operation(std::move(operation)) {} + +void PyOpResultList::bindDerived(ClassTy &c) { + c.def_prop_ro( + "types", + [](PyOpResultList &self) { + return getValueTypes(self, self.operation->getContext()); + }, + "Returns a list of types for all results in this result list."); + c.def_prop_ro( + "owner", + [](PyOpResultList &self) -> nb::typed { + return self.operation->createOpView(); + }, + "Returns the operation that owns this result list."); +} - PyOpResult getRawElement(intptr_t index) { - PyValue value(operation, mlirOperationGetResult(operation->get(), index)); - return PyOpResult(value); - } +intptr_t PyOpResultList::getRawNumElements() { + operation->checkValid(); + return mlirOperationGetNumResults(operation->get()); +} - PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { - return PyOpResultList(operation, startIndex, length, step); - } +PyOpResult PyOpResultList::getRawElement(intptr_t index) { + PyValue value(operation, mlirOperationGetResult(operation->get(), index)); + return PyOpResult(value); +} - PyOperationRef operation; -}; +PyOpResultList PyOpResultList::slice(intptr_t startIndex, intptr_t length, + intptr_t step) const { + return PyOpResultList(operation, startIndex, length, step); +} //------------------------------------------------------------------------------ // PyOpView @@ -1706,7 +1559,7 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList, } } -static MlirValue getUniqueResult(MlirOperation operation) { +MlirValue getUniqueResult(MlirOperation operation) { auto numResults = mlirOperationGetNumResults(operation); if (numResults != 1) { auto name = mlirIdentifierStr(mlirOperationGetName(operation)); @@ -1938,6 +1791,28 @@ PyOpView::PyOpView(const nb::object &operationObject) : operation(nb::cast(operationObject).getOperation()), operationObject(operation.getRef().getObject()) {} +//------------------------------------------------------------------------------ +// PyAsmState +//------------------------------------------------------------------------------ + +PyAsmState::PyAsmState(MlirValue value, bool useLocalScope) { + flags = mlirOpPrintingFlagsCreate(); + // The OpPrintingFlags are not exposed Python side, create locally and + // associate lifetime with the state. + if (useLocalScope) + mlirOpPrintingFlagsUseLocalScope(flags); + state = mlirAsmStateCreateForValue(value, flags); +} + +PyAsmState::PyAsmState(PyOperationBase &operation, bool useLocalScope) { + flags = mlirOpPrintingFlagsCreate(); + // The OpPrintingFlags are not exposed Python side, create locally and + // associate lifetime with the state. + if (useLocalScope) + mlirOpPrintingFlagsUseLocalScope(flags); + state = mlirAsmStateCreateForOperation(operation.getOperation().get(), flags); +} + //------------------------------------------------------------------------------ // PyInsertionPoint. //------------------------------------------------------------------------------ @@ -2319,420 +2194,318 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from, } } -namespace { - -/// Python wrapper for MlirBlockArgument. -class PyBlockArgument : public PyConcreteValue { -public: - static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; - static constexpr const char *pyClassName = "BlockArgument"; - using PyConcreteValue::PyConcreteValue; - - static void bindDerived(ClassTy &c) { - c.def_prop_ro( - "owner", - [](PyBlockArgument &self) { - return PyBlock(self.getParentOperation(), - mlirBlockArgumentGetOwner(self.get())); - }, - "Returns the block that owns this argument."); - c.def_prop_ro( - "arg_number", - [](PyBlockArgument &self) { - return mlirBlockArgumentGetArgNumber(self.get()); - }, - "Returns the position of this argument in the block's argument list."); - c.def( - "set_type", - [](PyBlockArgument &self, PyType type) { - return mlirBlockArgumentSetType(self.get(), type); - }, - nb::arg("type"), "Sets the type of this block argument."); - c.def( - "set_location", - [](PyBlockArgument &self, PyLocation loc) { - return mlirBlockArgumentSetLocation(self.get(), loc); - }, - nb::arg("loc"), "Sets the location of this block argument."); - } -}; - -/// A list of block arguments. Internally, these are stored as consecutive -/// elements, random access is cheap. The argument list is associated with the -/// operation that contains the block (detached blocks are not allowed in -/// Python bindings) and extends its lifetime. -class PyBlockArgumentList - : public Sliceable { -public: - static constexpr const char *pyClassName = "BlockArgumentList"; - using SliceableT = Sliceable; - - PyBlockArgumentList(PyOperationRef operation, MlirBlock block, - intptr_t startIndex = 0, intptr_t length = -1, - intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirBlockGetNumArguments(block) : length, - step), - operation(std::move(operation)), block(block) {} - - static void bindDerived(ClassTy &c) { - c.def_prop_ro( - "types", - [](PyBlockArgumentList &self) { - return getValueTypes(self, self.operation->getContext()); - }, - "Returns a list of types for all arguments in this argument list."); - } - -private: - /// Give the parent CRTP class access to hook implementations below. - friend class Sliceable; +void PyBlockArgument::bindDerived(ClassTy &c) { + c.def_prop_ro( + "owner", + [](PyBlockArgument &self) { + return PyBlock(self.getParentOperation(), + mlirBlockArgumentGetOwner(self.get())); + }, + "Returns the block that owns this argument."); + c.def_prop_ro( + "arg_number", + [](PyBlockArgument &self) { + return mlirBlockArgumentGetArgNumber(self.get()); + }, + "Returns the position of this argument in the block's argument list."); + c.def( + "set_type", + [](PyBlockArgument &self, PyType type) { + return mlirBlockArgumentSetType(self.get(), type); + }, + "type"_a, "Sets the type of this block argument."); + c.def( + "set_location", + [](PyBlockArgument &self, PyLocation loc) { + return mlirBlockArgumentSetLocation(self.get(), loc); + }, + "loc"_a, "Sets the location of this block argument."); +} - /// Returns the number of arguments in the list. - intptr_t getRawNumElements() { - operation->checkValid(); - return mlirBlockGetNumArguments(block); - } +PyBlockArgumentList::PyBlockArgumentList(PyOperationRef operation, + MlirBlock block, intptr_t startIndex, + intptr_t length, intptr_t step) + : Sliceable(startIndex, + length == -1 ? mlirBlockGetNumArguments(block) : length, step), + operation(std::move(operation)), block(block) {} + +void PyBlockArgumentList::bindDerived(ClassTy &c) { + c.def_prop_ro( + "types", + [](PyBlockArgumentList &self) { + return getValueTypes(self, self.operation->getContext()); + }, + "Returns a list of types for all arguments in this argument list."); +} - /// Returns `pos`-the element in the list. - PyBlockArgument getRawElement(intptr_t pos) { - MlirValue argument = mlirBlockGetArgument(block, pos); - return PyBlockArgument(operation, argument); - } +intptr_t PyBlockArgumentList::getRawNumElements() { + operation->checkValid(); + return mlirBlockGetNumArguments(block); +} - /// Returns a sublist of this list. - PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, - intptr_t step) { - return PyBlockArgumentList(operation, block, startIndex, length, step); - } +PyBlockArgument PyBlockArgumentList::getRawElement(intptr_t pos) const { + MlirValue argument = mlirBlockGetArgument(block, pos); + return PyBlockArgument(operation, argument); +} - PyOperationRef operation; - MlirBlock block; -}; - -/// A list of operation operands. Internally, these are stored as consecutive -/// elements, random access is cheap. The (returned) operand list is associated -/// with the operation whose operands these are, and thus extends the lifetime -/// of this operation. -class PyOpOperandList : public Sliceable { -public: - static constexpr const char *pyClassName = "OpOperandList"; - using SliceableT = Sliceable; - - PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirOperationGetNumOperands(operation->get()) - : length, - step), - operation(operation) {} - - void dunderSetItem(intptr_t index, PyValue value) { - index = wrapIndex(index); - mlirOperationSetOperand(operation->get(), index, value.get()); - } +PyBlockArgumentList PyBlockArgumentList::slice(intptr_t startIndex, + intptr_t length, + intptr_t step) const { + return PyBlockArgumentList(operation, block, startIndex, length, step); +} - static void bindDerived(ClassTy &c) { - c.def("__setitem__", &PyOpOperandList::dunderSetItem, nb::arg("index"), - nb::arg("value"), - "Sets the operand at the specified index to a new value."); - } +PyOpOperandList::PyOpOperandList(PyOperationRef operation, intptr_t startIndex, + intptr_t length, intptr_t step) + : Sliceable(startIndex, + length == -1 ? mlirOperationGetNumOperands(operation->get()) + : length, + step), + operation(operation) {} + +void PyOpOperandList::dunderSetItem(intptr_t index, PyValue value) { + index = wrapIndex(index); + mlirOperationSetOperand(operation->get(), index, value.get()); +} -private: - /// Give the parent CRTP class access to hook implementations below. - friend class Sliceable; +void PyOpOperandList::bindDerived(ClassTy &c) { + c.def("__setitem__", &PyOpOperandList::dunderSetItem, "index"_a, "value"_a, + "Sets the operand at the specified index to a new value."); +} - intptr_t getRawNumElements() { - operation->checkValid(); - return mlirOperationGetNumOperands(operation->get()); - } +intptr_t PyOpOperandList::getRawNumElements() { + operation->checkValid(); + return mlirOperationGetNumOperands(operation->get()); +} - PyValue getRawElement(intptr_t pos) { - MlirValue operand = mlirOperationGetOperand(operation->get(), pos); - MlirOperation owner; - if (mlirValueIsAOpResult(operand)) - owner = mlirOpResultGetOwner(operand); - else if (mlirValueIsABlockArgument(operand)) - owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); - else - assert(false && "Value must be an block arg or op result."); - PyOperationRef pyOwner = - PyOperation::forOperation(operation->getContext(), owner); - return PyValue(pyOwner, operand); - } +PyValue PyOpOperandList::getRawElement(intptr_t pos) { + MlirValue operand = mlirOperationGetOperand(operation->get(), pos); + MlirOperation owner; + if (mlirValueIsAOpResult(operand)) + owner = mlirOpResultGetOwner(operand); + else if (mlirValueIsABlockArgument(operand)) + owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); + else + assert(false && "Value must be an block arg or op result."); + PyOperationRef pyOwner = + PyOperation::forOperation(operation->getContext(), owner); + return PyValue(pyOwner, operand); +} - PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { - return PyOpOperandList(operation, startIndex, length, step); - } +PyOpOperandList PyOpOperandList::slice(intptr_t startIndex, intptr_t length, + intptr_t step) const { + return PyOpOperandList(operation, startIndex, length, step); +} - PyOperationRef operation; -}; - -/// A list of operation successors. Internally, these are stored as consecutive -/// elements, random access is cheap. The (returned) successor list is -/// associated with the operation whose successors these are, and thus extends -/// the lifetime of this operation. -class PyOpSuccessors : public Sliceable { -public: - static constexpr const char *pyClassName = "OpSuccessors"; - - PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirOperationGetNumSuccessors(operation->get()) - : length, - step), - operation(operation) {} - - void dunderSetItem(intptr_t index, PyBlock block) { - index = wrapIndex(index); - mlirOperationSetSuccessor(operation->get(), index, block.get()); - } +PyOpSuccessors::PyOpSuccessors(PyOperationRef operation, intptr_t startIndex, + intptr_t length, intptr_t step) + : Sliceable(startIndex, + length == -1 ? mlirOperationGetNumSuccessors(operation->get()) + : length, + step), + operation(operation) {} + +void PyOpSuccessors::dunderSetItem(intptr_t index, PyBlock block) { + index = wrapIndex(index); + mlirOperationSetSuccessor(operation->get(), index, block.get()); +} - static void bindDerived(ClassTy &c) { - c.def("__setitem__", &PyOpSuccessors::dunderSetItem, nb::arg("index"), - nb::arg("block"), "Sets the successor block at the specified index."); - } +void PyOpSuccessors::bindDerived(ClassTy &c) { + c.def("__setitem__", &PyOpSuccessors::dunderSetItem, "index"_a, "block"_a, + "Sets the successor block at the specified index."); +} -private: - /// Give the parent CRTP class access to hook implementations below. - friend class Sliceable; +intptr_t PyOpSuccessors::getRawNumElements() { + operation->checkValid(); + return mlirOperationGetNumSuccessors(operation->get()); +} - intptr_t getRawNumElements() { - operation->checkValid(); - return mlirOperationGetNumSuccessors(operation->get()); - } +PyBlock PyOpSuccessors::getRawElement(intptr_t pos) { + MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos); + return PyBlock(operation, block); +} - PyBlock getRawElement(intptr_t pos) { - MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos); - return PyBlock(operation, block); - } +PyOpSuccessors PyOpSuccessors::slice(intptr_t startIndex, intptr_t length, + intptr_t step) const { + return PyOpSuccessors(operation, startIndex, length, step); +} - PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) { - return PyOpSuccessors(operation, startIndex, length, step); - } +PyBlockSuccessors::PyBlockSuccessors(PyBlock block, PyOperationRef operation, + intptr_t startIndex, intptr_t length, + intptr_t step) + : Sliceable(startIndex, + length == -1 ? mlirBlockGetNumSuccessors(block.get()) : length, + step), + operation(operation), block(block) {} + +intptr_t PyBlockSuccessors::getRawNumElements() { + block.checkValid(); + return mlirBlockGetNumSuccessors(block.get()); +} - PyOperationRef operation; -}; - -/// A list of block successors. Internally, these are stored as consecutive -/// elements, random access is cheap. The (returned) successor list is -/// associated with the operation and block whose successors these are, and thus -/// extends the lifetime of this operation and block. -class PyBlockSuccessors : public Sliceable { -public: - static constexpr const char *pyClassName = "BlockSuccessors"; - - PyBlockSuccessors(PyBlock block, PyOperationRef operation, - intptr_t startIndex = 0, intptr_t length = -1, - intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirBlockGetNumSuccessors(block.get()) - : length, - step), - operation(operation), block(block) {} - -private: - /// Give the parent CRTP class access to hook implementations below. - friend class Sliceable; - - intptr_t getRawNumElements() { - block.checkValid(); - return mlirBlockGetNumSuccessors(block.get()); - } +PyBlock PyBlockSuccessors::getRawElement(intptr_t pos) { + MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos); + return PyBlock(operation, block); +} - PyBlock getRawElement(intptr_t pos) { - MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos); - return PyBlock(operation, block); - } +PyBlockSuccessors PyBlockSuccessors::slice(intptr_t startIndex, intptr_t length, + intptr_t step) const { + return PyBlockSuccessors(block, operation, startIndex, length, step); +} - PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) { - return PyBlockSuccessors(block, operation, startIndex, length, step); - } +PyBlockPredecessors::PyBlockPredecessors(PyBlock block, + PyOperationRef operation, + intptr_t startIndex, intptr_t length, + intptr_t step) + : Sliceable(startIndex, + length == -1 ? mlirBlockGetNumPredecessors(block.get()) + : length, + step), + operation(operation), block(block) {} + +intptr_t PyBlockPredecessors::getRawNumElements() { + block.checkValid(); + return mlirBlockGetNumPredecessors(block.get()); +} - PyOperationRef operation; - PyBlock block; -}; - -/// A list of block predecessors. The (returned) predecessor list is -/// associated with the operation and block whose predecessors these are, and -/// thus extends the lifetime of this operation and block. -/// -/// WARNING: This Sliceable is more expensive than the others here because -/// mlirBlockGetPredecessor actually iterates the use-def chain (of block -/// operands) anew for each indexed access. -class PyBlockPredecessors : public Sliceable { -public: - static constexpr const char *pyClassName = "BlockPredecessors"; - - PyBlockPredecessors(PyBlock block, PyOperationRef operation, - intptr_t startIndex = 0, intptr_t length = -1, - intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirBlockGetNumPredecessors(block.get()) - : length, - step), - operation(operation), block(block) {} - -private: - /// Give the parent CRTP class access to hook implementations below. - friend class Sliceable; - - intptr_t getRawNumElements() { - block.checkValid(); - return mlirBlockGetNumPredecessors(block.get()); - } +PyBlock PyBlockPredecessors::getRawElement(intptr_t pos) { + MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos); + return PyBlock(operation, block); +} - PyBlock getRawElement(intptr_t pos) { - MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos); - return PyBlock(operation, block); - } +PyBlockPredecessors PyBlockPredecessors::slice(intptr_t startIndex, + intptr_t length, + intptr_t step) const { + return PyBlockPredecessors(block, operation, startIndex, length, step); +} - PyBlockPredecessors slice(intptr_t startIndex, intptr_t length, - intptr_t step) { - return PyBlockPredecessors(block, operation, startIndex, length, step); +nb::typed +PyOpAttributeMap::dunderGetItemNamed(const std::string &name) { + MlirAttribute attr = + mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)); + if (mlirAttributeIsNull(attr)) { + throw nb::key_error("attempt to access a non-existent attribute"); } + return PyAttribute(operation->getContext(), attr).maybeDownCast(); +} - PyOperationRef operation; - PyBlock block; -}; - -/// A list of operation attributes. Can be indexed by name, producing -/// attributes, or by index, producing named attributes. -class PyOpAttributeMap { -public: - PyOpAttributeMap(PyOperationRef operation) - : operation(std::move(operation)) {} - - nb::typed - dunderGetItemNamed(const std::string &name) { - MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), - toMlirStringRef(name)); - if (mlirAttributeIsNull(attr)) { - throw nb::key_error("attempt to access a non-existent attribute"); - } - return PyAttribute(operation->getContext(), attr).maybeDownCast(); +PyNamedAttribute PyOpAttributeMap::dunderGetItemIndexed(intptr_t index) { + if (index < 0) { + index += dunderLen(); } - - PyNamedAttribute dunderGetItemIndexed(intptr_t index) { - if (index < 0) { - index += dunderLen(); - } - if (index < 0 || index >= dunderLen()) { - throw nb::index_error("attempt to access out of bounds attribute"); - } - MlirNamedAttribute namedAttr = - mlirOperationGetAttribute(operation->get(), index); - return PyNamedAttribute( - namedAttr.attribute, - std::string(mlirIdentifierStr(namedAttr.name).data, - mlirIdentifierStr(namedAttr.name).length)); + if (index < 0 || index >= dunderLen()) { + throw nb::index_error("attempt to access out of bounds attribute"); } + MlirNamedAttribute namedAttr = + mlirOperationGetAttribute(operation->get(), index); + return PyNamedAttribute( + namedAttr.attribute, + std::string(mlirIdentifierStr(namedAttr.name).data, + mlirIdentifierStr(namedAttr.name).length)); +} - void dunderSetItem(const std::string &name, const PyAttribute &attr) { - mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), - attr); - } +void PyOpAttributeMap::dunderSetItem(const std::string &name, + const PyAttribute &attr) { + mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), + attr); +} - void dunderDelItem(const std::string &name) { - int removed = mlirOperationRemoveAttributeByName(operation->get(), - toMlirStringRef(name)); - if (!removed) - throw nb::key_error("attempt to delete a non-existent attribute"); - } +void PyOpAttributeMap::dunderDelItem(const std::string &name) { + int removed = mlirOperationRemoveAttributeByName(operation->get(), + toMlirStringRef(name)); + if (!removed) + throw nb::key_error("attempt to delete a non-existent attribute"); +} - intptr_t dunderLen() { - return mlirOperationGetNumAttributes(operation->get()); - } +intptr_t PyOpAttributeMap::dunderLen() { + return mlirOperationGetNumAttributes(operation->get()); +} - bool dunderContains(const std::string &name) { - return !mlirAttributeIsNull(mlirOperationGetAttributeByName( - operation->get(), toMlirStringRef(name))); - } +bool PyOpAttributeMap::dunderContains(const std::string &name) { + return !mlirAttributeIsNull( + mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name))); +} - static void - forEachAttr(MlirOperation op, - llvm::function_ref fn) { - intptr_t n = mlirOperationGetNumAttributes(op); - for (intptr_t i = 0; i < n; ++i) { - MlirNamedAttribute na = mlirOperationGetAttribute(op, i); - MlirStringRef name = mlirIdentifierStr(na.name); - fn(name, na.attribute); - } +void PyOpAttributeMap::forEachAttr( + MlirOperation op, + llvm::function_ref fn) { + intptr_t n = mlirOperationGetNumAttributes(op); + for (intptr_t i = 0; i < n; ++i) { + MlirNamedAttribute na = mlirOperationGetAttribute(op, i); + MlirStringRef name = mlirIdentifierStr(na.name); + fn(name, na.attribute); } +} - static void bind(nb::module_ &m) { - nb::class_(m, "OpAttributeMap") - .def("__contains__", &PyOpAttributeMap::dunderContains, nb::arg("name"), - "Checks if an attribute with the given name exists in the map.") - .def("__len__", &PyOpAttributeMap::dunderLen, - "Returns the number of attributes in the map.") - .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed, - nb::arg("name"), "Gets an attribute by name.") - .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed, - nb::arg("index"), "Gets a named attribute by index.") - .def("__setitem__", &PyOpAttributeMap::dunderSetItem, nb::arg("name"), - nb::arg("attr"), "Sets an attribute with the given name.") - .def("__delitem__", &PyOpAttributeMap::dunderDelItem, nb::arg("name"), - "Deletes an attribute with the given name.") - .def( - "__iter__", - [](PyOpAttributeMap &self) { - nb::list keys; - PyOpAttributeMap::forEachAttr( - self.operation->get(), - [&](MlirStringRef name, MlirAttribute) { - keys.append(nb::str(name.data, name.length)); - }); - return nb::iter(keys); - }, - "Iterates over attribute names.") - .def( - "keys", - [](PyOpAttributeMap &self) { - nb::list out; - PyOpAttributeMap::forEachAttr( - self.operation->get(), - [&](MlirStringRef name, MlirAttribute) { - out.append(nb::str(name.data, name.length)); - }); - return out; - }, - "Returns a list of attribute names.") - .def( - "values", - [](PyOpAttributeMap &self) { - nb::list out; - PyOpAttributeMap::forEachAttr( - self.operation->get(), - [&](MlirStringRef, MlirAttribute attr) { - out.append(PyAttribute(self.operation->getContext(), attr) - .maybeDownCast()); - }); - return out; - }, - "Returns a list of attribute values.") - .def( - "items", - [](PyOpAttributeMap &self) { - nb::list out; - PyOpAttributeMap::forEachAttr( - self.operation->get(), - [&](MlirStringRef name, MlirAttribute attr) { - out.append(nb::make_tuple( - nb::str(name.data, name.length), - PyAttribute(self.operation->getContext(), attr) - .maybeDownCast())); - }); - return out; - }, - "Returns a list of `(name, attribute)` tuples."); - } +void PyOpAttributeMap::bind(nb::module_ &m) { + nb::class_(m, "OpAttributeMap") + .def("__contains__", &PyOpAttributeMap::dunderContains, "name"_a, + "Checks if an attribute with the given name exists in the map.") + .def("__len__", &PyOpAttributeMap::dunderLen, + "Returns the number of attributes in the map.") + .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed, "name"_a, + "Gets an attribute by name.") + .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed, "index"_a, + "Gets a named attribute by index.") + .def("__setitem__", &PyOpAttributeMap::dunderSetItem, "name"_a, "attr"_a, + "Sets an attribute with the given name.") + .def("__delitem__", &PyOpAttributeMap::dunderDelItem, "name"_a, + "Deletes an attribute with the given name.") + .def( + "__iter__", + [](PyOpAttributeMap &self) { + nb::list keys; + PyOpAttributeMap::forEachAttr( + self.operation->get(), [&](MlirStringRef name, MlirAttribute) { + keys.append(nb::str(name.data, name.length)); + }); + return nb::iter(keys); + }, + "Iterates over attribute names.") + .def( + "keys", + [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), [&](MlirStringRef name, MlirAttribute) { + out.append(nb::str(name.data, name.length)); + }); + return out; + }, + "Returns a list of attribute names.") + .def( + "values", + [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), [&](MlirStringRef, MlirAttribute attr) { + out.append(PyAttribute(self.operation->getContext(), attr) + .maybeDownCast()); + }); + return out; + }, + "Returns a list of attribute values.") + .def( + "items", + [](PyOpAttributeMap &self) { + nb::list out; + PyOpAttributeMap::forEachAttr( + self.operation->get(), + [&](MlirStringRef name, MlirAttribute attr) { + out.append(nb::make_tuple( + nb::str(name.data, name.length), + PyAttribute(self.operation->getContext(), attr) + .maybeDownCast())); + }); + return out; + }, + "Returns a list of `(name, attribute)` tuples."); +} -private: - PyOperationRef operation; -}; +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir +namespace { // see // https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h @@ -2799,6 +2572,8 @@ PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) { #endif // Python 3.9.0b1 +using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN; + MlirLocation tracebackToLocation(MlirContext ctx) { size_t framesLimit = PyGlobals::get().getTracebackLoc().locTracebackFramesLimit(); @@ -2882,30 +2657,151 @@ maybeGetTracebackLocation(const std::optional &location) { PyMlirContextRef ref = PyMlirContext::forContext(ctx.get()); return {ref, mlirLoc}; } - } // namespace +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { + +void populateRoot(nb::module_ &m) { + m.attr("T") = nb::type_var("T"); + m.attr("U") = nb::type_var("U"); + + nb::class_(m, "_Globals") + .def_prop_rw("dialect_search_modules", + &PyGlobals::getDialectSearchPrefixes, + &PyGlobals::setDialectSearchPrefixes) + .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix, + "module_name"_a) + .def( + "_check_dialect_module_loaded", + [](PyGlobals &self, const std::string &dialectNamespace) { + return self.loadDialectModule(dialectNamespace); + }, + "dialect_namespace"_a) + .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, + "dialect_namespace"_a, "dialect_class"_a, + "Testing hook for directly registering a dialect") + .def("_register_operation_impl", &PyGlobals::registerOperationImpl, + "operation_name"_a, "operation_class"_a, nb::kw_only(), + "replace"_a = false, + "Testing hook for directly registering an operation") + .def("loc_tracebacks_enabled", + [](PyGlobals &self) { + return self.getTracebackLoc().locTracebacksEnabled(); + }) + .def("set_loc_tracebacks_enabled", + [](PyGlobals &self, bool enabled) { + self.getTracebackLoc().setLocTracebacksEnabled(enabled); + }) + .def("loc_tracebacks_frame_limit", + [](PyGlobals &self) { + return self.getTracebackLoc().locTracebackFramesLimit(); + }) + .def("set_loc_tracebacks_frame_limit", + [](PyGlobals &self, std::optional n) { + self.getTracebackLoc().setLocTracebackFramesLimit( + n.value_or(PyGlobals::TracebackLoc::kMaxFrames)); + }) + .def("register_traceback_file_inclusion", + [](PyGlobals &self, const std::string &filename) { + self.getTracebackLoc().registerTracebackFileInclusion(filename); + }) + .def("register_traceback_file_exclusion", + [](PyGlobals &self, const std::string &filename) { + self.getTracebackLoc().registerTracebackFileExclusion(filename); + }); + + // Aside from making the globals accessible to python, having python manage + // it is necessary to make sure it is destroyed (and releases its python + // resources) properly. + m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership); + + // Registration decorators. + m.def( + "register_dialect", + [](nb::type_object pyClass) { + std::string dialectNamespace = + nb::cast(pyClass.attr("DIALECT_NAMESPACE")); + PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); + return pyClass; + }, + "dialect_class"_a, + "Class decorator for registering a custom Dialect wrapper"); + m.def( + "register_operation", + [](const nb::type_object &dialectClass, bool replace) -> nb::object { + return nb::cpp_function( + [dialectClass, + replace](nb::type_object opClass) -> nb::type_object { + std::string operationName = + nb::cast(opClass.attr("OPERATION_NAME")); + PyGlobals::get().registerOperationImpl(operationName, opClass, + replace); + // Dict-stuff the new opClass by name onto the dialect class. + nb::object opClassName = opClass.attr("__name__"); + dialectClass.attr(opClassName) = opClass; + return opClass; + }); + }, + // clang-format off + nb::sig("def register_operation(dialect_class: type, *, replace: bool = False) " + "-> typing.Callable[[type[T]], type[T]]"), + // clang-format on + "dialect_class"_a, nb::kw_only(), "replace"_a = false, + "Produce a class decorator for registering an Operation class as part of " + "a dialect"); + m.def( + MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR, + [](MlirTypeID mlirTypeID, bool replace) -> nb::object { + return nb::cpp_function([mlirTypeID, replace]( + nb::callable typeCaster) -> nb::object { + PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace); + return typeCaster; + }); + }, + // clang-format off + nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) " + "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"), + // clang-format on + "typeid"_a, nb::kw_only(), "replace"_a = false, + "Register a type caster for casting MLIR types to custom user types."); + m.def( + MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR, + [](MlirTypeID mlirTypeID, bool replace) -> nb::object { + return nb::cpp_function( + [mlirTypeID, replace](nb::callable valueCaster) -> nb::object { + PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster, + replace); + return valueCaster; + }); + }, + // clang-format off + nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) " + "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"), + // clang-format on + "typeid"_a, nb::kw_only(), "replace"_a = false, + "Register a value caster for casting MLIR values to custom user values."); +} + //------------------------------------------------------------------------------ // Populates the core exports of the 'ir' submodule. //------------------------------------------------------------------------------ - -void mlir::python::populateIRCore(nb::module_ &m) { - // disable leak warnings which tend to be false positives. - nb::set_leak_warnings(false); +void populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- // Enums. //---------------------------------------------------------------------------- - nb::enum_(m, "DiagnosticSeverity") + nb::enum_(m, "DiagnosticSeverity") .value("ERROR", MlirDiagnosticError) .value("WARNING", MlirDiagnosticWarning) .value("NOTE", MlirDiagnosticNote) .value("REMARK", MlirDiagnosticRemark); - nb::enum_(m, "WalkOrder") + nb::enum_(m, "WalkOrder") .value("PRE_ORDER", MlirWalkPreOrder) .value("POST_ORDER", MlirWalkPostOrder); - nb::enum_(m, "WalkResult") + nb::enum_(m, "WalkResult") .value("ADVANCE", MlirWalkResultAdvance) .value("INTERRUPT", MlirWalkResultInterrupt) .value("SKIP", MlirWalkResultSkip); @@ -2961,9 +2857,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { "handling.") .def("__enter__", &PyDiagnosticHandler::contextEnter, "Enters the diagnostic handler as a context manager.") - .def("__exit__", &PyDiagnosticHandler::contextExit, - nb::arg("exc_type").none(), nb::arg("exc_value").none(), - nb::arg("traceback").none(), + .def("__exit__", &PyDiagnosticHandler::contextExit, "exc_type"_a.none(), + "exc_value"_a.none(), "traceback"_a.none(), "Exits the diagnostic handler context manager."); // Expose DefaultThreadPool to python @@ -3008,8 +2903,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { "Creates a Context from a capsule wrapping MlirContext.") .def("__enter__", &PyMlirContext::contextEnter, "Enters the context as a context manager.") - .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(), - nb::arg("exc_value").none(), nb::arg("traceback").none(), + .def("__exit__", &PyMlirContext::contextExit, "exc_type"_a.none(), + "exc_value"_a.none(), "traceback"_a.none(), "Exits the context manager.") .def_prop_ro_static( "current", @@ -3041,7 +2936,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { } return PyDialectDescriptor(self.getRef(), dialect); }, - nb::arg("dialect_name"), + "dialect_name"_a, "Gets or loads a dialect by name, returning its descriptor object.") .def_prop_rw( "allow_unregistered_dialects", @@ -3053,14 +2948,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, "Controls whether unregistered dialects are allowed in this context.") .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, - nb::arg("callback"), + "callback"_a, "Attaches a diagnostic handler that will receive callbacks.") .def( "enable_multithreading", [](PyMlirContext &self, bool enable) { mlirContextEnableMultithreading(self.get(), enable); }, - nb::arg("enable"), + "enable"_a, R"( Enables or disables multi-threading support in the context. @@ -3105,7 +3000,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { return mlirContextIsRegisteredOperation( self.get(), MlirStringRef{name.data(), name.size()}); }, - nb::arg("operation_name"), + "operation_name"_a, R"( Checks whether an operation with the given name is registered. @@ -3119,7 +3014,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](PyMlirContext &self, PyDialectRegistry ®istry) { mlirContextAppendDialectRegistry(self.get(), registry); }, - nb::arg("registry"), + "registry"_a, R"( Appends the contents of a dialect registry to the context. @@ -3195,7 +3090,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Mapping of PyDialect //---------------------------------------------------------------------------- nb::class_(m, "Dialect") - .def(nb::init(), nb::arg("descriptor"), + .def(nb::init(), "descriptor"_a, "Creates a Dialect from a DialectDescriptor.") .def_prop_ro( "descriptor", [](PyDialect &self) { return self.getDescriptor(); }, @@ -3234,8 +3129,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { "Creates a Location from a capsule wrapping MlirLocation.") .def("__enter__", &PyLocation::contextEnter, "Enters the location as a context manager.") - .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(), - nb::arg("exc_value").none(), nb::arg("traceback").none(), + .def("__exit__", &PyLocation::contextExit, "exc_type"_a.none(), + "exc_value"_a.none(), "traceback"_a.none(), "Exits the location context manager.") .def( "__eq__", @@ -3264,7 +3159,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyLocation(context->getRef(), mlirLocationUnknownGet(context->get())); }, - nb::arg("context") = nb::none(), + "context"_a = nb::none(), "Gets a Location representing an unknown location.") .def_static( "callsite", @@ -3279,7 +3174,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyLocation(context->getRef(), mlirLocationCallSiteGet(callee.get(), caller)); }, - nb::arg("callee"), nb::arg("frames"), nb::arg("context") = nb::none(), + "callee"_a, "frames"_a, "context"_a = nb::none(), "Gets a Location representing a caller and callsite.") .def("is_a_callsite", mlirLocationIsACallSite, "Returns True if this location is a CallSiteLoc.") @@ -3306,8 +3201,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirLocationFileLineColGet( context->get(), toMlirStringRef(filename), line, col)); }, - nb::arg("filename"), nb::arg("line"), nb::arg("col"), - nb::arg("context") = nb::none(), + "filename"_a, "line"_a, "col"_a, "context"_a = nb::none(), "Gets a Location representing a file, line and column.") .def_static( "file", @@ -3318,9 +3212,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { context->get(), toMlirStringRef(filename), startLine, startCol, endLine, endCol)); }, - nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"), - nb::arg("end_line"), nb::arg("end_col"), - nb::arg("context") = nb::none(), + "filename"_a, "start_line"_a, "start_col"_a, "end_line"_a, + "end_col"_a, "context"_a = nb::none(), "Gets a Location representing a file, line and column range.") .def("is_a_file", mlirLocationIsAFileLineColRange, "Returns True if this location is a FileLineColLoc.") @@ -3353,8 +3246,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { metadata ? metadata->get() : MlirAttribute{0}); return PyLocation(context->getRef(), location); }, - nb::arg("locations"), nb::arg("metadata") = nb::none(), - nb::arg("context") = nb::none(), + "locations"_a, "metadata"_a = nb::none(), "context"_a = nb::none(), "Gets a Location representing a fused location with optional " "metadata.") .def("is_a_fused", mlirLocationIsAFused, @@ -3384,8 +3276,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { childLoc ? childLoc->get() : mlirLocationUnknownGet(context->get()))); }, - nb::arg("name"), nb::arg("childLoc") = nb::none(), - nb::arg("context") = nb::none(), + "name"_a, "childLoc"_a = nb::none(), "context"_a = nb::none(), "Gets a Location representing a named location with optional child " "location.") .def("is_a_name", mlirLocationIsAName, @@ -3409,7 +3300,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyLocation(context->getRef(), mlirLocationFromAttribute(attribute)); }, - nb::arg("attribute"), nb::arg("context") = nb::none(), + "attribute"_a, "context"_a = nb::none(), "Gets a Location from a `LocationAttr`.") .def_prop_ro( "context", @@ -3429,7 +3320,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](PyLocation &self, std::string message) { mlirEmitError(self, message.c_str()); }, - nb::arg("message"), + "message"_a, R"( Emits an error diagnostic at this location. @@ -3474,8 +3365,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { throw MLIRError("Unable to parse module assembly", errors.take()); return PyModule::forModule(module).releaseObject(); }, - nb::arg("asm"), nb::arg("context") = nb::none(), - kModuleParseDocstring) + "asm"_a, "context"_a = nb::none(), kModuleParseDocstring) .def_static( "parse", [](nb::bytes moduleAsm, DefaultingPyMlirContext context) @@ -3487,8 +3377,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { throw MLIRError("Unable to parse module assembly", errors.take()); return PyModule::forModule(module).releaseObject(); }, - nb::arg("asm"), nb::arg("context") = nb::none(), - kModuleParseDocstring) + "asm"_a, "context"_a = nb::none(), kModuleParseDocstring) .def_static( "parseFile", [](const std::string &path, DefaultingPyMlirContext context) @@ -3500,8 +3389,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { throw MLIRError("Unable to parse module assembly", errors.take()); return PyModule::forModule(module).releaseObject(); }, - nb::arg("path"), nb::arg("context") = nb::none(), - kModuleParseDocstring) + "path"_a, "context"_a = nb::none(), kModuleParseDocstring) .def_static( "create", [](const std::optional &loc) @@ -3510,7 +3398,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { MlirModule module = mlirModuleCreateEmpty(pyLoc.get()); return PyModule::forModule(module).releaseObject(); }, - nb::arg("loc") = nb::none(), "Creates an empty module.") + "loc"_a = nb::none(), "Creates an empty module.") .def_prop_ro( "context", [](PyModule &self) -> nb::typed { @@ -3689,8 +3577,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def("print", nb::overload_cast( &PyOperationBase::print), - nb::arg("state"), nb::arg("file") = nb::none(), - nb::arg("binary") = false, + "state"_a, "file"_a = nb::none(), "binary"_a = false, R"( Prints the assembly form of the operation to a file like object. @@ -3703,15 +3590,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { bool, bool, bool, bool, bool, bool, nb::object, bool, bool>(&PyOperationBase::print), // Careful: Lots of arguments must match up with print method. - nb::arg("large_elements_limit") = nb::none(), - nb::arg("large_resource_limit") = nb::none(), - nb::arg("enable_debug_info") = false, - nb::arg("pretty_debug_info") = false, - nb::arg("print_generic_op_form") = false, - nb::arg("use_local_scope") = false, - nb::arg("use_name_loc_as_prefix") = false, - nb::arg("assume_verified") = false, nb::arg("file") = nb::none(), - nb::arg("binary") = false, nb::arg("skip_regions") = false, + "large_elements_limit"_a = nb::none(), + "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false, + "pretty_debug_info"_a = false, "print_generic_op_form"_a = false, + "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false, + "assume_verified"_a = false, "file"_a = nb::none(), + "binary"_a = false, "skip_regions"_a = false, R"( Prints the assembly form of the operation to a file like object. @@ -3743,8 +3627,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { file: The file like object to write to. Defaults to sys.stdout. binary: Whether to write bytes (True) or str (False). Defaults to False. skip_regions: Whether to skip printing regions. Defaults to False.)") - .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"), - nb::arg("desired_version") = nb::none(), + .def("write_bytecode", &PyOperationBase::writeBytecode, "file"_a, + "desired_version"_a = nb::none(), R"( Write the bytecode form of the operation to a file like object. @@ -3755,15 +3639,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { The bytecode writer status.)") .def("get_asm", &PyOperationBase::getAsm, // Careful: Lots of arguments must match up with get_asm method. - nb::arg("binary") = false, - nb::arg("large_elements_limit") = nb::none(), - nb::arg("large_resource_limit") = nb::none(), - nb::arg("enable_debug_info") = false, - nb::arg("pretty_debug_info") = false, - nb::arg("print_generic_op_form") = false, - nb::arg("use_local_scope") = false, - nb::arg("use_name_loc_as_prefix") = false, - nb::arg("assume_verified") = false, nb::arg("skip_regions") = false, + "binary"_a = false, "large_elements_limit"_a = nb::none(), + "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false, + "pretty_debug_info"_a = false, "print_generic_op_form"_a = false, + "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false, + "assume_verified"_a = false, "skip_regions"_a = false, R"( Gets the assembly form of the operation with all options available. @@ -3778,14 +3658,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def("verify", &PyOperationBase::verify, "Verify the operation. Raises MLIRError if verification fails, and " "returns true otherwise.") - .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"), + .def("move_after", &PyOperationBase::moveAfter, "other"_a, "Puts self immediately after the other operation in its parent " "block.") - .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"), + .def("move_before", &PyOperationBase::moveBefore, "other"_a, "Puts self immediately before the other operation in its parent " "block.") - .def("is_before_in_block", &PyOperationBase::isBeforeInBlock, - nb::arg("other"), + .def("is_before_in_block", &PyOperationBase::isBeforeInBlock, "other"_a, R"( Checks if this operation is before another in the same block. @@ -3800,7 +3679,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { const nb::object &ip) -> nb::typed { return self.getOperation().clone(ip); }, - nb::arg("ip") = nb::none(), + "ip"_a = nb::none(), R"( Creates a deep copy of the operation. @@ -3838,8 +3717,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { Note: After erasing, any Python references to the operation become invalid.)") - .def("walk", &PyOperationBase::walk, nb::arg("callback"), - nb::arg("walk_order") = MlirWalkPostOrder, + .def("walk", &PyOperationBase::walk, "callback"_a, + "walk_order"_a = PyWalkOrder::MlirWalkPostOrder, // clang-format off nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"), // clang-format on @@ -3877,11 +3756,10 @@ void mlir::python::populateIRCore(nb::module_ &m) { successors, regions, pyLoc, maybeIp, inferType); }, - nb::arg("name"), nb::arg("results") = nb::none(), - nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(), - nb::arg("successors") = nb::none(), nb::arg("regions") = 0, - nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(), - nb::arg("infer_type") = false, + "name"_a, "results"_a = nb::none(), "operands"_a = nb::none(), + "attributes"_a = nb::none(), "successors"_a = nb::none(), + "regions"_a = 0, "loc"_a = nb::none(), "ip"_a = nb::none(), + "infer_type"_a = false, R"( Creates a new operation. @@ -3905,8 +3783,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyOperation::parse(context->getRef(), sourceStr, sourceName) ->createOpView(); }, - nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "", - nb::arg("context") = nb::none(), + "source"_a, nb::kw_only(), "source_name"_a = "", + "context"_a = nb::none(), "Parses an operation. Supports both text assembly format and binary " "bytecode format.") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule, @@ -3952,8 +3830,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { auto opViewClass = nb::class_(m, "OpView") - .def(nb::init>(), - nb::arg("operation")) + .def(nb::init>(), "operation"_a) .def( "__init__", [](PyOpView *self, std::string_view name, @@ -3972,14 +3849,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { resultSegmentSpecObj, resultTypeList, operandList, attributes, successors, regions, pyLoc, maybeIp)); }, - nb::arg("name"), nb::arg("opRegionSpec"), - nb::arg("operandSegmentSpecObj") = nb::none(), - nb::arg("resultSegmentSpecObj") = nb::none(), - nb::arg("results") = nb::none(), nb::arg("operands") = nb::none(), - nb::arg("attributes") = nb::none(), - nb::arg("successors") = nb::none(), - nb::arg("regions") = nb::none(), nb::arg("loc") = nb::none(), - nb::arg("ip") = nb::none()) + "name"_a, "opRegionSpec"_a, + "operandSegmentSpecObj"_a = nb::none(), + "resultSegmentSpecObj"_a = nb::none(), "results"_a = nb::none(), + "operands"_a = nb::none(), "attributes"_a = nb::none(), + "successors"_a = nb::none(), "regions"_a = nb::none(), + "loc"_a = nb::none(), "ip"_a = nb::none()) .def_prop_ro( "operation", [](PyOpView &self) -> nb::typed { @@ -4025,10 +3900,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { operandList, attributes, successors, regions, pyLoc, maybeIp); }, - nb::arg("cls"), nb::arg("results") = nb::none(), - nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(), - nb::arg("successors") = nb::none(), nb::arg("regions") = nb::none(), - nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(), + "cls"_a, "results"_a = nb::none(), "operands"_a = nb::none(), + "attributes"_a = nb::none(), "successors"_a = nb::none(), + "regions"_a = nb::none(), "loc"_a = nb::none(), "ip"_a = nb::none(), "Builds a specific, generated OpView based on class level attributes."); opViewClass.attr("parse") = classmethod( [](const nb::object &cls, const std::string &sourceStr, @@ -4052,8 +3926,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { parsedOpName + "'"); return PyOpView::constructDerived(cls, parsed.getObject()); }, - nb::arg("cls"), nb::arg("source"), nb::kw_only(), - nb::arg("source_name") = "", nb::arg("context") = nb::none(), + "cls"_a, "source"_a, nb::kw_only(), "source_name"_a = "", + "context"_a = nb::none(), "Parses a specific, generated OpView based on class level attributes."); //---------------------------------------------------------------------------- @@ -4136,7 +4010,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](PyBlock &self, unsigned index) { return mlirBlockEraseArgument(self.get(), index); }, - nb::arg("index"), + "index"_a, R"( Erases the argument at the specified index. @@ -4157,8 +4031,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirRegionInsertOwnedBlock(parent, 0, block); return PyBlock(parent.getParentOperation(), block); }, - nb::arg("parent"), nb::arg("arg_types") = nb::list(), - nb::arg("arg_locs") = std::nullopt, + "parent"_a, "arg_types"_a = nb::list(), "arg_locs"_a = std::nullopt, "Creates and returns a new Block at the beginning of the given " "region (with given argument types and locations).") .def( @@ -4169,7 +4042,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirBlockDetach(b); mlirRegionAppendOwnedBlock(region.get(), b); }, - nb::arg("region"), + "region"_a, R"( Appends this block to a region. @@ -4188,8 +4061,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirRegionInsertOwnedBlockBefore(region, self.get(), block); return PyBlock(self.getParentOperation(), block); }, - nb::arg("arg_types"), nb::kw_only(), - nb::arg("arg_locs") = std::nullopt, + "arg_types"_a, nb::kw_only(), "arg_locs"_a = std::nullopt, "Creates and returns a new Block before this block " "(with given argument types and locations).") .def( @@ -4203,8 +4075,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirRegionInsertOwnedBlockAfter(region, self.get(), block); return PyBlock(self.getParentOperation(), block); }, - nb::arg("arg_types"), nb::kw_only(), - nb::arg("arg_locs") = std::nullopt, + "arg_types"_a, nb::kw_only(), "arg_locs"_a = std::nullopt, "Creates and returns a new Block after this block " "(with given argument types and locations).") .def( @@ -4253,7 +4124,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { operation.getOperation().setAttached( self.getParentOperation().getObject()); }, - nb::arg("operation"), + "operation"_a, R"( Appends an operation to this block. @@ -4279,13 +4150,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- nb::class_(m, "InsertionPoint") - .def(nb::init(), nb::arg("block"), + .def(nb::init(), "block"_a, "Inserts after the last operation but still inside the block.") .def("__enter__", &PyInsertionPoint::contextEnter, "Enters the insertion point as a context manager.") - .def("__exit__", &PyInsertionPoint::contextExit, - nb::arg("exc_type").none(), nb::arg("exc_value").none(), - nb::arg("traceback").none(), + .def("__exit__", &PyInsertionPoint::contextExit, "exc_type"_a.none(), + "exc_value"_a.none(), "traceback"_a.none(), "Exits the insertion point context manager.") .def_prop_ro_static( "current", @@ -4298,10 +4168,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::sig("def current(/) -> InsertionPoint"), "Gets the InsertionPoint bound to the current thread or raises " "ValueError if none has been set.") - .def(nb::init(), nb::arg("beforeOperation"), + .def(nb::init(), "beforeOperation"_a, "Inserts before a referenced operation.") - .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, - nb::arg("block"), + .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, "block"_a, R"( Creates an insertion point at the beginning of a block. @@ -4311,7 +4180,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { Returns: An InsertionPoint at the block's beginning.)") .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, - nb::arg("block"), + "block"_a, R"( Creates an insertion point before a block's terminator. @@ -4323,7 +4192,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { Raises: ValueError: If the block has no terminator.)") - .def_static("after", &PyInsertionPoint::after, nb::arg("operation"), + .def_static("after", &PyInsertionPoint::after, "operation"_a, R"( Creates an insertion point immediately after an operation. @@ -4332,7 +4201,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { Returns: An InsertionPoint after the operation.)") - .def("insert", &PyInsertionPoint::insert, nb::arg("operation"), + .def("insert", &PyInsertionPoint::insert, "operation"_a, R"( Inserts an operation at this insertion point. @@ -4360,7 +4229,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::class_(m, "Attribute") // Delegate to the PyAttribute copy constructor, which will also lifetime // extend the backing context which owns the MlirAttribute. - .def(nb::init(), nb::arg("cast_from_type"), + .def(nb::init(), "cast_from_type"_a, "Casts the passed attribute to the generic `Attribute`.") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule, "Gets a capsule wrapping the MlirAttribute.") @@ -4378,7 +4247,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { throw MLIRError("Unable to parse attribute", errors.take()); return PyAttribute(context.get()->getRef(), attr).maybeDownCast(); }, - nb::arg("asm"), nb::arg("context") = nb::none(), + "asm"_a, "context"_a = nb::none(), "Parses an attribute from an assembly form. Raises an `MLIRError` on " "failure.") .def_prop_ro( @@ -4504,7 +4373,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::class_(m, "Type") // Delegate to the PyType copy constructor, which will also lifetime // extend the backing context which owns the MlirType. - .def(nb::init(), nb::arg("cast_from_type"), + .def(nb::init(), "cast_from_type"_a, "Casts the passed type to the generic `Type`.") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule, "Gets a capsule wrapping the `MlirType`.") @@ -4521,7 +4390,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { throw MLIRError("Unable to parse type", errors.take()); return PyType(context.get()->getRef(), type).maybeDownCast(); }, - nb::arg("asm"), nb::arg("context") = nb::none(), + "asm"_a, "context"_a = nb::none(), R"( Parses the assembly form of a type. @@ -4539,7 +4408,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { "Compares two types for equality.") .def( "__eq__", [](PyType &self, nb::object &other) { return false; }, - nb::arg("other").none(), + "other"_a.none(), "Compares type with non-type object (always returns False).") .def( "__hash__", @@ -4625,11 +4494,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of Value. //---------------------------------------------------------------------------- - m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type")); + m.attr("_T") = nb::type_var("_T", "bound"_a = m.attr("Type")); nb::class_(m, "Value", nb::is_generic(), nb::sig("class Value(Generic[_T])")) - .def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value"), + .def(nb::init(), nb::keep_alive<0, 1>(), "value"_a, "Creates a Value reference from another `Value`.") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule, "Gets a capsule wrapping the `MlirValue`.") @@ -4724,8 +4593,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirAsmStateDestroy(valueState); return printAccum.join(); }, - nb::arg("use_local_scope") = false, - nb::arg("use_name_loc_as_prefix") = false, + "use_local_scope"_a = false, "use_name_loc_as_prefix"_a = false, R"( Returns the string form of value as an operand. @@ -4745,7 +4613,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { printAccum.getUserData()); return printAccum.join(); }, - nb::arg("state"), + "state"_a, "Returns the string form of value as an operand (i.e., the ValueID).") .def_prop_ro( "type", @@ -4760,7 +4628,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](PyValue &self, const PyType &type) { mlirValueSetType(self.get(), type); }, - nb::arg("type"), "Sets the type of the value.", + "type"_a, "Sets the type of the value.", nb::sig("def set_type(self, type: _T)")) .def( "replace_all_uses_with", @@ -4775,8 +4643,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { MlirOperation exceptedUser = exception.get(); mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser); }, - nb::arg("with_"), nb::arg("exceptions"), - kValueReplaceAllUsesExceptDocstring) + "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring) .def( "replace_all_uses_except", [](PyValue &self, PyValue &with, const nb::list &exceptions) { @@ -4790,16 +4657,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { self, with, static_cast(exceptionOps.size()), exceptionOps.data()); }, - nb::arg("with_"), nb::arg("exceptions"), - kValueReplaceAllUsesExceptDocstring) + "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring) .def( "replace_all_uses_except", [](PyValue &self, PyValue &with, PyOperation &exception) { MlirOperation exceptedUser = exception.get(); mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser); }, - nb::arg("with_"), nb::arg("exceptions"), - kValueReplaceAllUsesExceptDocstring) + "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring) .def( "replace_all_uses_except", [](PyValue &self, PyValue &with, @@ -4812,8 +4677,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { self, with, static_cast(exceptionOps.size()), exceptionOps.data()); }, - nb::arg("with_"), nb::arg("exceptions"), - kValueReplaceAllUsesExceptDocstring) + "with_"_a, "exceptions"_a, kValueReplaceAllUsesExceptDocstring) .def( MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyValue &self) -> nb::typed { @@ -4834,16 +4698,15 @@ void mlir::python::populateIRCore(nb::module_ &m) { PyOpOperand::bind(m); nb::class_(m, "AsmState") - .def(nb::init(), nb::arg("value"), - nb::arg("use_local_scope") = false, + .def(nb::init(), "value"_a, "use_local_scope"_a = false, R"( Creates an `AsmState` for consistent SSA value naming. Args: value: The value to create state for. use_local_scope: Whether to use local scope for naming.)") - .def(nb::init(), nb::arg("op"), - nb::arg("use_local_scope") = false, + .def(nb::init(), "op"_a, + "use_local_scope"_a = false, R"( Creates an AsmState for consistent SSA value naming. @@ -4881,7 +4744,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { Raises: KeyError: If the symbol is not found.)") - .def("insert", &PySymbolTable::insert, nb::arg("operation"), + .def("insert", &PySymbolTable::insert, "operation"_a, R"( Inserts a symbol operation into the symbol table. @@ -4893,7 +4756,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { Raises: ValueError: If the operation does not have a symbol name.)") - .def("erase", &PySymbolTable::erase, nb::arg("operation"), + .def("erase", &PySymbolTable::erase, "operation"_a, R"( Erases a symbol operation from the symbol table. @@ -4912,26 +4775,22 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, "Checks if a symbol with the given name exists in the table.") // Static helpers. - .def_static("set_symbol_name", &PySymbolTable::setSymbolName, - nb::arg("symbol"), nb::arg("name"), - "Sets the symbol name for a symbol operation.") - .def_static("get_symbol_name", &PySymbolTable::getSymbolName, - nb::arg("symbol"), + .def_static("set_symbol_name", &PySymbolTable::setSymbolName, "symbol"_a, + "name"_a, "Sets the symbol name for a symbol operation.") + .def_static("get_symbol_name", &PySymbolTable::getSymbolName, "symbol"_a, "Gets the symbol name from a symbol operation.") - .def_static("get_visibility", &PySymbolTable::getVisibility, - nb::arg("symbol"), + .def_static("get_visibility", &PySymbolTable::getVisibility, "symbol"_a, "Gets the visibility attribute of a symbol operation.") - .def_static("set_visibility", &PySymbolTable::setVisibility, - nb::arg("symbol"), nb::arg("visibility"), + .def_static("set_visibility", &PySymbolTable::setVisibility, "symbol"_a, + "visibility"_a, "Sets the visibility attribute of a symbol operation.") .def_static("replace_all_symbol_uses", - &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"), - nb::arg("new_symbol"), nb::arg("from_op"), + &PySymbolTable::replaceAllSymbolUses, "old_symbol"_a, + "new_symbol"_a, "from_op"_a, "Replaces all uses of a symbol with a new symbol name within " "the given operation.") .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, - nb::arg("from_op"), nb::arg("all_sym_uses_visible"), - nb::arg("callback"), + "from_op"_a, "all_sym_uses_visible"_a, "callback"_a, "Walks symbol tables starting from an operation with a " "callback function."); @@ -4956,18 +4815,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Attribute builder getter. PyAttrBuilderMap::bind(m); - - nb::register_exception_translator([](const std::exception_ptr &p, - void *payload) { - // We can't define exceptions with custom fields through pybind, so instead - // the exception class is defined in python and imported here. - try { - if (p) - std::rethrow_exception(p); - } catch (const MLIRError &e) { - nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("MLIRError")(e.message, e.errorDiagnostics); - PyErr_SetObject(PyExc_Exception, obj.ptr()); - } - }); } +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index 31d4798ffb906..09112d4989ae4 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -12,11 +12,11 @@ #include #include -#include "IRModule.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/IR.h" #include "mlir-c/Interfaces.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/IRCore.h" #include "mlir/Bindings/Python/Nanobind.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -25,7 +25,7 @@ namespace nb = nanobind; namespace mlir { namespace python { - +namespace MLIR_BINDINGS_PYTHON_DOMAIN { constexpr static const char *constructorDoc = R"(Creates an interface from a given operation/opview object or from a subclass of OpView. Raises ValueError if the operation does not implement the @@ -469,6 +469,6 @@ void populateIRInterfaces(nb::module_ &m) { PyShapedTypeComponents::bind(m); PyInferShapedTypeOpInterface::bind(m); } - +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 34c5b8dd86a66..7350046f428c7 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -7,26 +7,27 @@ //===----------------------------------------------------------------------===// // clang-format off -#include "IRModule.h" +#include "mlir/Bindings/Python/IRCore.h" #include "mlir/Bindings/Python/IRTypes.h" // clang-format on #include -#include "IRModule.h" -#include "NanobindUtils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/NanobindUtils.h" namespace nb = nanobind; using namespace mlir; -using namespace mlir::python; +using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN; using llvm::SmallVector; using llvm::Twine; -namespace { +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { /// Checks whether the given type is an integer or float type. static int mlirTypeIsAIntegerOrFloat(MlirType type) { @@ -509,10 +510,12 @@ class PyComplexType : public PyConcreteType { } }; -} // namespace +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir // Shaped Type Interface - ShapedType -void mlir::PyShapedType::bindDerived(ClassTy &c) { +void PyShapedType::bindDerived(ClassTy &c) { c.def_prop_ro( "element_type", [](PyShapedType &self) -> nb::typed { @@ -617,17 +620,18 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { "shaped types."); } -void mlir::PyShapedType::requireHasRank() { +void PyShapedType::requireHasRank() { if (!mlirShapedTypeHasRank(*this)) { throw nb::value_error( "calling this method requires that the type has a rank."); } } -const mlir::PyShapedType::IsAFunctionTy mlir::PyShapedType::isaFunction = - mlirTypeIsAShaped; +const PyShapedType::IsAFunctionTy PyShapedType::isaFunction = mlirTypeIsAShaped; -namespace { +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { /// Vector Type subclass - VectorType. class PyVectorType : public PyConcreteType { @@ -1099,10 +1103,6 @@ class PyFunctionType : public PyConcreteType { } }; -static MlirStringRef toMlirStringRef(const std::string &s) { - return mlirStringRefCreate(s.data(), s.size()); -} - /// Opaque Type subclass - OpaqueType. class PyOpaqueType : public PyConcreteType { public: @@ -1142,9 +1142,14 @@ class PyOpaqueType : public PyConcreteType { } }; -} // namespace +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir -void mlir::python::populateIRTypes(nb::module_ &m) { +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { +void populateIRTypes(nb::module_ &m) { PyIntegerType::bind(m); PyFloatType::bind(m); PyIndexType::bind(m); @@ -1176,3 +1181,6 @@ void mlir::python::populateIRTypes(nb::module_ &m) { PyFunctionType::bind(m); PyOpaqueType::bind(m); } +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index ba767ad6692cf..b2c9380bc1d73 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -6,143 +6,37 @@ // //===----------------------------------------------------------------------===// -#include "Globals.h" -#include "IRModule.h" -#include "NanobindUtils.h" #include "Pass.h" #include "Rewrite.h" +#include "mlir/Bindings/Python/Globals.h" +#include "mlir/Bindings/Python/IRCore.h" #include "mlir/Bindings/Python/Nanobind.h" namespace nb = nanobind; -using namespace mlir; -using namespace nb::literals; -using namespace mlir::python; +using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN; + +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { +void populateIRAffine(nb::module_ &m); +void populateIRAttributes(nb::module_ &m); +void populateIRInterfaces(nb::module_ &m); +void populateIRTypes(nb::module_ &m); +void populateIRCore(nb::module_ &m); +void populateRoot(nb::module_ &m); +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- - NB_MODULE(_mlir, m) { - m.doc() = "MLIR Python Native Extension"; - m.attr("T") = nb::type_var("T"); - m.attr("U") = nb::type_var("U"); - - nb::class_(m, "_Globals") - .def_prop_rw("dialect_search_modules", - &PyGlobals::getDialectSearchPrefixes, - &PyGlobals::setDialectSearchPrefixes) - .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix, - "module_name"_a) - .def( - "_check_dialect_module_loaded", - [](PyGlobals &self, const std::string &dialectNamespace) { - return self.loadDialectModule(dialectNamespace); - }, - "dialect_namespace"_a) - .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, - "dialect_namespace"_a, "dialect_class"_a, - "Testing hook for directly registering a dialect") - .def("_register_operation_impl", &PyGlobals::registerOperationImpl, - "operation_name"_a, "operation_class"_a, nb::kw_only(), - "replace"_a = false, - "Testing hook for directly registering an operation") - .def("loc_tracebacks_enabled", - [](PyGlobals &self) { - return self.getTracebackLoc().locTracebacksEnabled(); - }) - .def("set_loc_tracebacks_enabled", - [](PyGlobals &self, bool enabled) { - self.getTracebackLoc().setLocTracebacksEnabled(enabled); - }) - .def("loc_tracebacks_frame_limit", - [](PyGlobals &self) { - return self.getTracebackLoc().locTracebackFramesLimit(); - }) - .def("set_loc_tracebacks_frame_limit", - [](PyGlobals &self, std::optional n) { - self.getTracebackLoc().setLocTracebackFramesLimit( - n.value_or(PyGlobals::TracebackLoc::kMaxFrames)); - }) - .def("register_traceback_file_inclusion", - [](PyGlobals &self, const std::string &filename) { - self.getTracebackLoc().registerTracebackFileInclusion(filename); - }) - .def("register_traceback_file_exclusion", - [](PyGlobals &self, const std::string &filename) { - self.getTracebackLoc().registerTracebackFileExclusion(filename); - }); - - // Aside from making the globals accessible to python, having python manage - // it is necessary to make sure it is destroyed (and releases its python - // resources) properly. - m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership); - - // Registration decorators. - m.def( - "register_dialect", - [](nb::type_object pyClass) { - std::string dialectNamespace = - nanobind::cast(pyClass.attr("DIALECT_NAMESPACE")); - PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); - return pyClass; - }, - "dialect_class"_a, - "Class decorator for registering a custom Dialect wrapper"); - m.def( - "register_operation", - [](const nb::type_object &dialectClass, bool replace) -> nb::object { - return nb::cpp_function( - [dialectClass, - replace](nb::type_object opClass) -> nb::type_object { - std::string operationName = - nanobind::cast(opClass.attr("OPERATION_NAME")); - PyGlobals::get().registerOperationImpl(operationName, opClass, - replace); - // Dict-stuff the new opClass by name onto the dialect class. - nb::object opClassName = opClass.attr("__name__"); - dialectClass.attr(opClassName) = opClass; - return opClass; - }); - }, - // clang-format off - nb::sig("def register_operation(dialect_class: type, *, replace: bool = False) " - "-> typing.Callable[[type[T]], type[T]]"), - // clang-format on - "dialect_class"_a, nb::kw_only(), "replace"_a = false, - "Produce a class decorator for registering an Operation class as part of " - "a dialect"); - m.def( - MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR, - [](MlirTypeID mlirTypeID, bool replace) -> nb::object { - return nb::cpp_function([mlirTypeID, replace]( - nb::callable typeCaster) -> nb::object { - PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace); - return typeCaster; - }); - }, - // clang-format off - nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) " - "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"), - // clang-format on - "typeid"_a, nb::kw_only(), "replace"_a = false, - "Register a type caster for casting MLIR types to custom user types."); - m.def( - MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR, - [](MlirTypeID mlirTypeID, bool replace) -> nb::object { - return nb::cpp_function( - [mlirTypeID, replace](nb::callable valueCaster) -> nb::object { - PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster, - replace); - return valueCaster; - }); - }, - // clang-format off - nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, replace: bool = False) " - "-> typing.Callable[[typing.Callable[[T], U]], typing.Callable[[T], U]]"), - // clang-format on - "typeid"_a, nb::kw_only(), "replace"_a = false, - "Register a value caster for casting MLIR values to custom user values."); + // disable leak warnings which tend to be false positives. + nb::set_leak_warnings(false); + m.doc() = "MLIR Python Native Extension"; + populateRoot(m); // Define and populate IR submodule. auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); populateIRCore(irModule); @@ -158,4 +52,18 @@ NB_MODULE(_mlir, m) { auto passManagerModule = m.def_submodule("passmanager", "MLIR Pass Management Bindings"); populatePassManagerSubmodule(passManagerModule); + nanobind::register_exception_translator( + [](const std::exception_ptr &p, void *payload) { + // We can't define exceptions with custom fields through pybind, so + // instead the exception class is defined in python and imported here. + try { + if (p) + std::rethrow_exception(p); + } catch (const MLIRError &e) { + nanobind::object obj = + nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("MLIRError")(e.message, e.errorDiagnostics); + PyErr_SetObject(PyExc_Exception, obj.ptr()); + } + }); } diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index cdf01fff28cf2..b4a256d847ad5 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -8,9 +8,9 @@ #include "Pass.h" -#include "Globals.h" -#include "IRModule.h" #include "mlir-c/Pass.h" +#include "mlir/Bindings/Python/Globals.h" +#include "mlir/Bindings/Python/IRCore.h" // clang-format off #include "mlir/Bindings/Python/Nanobind.h" #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. @@ -19,9 +19,11 @@ namespace nb = nanobind; using namespace nb::literals; using namespace mlir; -using namespace mlir::python; +using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN; -namespace { +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { /// Owning Wrapper around a PassManager. class PyPassManager { @@ -53,23 +55,29 @@ class PyPassManager { MlirPassManager passManager; }; -} // namespace +enum PyMlirPassDisplayMode : std::underlying_type_t { + MLIR_PASS_DISPLAY_MODE_LIST = MLIR_PASS_DISPLAY_MODE_LIST, + MLIR_PASS_DISPLAY_MODE_PIPELINE = MLIR_PASS_DISPLAY_MODE_PIPELINE +}; + +struct PyMlirExternalPass : MlirExternalPass {}; /// Create the `mlir.passmanager` here. -void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { +void populatePassManagerSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of enumerated types //---------------------------------------------------------------------------- - nb::enum_(m, "PassDisplayMode") + nb::enum_(m, "PassDisplayMode") .value("LIST", MLIR_PASS_DISPLAY_MODE_LIST) .value("PIPELINE", MLIR_PASS_DISPLAY_MODE_PIPELINE); //---------------------------------------------------------------------------- // Mapping of MlirExternalPass //---------------------------------------------------------------------------- - nb::class_(m, "ExternalPass") - .def("signal_pass_failure", - [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); }); + nb::class_(m, "ExternalPass") + .def("signal_pass_failure", [](PyMlirExternalPass pass) { + mlirExternalPassSignalFailure(pass); + }); //---------------------------------------------------------------------------- // Mapping of the top-level PassManager @@ -148,11 +156,12 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { "Enable pass timing.") .def( "enable_statistics", - [](PyPassManager &passManager, MlirPassDisplayMode displayMode) { - mlirPassManagerEnableStatistics(passManager.get(), displayMode); + [](PyPassManager &passManager, PyMlirPassDisplayMode displayMode) { + mlirPassManagerEnableStatistics( + passManager.get(), + static_cast(displayMode)); }, - "displayMode"_a = - MlirPassDisplayMode::MLIR_PASS_DISPLAY_MODE_PIPELINE, + "displayMode"_a = MLIR_PASS_DISPLAY_MODE_PIPELINE, "Enable pass statistics.") .def_static( "parse", @@ -211,7 +220,8 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { }; callbacks.run = [](MlirOperation op, MlirExternalPass pass, void *userData) { - nb::handle(static_cast(userData))(op, pass); + nb::handle(static_cast(userData))( + op, PyMlirExternalPass{pass.ptr}); }; auto externalPass = mlirCreateExternalPass( passID, mlirStringRefCreate(name->data(), name->length()), @@ -267,3 +277,6 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { "Print the textual representation for this PassManager, suitable to " "be passed to `parse` for round-tripping."); } +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h index bc40943521829..1a311666ebecd 100644 --- a/mlir/lib/Bindings/Python/Pass.h +++ b/mlir/lib/Bindings/Python/Pass.h @@ -9,12 +9,13 @@ #ifndef MLIR_BINDINGS_PYTHON_PASS_H #define MLIR_BINDINGS_PYTHON_PASS_H -#include "NanobindUtils.h" +#include "mlir/Bindings/Python/NanobindUtils.h" namespace mlir { namespace python { - +namespace MLIR_BINDINGS_PYTHON_DOMAIN { void populatePassManagerSubmodule(nanobind::module_ &m); +} } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index dc6dc7f7c9b72..c282f4b6996e5 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -8,10 +8,10 @@ #include "Rewrite.h" -#include "IRModule.h" #include "mlir-c/IR.h" #include "mlir-c/Rewrite.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/IRCore.h" // clang-format off #include "mlir/Bindings/Python/Nanobind.h" #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. @@ -22,9 +22,11 @@ namespace nb = nanobind; using namespace mlir; using namespace nb::literals; -using namespace mlir::python; +using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN; -namespace { +namespace mlir { +namespace python { +namespace MLIR_BINDINGS_PYTHON_DOMAIN { class PyPatternRewriter { public: @@ -60,6 +62,8 @@ class PyPatternRewriter { PyMlirContextRef ctx; }; +struct PyMlirPDLResultList : MlirPDLResultList {}; + #if MLIR_ENABLE_PDL_IN_PATTERNMATCH static nb::object objectFromPDLValue(MlirPDLValue value) { if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v)) @@ -118,7 +122,7 @@ class PyPDLPatternModule { void *userData) -> MlirLogicalResult { nb::handle f = nb::handle(static_cast(userData)); return logicalResultFromObject( - f(PyPatternRewriter(rewriter), results, + f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr}, objectsFromPDLValues(nValues, values))); }, fn.ptr()); @@ -133,7 +137,7 @@ class PyPDLPatternModule { void *userData) -> MlirLogicalResult { nb::handle f = nb::handle(static_cast(userData)); return logicalResultFromObject( - f(PyPatternRewriter(rewriter), results, + f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr}, objectsFromPDLValues(nValues, values))); }, fn.ptr()); @@ -223,6 +227,25 @@ class PyRewritePatternSet { MlirContext ctx; }; +enum PyGreedyRewriteStrictness : std::underlying_type_t< + MlirGreedyRewriteStrictness> { + MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP = MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP, + MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS = + MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS, + MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS = + MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS, +}; + +enum PyGreedySimplifyRegionLevel : std::underlying_type_t< + MlirGreedySimplifyRegionLevel> { + MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED = + MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED, + MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL = + MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL, + MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE = + MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE +}; + /// Owning Wrapper around a GreedyRewriteDriverConfig. class PyGreedyRewriteDriverConfig { public: @@ -255,12 +278,14 @@ class PyGreedyRewriteDriverConfig { mlirGreedyRewriteDriverConfigEnableFolding(config, enable); } - void setStrictness(MlirGreedyRewriteStrictness strictness) { - mlirGreedyRewriteDriverConfigSetStrictness(config, strictness); + void setStrictness(PyGreedyRewriteStrictness strictness) { + mlirGreedyRewriteDriverConfigSetStrictness( + config, static_cast(strictness)); } - void setRegionSimplificationLevel(MlirGreedySimplifyRegionLevel level) { - mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(config, level); + void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level) { + mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel( + config, static_cast(level)); } void enableConstantCSE(bool enable) { @@ -283,12 +308,14 @@ class PyGreedyRewriteDriverConfig { return mlirGreedyRewriteDriverConfigIsFoldingEnabled(config); } - MlirGreedyRewriteStrictness getStrictness() { - return mlirGreedyRewriteDriverConfigGetStrictness(config); + PyGreedyRewriteStrictness getStrictness() { + return static_cast( + mlirGreedyRewriteDriverConfigGetStrictness(config)); } - MlirGreedySimplifyRegionLevel getRegionSimplificationLevel() { - return mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config); + PyGreedySimplifyRegionLevel getRegionSimplificationLevel() { + return static_cast( + mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config)); } bool isConstantCSEEnabled() { @@ -299,22 +326,19 @@ class PyGreedyRewriteDriverConfig { MlirGreedyRewriteDriverConfig config; }; -} // namespace - /// Create the `mlir.rewrite` here. -void mlir::python::populateRewriteSubmodule(nb::module_ &m) { +void populateRewriteSubmodule(nb::module_ &m) { // Enum definitions - nb::enum_(m, "GreedyRewriteStrictness") + nb::enum_(m, "GreedyRewriteStrictness") .value("ANY_OP", MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP) .value("EXISTING_AND_NEW_OPS", MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS) .value("EXISTING_OPS", MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS); - nb::enum_(m, "GreedySimplifyRegionLevel") + nb::enum_(m, "GreedySimplifyRegionLevel") .value("DISABLED", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED) .value("NORMAL", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL) .value("AGGRESSIVE", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE); - //---------------------------------------------------------------------------- // Mapping of the PatternRewriter //---------------------------------------------------------------------------- @@ -403,10 +427,10 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { // Mapping of the PDLResultList and PDLModule //---------------------------------------------------------------------------- #if MLIR_ENABLE_PDL_IN_PATTERNMATCH - nb::class_(m, "PDLResultList") + nb::class_(m, "PDLResultList") .def( "append", - [](MlirPDLResultList results, const PyValue &value) { + [](PyMlirPDLResultList results, const PyValue &value) { mlirPDLResultListPushBackValue(results, value); }, // clang-format off @@ -415,7 +439,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { ) .def( "append", - [](MlirPDLResultList results, const PyOperation &op) { + [](PyMlirPDLResultList results, const PyOperation &op) { mlirPDLResultListPushBackOperation(results, op); }, // clang-format off @@ -424,7 +448,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { ) .def( "append", - [](MlirPDLResultList results, const PyType &type) { + [](PyMlirPDLResultList results, const PyType &type) { mlirPDLResultListPushBackType(results, type); }, // clang-format off @@ -433,7 +457,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { ) .def( "append", - [](MlirPDLResultList results, const PyAttribute &attr) { + [](PyMlirPDLResultList results, const PyAttribute &attr) { mlirPDLResultListPushBackAttribute(results, attr); }, // clang-format off @@ -443,9 +467,9 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { nb::class_(m, "PDLModule") .def( "__init__", - [](PyPDLPatternModule &self, MlirModule module) { - new (&self) - PyPDLPatternModule(mlirPDLPatternModuleFromModule(module)); + [](PyPDLPatternModule &self, PyModule &module) { + new (&self) PyPDLPatternModule( + mlirPDLPatternModuleFromModule(module.get())); }, // clang-format off nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"), @@ -533,22 +557,6 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { // clang-format on "Applys the given patterns to the given module greedily while folding " "results.") - .def( - "apply_patterns_and_fold_greedily", - [](PyModule &module, MlirFrozenRewritePatternSet set) { - auto status = mlirApplyPatternsAndFoldGreedily( - module.get(), set, mlirGreedyRewriteDriverConfigCreate()); - if (mlirLogicalResultIsFailure(status)) - throw std::runtime_error( - "pattern application failed to converge"); - }, - "module"_a, "set"_a, - // clang-format off - nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet) -> None"), - // clang-format on - "Applys the given patterns to the given module greedily while " - "folding " - "results.") .def( "apply_patterns_and_fold_greedily", [](PyOperationBase &op, PyFrozenRewritePatternSet &set) { @@ -565,21 +573,6 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { // clang-format on "Applys the given patterns to the given op greedily while folding " "results.") - .def( - "apply_patterns_and_fold_greedily", - [](PyOperationBase &op, MlirFrozenRewritePatternSet set) { - auto status = mlirApplyPatternsAndFoldGreedilyWithOp( - op.getOperation(), set, mlirGreedyRewriteDriverConfigCreate()); - if (mlirLogicalResultIsFailure(status)) - throw std::runtime_error( - "pattern application failed to converge"); - }, - "op"_a, "set"_a, - // clang-format off - nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"), - // clang-format on - "Applys the given patterns to the given op greedily while folding " - "results.") .def( "walk_and_apply_patterns", [](PyOperationBase &op, PyFrozenRewritePatternSet &set) { @@ -592,3 +585,6 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { "Applies the given patterns to the given op by a fast walk-based " "driver."); } +} // namespace MLIR_BINDINGS_PYTHON_DOMAIN +} // namespace python +} // namespace mlir diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h index ae89e2b9589f1..d287f19187708 100644 --- a/mlir/lib/Bindings/Python/Rewrite.h +++ b/mlir/lib/Bindings/Python/Rewrite.h @@ -9,13 +9,13 @@ #ifndef MLIR_BINDINGS_PYTHON_REWRITE_H #define MLIR_BINDINGS_PYTHON_REWRITE_H -#include "NanobindUtils.h" +#include "mlir/Bindings/Python/NanobindUtils.h" namespace mlir { namespace python { - +namespace MLIR_BINDINGS_PYTHON_DOMAIN { void populateRewriteSubmodule(nanobind::module_ &m); - +} } // namespace python } // namespace mlir diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 1e9f1e11d4d06..4a9fb127ee08c 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -3,6 +3,8 @@ include(AddMLIRPython) # Specifies that all MLIR packages are co-located under the `MLIR_PYTHON_PACKAGE_PREFIX.` # top level package (the API has been embedded in a relocatable way). add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.") +set(MLIRPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}") +set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python") ################################################################################ # Structural groupings. @@ -524,7 +526,6 @@ declare_mlir_dialect_python_bindings( # dependencies. ################################################################################ -set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python") declare_mlir_python_extension(MLIRPythonExtension.Core MODULE_NAME _mlir ADD_TO_PARENT MLIRPythonSources.Core @@ -533,18 +534,13 @@ declare_mlir_python_extension(MLIRPythonExtension.Core MainModule.cpp IRAffine.cpp IRAttributes.cpp - IRCore.cpp IRInterfaces.cpp - IRModule.cpp IRTypes.cpp Pass.cpp Rewrite.cpp # Headers must be included explicitly so they are installed. - Globals.h - IRModule.h Pass.h - NanobindUtils.h Rewrite.h PRIVATE_LINK_LIBS LLVMSupport @@ -752,8 +748,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Nanobind ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES DialectSMT.cpp - # Headers must be included explicitly so they are installed. - NanobindUtils.h PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -790,7 +784,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.AMDGPU.Nanobind MODULE_NAME _mlirDialectsAMDGPU ADD_TO_PARENT MLIRPythonSources.Dialects.amdgpu ROOT_DIR "${PYTHON_SOURCE_DIR}" - PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectAMDGPU.cpp PRIVATE_LINK_LIBS @@ -847,6 +840,16 @@ if(MLIR_INCLUDE_TESTS) ) endif() +declare_mlir_python_extension(MLIRPythonExtension.MLIRPythonSupport + _PRIVATE_SUPPORT_LIB + MODULE_NAME MLIRPythonSupport + ADD_TO_PARENT MLIRPythonSources.Core + ROOT_DIR "${PYTHON_SOURCE_DIR}" + SOURCES + IRCore.cpp + Globals.cpp +) + ################################################################################ # Common CAPI dependency DSO. # All python extensions must link through one DSO which exports the CAPI, and @@ -860,7 +863,6 @@ endif() # once ready. ################################################################################ -set(MLIRPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}") add_mlir_python_common_capi_library(MLIRPythonCAPI INSTALL_COMPONENT MLIRPythonModules INSTALL_DESTINATION "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs" diff --git a/mlir/test/Examples/standalone/test.wheel.toy b/mlir/test/Examples/standalone/test.wheel.toy index b60347ba687d0..46f170579a977 100644 --- a/mlir/test/Examples/standalone/test.wheel.toy +++ b/mlir/test/Examples/standalone/test.wheel.toy @@ -37,6 +37,8 @@ # CHECK: %[[V0:.*]] = standalone.foo %[[C2]] : i32 # CHECK: } +# CHECK: !standalone.custom<"foo"> + # CHECK: Testing mlir package # CHECK-NOT: RuntimeWarning: nanobind: type '{{.*}}' was already registered! diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index 7bba20931e675..9c0966d2d8798 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -586,9 +586,18 @@ def testCustomAttribute(): try: TestAttr(42) except TypeError as e: - assert "Expected an MLIR object (got 42)" in str(e) - except ValueError as e: - assert "Cannot cast attribute to TestAttr (from 42)" in str(e) + assert ( + "__init__(): incompatible function arguments. The following argument types are supported" + in str(e) + ) + assert ( + "__init__(self, cast_from_attr: mlir._mlir_libs._mlir.ir.Attribute) -> None" + in str(e) + ) + assert ( + "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestAttr, int" + in str(e) + ) else: raise @@ -613,12 +622,6 @@ def testCustomType(): b = TestType(a) # Instance custom types should have typeids assert isinstance(b.typeid, TypeID) - # Subclasses of ir.Type should not have a static_typeid - # CHECK: 'TestType' object has no attribute 'static_typeid' - try: - b.static_typeid - except AttributeError as e: - print(e) i8 = IntegerType.get_signless(8) try: @@ -633,9 +636,18 @@ def testCustomType(): try: TestType(42) except TypeError as e: - assert "Expected an MLIR object (got 42)" in str(e) - except ValueError as e: - assert "Cannot cast type to TestType (from 42)" in str(e) + assert ( + "__init__(): incompatible function arguments. The following argument types are supported" + in str(e) + ) + assert ( + "__init__(self, cast_from_type: mlir._mlir_libs._mlir.ir.Type) -> None" + in str(e) + ) + assert ( + "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestType, int" + in str(e) + ) else: raise diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp index a497fcccf13d7..43573cbc305fa 100644 --- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp +++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp @@ -14,6 +14,7 @@ #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/Diagnostics.h" +#include "mlir/Bindings/Python/IRCore.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" #include "nanobind/nanobind.h" @@ -26,6 +27,49 @@ static bool mlirTypeIsARankedIntegerTensor(MlirType t) { mlirTypeIsAInteger(mlirShapedTypeGetElementType(t)); } +struct PyTestType + : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPythonTestTestType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirPythonTestTestTypeGetTypeID; + static constexpr const char *pyClassName = "TestType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext + context) { + return PyTestType(context->getRef(), + mlirPythonTestTestTypeGet(context.get()->get())); + }, + nb::arg("context").none() = nb::none()); + } +}; + +class PyTestAttr + : public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute< + PyTestAttr> { +public: + static constexpr IsAFunctionTy isaFunction = + mlirAttributeIsAPythonTestTestAttribute; + static constexpr const char *pyClassName = "TestAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirPythonTestTestAttributeGetTypeID; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext + context) { + return PyTestAttr(context->getRef(), mlirPythonTestTestAttributeGet( + context.get()->get())); + }, + nb::arg("context").none() = nb::none()); + } +}; + NB_MODULE(_mlirPythonTestNanobind, m) { m.def( "register_python_test_dialect", @@ -65,30 +109,8 @@ NB_MODULE(_mlirPythonTestNanobind, m) { nb::sig("def test_diagnostics_with_errors_and_notes(arg: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", /) -> None")); // clang-format on - mlir_attribute_subclass(m, "TestAttr", - mlirAttributeIsAPythonTestTestAttribute, - mlirPythonTestTestAttributeGetTypeID) - .def_classmethod( - "get", - [](const nb::object &cls, MlirContext ctx) { - return cls(mlirPythonTestTestAttributeGet(ctx)); - }, - // clang-format off - nb::sig("def get(cls: object, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> object"), - // clang-format on - nb::arg("cls"), nb::arg("context").none() = nb::none()); - - mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType, - mlirPythonTestTestTypeGetTypeID) - .def_classmethod( - "get", - [](const nb::object &cls, MlirContext ctx) { - return cls(mlirPythonTestTestTypeGet(ctx)); - }, - // clang-format off - nb::sig("def get(cls: object, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> object"), - // clang-format on - nb::arg("cls"), nb::arg("context").none() = nb::none()); + PyTestAttr::bind(m); + PyTestType::bind(m); auto typeCls = mlir_type_subclass(m, "TestIntegerRankedTensorType", diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 41223b72a7d10..2d98c9ce376b1 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1172,11 +1172,11 @@ PYBIND11_FEATURES = [ filegroup( name = "MLIRBindingsPythonSourceFiles", srcs = [ + "lib/Bindings/Python/Globals.cpp", "lib/Bindings/Python/IRAffine.cpp", "lib/Bindings/Python/IRAttributes.cpp", "lib/Bindings/Python/IRCore.cpp", "lib/Bindings/Python/IRInterfaces.cpp", - "lib/Bindings/Python/IRModule.cpp", "lib/Bindings/Python/IRTypes.cpp", "lib/Bindings/Python/Pass.cpp", "lib/Bindings/Python/Rewrite.cpp",