Skip to content

Commit

Permalink
[mlir][python] Capture error diagnostics in exceptions
Browse files Browse the repository at this point in the history
This updates most (all?) error-diagnostic-emitting python APIs to
capture error diagnostics and include them in the raised exception's
message:
```
>>> Operation.parse('"arith.addi"() : () -> ()'))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
mlir._mlir_libs.MLIRError: Unable to parse operation assembly:
error: "-":1:1: 'arith.addi' op requires one result
 note: "-":1:1: see current operation: "arith.addi"() : () -> ()
```

The diagnostic information is available on the exception for users who
may want to customize the error message:
```
>>> try:
...   Operation.parse('"arith.addi"() : () -> ()')
... except MLIRError as e:
...   print(e.message)
...   print(e.error_diagnostics)
...   print(e.error_diagnostics[0].message)
...
Unable to parse operation assembly
[<mlir._mlir_libs._mlir.ir.DiagnosticInfo object at 0x7fed32bd6b70>]
'arith.addi' op requires one result
```

Error diagnostics captured in exceptions aren't propagated to diagnostic
handlers, to avoid double-reporting of errors. The context-level
`emit_error_diagnostics` option can be used to revert to the old
behaviour, causing error diagnostics to be reported to handlers instead
of as part of exceptions.

API changes:
- `Operation.verify` now raises an exception on verification failure,
  instead of returning `false`
- The exception raised by the following methods has been changed to
  `MLIRError`:
  - `PassManager.run`
  - `{Module,Operation,Type,Attribute}.parse`
  - `{RankedTensorType,UnrankedTensorType}.get`
  - `{MemRefType,UnrankedMemRefType}.get`
  - `VectorType.get`
  - `FloatAttr.get`

closes #60595

depends on D144804, D143830

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D143869
  • Loading branch information
rkayaith committed Mar 7, 2023
1 parent 8200848 commit 3ea4c50
Show file tree
Hide file tree
Showing 13 changed files with 360 additions and 168 deletions.
11 changes: 3 additions & 8 deletions mlir/lib/Bindings/Python/IRAttributes.cpp
Expand Up @@ -344,15 +344,10 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
c.def_static(
"get",
[](PyType &type, double value, DefaultingPyLocation loc) {
PyMlirContext::ErrorCapture errors(loc->getContext());
MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirAttributeIsNull(attr)) {
throw SetPyError(PyExc_ValueError,
Twine("invalid '") +
py::repr(py::cast(type)).cast<std::string>() +
"' and expected floating point type.");
}
if (mlirAttributeIsNull(attr))
throw MLIRError("Invalid attribute", errors.take());
return PyFloatAttribute(type.getContext(), attr);
},
py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
Expand Down
122 changes: 82 additions & 40 deletions mlir/lib/Bindings/Python/IRCore.cpp
Expand Up @@ -15,6 +15,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
//#include "mlir-c/Registration.h"
#include "llvm/ADT/ArrayRef.h"
Expand All @@ -38,7 +39,7 @@ using llvm::Twine;
static const char kContextParseTypeDocstring[] =
R"(Parses the assembly form of a type.
Returns a Type object or raises a ValueError if the type cannot be parsed.
Returns a Type object or raises an MLIRError if the type cannot be parsed.
See also: https://mlir.llvm.org/docs/LangRef/#type-system
)";
Expand All @@ -58,7 +59,7 @@ static const char kContextGetNameLocationDocString[] =
static const char kModuleParseDocstring[] =
R"(Parses a module's assembly format from a string.
Returns a new MlirModule or raises a ValueError if the parsing fails.
Returns a new MlirModule or raises an MLIRError if the parsing fails.
See also: https://mlir.llvm.org/docs/LangRef/
)";
Expand Down Expand Up @@ -654,6 +655,20 @@ py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
return pyHandlerObject;
}

MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
void *userData) {
auto *self = static_cast<ErrorCapture *>(userData);
// Check if the context requested we emit errors instead of capturing them.
if (self->ctx->emitErrorDiagnostics)
return mlirLogicalResultFailure();

if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError)
return mlirLogicalResultFailure();

self->errors.emplace_back(PyDiagnostic(diag).getInfo());
return mlirLogicalResultSuccess();
}

PyMlirContext &DefaultingPyMlirContext::resolve() {
PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
if (!context) {
Expand Down Expand Up @@ -870,6 +885,13 @@ py::tuple PyDiagnostic::getNotes() {
return *materializedNotes;
}

PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() {
std::vector<DiagnosticInfo> notes;
for (py::handle n : getNotes())
notes.emplace_back(n.cast<PyDiagnostic>().getInfo());
return {getSeverity(), getLocation(), getMessage(), std::move(notes)};
}

//------------------------------------------------------------------------------
// PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -1062,13 +1084,12 @@ PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
const std::string &sourceStr,
const std::string &sourceName) {
PyMlirContext::ErrorCapture errors(contextRef);
MlirOperation op =
mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
toMlirStringRef(sourceName));
// TODO: Include error diagnostic messages in the exception message
if (mlirOperationIsNull(op))
throw py::value_error(
"Unable to parse operation assembly (see diagnostics)");
throw MLIRError("Unable to parse operation assembly", errors.take());
return PyOperation::createDetached(std::move(contextRef), op);
}

Expand Down Expand Up @@ -1155,6 +1176,14 @@ void PyOperationBase::moveBefore(PyOperationBase &other) {
operation.parentKeepAlive = otherOp.parentKeepAlive;
}

bool PyOperationBase::verify() {
PyOperation &op = getOperation();
PyMlirContext::ErrorCapture errors(op.getContext());
if (!mlirOperationVerify(op.get()))
throw MLIRError("Verification failed", errors.take());
return true;
}

