From 4005e94ca3128354e10ce8d9585d4df2e8903d34 Mon Sep 17 00:00:00 2001 From: makslevental Date: Thu, 2 Oct 2025 22:19:58 -0700 Subject: [PATCH] [MLIR][Python] demo building bindings with no CAPI aggregate and using mlir aggregate --- mlir/examples/standalone/CMakeLists.txt | 4 + .../standalone/lib/Standalone/CMakeLists.txt | 6 +- .../standalone/really_alone/CMakeLists.txt | 32 ++ .../StandReallyAloneExtensionNanobind.cpp | 43 +++ .../dialects/_ods_common.py | 307 ++++++++++++++++++ .../dialects/standalonereallyalone.py | 10 + .../test/python/smoketest_really_alone.py | 37 +++ mlir/lib/Bindings/Python/IRModule.h | 1 - mlir/python/CMakeLists.txt | 1 + mlir/test/Examples/standalone/test.really.toy | 15 + mlir/test/Examples/standalone/test.toy | 4 +- mlir/test/lit.cfg.py | 9 + 12 files changed, 461 insertions(+), 8 deletions(-) create mode 100644 mlir/examples/standalone/really_alone/CMakeLists.txt create mode 100644 mlir/examples/standalone/really_alone/StandReallyAloneExtensionNanobind.cpp create mode 100644 mlir/examples/standalone/really_alone/mlir_standreallyalone/dialects/_ods_common.py create mode 100644 mlir/examples/standalone/really_alone/mlir_standreallyalone/dialects/standalonereallyalone.py create mode 100644 mlir/examples/standalone/test/python/smoketest_really_alone.py create mode 100644 mlir/test/Examples/standalone/test.really.toy diff --git a/mlir/examples/standalone/CMakeLists.txt b/mlir/examples/standalone/CMakeLists.txt index c6c49fde12d2e..6f39d2907e8c8 100644 --- a/mlir/examples/standalone/CMakeLists.txt +++ b/mlir/examples/standalone/CMakeLists.txt @@ -70,6 +70,10 @@ if(MLIR_ENABLE_BINDINGS_PYTHON) set(MLIR_BINDINGS_PYTHON_INSTALL_PREFIX "python_packages/standalone/${MLIR_PYTHON_PACKAGE_PREFIX}" CACHE STRING "" FORCE) endif() add_subdirectory(python) + option(MLIR_STANDALONE_REALLY "" OFF) + if (MLIR_STANDALONE_REALLY) + add_subdirectory(really_alone) + endif() endif() add_subdirectory(test) add_subdirectory(standalone-opt) diff --git a/mlir/examples/standalone/lib/Standalone/CMakeLists.txt b/mlir/examples/standalone/lib/Standalone/CMakeLists.txt index 0f1705a25c8c8..f2e023461329a 100644 --- a/mlir/examples/standalone/lib/Standalone/CMakeLists.txt +++ b/mlir/examples/standalone/lib/Standalone/CMakeLists.txt @@ -10,9 +10,5 @@ add_mlir_dialect_library(MLIRStandalone DEPENDS MLIRStandaloneOpsIncGen MLIRStandalonePassesIncGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRInferTypeOpInterface - MLIRFuncDialect ) +target_link_options(obj.MLIRStandalone PUBLIC --unresolved-symbols=ignore-all) diff --git a/mlir/examples/standalone/really_alone/CMakeLists.txt b/mlir/examples/standalone/really_alone/CMakeLists.txt new file mode 100644 index 0000000000000..91b54a8c33a19 --- /dev/null +++ b/mlir/examples/standalone/really_alone/CMakeLists.txt @@ -0,0 +1,32 @@ +add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=mlir.") + +declare_mlir_python_sources(StandReallyAlonePythonSources) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT StandReallyAlonePythonSources + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir_standreallyalone" + SOURCES + dialects/standalonereallyalone.py + dialects/_ods_common.py +) + +declare_mlir_python_extension(StandReallyAlonePythonSources.NanobindExtension + MODULE_NAME _standReallyAloneDialectsNanobind + ADD_TO_PARENT StandReallyAlonePythonSources + SOURCES + StandReallyAloneExtensionNanobind.cpp + PRIVATE_LINK_LIBS + StandaloneCAPI + PYTHON_BINDINGS_LIBRARY nanobind +) + +set(StandReallyAlonePythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}") +add_mlir_python_modules(StandReallyAlonePythonModules + ROOT_PREFIX "${StandReallyAlonePythonModules_ROOT_PREFIX}/../mlir_standreallyalone" + INSTALL_PREFIX "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/../mlir_standreallyalone" + DECLARED_SOURCES + StandReallyAlonePythonSources + StandReallyAlonePythonSources.NanobindExtension + StandalonePythonSources.standalone.ops_gen + StandalonePythonSources.standalone.tablegen + MLIRPythonSources.Dialects.builtin +) diff --git a/mlir/examples/standalone/really_alone/StandReallyAloneExtensionNanobind.cpp b/mlir/examples/standalone/really_alone/StandReallyAloneExtensionNanobind.cpp new file mode 100644 index 0000000000000..8044b7a540e27 --- /dev/null +++ b/mlir/examples/standalone/really_alone/StandReallyAloneExtensionNanobind.cpp @@ -0,0 +1,43 @@ +//===- StandaloneExtension.cpp - Extension module -------------------------===// +// +// This is the nanobind version of the example module. There is also a pybind11 +// example in StandaloneExtensionPybind11.cpp. +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Standalone-c/Dialects.h" +#include "mlir-c/Dialect/Arith.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" + +namespace nb = nanobind; + +NB_MODULE(_standReallyAloneDialectsNanobind, m) { + //===--------------------------------------------------------------------===// + // standalone dialect + //===--------------------------------------------------------------------===// + auto standaloneM = m.def_submodule("standalone"); + + standaloneM.def( + "register_dialects", + [](MlirContext context, bool load) { + MlirDialectHandle arithHandle = mlirGetDialectHandle__arith__(); + MlirDialectHandle standaloneHandle = + mlirGetDialectHandle__standalone__(); + mlirDialectHandleRegisterDialect(arithHandle, context); + mlirDialectHandleRegisterDialect(standaloneHandle, context); + if (load) { + mlirDialectHandleLoadDialect(arithHandle, context); + mlirDialectHandleRegisterDialect(standaloneHandle, context); + } + }, + nb::arg("context").none() = nb::none(), nb::arg("load") = true, + // clang-format off + nb::sig("def register_dialects(context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", load: bool = True) -> None") + // clang-format on + ); +} diff --git a/mlir/examples/standalone/really_alone/mlir_standreallyalone/dialects/_ods_common.py b/mlir/examples/standalone/really_alone/mlir_standreallyalone/dialects/_ods_common.py new file mode 100644 index 0000000000000..aeaa533e0a1f4 --- /dev/null +++ b/mlir/examples/standalone/really_alone/mlir_standreallyalone/dialects/_ods_common.py @@ -0,0 +1,307 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import ( + List as _List, + Optional as _Optional, + Sequence as _Sequence, + Tuple as _Tuple, + Type as _Type, + Union as _Union, +) + +from mlir._mlir_libs import _mlir as _cext +from mlir.ir import ( + ArrayAttr, + Attribute, + BoolAttr, + DenseI64ArrayAttr, + IntegerAttr, + IntegerType, + OpView, + Operation, + ShapedType, + Value, +) + +__all__ = [ + "equally_sized_accessor", + "get_default_loc_context", + "get_op_result_or_value", + "get_op_results_or_values", + "get_op_result_or_op_results", + "segmented_accessor", +] + + +def segmented_accessor(elements, raw_segments, idx): + """ + Returns a slice of elements corresponding to the idx-th segment. + + elements: a sliceable container (operands or results). + raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing + sizes of the segments. + idx: index of the segment. + """ + segments = _cext.ir.DenseI32ArrayAttr(raw_segments) + start = sum(segments[i] for i in range(idx)) + end = start + segments[idx] + return elements[start:end] + + +def equally_sized_accessor( + elements, n_simple, n_variadic, n_preceding_simple, n_preceding_variadic +): + """ + Returns a starting position and a number of elements per variadic group + assuming equally-sized groups and the given numbers of preceding groups. + + elements: a sequential container. + n_simple: the number of non-variadic groups in the container. + n_variadic: the number of variadic groups in the container. + n_preceding_simple: the number of non-variadic groups preceding the current + group. + n_preceding_variadic: the number of variadic groups preceding the current + group. + """ + + total_variadic_length = len(elements) - n_simple + # This should be enforced by the C++-side trait verifier. + assert total_variadic_length % n_variadic == 0 + + elements_per_group = total_variadic_length // n_variadic + start = n_preceding_simple + n_preceding_variadic * elements_per_group + return start, elements_per_group + + +def get_default_loc_context(location=None): + """ + Returns a context in which the defaulted location is created. If the location + is None, takes the current location from the stack. + """ + if location is None: + if _cext.ir.Location.current: + return _cext.ir.Location.current.context + return None + return location.context + + +def get_op_result_or_value( + arg: _Union[ + _cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList + ], +) -> _cext.ir.Value: + """Returns the given value or the single result of the given op. + + This is useful to implement op constructors so that they can take other ops as + arguments instead of requiring the caller to extract results for every op. + Raises ValueError if provided with an op that doesn't have a single result. + """ + if isinstance(arg, _cext.ir.OpView): + return arg.operation.result + elif isinstance(arg, _cext.ir.Operation): + return arg.result + elif isinstance(arg, _cext.ir.OpResultList): + return arg[0] + else: + assert isinstance(arg, _cext.ir.Value), f"expects Value, got {type(arg)}" + return arg + + +def get_op_results_or_values( + arg: _Union[ + _cext.ir.OpView, + _cext.ir.Operation, + _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]], + ], +) -> _Union[ + _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]], + _cext.ir.OpResultList, +]: + """Returns the given sequence of values or the results of the given op. + + This is useful to implement op constructors so that they can take other ops as + lists of arguments instead of requiring the caller to extract results for + every op. + """ + if isinstance(arg, _cext.ir.OpView): + return arg.operation.results + elif isinstance(arg, _cext.ir.Operation): + return arg.results + else: + return arg + + +def get_op_result_or_op_results( + op: _Union[_cext.ir.OpView, _cext.ir.Operation], +) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]: + results = op.results + num_results = len(results) + if num_results == 1: + return results[0] + elif num_results > 1: + return results + elif isinstance(op, _cext.ir.OpView): + return op.operation + else: + return op + + +ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value +ResultValueT = _Union[ResultValueTypeTuple] +VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]] + +StaticIntLike = _Union[int, IntegerAttr] +ValueLike = _Union[Operation, OpView, Value] +MixedInt = _Union[StaticIntLike, ValueLike] + +IntOrAttrList = _Sequence[_Union[IntegerAttr, int]] +OptionalIntList = _Optional[_Union[ArrayAttr, IntOrAttrList]] + +BoolOrAttrList = _Sequence[_Union[BoolAttr, bool]] +OptionalBoolList = _Optional[_Union[ArrayAttr, BoolOrAttrList]] + +MixedValues = _Union[_Sequence[_Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike] + +DynamicIndexList = _Sequence[_Union[MixedInt, _Sequence[MixedInt]]] + + +def _dispatch_dynamic_index_list( + indices: _Union[DynamicIndexList, ArrayAttr], +) -> _Tuple[_List[ValueLike], _Union[_List[int], ArrayAttr], _List[bool]]: + """Dispatches a list of indices to the appropriate form. + + This is similar to the custom `DynamicIndexList` directive upstream: + provided indices may be in the form of dynamic SSA values or static values, + and they may be scalable (i.e., as a singleton list) or not. This function + dispatches each index into its respective form. It also extracts the SSA + values and static indices from various similar structures, respectively. + """ + dynamic_indices = [] + static_indices = [ShapedType.get_dynamic_size()] * len(indices) + scalable_indices = [False] * len(indices) + + # ArrayAttr: Extract index values. + if isinstance(indices, ArrayAttr): + indices = [idx for idx in indices] + + def process_nonscalable_index(i, index): + """Processes any form of non-scalable index. + + Returns False if the given index was scalable and thus remains + unprocessed; True otherwise. + """ + if isinstance(index, int): + static_indices[i] = index + elif isinstance(index, IntegerAttr): + static_indices[i] = index.value # pytype: disable=attribute-error + elif isinstance(index, (Operation, Value, OpView)): + dynamic_indices.append(index) + else: + return False + return True + + # Process each index at a time. + for i, index in enumerate(indices): + if not process_nonscalable_index(i, index): + # If it wasn't processed, it must be a scalable index, which is + # provided as a _Sequence of one value, so extract and process that. + scalable_indices[i] = True + assert len(index) == 1 + ret = process_nonscalable_index(i, index[0]) + assert ret + + return dynamic_indices, static_indices, scalable_indices + + +# Dispatches `MixedValues` that all represents integers in various forms into +# the following three categories: +# - `dynamic_values`: a list of `Value`s, potentially from op results; +# - `packed_values`: a value handle, potentially from an op result, associated +# to one or more payload operations of integer type; +# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python +# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`. +# The input is in the form for `packed_values`, only that result is set and the +# other two are empty. Otherwise, the input can be a mix of the other two forms, +# and for each dynamic value, a special value is added to the `static_values`. +def _dispatch_mixed_values( + values: MixedValues, +) -> _Tuple[_List[Value], _Union[Operation, Value, OpView], DenseI64ArrayAttr]: + dynamic_values = [] + packed_values = None + static_values = None + if isinstance(values, ArrayAttr): + static_values = values + elif isinstance(values, (Operation, Value, OpView)): + packed_values = values + else: + static_values = [] + for size in values or []: + if isinstance(size, int): + static_values.append(size) + else: + static_values.append(ShapedType.get_dynamic_size()) + dynamic_values.append(size) + static_values = DenseI64ArrayAttr.get(static_values) + + return (dynamic_values, packed_values, static_values) + + +def _get_value_or_attribute_value( + value_or_attr: _Union[any, Attribute, ArrayAttr], +) -> any: + if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"): + return value_or_attr.value + if isinstance(value_or_attr, ArrayAttr): + return _get_value_list(value_or_attr) + return value_or_attr + + +def _get_value_list( + sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr], +) -> _Sequence[any]: + return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr] + + +def _get_int_array_attr( + values: _Optional[_Union[ArrayAttr, IntOrAttrList]], +) -> ArrayAttr: + if values is None: + return None + + # Turn into a Python list of Python ints. + values = _get_value_list(values) + + # Make an ArrayAttr of IntegerAttrs out of it. + return ArrayAttr.get( + [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values] + ) + + +def _get_int_array_array_attr( + values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]], +) -> ArrayAttr: + """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs. + + The input has to be a collection of a collection of integers, where any + Python _Sequence and ArrayAttr are admissible collections and Python ints and + any IntegerAttr are admissible integers. Both levels of collections are + turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s. + If the input is None, an empty ArrayAttr is returned. + """ + if values is None: + return None + + # Make sure the outer level is a list. + values = _get_value_list(values) + + # The inner level is now either invalid or a mixed sequence of ArrayAttrs and + # Sequences. Make sure the nested values are all lists. + values = [_get_value_list(nested) for nested in values] + + # Turn each nested list into an ArrayAttr. + values = [_get_int_array_attr(nested) for nested in values] + + # Turn the outer list into an ArrayAttr. + return ArrayAttr.get(values) diff --git a/mlir/examples/standalone/really_alone/mlir_standreallyalone/dialects/standalonereallyalone.py b/mlir/examples/standalone/really_alone/mlir_standreallyalone/dialects/standalonereallyalone.py new file mode 100644 index 0000000000000..edbcbb49ee036 --- /dev/null +++ b/mlir/examples/standalone/really_alone/mlir_standreallyalone/dialects/standalonereallyalone.py @@ -0,0 +1,10 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._ods_common import _cext + +_cext.globals.append_dialect_search_prefix("mlir_standreallyalone.dialects") + +from ._standalone_ops_gen import * +from .._mlir_libs._standReallyAloneDialectsNanobind.standalone import * diff --git a/mlir/examples/standalone/test/python/smoketest_really_alone.py b/mlir/examples/standalone/test/python/smoketest_really_alone.py new file mode 100644 index 0000000000000..528389dc1f1ae --- /dev/null +++ b/mlir/examples/standalone/test/python/smoketest_really_alone.py @@ -0,0 +1,37 @@ +# RUN: echo "do nothing" +# just so lit doesn't complain about a missing RUN line +# noinspection PyUnusedImports +import contextlib +import ctypes +import sys + + +@contextlib.contextmanager +def dl_open_guard(): + old_flags = sys.getdlopenflags() + sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL) + yield + sys.setdlopenflags(old_flags) + + +with dl_open_guard(): + # noinspection PyUnresolvedReferences + from mlir._mlir_libs import _mlir + from mlir import ir + +from mlir_standreallyalone.dialects import standalonereallyalone as standalone_d + +with ir.Context() as ctx: + standalone_d.register_dialects() + module = ir.Module.parse( + """ + %0 = arith.constant 2 : i32 + %1 = standalone.foo %0 : i32 + """ + ) + # CHECK: %[[C:.*]] = arith.constant 2 : i32 + # CHECK: standalone.foo %[[C]] : i32 + print(str(module)) + +# just so lit doesn't complain about this file +# UNSUPPORTED: target={{.*}} diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index edbd73eade906..e712fa7780e26 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -23,7 +23,6 @@ #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" -#include "mlir-c/Transforms.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" #include "llvm/ADT/DenseMap.h" diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 9f5246de6bda0..aae8ebb54cb73 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -528,6 +528,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core # Dialects MLIRCAPIFunc + MLIRCAPIArith ) # This extension exposes an API to register all dialects, extensions, and passes diff --git a/mlir/test/Examples/standalone/test.really.toy b/mlir/test/Examples/standalone/test.really.toy new file mode 100644 index 0000000000000..7ed6ef765bf22 --- /dev/null +++ b/mlir/test/Examples/standalone/test.really.toy @@ -0,0 +1,15 @@ +# RUN: "%cmake_exe" "%mlir_src_root/examples/standalone" -G "%cmake_generator" \ +# RUN: -DCMAKE_BUILD_TYPE=%cmake_build_type \ +# RUN: -DCMAKE_CXX_COMPILER=%host_cxx -DCMAKE_C_COMPILER=%host_cc \ +# RUN: -DLLVM_ENABLE_LIBCXX=%enable_libcxx -DMLIR_DIR=%mlir_cmake_dir \ +# RUN: -DLLVM_USE_LINKER=%llvm_use_linker \ +# RUN: -DMLIR_PYTHON_PACKAGE_PREFIX=mlir \ +# RUN: -DMLIR_STANDALONE_REALLY=ON \ +# RUN: -DPython3_EXECUTABLE=%python \ +# RUN: -DPython_EXECUTABLE=%python | tee %t +# RUN: "%cmake_exe" --build . --target StandReallyAlonePythonModules | tee -a %t +# RUN: %python "%mlir_src_root/examples/standalone/test/python/smoketest_really_alone.py" | tee -a %t +# RUN: FileCheck --input-file=%t %s + +# CHECK: %[[C:.*]] = arith.constant 2 : i32 +# CHECK: standalone.foo %[[C]] : i32 diff --git a/mlir/test/Examples/standalone/test.toy b/mlir/test/Examples/standalone/test.toy index a88c115ebf197..f00dd38c55d50 100644 --- a/mlir/test/Examples/standalone/test.toy +++ b/mlir/test/Examples/standalone/test.toy @@ -5,8 +5,8 @@ # RUN: -DLLVM_USE_LINKER=%llvm_use_linker \ # RUN: -DMLIR_PYTHON_PACKAGE_PREFIX=mlir_standalone \ # RUN: -DPython3_EXECUTABLE=%python \ -# RUN: -DPython_EXECUTABLE=%python -# RUN: "%cmake_exe" --build . --target check-standalone | tee %t +# RUN: -DPython_EXECUTABLE=%python | tee %t +# RUN: "%cmake_exe" --build . --target check-standalone | tee -a %t # RUN: FileCheck --input-file=%t %s # Note: The number of checked tests is not important. The command will fail diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py index f99c24d6e299a..138d66d84ab74 100644 --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -339,6 +339,14 @@ def find_real_python_interpreter(): [ os.path.join(config.mlir_obj_root, "python_packages", "mlir_core"), os.path.join(config.mlir_obj_root, "python_packages", "mlir_test"), + os.path.join( + config.mlir_obj_root, + "test", + "Examples", + "standalone", + "python_packages", + "standalone", + ), ], append_path=True, ) @@ -348,6 +356,7 @@ def find_real_python_interpreter(): else: config.available_features.add("noasserts") + def have_host_jit_feature_support(feature_name): mlir_runner_exe = lit.util.which("mlir-runner", config.mlir_tools_dir)