diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake index ca90151e76268..2d86a9a9d0bc4 100644 --- a/mlir/cmake/modules/AddMLIRPython.cmake +++ b/mlir/cmake/modules/AddMLIRPython.cmake @@ -389,6 +389,8 @@ endfunction() # This file is where the *EnumAttrs are defined, not where the *Enums are defined. # **WARNING**: This arg will shortly be removed when the just-below TODO is satisfied. Use at your # risk. +# MLIR_PYTHON_PACKAGE: Optional name of the current main MLIR package. +# It can be used to build extensions against a main package. # # TODO: Right now `TD_FILE` can't be the actual dialect tablegen file, since we # use its path to determine where to place the generated python file. If @@ -397,7 +399,7 @@ endfunction() function(declare_mlir_dialect_python_bindings) cmake_parse_arguments(ARG "GEN_ENUM_BINDINGS" - "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME" + "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME;MLIR_PYTHON_PACKAGE" "SOURCES;SOURCES_GLOB;DEPENDS;GEN_ENUM_BINDINGS_TD_FILE" ${ARGN}) # Sources. @@ -417,8 +419,12 @@ function(declare_mlir_dialect_python_bindings) file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${relative_td_directory}") set(dialect_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_ops_gen.py") set(LLVM_TARGET_DEFINITIONS ${td_file}) + if(NOT DEFINED ARG_MLIR_PYTHON_PACKAGE) + set(ARG_MLIR_PYTHON_PACKAGE "mlir") + endif() mlir_tablegen("${dialect_filename}" -gen-python-op-bindings -bind-dialect=${ARG_DIALECT_NAME} + -python-package-prefix=${ARG_MLIR_PYTHON_PACKAGE} DEPENDS ${ARG_DEPENDS} ) add_public_tablegen_target(${tblgen_target}) @@ -430,7 +436,8 @@ function(declare_mlir_dialect_python_bindings) set(LLVM_TARGET_DEFINITIONS ${td_file}) endif() set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py") - mlir_tablegen(${enum_filename} -gen-python-enum-bindings) + mlir_tablegen(${enum_filename} -gen-python-enum-bindings + -python-package-prefix=${ARG_MLIR_PYTHON_PACKAGE}) list(APPEND _sources ${enum_filename}) endif() @@ -464,10 +471,12 @@ endfunction() # This file is where the *Attrs are defined, not where the *Enums are defined. # **WARNING**: This arg will shortly be removed when the TODO for # declare_mlir_dialect_python_bindings is satisfied. Use at your risk. +# MLIR_PYTHON_PACKAGE: Optional name of the current main MLIR package. +# It can be used to build extensions against a main package. function(declare_mlir_dialect_extension_python_bindings) cmake_parse_arguments(ARG "GEN_ENUM_BINDINGS" - "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME;EXTENSION_NAME" + "ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME;EXTENSION_NAME;MLIR_PYTHON_PACKAGE" "SOURCES;SOURCES_GLOB;DEPENDS;GEN_ENUM_BINDINGS_TD_FILE" ${ARGN}) # Source files. @@ -487,9 +496,13 @@ function(declare_mlir_dialect_extension_python_bindings) file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${relative_td_directory}") set(output_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_ops_gen.py") set(LLVM_TARGET_DEFINITIONS ${td_file}) + if(NOT DEFINED ARG_MLIR_PYTHON_PACKAGE) + set(ARG_MLIR_PYTHON_PACKAGE "mlir") + endif() mlir_tablegen("${output_filename}" -gen-python-op-bindings -bind-dialect=${ARG_DIALECT_NAME} - -dialect-extension=${ARG_EXTENSION_NAME}) + -dialect-extension=${ARG_EXTENSION_NAME} + -python-package-prefix=${ARG_MLIR_PYTHON_PACKAGE}) add_public_tablegen_target(${tblgen_target}) if(ARG_DEPENDS) add_dependencies(${tblgen_target} ${ARG_DEPENDS}) @@ -502,7 +515,8 @@ function(declare_mlir_dialect_extension_python_bindings) set(LLVM_TARGET_DEFINITIONS ${td_file}) endif() set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py") - mlir_tablegen(${enum_filename} -gen-python-enum-bindings) + mlir_tablegen(${enum_filename} -gen-python-enum-bindings + -python-package-prefix=${ARG_MLIR_PYTHON_PACKAGE}) list(APPEND _sources ${enum_filename}) endif() @@ -601,7 +615,6 @@ function(add_mlir_python_common_capi_library name) # Generate the aggregate .so that everything depends on. add_mlir_aggregate(${name} SHARED - DISABLE_INSTALL EMBED_LIBS ${_embed_libs} ) diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt index edaedf18cc843..91373da0e9377 100644 --- a/mlir/examples/standalone/python/CMakeLists.txt +++ b/mlir/examples/standalone/python/CMakeLists.txt @@ -18,7 +18,8 @@ declare_mlir_dialect_python_bindings( SOURCES dialects/standalone_nanobind.py _mlir_libs/_standaloneDialectsNanobind/py.typed - DIALECT_NAME standalone) + DIALECT_NAME standalone + MLIR_PYTHON_PACKAGE "${MLIR_PYTHON_PACKAGE_PREFIX}") declare_mlir_python_extension(StandalonePythonSources.NanobindExtension MODULE_NAME _standaloneDialectsNanobind diff --git a/mlir/test/mlir-tblgen/enums-python-bindings.td b/mlir/test/mlir-tblgen/enums-python-bindings.td index cd23b6a2effb9..5dd002ca21bd3 100644 --- a/mlir/test/mlir-tblgen/enums-python-bindings.td +++ b/mlir/test/mlir-tblgen/enums-python-bindings.td @@ -10,8 +10,8 @@ def Test_Dialect : Dialect { // CHECK: Autogenerated by mlir-tblgen; don't manually edit. // CHECK: from enum import IntEnum, auto, IntFlag -// CHECK: from ._ods_common import _cext as _ods_cext -// CHECK: from ..ir import register_attribute_builder +// CHECK: from mlir.dialects._ods_common import _cext as _ods_cext +// CHECK: from mlir.ir import register_attribute_builder // CHECK: _ods_ir = _ods_cext.ir def One : I32EnumAttrCase<"CaseOne", 1, "one">; diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp index acc9b61d7121c..31211a2094f06 100644 --- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp @@ -17,6 +17,7 @@ #include "mlir/TableGen/Dialect.h" #include "mlir/TableGen/EnumInfo.h" #include "mlir/TableGen/GenInfo.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Record.h" @@ -27,16 +28,19 @@ using llvm::Record; using llvm::RecordKeeper; /// File header and includes. +/// {0} is the Python package prefix. constexpr const char *fileHeader = R"Py( # Autogenerated by mlir-tblgen; don't manually edit. from enum import IntEnum, auto, IntFlag -from ._ods_common import _cext as _ods_cext -from ..ir import register_attribute_builder +from {0}.dialects._ods_common import _cext as _ods_cext +from {0}.ir import register_attribute_builder _ods_ir = _ods_cext.ir )Py"; +extern llvm::cl::opt clPythonPackagePrefix; + /// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE. static std::string makePythonEnumCaseName(StringRef name) { if (isPythonReserved(name.str())) @@ -122,7 +126,7 @@ static bool emitDialectEnumAttributeBuilder(StringRef attrDefName, /// Emits Python bindings for all enums in the record keeper. Returns /// `false` on success, `true` on failure. static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) { - os << fileHeader; + os << formatv(fileHeader, clPythonPackagePrefix); for (const Record *it : records.getAllDerivedDefinitionsIfDefined("EnumInfo")) { EnumInfo enumInfo(*it); diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 2c33f4efac3ac..fae0a5c25435f 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -30,12 +30,12 @@ using llvm::Record; using llvm::RecordKeeper; /// File header and includes. -/// {0} is the dialect namespace. +/// {0} is the Python package prefix. constexpr const char *fileHeader = R"Py( # Autogenerated by mlir-tblgen; don't manually edit. -from ._ods_common import _cext as _ods_cext -from ._ods_common import ( +from {0}.dialects._ods_common import _cext as _ods_cext +from {0}.dialects._ods_common import ( equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_results_or_values as _get_op_results_or_values, @@ -51,6 +51,7 @@ from typing import Sequence as _Sequence, Union as _Union, Optional as _Optional /// Template for dialect class: /// {0} is the dialect namespace. +/// {1} is the Python package prefix. constexpr const char *dialectClassTemplate = R"Py( @_ods_cext.register_dialect class _Dialect(_ods_ir.Dialect): @@ -58,7 +59,7 @@ class _Dialect(_ods_ir.Dialect): )Py"; constexpr const char *dialectExtensionTemplate = R"Py( -from ._{0}_ops_gen import _Dialect +from {1}.dialects._{0}_ops_gen import _Dialect )Py"; /// Template for operation class: @@ -293,6 +294,15 @@ def {0}({2}) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, {1}]: return results if len(results) > 1 else (results[0] if len(results) == 1 else op) )Py"; +static llvm::cl::OptionCategory + clPythonBindingCat("Options for -gen-python-op-bindings and " + "-gen-python-enum-bindings"); + +llvm::cl::opt clPythonPackagePrefix( + "python-package-prefix", + llvm::cl::desc("The prefix of the MLIR Python package"), + llvm::cl::init("mlir"), llvm::cl::cat(clPythonBindingCat)); + static llvm::cl::OptionCategory clOpPythonBindingCat("Options for -gen-python-op-bindings"); @@ -1222,9 +1232,10 @@ static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) { if (clDialectName.empty()) llvm::PrintFatalError("dialect name not provided"); - os << fileHeader; + os << formatv(fileHeader, clPythonPackagePrefix.getValue()); if (!clDialectExtensionName.empty()) - os << formatv(dialectExtensionTemplate, clDialectName.getValue()); + os << formatv(dialectExtensionTemplate, clDialectName.getValue(), + clPythonPackagePrefix.getValue()); else os << formatv(dialectClassTemplate, clDialectName.getValue());