diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 8d30051d615f4..fdff79f8b67dd 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -363,6 +363,12 @@ MLIR_CAPI_EXPORTED bool mlirLocationIsAName(MlirLocation location); /// Creates a location with unknown position owned by the given context. MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context); +/// TypeID Getter for Unknown. +MLIR_CAPI_EXPORTED MlirTypeID mlirLocationUnknownGetTypeID(void); + +/// Checks whether the given location is an Unknown. +MLIR_CAPI_EXPORTED bool mlirLocationIsAUnknown(MlirLocation location); + /// Gets the context that a location was created with. MLIR_CAPI_EXPORTED MlirContext mlirLocationGetContext(MlirLocation location); diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h index b2edcace7298e..0b758d4061d67 100644 --- a/mlir/include/mlir/Bindings/Python/IRCore.h +++ b/mlir/include/mlir/Bindings/Python/IRCore.h @@ -337,6 +337,10 @@ class MLIR_PYTHON_API_EXPORTED PyLocation : public BaseContextObject { /// is taken by calling this function. static PyLocation createFromCapsule(nanobind::object capsule); + /// Returns the most-derived Location subclass registered for this TypeID, + /// or self. + nanobind::typed maybeDownCast(); + private: MlirLocation loc; }; @@ -1164,6 +1168,131 @@ class MLIR_PYTHON_API_EXPORTED PyStringAttribute static void bindDerived(ClassTy &c); }; +/// CRTP base class for Python classes that subclass Location and should be +/// castable from it (i.e. via something like FileLineColLoc(loc)). +template +class MLIR_PYTHON_API_EXPORTED PyConcreteLocation : public BaseTy { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + using ClassTy = nanobind::class_; + using IsAFunctionTy = bool (*)(MlirLocation); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; + using Base = PyConcreteLocation; + + PyConcreteLocation() = default; + PyConcreteLocation(PyMlirContextRef contextRef, MlirLocation loc) + : BaseTy(std::move(contextRef), loc) {} + PyConcreteLocation(PyLocation &orig) + : PyConcreteLocation(orig.getContext(), castFrom(orig)) {} + + static MlirLocation castFrom(PyLocation &orig) { + if (!DerivedTy::isaFunction(orig.get())) { + auto origRepr = + nanobind::cast(nanobind::repr(nanobind::cast(orig))); + throw nanobind::value_error((std::string("Cannot cast location to ") + + DerivedTy::pyClassName + " (from " + + origRepr + ")") + .c_str()); + } + return orig.get(); + } + + static void bind(nanobind::module_ &m) { + ClassTy cls(m, DerivedTy::pyClassName, nanobind::is_generic()); + cls.def(nanobind::init(), nanobind::keep_alive<0, 1>(), + nanobind::arg("cast_from_loc")); + cls.def_prop_ro_static("static_typeid", [](nanobind::object & /*class*/) { + if (DerivedTy::getTypeIdFunction) + return PyTypeID(DerivedTy::getTypeIdFunction()); + throw nanobind::attribute_error( + (DerivedTy::pyClassName + std::string(" has no typeid.")).c_str()); + }); + cls.def("__repr__", [](DerivedTy &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append(DerivedTy::pyClassName); + printAccum.parts.append("("); + mlirLocationPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }); + if (DerivedTy::getTypeIdFunction) { + PyGlobals::get().registerTypeCaster( + DerivedTy::getTypeIdFunction(), + nanobind::cast(nanobind::cpp_function( + [](PyLocation pyLoc) -> DerivedTy { return pyLoc; })), + /*replace*/ true); + } + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + +class MLIR_PYTHON_API_EXPORTED PyUnknownLocation + : public PyConcreteLocation { +public: + static constexpr IsAFunctionTy isaFunction = mlirLocationIsAUnknown; + static constexpr const char *pyClassName = "UnknownLoc"; + using PyConcreteLocation::PyConcreteLocation; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirLocationUnknownGetTypeID; + + static void bindDerived(ClassTy &c); +}; + +class MLIR_PYTHON_API_EXPORTED PyFileLineColLocation + : public PyConcreteLocation { +public: + static constexpr IsAFunctionTy isaFunction = mlirLocationIsAFileLineColRange; + static constexpr const char *pyClassName = "FileLineColLoc"; + using PyConcreteLocation::PyConcreteLocation; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirLocationFileLineColRangeGetTypeID; + + static void bindDerived(ClassTy &c); +}; + +class MLIR_PYTHON_API_EXPORTED PyNameLocation + : public PyConcreteLocation { +public: + static constexpr IsAFunctionTy isaFunction = mlirLocationIsAName; + static constexpr const char *pyClassName = "NameLoc"; + using PyConcreteLocation::PyConcreteLocation; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirLocationNameGetTypeID; + + static void bindDerived(ClassTy &c); +}; + +class MLIR_PYTHON_API_EXPORTED PyCallSiteLocation + : public PyConcreteLocation { +public: + static constexpr IsAFunctionTy isaFunction = mlirLocationIsACallSite; + static constexpr const char *pyClassName = "CallSiteLoc"; + using PyConcreteLocation::PyConcreteLocation; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirLocationCallSiteGetTypeID; + + static void bindDerived(ClassTy &c); +}; + +class MLIR_PYTHON_API_EXPORTED PyFusedLocation + : public PyConcreteLocation { +public: + static constexpr IsAFunctionTy isaFunction = mlirLocationIsAFused; + static constexpr const char *pyClassName = "FusedLoc"; + using PyConcreteLocation::PyConcreteLocation; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirLocationFusedGetTypeID; + + static void bindDerived(ClassTy &c); +}; + /// Wrapper around the generic MlirValue. /// Values are managed completely by the operation that resulted in their /// definition. For op result value, this is the operation that defines the diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 0fa84a11f2f35..f92d4c14ceb16 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1888,6 +1888,23 @@ nb::typed PyAttribute::maybeDownCast() { return typeCaster.value()(thisObj); } +//------------------------------------------------------------------------------ +// PyLocation::maybeDownCast. +//------------------------------------------------------------------------------ + +nb::typed PyLocation::maybeDownCast() { + MlirAttribute locAttr = mlirLocationGetAttribute(this->get()); + MlirTypeID mlirTypeID = mlirAttributeGetTypeID(locAttr); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + std::optional typeCaster = PyGlobals::get().lookupTypeCaster( + mlirTypeID, mlirAttributeGetDialect(locAttr)); + nb::object thisObj = nb::cast(this, nb::rv_policy::move); + if (!typeCaster) + return thisObj; + return typeCaster.value()(thisObj); +} + //------------------------------------------------------------------------------ // PyNamedAttribute. //------------------------------------------------------------------------------ @@ -3051,6 +3068,191 @@ void populateRoot(nb::module_ &m) { "Register a value caster for casting MLIR values to custom user values."); } +//------------------------------------------------------------------------------ +// Location subclass bindDerived implementations. +//------------------------------------------------------------------------------ + +void PyUnknownLocation::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + return PyUnknownLocation(context->getRef(), + mlirLocationUnknownGet(context->get())); + }, + "context"_a = nb::none(), + "Gets a Location representing an unknown location."); +} + +void PyFileLineColLocation::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::string filename, int line, int col, + DefaultingPyMlirContext context) { + return PyFileLineColLocation( + context->getRef(), + mlirLocationFileLineColGet(context->get(), + toMlirStringRef(filename), line, col)); + }, + "filename"_a, "line"_a, "col"_a, "context"_a = nb::none(), + "Gets a FileLineColLoc for a file, line, and column."); + c.def_static( + "get", + [](std::string filename, int startLine, int startCol, int endLine, + int endCol, DefaultingPyMlirContext context) { + return PyFileLineColLocation( + context->getRef(), mlirLocationFileLineColRangeGet( + context->get(), toMlirStringRef(filename), + startLine, startCol, endLine, endCol)); + }, + "filename"_a, "start_line"_a, "start_col"_a, "end_line"_a, "end_col"_a, + "context"_a = nb::none(), + "Gets a FileLineColLoc spanning a file and line/column range."); + c.def_prop_ro( + "filename", + [](PyFileLineColLocation &self) { + return mlirIdentifierStr( + mlirLocationFileLineColRangeGetFilename(self.get())); + }, + "Gets the filename from a `FileLineColLoc`."); + c.def_prop_ro( + "start_line", + [](PyFileLineColLocation &self) { + return mlirLocationFileLineColRangeGetStartLine(self.get()); + }, + "Gets the start line number from a `FileLineColLoc`."); + c.def_prop_ro( + "start_col", + [](PyFileLineColLocation &self) { + return mlirLocationFileLineColRangeGetStartColumn(self.get()); + }, + "Gets the start column number from a `FileLineColLoc`."); + c.def_prop_ro( + "end_line", + [](PyFileLineColLocation &self) { + return mlirLocationFileLineColRangeGetEndLine(self.get()); + }, + "Gets the end line number from a `FileLineColLoc`."); + c.def_prop_ro( + "end_col", + [](PyFileLineColLocation &self) { + return mlirLocationFileLineColRangeGetEndColumn(self.get()); + }, + "Gets the end column number from a `FileLineColLoc`."); +} + +void PyNameLocation::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::string name, std::optional childLoc, + DefaultingPyMlirContext context) { + return PyNameLocation( + context->getRef(), + mlirLocationNameGet(context->get(), toMlirStringRef(name), + childLoc + ? childLoc->get() + : mlirLocationUnknownGet(context->get()))); + }, + "name"_a, "child_loc"_a = nb::none(), "context"_a = nb::none(), + "Gets a NameLoc with an optional child location."); + c.def_prop_ro( + "name_str", + [](PyNameLocation &self) { + return mlirIdentifierStr(mlirLocationNameGetName(self.get())); + }, + "Gets the name string from a `NameLoc`."); + c.def_prop_ro( + "child_loc", + [](PyNameLocation &self) { + return PyLocation(self.getContext(), + mlirLocationNameGetChildLoc(self.get())) + .maybeDownCast(); + }, + "Gets the child location from a `NameLoc`."); +} + +void PyCallSiteLocation::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyLocation callee, const std::vector &frames, + DefaultingPyMlirContext context) { + if (frames.empty()) + throw nb::value_error("No caller frames provided."); + MlirLocation caller = frames.back().get(); + for (size_t index = frames.size() - 1; index-- > 0;) { + caller = mlirLocationCallSiteGet(frames[index].get(), caller); + } + return PyCallSiteLocation( + context->getRef(), mlirLocationCallSiteGet(callee.get(), caller)); + }, + "callee"_a, "frames"_a, "context"_a = nb::none(), + "Gets a CallSiteLoc chaining a callee and one or more caller frames."); + c.def_prop_ro( + "callee", + [](PyCallSiteLocation &self) { + return PyLocation(self.getContext(), + mlirLocationCallSiteGetCallee(self.get())) + .maybeDownCast(); + }, + "Gets the callee location from a `CallSiteLoc`."); + c.def_prop_ro( + "caller", + [](PyCallSiteLocation &self) { + return PyLocation(self.getContext(), + mlirLocationCallSiteGetCaller(self.get())) + .maybeDownCast(); + }, + "Gets the caller location from a `CallSiteLoc`."); +} + +void PyFusedLocation::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const std::vector &pyLocations, + std::optional metadata, DefaultingPyMlirContext context) { + std::vector locations; + locations.reserve(pyLocations.size()); + for (const PyLocation &pyLocation : pyLocations) + locations.push_back(pyLocation.get()); + MlirLocation location = mlirLocationFusedGet( + context->get(), locations.size(), locations.data(), + metadata ? metadata->get() : MlirAttribute{0}); + // Strict: `Location.fused(...)` handles the collapse case. + if (!mlirLocationIsAFused(location)) + throw nb::value_error( + "FusedLoc.get would collapse to a non-fused location; use " + "Location.fused(...) for the permissive variant."); + return PyFusedLocation(context->getRef(), location); + }, + "locations"_a, "metadata"_a = nb::none(), "context"_a = nb::none(), + "Gets a FusedLoc from an array of locations and optional metadata. " + "Raises if the fuse would collapse to a non-fused location; use " + "`Location.fused(...)` for the permissive variant."); + c.def_prop_ro( + "locations", + [](PyFusedLocation &self) { + unsigned numLocations = mlirLocationFusedGetNumLocations(self.get()); + std::vector locations(numLocations); + if (numLocations) + mlirLocationFusedGetLocations(self.get(), locations.data()); + std::vector pyLocations; + pyLocations.reserve(numLocations); + for (unsigned i = 0; i < numLocations; ++i) + pyLocations.push_back( + PyLocation(self.getContext(), locations[i]).maybeDownCast()); + return pyLocations; + }, + "Gets the list of locations from a `FusedLoc`."); + c.def_prop_ro( + "metadata", + [](PyFusedLocation &self) -> std::optional { + MlirAttribute metadata = mlirLocationFusedGetMetadata(self.get()); + if (mlirAttributeIsNull(metadata)) + return std::nullopt; + return PyAttribute(self.getContext(), metadata); + }, + "Gets the metadata attribute from a `FusedLoc`, or None if absent."); +} + //------------------------------------------------------------------------------ // Populates the core exports of the 'ir' submodule. //------------------------------------------------------------------------------ @@ -3423,122 +3625,52 @@ void populateIRCore(nb::module_ &m) { // clang-format on "Gets the Location bound to the current thread or raises ValueError.") .def_static( - "unknown", - [](DefaultingPyMlirContext context) { + "from_attr", + [](PyAttribute &attribute, DefaultingPyMlirContext context) { return PyLocation(context->getRef(), - mlirLocationUnknownGet(context->get())); + mlirLocationFromAttribute(attribute)) + .maybeDownCast(); }, - "context"_a = nb::none(), - "Gets a Location representing an unknown location.") + "attribute"_a, "context"_a = nb::none(), + "Gets a Location from a `LocationAttr`.") + // Factory shims kept for backward compatibility; return the concrete + // subclass. New code should use the subclass `.get()` directly. .def_static( - "callsite", - [](PyLocation callee, const std::vector &frames, - DefaultingPyMlirContext context) { - if (frames.empty()) - throw nb::value_error("No caller frames provided."); - MlirLocation caller = frames.back().get(); - for (size_t index = frames.size() - 1; index-- > 0;) { - caller = mlirLocationCallSiteGet(frames[index].get(), caller); - } - return PyLocation(context->getRef(), - mlirLocationCallSiteGet(callee.get(), caller)); - }, - "callee"_a, "frames"_a, "context"_a = nb::none(), - "Gets a Location representing a caller and callsite.") - .def("is_a_callsite", mlirLocationIsACallSite, - "Returns True if this location is a CallSiteLoc.") - .def_prop_ro( - "callee", - [](PyLocation &self) { - return PyLocation(self.getContext(), - mlirLocationCallSiteGetCallee(self)); - }, - "Gets the callee location from a CallSiteLoc.") - .def_prop_ro( - "caller", - [](PyLocation &self) { - return PyLocation(self.getContext(), - mlirLocationCallSiteGetCaller(self)); + "unknown", + [](DefaultingPyMlirContext context) { + return PyUnknownLocation(context->getRef(), + mlirLocationUnknownGet(context->get())); }, - "Gets the caller location from a CallSiteLoc.") + "context"_a = nb::none(), "Alias for `UnknownLoc.get()`.") .def_static( "file", [](std::string filename, int line, int col, DefaultingPyMlirContext context) { - return PyLocation( + return PyFileLineColLocation( context->getRef(), mlirLocationFileLineColGet( context->get(), toMlirStringRef(filename), line, col)); }, "filename"_a, "line"_a, "col"_a, "context"_a = nb::none(), - "Gets a Location representing a file, line and column.") + "Alias for `FileLineColLoc.get()`.") .def_static( "file", [](std::string filename, int startLine, int startCol, int endLine, int endCol, DefaultingPyMlirContext context) { - return PyLocation(context->getRef(), - mlirLocationFileLineColRangeGet( - context->get(), toMlirStringRef(filename), - startLine, startCol, endLine, endCol)); + return PyFileLineColLocation( + context->getRef(), + mlirLocationFileLineColRangeGet( + context->get(), toMlirStringRef(filename), startLine, + startCol, endLine, endCol)); }, "filename"_a, "start_line"_a, "start_col"_a, "end_line"_a, "end_col"_a, "context"_a = nb::none(), - "Gets a Location representing a file, line and column range.") - .def("is_a_file", mlirLocationIsAFileLineColRange, - "Returns True if this location is a FileLineColLoc.") - .def_prop_ro( - "filename", - [](PyLocation loc) { - return mlirIdentifierStr( - mlirLocationFileLineColRangeGetFilename(loc)); - }, - "Gets the filename from a FileLineColLoc.") - .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine, - "Gets the start line number from a `FileLineColLoc`.") - .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn, - "Gets the start column number from a `FileLineColLoc`.") - .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine, - "Gets the end line number from a `FileLineColLoc`.") - .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn, - "Gets the end column number from a `FileLineColLoc`.") - .def_static( - "fused", - [](const std::vector &pyLocations, - std::optional metadata, - DefaultingPyMlirContext context) { - std::vector locations; - locations.reserve(pyLocations.size()); - for (const PyLocation &pyLocation : pyLocations) - locations.push_back(pyLocation.get()); - MlirLocation location = mlirLocationFusedGet( - context->get(), locations.size(), locations.data(), - metadata ? metadata->get() : MlirAttribute{0}); - return PyLocation(context->getRef(), location); - }, - "locations"_a, "metadata"_a = nb::none(), "context"_a = nb::none(), - "Gets a Location representing a fused location with optional " - "metadata.") - .def("is_a_fused", mlirLocationIsAFused, - "Returns True if this location is a `FusedLoc`.") - .def_prop_ro( - "locations", - [](PyLocation &self) { - unsigned numLocations = mlirLocationFusedGetNumLocations(self); - std::vector locations(numLocations); - if (numLocations) - mlirLocationFusedGetLocations(self, locations.data()); - std::vector pyLocations{}; - pyLocations.reserve(numLocations); - for (unsigned i = 0; i < numLocations; ++i) - pyLocations.emplace_back(self.getContext(), locations[i]); - return pyLocations; - }, - "Gets the list of locations from a `FusedLoc`.") + "Alias for `FileLineColLoc.get()` over a range.") .def_static( "name", [](std::string name, std::optional childLoc, DefaultingPyMlirContext context) { - return PyLocation( + return PyNameLocation( context->getRef(), mlirLocationNameGet( context->get(), toMlirStringRef(name), @@ -3546,31 +3678,38 @@ void populateIRCore(nb::module_ &m) { : mlirLocationUnknownGet(context->get()))); }, "name"_a, "childLoc"_a = nb::none(), "context"_a = nb::none(), - "Gets a Location representing a named location with optional child " - "location.") - .def("is_a_name", mlirLocationIsAName, - "Returns True if this location is a `NameLoc`.") - .def_prop_ro( - "name_str", - [](PyLocation loc) { - return mlirIdentifierStr(mlirLocationNameGetName(loc)); - }, - "Gets the name string from a `NameLoc`.") - .def_prop_ro( - "child_loc", - [](PyLocation &self) { - return PyLocation(self.getContext(), - mlirLocationNameGetChildLoc(self)); + "Alias for `NameLoc.get()`.") + .def_static( + "callsite", + [](PyLocation callee, const std::vector &frames, + DefaultingPyMlirContext context) { + if (frames.empty()) + throw nb::value_error("No caller frames provided."); + MlirLocation caller = frames.back().get(); + for (size_t index = frames.size() - 1; index-- > 0;) + caller = mlirLocationCallSiteGet(frames[index].get(), caller); + return PyCallSiteLocation( + context->getRef(), + mlirLocationCallSiteGet(callee.get(), caller)); }, - "Gets the child location from a `NameLoc`.") + "callee"_a, "frames"_a, "context"_a = nb::none(), + "Alias for `CallSiteLoc.get()`.") .def_static( - "from_attr", - [](PyAttribute &attribute, DefaultingPyMlirContext context) { - return PyLocation(context->getRef(), - mlirLocationFromAttribute(attribute)); + "fused", + [](const std::vector &pyLocations, + std::optional metadata, + DefaultingPyMlirContext context) { + std::vector locations; + locations.reserve(pyLocations.size()); + for (const PyLocation &pyLocation : pyLocations) + locations.push_back(pyLocation.get()); + MlirLocation location = mlirLocationFusedGet( + context->get(), locations.size(), locations.data(), + metadata ? metadata->get() : MlirAttribute{0}); + return PyLocation(context->getRef(), location).maybeDownCast(); }, - "attribute"_a, "context"_a = nb::none(), - "Gets a Location from a `LocationAttr`.") + "locations"_a, "metadata"_a = nb::none(), "context"_a = nb::none(), + "Alias for `FusedLoc.get()` (may collapse to a non-fused location).") .def_prop_ro( "context", [](PyLocation &self) -> nb::typed { @@ -3584,6 +3723,16 @@ void populateIRCore(nb::module_ &m) { mlirLocationGetAttribute(self)); }, "Get the underlying `LocationAttr`.") + .def_prop_ro( + "typeid", + [](PyLocation &self) { + MlirTypeID mlirTypeID = + mlirAttributeGetTypeID(mlirLocationGetAttribute(self.get())); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + return PyTypeID(mlirTypeID); + }, + "Gets the `TypeID` of the underlying LocationAttr.") .def( "emit_error", [](PyLocation &self, std::string message) { @@ -3595,6 +3744,15 @@ void populateIRCore(nb::module_ &m) { Args: message: The error message to emit.)") + .def( + "__str__", + [](PyLocation &self) { + PyPrintAccumulator printAccum; + mlirLocationPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }, + "Returns the assembly form of the Location.") .def( "__repr__", [](PyLocation &self) { @@ -3605,6 +3763,12 @@ void populateIRCore(nb::module_ &m) { }, "Returns the assembly representation of the location."); + PyUnknownLocation::bind(m); + PyFileLineColLocation::bind(m); + PyNameLocation::bind(m); + PyCallSiteLocation::bind(m); + PyFusedLocation::bind(m); + //---------------------------------------------------------------------------- // Mapping of Module //---------------------------------------------------------------------------- @@ -3813,7 +3977,8 @@ void populateIRCore(nb::module_ &m) { [](PyOperationBase &self) { PyOperation &operation = self.getOperation(); return PyLocation(operation.getContext(), - mlirOperationGetLocation(operation.get())); + mlirOperationGetLocation(operation.get())) + .maybeDownCast(); }, [](PyOperationBase &self, const PyLocation &location) { PyOperation &operation = self.getOperation(); @@ -4991,8 +5156,9 @@ void populateIRCore(nb::module_ &m) { "location", [](PyValue self) { return PyLocation( - PyMlirContext::forContext(mlirValueGetContext(self)), - mlirValueGetLocation(self)); + PyMlirContext::forContext(mlirValueGetContext(self)), + mlirValueGetLocation(self)) + .maybeDownCast(); }, "Returns the source location of the value."); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 481130e5069db..28ca830270366 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -403,6 +403,14 @@ MlirLocation mlirLocationUnknownGet(MlirContext context) { return wrap(Location(UnknownLoc::get(unwrap(context)))); } +MlirTypeID mlirLocationUnknownGetTypeID() { + return wrap(UnknownLoc::getTypeID()); +} + +bool mlirLocationIsAUnknown(MlirLocation location) { + return isa(unwrap(location)); +} + bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) { return unwrap(l1) == unwrap(l2); } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index e66c931383f89..721d68ab781b7 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -2376,6 +2376,24 @@ void testExplicitThreadPools(void) { mlirLlvmThreadPoolDestroy(threadPool); } +void testLocation(void) { + MlirContext ctx = mlirContextCreate(); + fprintf(stderr, "@test_location\n"); + + MlirLocation unknownLoc = mlirLocationUnknownGet(ctx); + MlirLocation fileLoc = mlirLocationFileLineColGet( + ctx, mlirStringRefCreateFromCString("foo.c"), 1, 2); + + // CHECK-LABEL: @test_location + // CHECK: unknown is_a_unknown: 1 + fprintf(stderr, "unknown is_a_unknown: %d\n", + mlirLocationIsAUnknown(unknownLoc)); + // CHECK: file is_a_unknown: 0 + fprintf(stderr, "file is_a_unknown: %d\n", mlirLocationIsAUnknown(fileLoc)); + + mlirContextDestroy(ctx); +} + void testDiagnostics(void) { MlirContext ctx = mlirContextCreate(); MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler( @@ -2552,6 +2570,7 @@ int main(void) { return 16; testExplicitThreadPools(); + testLocation(); testDiagnostics(); if (testBlockPredecessorsSuccessors(ctx)) diff --git a/mlir/test/python/ir/location.py b/mlir/test/python/ir/location.py index 3e54dc922cd67..33a4ffce48b9c 100644 --- a/mlir/test/python/ir/location.py +++ b/mlir/test/python/ir/location.py @@ -14,15 +14,22 @@ def run(f): # CHECK-LABEL: TEST: testUnknown def testUnknown(): with Context() as ctx: - loc = Location.unknown() + loc = UnknownLoc.get() assert loc.context is ctx ctx = None gc.collect() # CHECK: unknown str: loc(unknown) print("unknown str:", str(loc)) - # CHECK: unknown repr: loc(unknown) + # CHECK: unknown repr: UnknownLoc(loc(unknown)) print("unknown repr:", repr(loc)) + assert isinstance(loc, UnknownLoc) + assert isinstance(loc, Location) + assert not isinstance(loc, FileLineColLoc) + assert not isinstance(loc, NameLoc) + assert not isinstance(loc, CallSiteLoc) + assert not isinstance(loc, FusedLoc) + run(testUnknown) @@ -30,7 +37,7 @@ def testUnknown(): # CHECK-LABEL: TEST: testLocationAttr def testLocationAttr(): with Context() as ctxt: - loc = Location.unknown() + loc = UnknownLoc.get() attr = loc.attr clone = Location.from_attr(attr) gc.collect() @@ -39,6 +46,7 @@ def testLocationAttr(): # CHECK: clone: loc(unknown) print("clone:", str(clone)) assert loc == clone + assert isinstance(clone, UnknownLoc) run(testLocationAttr) @@ -47,25 +55,25 @@ def testLocationAttr(): # CHECK-LABEL: TEST: testFileLineCol def testFileLineCol(): with Context() as ctx: - loc = Location.file("foo1.txt", 123, 56) - range = Location.file("foo2.txt", 123, 56, 124, 100) + loc = FileLineColLoc.get("foo1.txt", 123, 56) + range = FileLineColLoc.get("foo2.txt", 123, 56, 124, 100) ctx = None gc.collect() # CHECK: file str: loc("foo1.txt":123:56) print("file str:", str(loc)) - # CHECK: file repr: loc("foo1.txt":123:56) + # CHECK: file repr: FileLineColLoc(loc("foo1.txt":123:56)) print("file repr:", repr(loc)) # CHECK: file range str: loc("foo2.txt":123:56 to 124:100) print("file range str:", str(range)) - # CHECK: file range repr: loc("foo2.txt":123:56 to 124:100) + # CHECK: file range repr: FileLineColLoc(loc("foo2.txt":123:56 to 124:100)) print("file range repr:", repr(range)) - assert loc.is_a_file() - assert not loc.is_a_name() - assert not loc.is_a_callsite() - assert not loc.is_a_fused() + assert isinstance(loc, FileLineColLoc) + assert not isinstance(loc, NameLoc) + assert not isinstance(loc, CallSiteLoc) + assert not isinstance(loc, FusedLoc) # CHECK: file filename: foo1.txt print("file filename:", loc.filename) @@ -78,7 +86,7 @@ def testFileLineCol(): # CHECK: file end_col: 56 print("file end_col:", loc.end_col) - assert range.is_a_file() + assert isinstance(range, FileLineColLoc) # CHECK: file filename: foo2.txt print("file filename:", range.filename) # CHECK: file start_line: 123 @@ -92,14 +100,20 @@ def testFileLineCol(): with Context() as ctx: ctx.allow_unregistered_dialects = True - loc = Location.file("foo3.txt", 127, 61) + loc = FileLineColLoc.get("foo3.txt", 127, 61) with loc: i32 = IntegerType.get_signless(32) module = Module.create() with InsertionPoint(module.body): - new_value = Operation.create("custom.op1", results=[i32]).result + op = Operation.create("custom.op1", results=[i32]) + new_value = op.result # CHECK: new_value location: loc("foo3.txt":127:61) print("new_value location: ", new_value.location) + # `op.location` and `value.location` both downcast to the + # concrete subclass. + assert isinstance(op.location, FileLineColLoc) + assert isinstance(new_value.location, FileLineColLoc) + assert op.location.typeid == FileLineColLoc.static_typeid run(testFileLineCol) @@ -108,32 +122,34 @@ def testFileLineCol(): # CHECK-LABEL: TEST: testName def testName(): with Context() as ctx: - loc = Location.name("nombre") - loc_with_child_loc = Location.name("naam", loc) + loc = NameLoc.get("nombre") + loc_with_child_loc = NameLoc.get("naam", loc) ctx = None gc.collect() # CHECK: name str: loc("nombre") print("name str:", str(loc)) - # CHECK: name repr: loc("nombre") + # CHECK: name repr: NameLoc(loc("nombre")) print("name repr:", repr(loc)) # CHECK: name str: loc("naam"("nombre")) print("name str:", str(loc_with_child_loc)) - # CHECK: name repr: loc("naam"("nombre")) + # CHECK: name repr: NameLoc(loc("naam"("nombre"))) print("name repr:", repr(loc_with_child_loc)) - assert loc.is_a_name() + assert isinstance(loc, NameLoc) # CHECK: name name_str: nombre print("name name_str:", loc.name_str) # CHECK: name child_loc: loc(unknown) print("name child_loc:", loc.child_loc) + assert isinstance(loc.child_loc, UnknownLoc) - assert loc_with_child_loc.is_a_name() + assert isinstance(loc_with_child_loc, NameLoc) # CHECK: name name_str: naam print("name name_str:", loc_with_child_loc.name_str) # CHECK: name child_loc_with_child_loc: loc("nombre") print("name child_loc_with_child_loc:", loc_with_child_loc.child_loc) + assert isinstance(loc_with_child_loc.child_loc, NameLoc) run(testName) @@ -142,22 +158,26 @@ def testName(): # CHECK-LABEL: TEST: testCallSite def testCallSite(): with Context() as ctx: - loc = Location.callsite( - Location.file("foo.text", 123, 45), - [Location.file("util.foo", 379, 21), Location.file("main.foo", 100, 63)], + loc = CallSiteLoc.get( + FileLineColLoc.get("foo.text", 123, 45), + [ + FileLineColLoc.get("util.foo", 379, 21), + FileLineColLoc.get("main.foo", 100, 63), + ], ) ctx = None # CHECK: callsite str: loc(callsite("foo.text":123:45 at callsite("util.foo":379:21 at "main.foo":100:63)) print("callsite str:", str(loc)) - # CHECK: callsite repr: loc(callsite("foo.text":123:45 at callsite("util.foo":379:21 at "main.foo":100:63)) + # CHECK: callsite repr: CallSiteLoc(loc(callsite("foo.text":123:45 at callsite("util.foo":379:21 at "main.foo":100:63))) print("callsite repr:", repr(loc)) - assert loc.is_a_callsite() - + assert isinstance(loc, CallSiteLoc) # CHECK: callsite callee: loc("foo.text":123:45) print("callsite callee:", loc.callee) + assert isinstance(loc.callee, FileLineColLoc) # CHECK: callsite caller: loc(callsite("util.foo":379:21 at "main.foo":100:63)) print("callsite caller:", loc.caller) + assert isinstance(loc.caller, CallSiteLoc) run(testCallSite) @@ -166,74 +186,103 @@ def testCallSite(): # CHECK-LABEL: TEST: testFused def testFused(): with Context() as ctx: - loc_single = Location.fused([Location.name("apple")]) - loc = Location.fused([Location.name("apple"), Location.name("banana")]) - attr = Attribute.parse('"sauteed"') - loc_attr = Location.fused( - [Location.name("carrot"), Location.name("potatoes")], attr - ) + loc_single = Location.fused([NameLoc.get("apple")]) loc_empty = Location.fused([]) - loc_empty_attr = Location.fused([], attr) - loc_single_attr = Location.fused([Location.name("apple")], attr) + loc = FusedLoc.get([NameLoc.get("apple"), NameLoc.get("banana")]) + attr = Attribute.parse('"sauteed"') + loc_attr = FusedLoc.get([NameLoc.get("carrot"), NameLoc.get("potatoes")], attr) + loc_empty_attr = FusedLoc.get([], attr) + loc_single_attr = FusedLoc.get([NameLoc.get("apple")], attr) + + try: + FusedLoc.get([NameLoc.get("x")]) + except ValueError as e: + # CHECK: fused strict error: FusedLoc.get would collapse + print("fused strict error:", str(e)[:35]) + else: + assert False, "expected ValueError from strict FusedLoc.get" ctx = None - assert not loc_single.is_a_fused() + assert not isinstance(loc_single, FusedLoc) + assert isinstance(loc_single, NameLoc) # CHECK: fused str: loc("apple") print("fused str:", str(loc_single)) - # CHECK: fused repr: loc("apple") + # CHECK: fused repr: NameLoc(loc("apple")) print("fused repr:", repr(loc_single)) - # # CHECK: fused locations: [] - print("fused locations:", loc_single.locations) - assert loc.is_a_fused() + assert isinstance(loc, FusedLoc) # CHECK: fused str: loc(fused["apple", "banana"]) print("fused str:", str(loc)) - # CHECK: fused repr: loc(fused["apple", "banana"]) + # CHECK: fused repr: FusedLoc(loc(fused["apple", "banana"])) print("fused repr:", repr(loc)) - # CHECK: fused locations: [loc("apple"), loc("banana")] + # CHECK: fused locations: [NameLoc(loc("apple")), NameLoc(loc("banana"))] print("fused locations:", loc.locations) + # CHECK: fused metadata: None + print("fused metadata:", loc.metadata) - assert loc_attr.is_a_fused() + assert isinstance(loc_attr, FusedLoc) + # CHECK: fused metadata: "sauteed" + print("fused metadata:", loc_attr.metadata) # CHECK: fused str: loc(fused<"sauteed">["carrot", "potatoes"]) print("fused str:", str(loc_attr)) - # CHECK: fused repr: loc(fused<"sauteed">["carrot", "potatoes"]) + # CHECK: fused repr: FusedLoc(loc(fused<"sauteed">["carrot", "potatoes"])) print("fused repr:", repr(loc_attr)) - # CHECK: fused locations: [loc("carrot"), loc("potatoes")] + # CHECK: fused locations: [NameLoc(loc("carrot")), NameLoc(loc("potatoes"))] print("fused locations:", loc_attr.locations) - assert not loc_empty.is_a_fused() + assert not isinstance(loc_empty, FusedLoc) + assert isinstance(loc_empty, UnknownLoc) # CHECK: fused str: loc(unknown) print("fused str:", str(loc_empty)) - # CHECK: fused repr: loc(unknown) + # CHECK: fused repr: UnknownLoc(loc(unknown)) print("fused repr:", repr(loc_empty)) - # CHECK: fused locations: [] - print("fused locations:", loc_empty.locations) - assert loc_empty_attr.is_a_fused() + assert isinstance(loc_empty_attr, FusedLoc) # CHECK: fused str: loc(fused<"sauteed">[unknown]) print("fused str:", str(loc_empty_attr)) - # CHECK: fused repr: loc(fused<"sauteed">[unknown]) + # CHECK: fused repr: FusedLoc(loc(fused<"sauteed">[unknown])) print("fused repr:", repr(loc_empty_attr)) - # CHECK: fused locations: [loc(unknown)] + # CHECK: fused locations: [UnknownLoc(loc(unknown))] print("fused locations:", loc_empty_attr.locations) - assert loc_single_attr.is_a_fused() + assert isinstance(loc_single_attr, FusedLoc) # CHECK: fused str: loc(fused<"sauteed">["apple"]) print("fused str:", str(loc_single_attr)) - # CHECK: fused repr: loc(fused<"sauteed">["apple"]) + # CHECK: fused repr: FusedLoc(loc(fused<"sauteed">["apple"])) print("fused repr:", repr(loc_single_attr)) - # CHECK: fused locations: [loc("apple")] + # CHECK: fused locations: [NameLoc(loc("apple"))] print("fused locations:", loc_single_attr.locations) run(testFused) +# CHECK-LABEL: TEST: testCast +def testCast(): + with Context() as ctx: + unknown = UnknownLoc.get() + as_unknown = UnknownLoc(unknown) + assert isinstance(as_unknown, UnknownLoc) + + try: + FileLineColLoc(unknown) + except ValueError as e: + # CHECK: cast error: Cannot cast location to FileLineColLoc (from loc(unknown)) + print("cast error:", str(e)) + else: + assert False, "expected ValueError" + + ctx = None + + +run(testCast) + + # CHECK-LABEL: TEST: testLocationCapsule def testLocationCapsule(): with Context() as ctx: - loc1 = Location.file("foo.txt", 123, 56) + loc1 = FileLineColLoc.get("foo.txt", 123, 56) # CHECK: mlir.ir.Location._CAPIPtr loc_capsule = loc1._CAPIPtr print(loc_capsule)