-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Python] remove nb::typed to fix bazel build #160183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Python] remove nb::typed to fix bazel build #160183
Conversation
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) Changes#157930 broke bazel build (see #157930 (comment)) because bazel is stricter on implicit conversion (some difference in flags passed to clang). This PR fixes by moving/removing Full diff: https://github.com/llvm/llvm-project/pull/160183.diff 6 Files Affected:
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 7818caf2e8a55..212228fbac91e 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -485,7 +485,7 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
PyArrayAttributeIterator &dunderIter() { return *this; }
- nb::typed<nb::object, PyAttribute> dunderNext() {
+ nb::object dunderNext() {
// TODO: Throw is an inefficient way to stop iteration.
if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
throw nb::stop_iteration();
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 609502041f4ae..81386f2227a7f 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -513,7 +513,7 @@ class PyOperationIterator {
PyOperationIterator &dunderIter() { return *this; }
- nb::typed<nb::object, PyOpView> dunderNext() {
+ nb::object dunderNext() {
parentOperation->checkValid();
if (mlirOperationIsNull(next)) {
throw nb::stop_iteration();
@@ -562,7 +562,7 @@ class PyOperationList {
return count;
}
- nb::typed<nb::object, PyOpView> dunderGetItem(intptr_t index) {
+ nb::object dunderGetItem(intptr_t index) {
parentOperation->checkValid();
if (index < 0) {
index += dunderLen();
@@ -1534,7 +1534,7 @@ nb::object PyOperation::create(std::string_view name,
return created.getObject();
}
-nb::typed<nb::object, PyOpView> PyOperation::clone(const nb::object &maybeIp) {
+nb::object PyOperation::clone(const nb::object &maybeIp) {
MlirOperation clonedOperation = mlirOperationClone(operation);
PyOperationRef cloned =
PyOperation::createDetached(getContext(), clonedOperation);
@@ -1543,7 +1543,7 @@ nb::typed<nb::object, PyOpView> PyOperation::clone(const nb::object &maybeIp) {
return cloned->createOpView();
}
-nb::typed<nb::object, PyOpView> PyOperation::createOpView() {
+nb::object PyOperation::createOpView() {
checkValid();
MlirIdentifier ident = mlirOperationGetName(get());
MlirStringRef identStr = mlirIdentifierStr(ident);
@@ -1638,9 +1638,9 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
/// Returns the list of types of the values held by container.
template <typename Container>
-static std::vector<nb::typed<nb::object, PyType>>
-getValueTypes(Container &container, PyMlirContextRef &context) {
- std::vector<nb::typed<nb::object, PyType>> result;
+static std::vector<nb::object> getValueTypes(Container &container,
+ PyMlirContextRef &context) {
+ std::vector<nb::object> result;
result.reserve(container.size());
for (int i = 0, e = container.size(); i < e; ++i) {
result.push_back(PyType(context->getRef(),
@@ -2133,7 +2133,7 @@ PyAttribute PyAttribute::createFromCapsule(nb::object capsule) {
PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
}
-nb::typed<nb::object, PyAttribute> PyAttribute::maybeDownCast() {
+nb::object PyAttribute::maybeDownCast() {
MlirTypeID mlirTypeID = mlirAttributeGetTypeID(this->get());
assert(!mlirTypeIDIsNull(mlirTypeID) &&
"mlirTypeID was expected to be non-null.");
@@ -2179,7 +2179,7 @@ PyType PyType::createFromCapsule(nb::object capsule) {
rawType);
}
-nb::typed<nb::object, PyType> PyType::maybeDownCast() {
+nb::object PyType::maybeDownCast() {
MlirTypeID mlirTypeID = mlirTypeGetTypeID(this->get());
assert(!mlirTypeIDIsNull(mlirTypeID) &&
"mlirTypeID was expected to be non-null.");
@@ -2219,7 +2219,7 @@ nb::object PyValue::getCapsule() {
return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
}
-nanobind::typed<nanobind::object, PyValue> PyValue::maybeDownCast() {
+nb::object PyValue::maybeDownCast() {
MlirType type = mlirValueGetType(get());
MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
assert(!mlirTypeIDIsNull(mlirTypeID) &&
@@ -2263,8 +2263,7 @@ PySymbolTable::PySymbolTable(PyOperationBase &operation)
}
}
-nb::typed<nb::object, PyOpView>
-PySymbolTable::dunderGetItem(const std::string &name) {
+nb::object PySymbolTable::dunderGetItem(const std::string &name) {
operation->checkValid();
MlirOperation symbol = mlirSymbolTableLookup(
symbolTable, mlirStringRefCreate(name.data(), name.length()));
@@ -2678,8 +2677,7 @@ class PyOpAttributeMap {
PyOpAttributeMap(PyOperationRef operation)
: operation(std::move(operation)) {}
- nb::typed<nb::object, PyAttribute>
- dunderGetItemNamed(const std::string &name) {
+ nb::object dunderGetItemNamed(const std::string &name) {
MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
toMlirStringRef(name));
if (mlirAttributeIsNull(attr)) {
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 6c53289c5011e..44aad10ded082 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -223,7 +223,7 @@ class PyConcreteOpInterface {
/// Returns the opview of the operation instance from which this object was
/// constructed. Throws a type error if this object was constructed form a
/// subclass of OpView.
- nb::typed<nb::object, PyOpView> getOpView() {
+ nb::object getOpView() {
if (operation == nullptr) {
throw nb::type_error("Cannot get an opview from a static interface");
}
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 414f37cc97f2a..6e97c00d478f1 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -76,7 +76,7 @@ class PyObjectRef {
/// Releases the object held by this instance, returning it.
/// This is the proper thing to return from a function that wants to return
/// the reference. Note that this does not work from initializers.
- nanobind::typed<nanobind::object, T> releaseObject() {
+ nanobind::object releaseObject() {
assert(referrent && object);
referrent = nullptr;
auto stolen = std::move(object);
@@ -88,12 +88,14 @@ class PyObjectRef {
assert(referrent && object);
return referrent;
}
- nanobind::typed<nanobind::object, T> getObject() {
+ nanobind::object getObject() {
assert(referrent && object);
return object;
}
operator bool() const { return referrent && object; }
+ using NBTypedT = nanobind::typed<nanobind::object, T>;
+
private:
T *referrent;
nanobind::object object;
@@ -680,7 +682,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
PyLocation &location, const nanobind::object &ip, bool inferType);
/// Creates an OpView suitable for this operation.
- nanobind::typed<nanobind::object, PyOpView> createOpView();
+ nanobind::object createOpView();
/// Erases the underlying MlirOperation, removes its pointer from the
/// parent context's live operations map, and sets the valid bit false.
@@ -690,7 +692,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
void setInvalid() { valid = false; }
/// Clones this operation.
- nanobind::typed<nanobind::object, PyOpView> clone(const nanobind::object &ip);
+ nanobind::object clone(const nanobind::object &ip);
PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
@@ -890,7 +892,7 @@ class PyType : public BaseContextObject {
/// is taken by calling this function.
static PyType createFromCapsule(nanobind::object capsule);
- nanobind::typed<nanobind::object, PyType> maybeDownCast();
+ nanobind::object maybeDownCast();
private:
MlirType type;
@@ -1020,7 +1022,7 @@ class PyAttribute : public BaseContextObject {
/// is taken by calling this function.
static PyAttribute createFromCapsule(nanobind::object capsule);
- nanobind::typed<nanobind::object, PyAttribute> maybeDownCast();
+ nanobind::object maybeDownCast();
private:
MlirAttribute attr;
@@ -1178,7 +1180,7 @@ class PyValue {
/// Gets a capsule wrapping the void* within the MlirValue.
nanobind::object getCapsule();
- nanobind::typed<nanobind::object, PyValue> maybeDownCast();
+ nanobind::object maybeDownCast();
/// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
/// the underlying MlirValue is still tied to the owning operation.
@@ -1269,8 +1271,7 @@ class PySymbolTable {
/// Returns the symbol (opview) with the given name, throws if there is no
/// such symbol in the table.
- nanobind::typed<nanobind::object, PyOpView>
- dunderGetItem(const std::string &name);
+ nanobind::object dunderGetItem(const std::string &name);
/// Removes the given operation from the symbol table and erases it.
void erase(PyOperationBase &symbol);
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 09ef64d4e0baf..a7aa1c65c6c43 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -731,7 +731,8 @@ class PyRankedTensorType
MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
if (mlirAttributeIsNull(encoding))
return std::nullopt;
- return PyAttribute(self.getContext(), encoding).maybeDownCast();
+ return nb::cast<nb::typed<nb::object, PyAttribute>>(
+ PyAttribute(self.getContext(), encoding).maybeDownCast());
});
}
};
@@ -793,9 +794,9 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
.def_prop_ro(
"layout",
[](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
- return PyAttribute(self.getContext(),
- mlirMemRefTypeGetLayout(self))
- .maybeDownCast();
+ return nb::cast<nb::typed<nb::object, PyAttribute>>(
+ PyAttribute(self.getContext(), mlirMemRefTypeGetLayout(self))
+ .maybeDownCast());
},
"The layout of the MemRef type.")
.def(
@@ -824,7 +825,8 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
if (mlirAttributeIsNull(a))
return std::nullopt;
- return PyAttribute(self.getContext(), a).maybeDownCast();
+ return nb::cast<nb::typed<nb::object, PyAttribute>>(
+ PyAttribute(self.getContext(), a).maybeDownCast());
},
"Returns the memory space of the given MemRef type.");
}
@@ -865,7 +867,8 @@ class PyUnrankedMemRefType
MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
if (mlirAttributeIsNull(a))
return std::nullopt;
- return PyAttribute(self.getContext(), a).maybeDownCast();
+ return nb::cast<nb::typed<nb::object, PyAttribute>>(
+ PyAttribute(self.getContext(), a).maybeDownCast());
},
"Returns the memory space of the given Unranked MemRef type.");
}
diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h
index 40b3215f6f5fe..64ea4329f65f1 100644
--- a/mlir/lib/Bindings/Python/NanobindUtils.h
+++ b/mlir/lib/Bindings/Python/NanobindUtils.h
@@ -276,7 +276,7 @@ class Sliceable {
/// Returns the element at the given slice index. Supports negative indices
/// by taking elements in inverse order. Returns a nullptr object if out
/// of bounds.
- nanobind::typed<nanobind::object, ElementTy> getItem(intptr_t index) {
+ nanobind::object getItem(intptr_t index) {
// Negative indices mean we count from the end.
index = wrapIndex(index);
if (index < 0) {
|
f2983ad
to
b2ae390
Compare
I've tested locally but it would be good if someone from JAX/G could test as well... |
I'm wondering if this was due to the old nanobind version in the Bazel config. Also, I'm seeing potentially a similar error in Modular's codebase, when updating this function fails: https://github.com/modular/modular/blob/d244d85e66baba96bb69f08f380ef65defca20e7/max/graph/graph.py#L856 on the specified line. Internal error, but I get
What's happening is the types passed to the second parameter aren't the correct types now, with or without this commit cherry-picked. Was there any functional change to |
Ugh you nailed it:
I changed it from [](std::vector<MlirType> inputs, std::vector<MlirType> results,
DefaultingPyMlirContext context to [](std::vector<PyType> inputs, std::vector<PyType> results,
DefaultingPyMlirContext context |
I'm guessing that |
Yea we have sort of two sets of bindings, we're (very) slowly moving to our in-house version. The signature of |
there's a way to do it by digging out the CAPIPtr (which actually wraps the opaque C structs) but it's easier if we just restore the old API; try this PR #160194 |
Thanks, I'll give that a shot |
That fixes that test! I have a whole slew of other errors but they may be unrelated, going to try isolating the nanobind changes. |
try this one, which is clean #160203 |
Yup, the other issues went away! Makes sense for a change, basically just restore the old overloads while keeping the new ones. I'll run the full test suite and let you know if anything else comes up. |
feel free to drop a comment on #160203 |
As for this PR, I personally liked the extra typing, and if it's only necessary because of the lower nanobind bazel version (which should be fixed on main), can we potentially revert this one? I can test locally to see if things still work. |
It's fine to leave this one and just update with |
Ah I see, yea. Minor issue, I can follow up with that change later. |
I seem to have hit a bug in nanobind(?) that I've seen in our codebase before but have no idea how to repro 🙃
Unfortunately I don't have much more time to debug today, I think for now your change is good, it's at least forward progress, and it very well could be a bug on our end. |
#157930 changed a few APIs from `Mlir*` to `Py*` and broke users that were using them (see #160183 (comment)). This PR restores those APIs.
llvm/llvm-project#157930 changed a few APIs from `Mlir*` to `Py*` and broke users that were using them (see llvm/llvm-project#160183 (comment)). This PR restores those APIs.
llvm/llvm-project#157930 changed a few APIs from `Mlir*` to `Py*` and broke users that were using them (see llvm/llvm-project#160183 (comment)). This PR restores those APIs.
#160183 removed `nb::typed` annotation to fix bazel but it turned out to be simply a matter of not using the correct version of nanobind (see #160183 (comment)). This PR restores those annotations but (mostly) moves to the return positions of the actual methods.
llvm/llvm-project#160183 removed `nb::typed` annotation to fix bazel but it turned out to be simply a matter of not using the correct version of nanobind (see llvm/llvm-project#160183 (comment)). This PR restores those annotations but (mostly) moves to the return positions of the actual methods.
llvm/llvm-project#160183 removed `nb::typed` annotation to fix bazel but it turned out to be simply a matter of not using the correct version of nanobind (see llvm/llvm-project#160183 (comment)). This PR restores those annotations but (mostly) moves to the return positions of the actual methods.
Just wanted to reply to close the loop, the issue was in fact on our end, all good here :) |
…8738) LLVM has also moved to generated stubfiles, so we need to generate these ourselves (we could omit them, but they're nice to have). See llvm/llvm-project#157930. Also pulls in the following followup fixes, these should be removed when bumping again. llvm/llvm-project#160183 llvm/llvm-project#160203 llvm/llvm-project#160221 This fixes some mypy lint errors, and causes a few more, I fixed a few but mostly just ignored them. MODULAR_ORIG_COMMIT_REV_ID: 524aaf2ab047e5185703c44ab3edd7754c67fa26
#157930 broke bazel build (see #157930 (comment)) because bazel is stricter on implicit conversions (some difference in flags passed to clang). This PR fixes by moving/removing
nb::typed
.EDIT: and also the overlay...