Skip to content

Commit a3594cd

Browse files
authored
[MLIR][Python] fixup Context and Location stubs and NanobindAdaptors (#161433)
add correct names for `NB_TYPE_CASTER(..., name)` so users of `NanobindAdaptors.h` can generate the correct hints. Also fix a few straggler stubs.
1 parent 197e77b commit a3594cd

File tree

5 files changed

+50
-62
lines changed

5 files changed

+50
-62
lines changed

mlir/include/mlir/Bindings/Python/NanobindAdaptors.h

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ mlirApiObjectToCapsule(nanobind::handle apiObject) {
116116
/// Casts object <-> MlirAffineMap.
117117
template <>
118118
struct type_caster<MlirAffineMap> {
119-
NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap"))
119+
NB_TYPE_CASTER(MlirAffineMap,
120+
const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.AffineMap")))
120121
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
121122
if (auto capsule = mlirApiObjectToCapsule(src)) {
122123
value = mlirPythonCapsuleToAffineMap(capsule->ptr());
@@ -138,7 +139,8 @@ struct type_caster<MlirAffineMap> {
138139
/// Casts object <-> MlirAttribute.
139140
template <>
140141
struct type_caster<MlirAttribute> {
141-
NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute"))
142+
NB_TYPE_CASTER(MlirAttribute,
143+
const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute")))
142144
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
143145
if (auto capsule = mlirApiObjectToCapsule(src)) {
144146
value = mlirPythonCapsuleToAttribute(capsule->ptr());
@@ -161,7 +163,7 @@ struct type_caster<MlirAttribute> {
161163
/// Casts object -> MlirBlock.
162164
template <>
163165
struct type_caster<MlirBlock> {
164-
NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock"))
166+
NB_TYPE_CASTER(MlirBlock, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Block")))
165167
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
166168
if (auto capsule = mlirApiObjectToCapsule(src)) {
167169
value = mlirPythonCapsuleToBlock(capsule->ptr());
@@ -174,7 +176,8 @@ struct type_caster<MlirBlock> {
174176
/// Casts object -> MlirContext.
175177
template <>
176178
struct type_caster<MlirContext> {
177-
NB_TYPE_CASTER(MlirContext, const_name("MlirContext"))
179+
NB_TYPE_CASTER(MlirContext,
180+
const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Context")))
178181
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
179182
if (src.is_none()) {
180183
// Gets the current thread-bound context.
@@ -192,7 +195,8 @@ struct type_caster<MlirContext> {
192195
/// Casts object <-> MlirDialectRegistry.
193196
template <>
194197
struct type_caster<MlirDialectRegistry> {
195-
NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry"))
198+
NB_TYPE_CASTER(MlirDialectRegistry,
199+
const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.DialectRegistry")))
196200
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
197201
if (auto capsule = mlirApiObjectToCapsule(src)) {
198202
value = mlirPythonCapsuleToDialectRegistry(capsule->ptr());
@@ -214,7 +218,8 @@ struct type_caster<MlirDialectRegistry> {
214218
/// Casts object <-> MlirLocation.
215219
template <>
216220
struct type_caster<MlirLocation> {
217-
NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation"))
221+
NB_TYPE_CASTER(MlirLocation,
222+
const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Location")))
218223
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
219224
if (src.is_none()) {
220225
// Gets the current thread-bound context.
@@ -240,7 +245,7 @@ struct type_caster<MlirLocation> {
240245
/// Casts object <-> MlirModule.
241246
template <>
242247
struct type_caster<MlirModule> {
243-
NB_TYPE_CASTER(MlirModule, const_name("MlirModule"))
248+
NB_TYPE_CASTER(MlirModule, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Module")))
244249
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
245250
if (auto capsule = mlirApiObjectToCapsule(src)) {
246251
value = mlirPythonCapsuleToModule(capsule->ptr());
@@ -262,8 +267,9 @@ struct type_caster<MlirModule> {
262267
/// Casts object <-> MlirFrozenRewritePatternSet.
263268
template <>
264269
struct type_caster<MlirFrozenRewritePatternSet> {
265-
NB_TYPE_CASTER(MlirFrozenRewritePatternSet,
266-
const_name("MlirFrozenRewritePatternSet"))
270+
NB_TYPE_CASTER(
271+
MlirFrozenRewritePatternSet,
272+
const_name(MAKE_MLIR_PYTHON_QUALNAME("rewrite.FrozenRewritePatternSet")))
267273
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
268274
if (auto capsule = mlirApiObjectToCapsule(src)) {
269275
value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule->ptr());
@@ -285,7 +291,8 @@ struct type_caster<MlirFrozenRewritePatternSet> {
285291
/// Casts object <-> MlirOperation.
286292
template <>
287293
struct type_caster<MlirOperation> {
288-
NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation"))
294+
NB_TYPE_CASTER(MlirOperation,
295+
const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Operation")))
289296
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
290297
if (auto capsule = mlirApiObjectToCapsule(src)) {
291298
value = mlirPythonCapsuleToOperation(capsule->ptr());
@@ -309,7 +316,7 @@ struct type_caster<MlirOperation> {
309316
/// Casts object <-> MlirValue.
310317
template <>
311318
struct type_caster<MlirValue> {
312-
NB_TYPE_CASTER(MlirValue, const_name("MlirValue"))
319+
NB_TYPE_CASTER(MlirValue, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Value")))
313320
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
314321
if (auto capsule = mlirApiObjectToCapsule(src)) {
315322
value = mlirPythonCapsuleToValue(capsule->ptr());
@@ -334,7 +341,8 @@ struct type_caster<MlirValue> {
334341
/// Casts object -> MlirPassManager.
335342
template <>
336343
struct type_caster<MlirPassManager> {
337-
NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager"))
344+
NB_TYPE_CASTER(MlirPassManager, const_name(MAKE_MLIR_PYTHON_QUALNAME(
345+
"passmanager.PassManager")))
338346
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
339347
if (auto capsule = mlirApiObjectToCapsule(src)) {
340348
value = mlirPythonCapsuleToPassManager(capsule->ptr());
@@ -347,7 +355,7 @@ struct type_caster<MlirPassManager> {
347355
/// Casts object <-> MlirTypeID.
348356
template <>
349357
struct type_caster<MlirTypeID> {
350-
NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID"))
358+
NB_TYPE_CASTER(MlirTypeID, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID")))
351359
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
352360
if (auto capsule = mlirApiObjectToCapsule(src)) {
353361
value = mlirPythonCapsuleToTypeID(capsule->ptr());
@@ -371,7 +379,7 @@ struct type_caster<MlirTypeID> {
371379
/// Casts object <-> MlirType.
372380
template <>
373381
struct type_caster<MlirType> {
374-
NB_TYPE_CASTER(MlirType, const_name("MlirType"))
382+
NB_TYPE_CASTER(MlirType, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Type")))
375383
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
376384
if (auto capsule = mlirApiObjectToCapsule(src)) {
377385
value = mlirPythonCapsuleToType(capsule->ptr());
@@ -394,7 +402,7 @@ struct type_caster<MlirType> {
394402
/// Casts MlirStringRef -> object.
395403
template <>
396404
struct type_caster<MlirStringRef> {
397-
NB_TYPE_CASTER(MlirStringRef, const_name("MlirStringRef"))
405+
NB_TYPE_CASTER(MlirStringRef, const_name("str"))
398406
static handle from_cpp(MlirStringRef s, rv_policy,
399407
cleanup_list *cleanup) noexcept {
400408
return nanobind::str(s.data, s.length).release();

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3219,13 +3219,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
32193219
nb::arg("end_line"), nb::arg("end_col"),
32203220
nb::arg("context") = nb::none(), kContextGetFileRangeDocstring)
32213221
.def("is_a_file", mlirLocationIsAFileLineColRange)
3222-
.def_prop_ro(
3223-
"filename",
3224-
[](MlirLocation loc) {
3225-
return mlirIdentifierStr(
3226-
mlirLocationFileLineColRangeGetFilename(loc));
3227-
},
3228-
nb::sig("def filename(self) -> str"))
3222+
.def_prop_ro("filename",
3223+
[](MlirLocation loc) {
3224+
return mlirIdentifierStr(
3225+
mlirLocationFileLineColRangeGetFilename(loc));
3226+
})
32293227
.def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine)
32303228
.def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn)
32313229
.def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine)
@@ -3274,12 +3272,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
32743272
nb::arg("name"), nb::arg("childLoc") = nb::none(),
32753273
nb::arg("context") = nb::none(), kContextGetNameLocationDocString)
32763274
.def("is_a_name", mlirLocationIsAName)
3277-
.def_prop_ro(
3278-
"name_str",
3279-
[](MlirLocation loc) {
3280-
return mlirIdentifierStr(mlirLocationNameGetName(loc));
3281-
},
3282-
nb::sig("def name_str(self) -> str"))
3275+
.def_prop_ro("name_str",
3276+
[](MlirLocation loc) {
3277+
return mlirIdentifierStr(mlirLocationNameGetName(loc));
3278+
})
32833279
.def_prop_ro("child_loc",
32843280
[](PyLocation &self) {
32853281
return PyLocation(self.getContext(),
@@ -3453,15 +3449,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34533449
return concreteOperation.getContext().getObject();
34543450
},
34553451
"Context that owns the Operation")
3456-
.def_prop_ro(
3457-
"name",
3458-
[](PyOperationBase &self) {
3459-
auto &concreteOperation = self.getOperation();
3460-
concreteOperation.checkValid();
3461-
MlirOperation operation = concreteOperation.get();
3462-
return mlirIdentifierStr(mlirOperationGetName(operation));
3463-
},
3464-
nb::sig("def name(self) -> str"))
3452+
.def_prop_ro("name",
3453+
[](PyOperationBase &self) {
3454+
auto &concreteOperation = self.getOperation();
3455+
concreteOperation.checkValid();
3456+
MlirOperation operation = concreteOperation.get();
3457+
return mlirIdentifierStr(mlirOperationGetName(operation));
3458+
})
34653459
.def_prop_ro("operands",
34663460
[](PyOperationBase &self) {
34673461
return PyOpOperandList(self.getOperation().getRef());
@@ -3603,12 +3597,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
36033597
},
36043598
"Reports if the operation is attached to its parent block.")
36053599
.def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
3606-
.def(
3607-
"walk", &PyOperationBase::walk, nb::arg("callback"),
3608-
nb::arg("walk_order") = MlirWalkPostOrder,
3609-
// clang-format off
3610-
nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder = " MAKE_MLIR_PYTHON_QUALNAME("ir.WalkOrder.POST_ORDER") ") -> None")
3611-
// clang-format on
3600+
.def("walk", &PyOperationBase::walk, nb::arg("callback"),
3601+
nb::arg("walk_order") = MlirWalkPostOrder,
3602+
// clang-format off
3603+
nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None")
3604+
// clang-format on
36123605
);
36133606

36143607
nb::class_<PyOperation, PyOperationBase>(m, "Operation")
@@ -4124,7 +4117,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
41244117
[](PyNamedAttribute &self) {
41254118
return mlirIdentifierStr(self.namedAttr.name);
41264119
},
4127-
nb::sig("def name(self) -> str"),
41284120
"The name of the NamedAttribute binding")
41294121
.def_prop_ro(
41304122
"attr",
@@ -4342,17 +4334,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
43424334
kValueReplaceAllUsesWithDocstring)
43434335
.def(
43444336
"replace_all_uses_except",
4345-
[](MlirValue self, MlirValue with, PyOperation &exception) {
4337+
[](PyValue &self, PyValue &with, PyOperation &exception) {
43464338
MlirOperation exceptedUser = exception.get();
43474339
mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
43484340
},
43494341
nb::arg("with_"), nb::arg("exceptions"),
4350-
nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: "
4351-
"Operation) -> None"),
43524342
kValueReplaceAllUsesExceptDocstring)
43534343
.def(
43544344
"replace_all_uses_except",
4355-
[](MlirValue self, MlirValue with, nb::list exceptions) {
4345+
[](PyValue &self, PyValue &with, const nb::list &exceptions) {
43564346
// Convert Python list to a SmallVector of MlirOperations
43574347
llvm::SmallVector<MlirOperation> exceptionOps;
43584348
for (nb::handle exception : exceptions) {
@@ -4364,8 +4354,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
43644354
exceptionOps.data());
43654355
},
43664356
nb::arg("with_"), nb::arg("exceptions"),
4367-
nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: "
4368-
"Sequence[Operation]) -> None"),
43694357
kValueReplaceAllUsesExceptDocstring)
43704358
.def(
43714359
"replace_all_uses_except",

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,7 @@ class DefaultingPyMlirContext
273273
: public Defaulting<DefaultingPyMlirContext, PyMlirContext> {
274274
public:
275275
using Defaulting::Defaulting;
276-
static constexpr const char kTypeDescription[] =
277-
MAKE_MLIR_PYTHON_QUALNAME("ir.Context");
276+
static constexpr const char kTypeDescription[] = "Context";
278277
static PyMlirContext &resolve();
279278
};
280279

@@ -500,8 +499,7 @@ class DefaultingPyLocation
500499
: public Defaulting<DefaultingPyLocation, PyLocation> {
501500
public:
502501
using Defaulting::Defaulting;
503-
static constexpr const char kTypeDescription[] =
504-
MAKE_MLIR_PYTHON_QUALNAME("ir.Location");
502+
static constexpr const char kTypeDescription[] = "Location";
505503
static PyLocation &resolve();
506504

507505
operator MlirLocation() const { return *get(); }

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,7 +1010,7 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
10101010
},
10111011
nb::arg("elements"), nb::arg("context") = nb::none(),
10121012
// clang-format off
1013-
nb::sig("def get_tuple(elements: Sequence[Type], context: mlir.ir.Context | None = None) -> TupleType"),
1013+
nb::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"),
10141014
// clang-format on
10151015
"Create a tuple type");
10161016
c.def(
@@ -1070,7 +1070,7 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
10701070
},
10711071
nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
10721072
// clang-format off
1073-
nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: mlir.ir.Context | None = None) -> FunctionType"),
1073+
nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"),
10741074
// clang-format on
10751075
"Gets a FunctionType from a list of input and result types");
10761076
c.def_prop_ro(

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,6 @@ NB_MODULE(_mlir, m) {
115115
});
116116
},
117117
"typeid"_a, nb::kw_only(), "replace"_a = false,
118-
// clang-format off
119-
nb::sig("def register_type_caster(typeid: " MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID") ", *, replace: bool = False) -> object"),
120-
// clang-format on
121118
"Register a type caster for casting MLIR types to custom user types.");
122119
m.def(
123120
MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
@@ -130,9 +127,6 @@ NB_MODULE(_mlir, m) {
130127
});
131128
},
132129
"typeid"_a, nb::kw_only(), "replace"_a = false,
133-
// clang-format off
134-
nb::sig("def register_value_caster(typeid: " MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID") ", *, replace: bool = False) -> object"),
135-
// clang-format on
136130
"Register a value caster for casting MLIR values to custom user values.");
137131

138132
// Define and populate IR submodule.

0 commit comments

Comments
 (0)