-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR] [Python] Added a context manager for enabling traceback-based locations #157562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Sergei Lebedev (superbobry) ChangesI also added the relevant APIs to the .pyi file to make sure they are visible to static analysis tooling. Full diff: https://github.com/llvm/llvm-project/pull/157562.diff 3 Files Affected:
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 278847e7ac7f5..1c9705407b232 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -52,9 +52,14 @@ NB_MODULE(_mlir, m) {
[](PyGlobals &self, bool enabled) {
self.getTracebackLoc().setLocTracebacksEnabled(enabled);
})
+ .def("loc_tracebacks_frame_limit",
+ [](PyGlobals &self) {
+ return self.getTracebackLoc().locTracebackFramesLimit();
+ })
.def("set_loc_tracebacks_frame_limit",
- [](PyGlobals &self, int n) {
- self.getTracebackLoc().setLocTracebackFramesLimit(n);
+ [](PyGlobals &self, std::optional<int> n) {
+ self.getTracebackLoc().setLocTracebackFramesLimit(
+ n.value_or(PyGlobals::TracebackLoc::kMaxFrames));
})
.def("register_traceback_file_inclusion",
[](PyGlobals &self, const std::string &filename) {
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
index 03449b70b7fa3..1963bde006fce 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
@@ -7,6 +7,12 @@ class _Globals:
def _register_operation_impl(self, operation_name: str, operation_class: type) -> None: ...
def append_dialect_search_prefix(self, module_name: str) -> None: ...
def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ...
+ def loc_tracebacks_enabled(self) -> bool: ...
+ def set_loc_tracebacks_enabled(self, enabled: bool, /) -> None: ...
+ def loc_tracebacks_frame_limit(self) -> int: ...
+ def set_loc_tracebacks_frame_limit(self, n: int, /) -> None: ...
+ def register_traceback_file_inclusion(self, filename: str, /) -> None: ...
+ def register_traceback_file_exclusion(self, filename: str, /) -> None: ...
def register_dialect(dialect_class: type) -> type: ...
def register_operation(dialect_class: type, *, replace: bool = ...) -> type: ...
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 6f37266d5bf39..9d21c734efb83 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -2,9 +2,16 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from collections.abc import Iterable
+from contextlib import contextmanager
+
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
-from ._mlir_libs._mlir import register_type_caster, register_value_caster
+from ._mlir_libs._mlir import (
+ register_type_caster,
+ register_value_caster,
+ globals,
+)
from ._mlir_libs import (
get_dialect_registry,
append_load_on_create_dialect,
@@ -12,6 +19,30 @@
)
+@contextmanager
+def loc_tracebacks(*, max_depth: int | None = None) -> Iterable[None]:
+ """Enables automatic traceback-based locations for MLIR operations.
+
+ Operations created within this context will have their location
+ automatically set based on the Python call stack.
+
+ Args:
+ max_depth: Maximum number of frames to include in the location.
+ If None, the default limit is used.
+ """
+ old_enabled = globals.loc_tracebacks_enabled()
+ old_limit = globals.loc_tracebacks_frame_limit()
+ try:
+ globals.set_loc_tracebacks_frame_limit(max_depth)
+ if not old_enabled:
+ globals.set_loc_tracebacks_enabled(True)
+ yield
+ finally:
+ if not old_enabled:
+ globals.set_loc_tracebacks_enabled(False)
+ globals.set_loc_tracebacks_frame_limit(old_limit)
+
+
# Convenience decorator for registering user-friendly Attribute builders.
def register_attribute_builder(kind, replace=False):
def decorator_builder(func):
|
40d2c19
to
c078a70
Compare
✅ With the latest revision this PR passed the Python code formatter. |
c078a70
to
9ba89c3
Compare
9ba89c3
to
8d24e67
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Can you merge, please? :) |
Previously this functionality was not surfaced in the public API.