Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/lib/Bindings/Python/IRAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {

PyArrayAttributeIterator &dunderIter() { return *this; }

nb::typed<nb::object, PyAttribute> dunderNext() {
nb::object dunderNext() {
// TODO: Throw is an inefficient way to stop iteration.
if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
throw nb::stop_iteration();
Expand Down
26 changes: 12 additions & 14 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ class PyOperationIterator {

PyOperationIterator &dunderIter() { return *this; }

nb::typed<nb::object, PyOpView> dunderNext() {
nb::object dunderNext() {
parentOperation->checkValid();
if (mlirOperationIsNull(next)) {
throw nb::stop_iteration();
Expand Down Expand Up @@ -562,7 +562,7 @@ class PyOperationList {
return count;
}

nb::typed<nb::object, PyOpView> dunderGetItem(intptr_t index) {
nb::object dunderGetItem(intptr_t index) {
parentOperation->checkValid();
if (index < 0) {
index += dunderLen();
Expand Down Expand Up @@ -1534,7 +1534,7 @@ nb::object PyOperation::create(std::string_view name,
return created.getObject();
}

nb::typed<nb::object, PyOpView> PyOperation::clone(const nb::object &maybeIp) {
nb::object PyOperation::clone(const nb::object &maybeIp) {
MlirOperation clonedOperation = mlirOperationClone(operation);
PyOperationRef cloned =
PyOperation::createDetached(getContext(), clonedOperation);
Expand All @@ -1543,7 +1543,7 @@ nb::typed<nb::object, PyOpView> PyOperation::clone(const nb::object &maybeIp) {
return cloned->createOpView();
}

nb::typed<nb::object, PyOpView> PyOperation::createOpView() {
nb::object PyOperation::createOpView() {
checkValid();
MlirIdentifier ident = mlirOperationGetName(get());
MlirStringRef identStr = mlirIdentifierStr(ident);
Expand Down Expand Up @@ -1638,9 +1638,9 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {

/// Returns the list of types of the values held by container.
template <typename Container>
static std::vector<nb::typed<nb::object, PyType>>
getValueTypes(Container &container, PyMlirContextRef &context) {
std::vector<nb::typed<nb::object, PyType>> result;
static std::vector<nb::object> getValueTypes(Container &container,
PyMlirContextRef &context) {
std::vector<nb::object> result;
result.reserve(container.size());
for (int i = 0, e = container.size(); i < e; ++i) {
result.push_back(PyType(context->getRef(),
Expand Down Expand Up @@ -2133,7 +2133,7 @@ PyAttribute PyAttribute::createFromCapsule(nb::object capsule) {
PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
}

nb::typed<nb::object, PyAttribute> PyAttribute::maybeDownCast() {
nb::object PyAttribute::maybeDownCast() {
MlirTypeID mlirTypeID = mlirAttributeGetTypeID(this->get());
assert(!mlirTypeIDIsNull(mlirTypeID) &&
"mlirTypeID was expected to be non-null.");
Expand Down Expand Up @@ -2179,7 +2179,7 @@ PyType PyType::createFromCapsule(nb::object capsule) {
rawType);
}

nb::typed<nb::object, PyType> PyType::maybeDownCast() {
nb::object PyType::maybeDownCast() {
MlirTypeID mlirTypeID = mlirTypeGetTypeID(this->get());
assert(!mlirTypeIDIsNull(mlirTypeID) &&
"mlirTypeID was expected to be non-null.");
Expand Down Expand Up @@ -2219,7 +2219,7 @@ nb::object PyValue::getCapsule() {
return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
}

nanobind::typed<nanobind::object, PyValue> PyValue::maybeDownCast() {
nb::object PyValue::maybeDownCast() {
MlirType type = mlirValueGetType(get());
MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
assert(!mlirTypeIDIsNull(mlirTypeID) &&
Expand Down Expand Up @@ -2263,8 +2263,7 @@ PySymbolTable::PySymbolTable(PyOperationBase &operation)
}
}

nb::typed<nb::object, PyOpView>
PySymbolTable::dunderGetItem(const std::string &name) {
nb::object PySymbolTable::dunderGetItem(const std::string &name) {
operation->checkValid();
MlirOperation symbol = mlirSymbolTableLookup(
symbolTable, mlirStringRefCreate(name.data(), name.length()));
Expand Down Expand Up @@ -2678,8 +2677,7 @@ class PyOpAttributeMap {
PyOpAttributeMap(PyOperationRef operation)
: operation(std::move(operation)) {}

nb::typed<nb::object, PyAttribute>
dunderGetItemNamed(const std::string &name) {
nb::object dunderGetItemNamed(const std::string &name) {
MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
toMlirStringRef(name));
if (mlirAttributeIsNull(attr)) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Bindings/Python/IRInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class PyConcreteOpInterface {
/// Returns the opview of the operation instance from which this object was
/// constructed. Throws a type error if this object was constructed form a
/// subclass of OpView.
nb::typed<nb::object, PyOpView> getOpView() {
nb::object getOpView() {
if (operation == nullptr) {
throw nb::type_error("Cannot get an opview from a static interface");
}
Expand Down
19 changes: 10 additions & 9 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class PyObjectRef {
/// Releases the object held by this instance, returning it.
/// This is the proper thing to return from a function that wants to return
/// the reference. Note that this does not work from initializers.
nanobind::typed<nanobind::object, T> releaseObject() {
nanobind::object releaseObject() {
assert(referrent && object);
referrent = nullptr;
auto stolen = std::move(object);
Expand All @@ -88,12 +88,14 @@ class PyObjectRef {
assert(referrent && object);
return referrent;
}
nanobind::typed<nanobind::object, T> getObject() {
nanobind::object getObject() {
assert(referrent && object);
return object;
}
operator bool() const { return referrent && object; }

using NBTypedT = nanobind::typed<nanobind::object, T>;

private:
T *referrent;
nanobind::object object;
Expand Down Expand Up @@ -680,7 +682,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
PyLocation &location, const nanobind::object &ip, bool inferType);

/// Creates an OpView suitable for this operation.
nanobind::typed<nanobind::object, PyOpView> createOpView();
nanobind::object createOpView();

/// Erases the underlying MlirOperation, removes its pointer from the
/// parent context's live operations map, and sets the valid bit false.
Expand All @@ -690,7 +692,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
void setInvalid() { valid = false; }

/// Clones this operation.
nanobind::typed<nanobind::object, PyOpView> clone(const nanobind::object &ip);
nanobind::object clone(const nanobind::object &ip);

PyOperation(PyMlirContextRef contextRef, MlirOperation operation);

Expand Down Expand Up @@ -890,7 +892,7 @@ class PyType : public BaseContextObject {
/// is taken by calling this function.
static PyType createFromCapsule(nanobind::object capsule);

nanobind::typed<nanobind::object, PyType> maybeDownCast();
nanobind::object maybeDownCast();

private:
MlirType type;
Expand Down Expand Up @@ -1020,7 +1022,7 @@ class PyAttribute : public BaseContextObject {
/// is taken by calling this function.
static PyAttribute createFromCapsule(nanobind::object capsule);

nanobind::typed<nanobind::object, PyAttribute> maybeDownCast();
nanobind::object maybeDownCast();

private:
MlirAttribute attr;
Expand Down Expand Up @@ -1178,7 +1180,7 @@ class PyValue {
/// Gets a capsule wrapping the void* within the MlirValue.
nanobind::object getCapsule();

nanobind::typed<nanobind::object, PyValue> maybeDownCast();
nanobind::object maybeDownCast();

/// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
/// the underlying MlirValue is still tied to the owning operation.
Expand Down Expand Up @@ -1269,8 +1271,7 @@ class PySymbolTable {

/// Returns the symbol (opview) with the given name, throws if there is no
/// such symbol in the table.
nanobind::typed<nanobind::object, PyOpView>
dunderGetItem(const std::string &name);
nanobind::object dunderGetItem(const std::string &name);

/// Removes the given operation from the symbol table and erases it.
void erase(PyOperationBase &symbol);
Expand Down
15 changes: 9 additions & 6 deletions mlir/lib/Bindings/Python/IRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,8 @@ class PyRankedTensorType
MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
if (mlirAttributeIsNull(encoding))
return std::nullopt;
return PyAttribute(self.getContext(), encoding).maybeDownCast();
return nb::cast<nb::typed<nb::object, PyAttribute>>(
PyAttribute(self.getContext(), encoding).maybeDownCast());
});
}
};
Expand Down Expand Up @@ -793,9 +794,9 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
.def_prop_ro(
"layout",
[](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
return PyAttribute(self.getContext(),
mlirMemRefTypeGetLayout(self))
.maybeDownCast();
return nb::cast<nb::typed<nb::object, PyAttribute>>(
PyAttribute(self.getContext(), mlirMemRefTypeGetLayout(self))
.maybeDownCast());
},
"The layout of the MemRef type.")
.def(
Expand Down Expand Up @@ -824,7 +825,8 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
if (mlirAttributeIsNull(a))
return std::nullopt;
return PyAttribute(self.getContext(), a).maybeDownCast();
return nb::cast<nb::typed<nb::object, PyAttribute>>(
PyAttribute(self.getContext(), a).maybeDownCast());
},
"Returns the memory space of the given MemRef type.");
}
Expand Down Expand Up @@ -865,7 +867,8 @@ class PyUnrankedMemRefType
MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
if (mlirAttributeIsNull(a))
return std::nullopt;
return PyAttribute(self.getContext(), a).maybeDownCast();
return nb::cast<nb::typed<nb::object, PyAttribute>>(
PyAttribute(self.getContext(), a).maybeDownCast());
},
"Returns the memory space of the given Unranked MemRef type.");
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Bindings/Python/NanobindUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class Sliceable {
/// Returns the element at the given slice index. Supports negative indices
/// by taking elements in inverse order. Returns a nullptr object if out
/// of bounds.
nanobind::typed<nanobind::object, ElementTy> getItem(intptr_t index) {
nanobind::object getItem(intptr_t index) {
// Negative indices mean we count from the end.
index = wrapIndex(index);
if (index < 0) {
Expand Down
15 changes: 0 additions & 15 deletions utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,6 @@ filegroup(
]),
)

filegroup(
name = "IRPyIFiles",
srcs = [
"mlir/_mlir_libs/_mlir/__init__.pyi",
"mlir/_mlir_libs/_mlir/ir.pyi",
],
)

filegroup(
name = "MlirLibsPyFiles",
srcs = [
Expand All @@ -75,13 +67,6 @@ filegroup(
],
)

filegroup(
name = "PassManagerPyIFiles",
srcs = [
"mlir/_mlir_libs/_mlir/passmanager.pyi",
],
)

filegroup(
name = "RewritePyFiles",
srcs = [
Expand Down