diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 278847e7ac7f5..1c9705407b232 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -52,9 +52,14 @@ NB_MODULE(_mlir, m) { [](PyGlobals &self, bool enabled) { self.getTracebackLoc().setLocTracebacksEnabled(enabled); }) + .def("loc_tracebacks_frame_limit", + [](PyGlobals &self) { + return self.getTracebackLoc().locTracebackFramesLimit(); + }) .def("set_loc_tracebacks_frame_limit", - [](PyGlobals &self, int n) { - self.getTracebackLoc().setLocTracebackFramesLimit(n); + [](PyGlobals &self, std::optional n) { + self.getTracebackLoc().setLocTracebackFramesLimit( + n.value_or(PyGlobals::TracebackLoc::kMaxFrames)); }) .def("register_traceback_file_inclusion", [](PyGlobals &self, const std::string &filename) { diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 6f37266d5bf39..7ddc70a35af96 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -2,9 +2,18 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from __future__ import annotations + +from collections.abc import Iterable +from contextlib import contextmanager + from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug -from ._mlir_libs._mlir import register_type_caster, register_value_caster +from ._mlir_libs._mlir import ( + register_type_caster, + register_value_caster, + globals, +) from ._mlir_libs import ( get_dialect_registry, append_load_on_create_dialect, @@ -12,6 +21,30 @@ ) +@contextmanager +def loc_tracebacks(*, max_depth: int | None = None) -> Iterable[None]: + """Enables automatic traceback-based locations for MLIR operations. + + Operations created within this context will have their location + automatically set based on the Python call stack. + + Args: + max_depth: Maximum number of frames to include in the location. + If None, the default limit is used. + """ + old_enabled = globals.loc_tracebacks_enabled() + old_limit = globals.loc_tracebacks_frame_limit() + try: + globals.set_loc_tracebacks_frame_limit(max_depth) + if not old_enabled: + globals.set_loc_tracebacks_enabled(True) + yield + finally: + if not old_enabled: + globals.set_loc_tracebacks_enabled(False) + globals.set_loc_tracebacks_frame_limit(old_limit) + + # Convenience decorator for registering user-friendly Attribute builders. def register_attribute_builder(kind, replace=False): def decorator_builder(func):