std::optional<PyOperationRef> PyOperation::getParentOperation() {
checkValid();
if (!isAttached())
Expand Down Expand Up @@ -2287,6 +2316,16 @@ void mlir::python::populateIRCore(py::module &m) {
return self.getMessage();
});

py::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo",
py::module_local())
.def(py::init<>([](PyDiagnostic diag) { return diag.getInfo(); }))
.def_readonly("severity", &PyDiagnostic::DiagnosticInfo::severity)
.def_readonly("location", &PyDiagnostic::DiagnosticInfo::location)
.def_readonly("message", &PyDiagnostic::DiagnosticInfo::message)
.def_readonly("notes", &PyDiagnostic::DiagnosticInfo::notes)
.def("__str__",
[](PyDiagnostic::DiagnosticInfo &self) { return self.message; });

py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
.def("detach", &PyDiagnosticHandler::detach)
.def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
Expand Down Expand Up @@ -2375,6 +2414,11 @@ void mlir::python::populateIRCore(py::module &m) {
mlirContextAppendDialectRegistry(self.get(), registry);
},
py::arg("registry"))
.def_property("emit_error_diagnostics", nullptr,
&PyMlirContext::setEmitErrorDiagnostics,
"Emit error diagnostics to diagnostic handlers. By default "
"error diagnostics are captured and reported through "
"MLIRError exceptions.")
.def("load_all_available_dialects", [](PyMlirContext &self) {
mlirContextLoadAllAvailableDialects(self.get());
});
Expand Down Expand Up @@ -2566,16 +2610,12 @@ void mlir::python::populateIRCore(py::module &m) {
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
.def_static(
"parse",
[](const std::string moduleAsm, DefaultingPyMlirContext context) {
[](const std::string &moduleAsm, DefaultingPyMlirContext context) {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirModule module = mlirModuleCreateParse(
context->get(), toMlirStringRef(moduleAsm));
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirModuleIsNull(module)) {
throw SetPyError(
PyExc_ValueError,
"Unable to parse module assembly (see diagnostics)");
}
if (mlirModuleIsNull(module))
throw MLIRError("Unable to parse module assembly", errors.take());
return PyModule::forModule(module).releaseObject();
},
py::arg("asm"), py::arg("context") = py::none(),
Expand Down Expand Up @@ -2724,13 +2764,9 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("print_generic_op_form") = false,
py::arg("use_local_scope") = false,
py::arg("assume_verified") = false, kOperationGetAsmDocstring)
.def(
"verify",
[](PyOperationBase &self) {
return mlirOperationVerify(self.getOperation());
},
"Verify the operation and return true if it passes, false if it "
"fails.")
.def("verify", &PyOperationBase::verify,
"Verify the operation. Raises MLIRError if verification fails, and "
"returns true otherwise.")
.def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
"Puts self immediately after the other operation in its parent "
"block.")
Expand Down Expand Up @@ -2833,12 +2869,12 @@ void mlir::python::populateIRCore(py::module &m) {
// directly.
std::string clsOpName =
py::cast<std::string>(cls.attr("OPERATION_NAME"));
MlirStringRef parsedOpName =
MlirStringRef identifier =
mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
if (!mlirStringRefEqual(parsedOpName, toMlirStringRef(clsOpName)))
throw py::value_error(
"Expected a '" + clsOpName + "' op, got: '" +
std::string(parsedOpName.data, parsedOpName.length) + "'");
std::string_view parsedOpName(identifier.data, identifier.length);
if (clsOpName != parsedOpName)
throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
parsedOpName + "'");
return PyOpView::constructDerived(cls, *parsed.get());
},
py::arg("cls"), py::arg("source"), py::kw_only(),
Expand Down Expand Up @@ -3071,19 +3107,16 @@ void mlir::python::populateIRCore(py::module &m) {
.def_static(
"parse",
[](std::string attrSpec, DefaultingPyMlirContext context) {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirAttribute type = mlirAttributeParseGet(
context->get(), toMlirStringRef(attrSpec));
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirAttributeIsNull(type)) {
throw SetPyError(PyExc_ValueError,
Twine("Unable to parse attribute: '") +
attrSpec + "'");
}
if (mlirAttributeIsNull(type))
throw MLIRError("Unable to parse attribute", errors.take());
return PyAttribute(context->getRef(), type);
},
py::arg("asm"), py::arg("context") = py::none(),
"Parses an attribute from an assembly form")
"Parses an attribute from an assembly form. Raises an MLIRError on "
"failure.")
.def_property_readonly(
"context",
[](PyAttribute &self) { return self.getContext().getObject(); },
Expand Down Expand Up @@ -3182,15 +3215,11 @@ void mlir::python::populateIRCore(py::module &m) {
.def_static(
"parse",
[](std::string typeSpec, DefaultingPyMlirContext context) {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirType type =
mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(type)) {
throw SetPyError(PyExc_ValueError,
Twine("Unable to parse type: '") + typeSpec +
"'");
}
if (mlirTypeIsNull(type))
throw MLIRError("Unable to parse type", errors.take());
return PyType(context->getRef(), type);
},
py::arg("asm"), py::arg("context") = py::none(),
Expand Down Expand Up @@ -3342,4 +3371,17 @@ void mlir::python::populateIRCore(py::module &m) {

// Attribute builder getter.
PyAttrBuilderMap::bind(m);

py::register_local_exception_translator([](std::exception_ptr p) {
// We can't define exceptions with custom fields through pybind, so instead
// the exception class is defined in python and imported here.
try {
if (p)
std::rethrow_exception(p);
} catch (const MLIRError &e) {
py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("MLIRError")(e.message, e.errorDiagnostics);
PyErr_SetObject(PyExc_Exception, obj.ptr());
}
});
}

0 comments on commit 3ea4c50

Please sign in to comment.