diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index cda4fe19c16f8..c21240926ddac 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -18,6 +18,7 @@ #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" #include "nanobind/nanobind.h" +#include "nanobind/typing.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -1596,7 +1597,11 @@ class PyConcreteValue : public PyValue { /// Binds the Python module objects to functions of this class. static void bind(nb::module_ &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); + auto cls = ClassTy( + m, DerivedTy::pyClassName, nb::is_generic(), + nb::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])") + .str() + .c_str())); cls.def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")); cls.def_static( "isinstance", @@ -4278,7 +4283,10 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of Value. //---------------------------------------------------------------------------- - nb::class_(m, "Value") + m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type")); + + nb::class_(m, "Value", nb::is_generic(), + nb::sig("class Value(Generic[_T])")) .def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")) .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) @@ -4371,18 +4379,20 @@ void mlir::python::populateIRCore(nb::module_ &m) { return printAccum.join(); }, nb::arg("state"), kGetNameAsOperand) - .def_prop_ro("type", - [](PyValue &self) -> nb::typed { - return PyType(self.getParentOperation()->getContext(), - mlirValueGetType(self.get())) - .maybeDownCast(); - }) + .def_prop_ro( + "type", + [](PyValue &self) { + return PyType(self.getParentOperation()->getContext(), + mlirValueGetType(self.get())) + .maybeDownCast(); + }, + nb::sig("def type(self) -> _T")) .def( "set_type", [](PyValue &self, const PyType &type) { return mlirValueSetType(self.get(), type); }, - nb::arg("type")) + nb::arg("type"), nb::sig("def set_type(self, type: _T)")) .def( "replace_all_uses_with", [](PyValue &self, PyValue &with) { diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index 42de7e4eda573..ff16ad8ca0cdd 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -350,16 +350,16 @@ def MissingNamesOp : TestOp<"missing_names"> { // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip) // CHECK: @builtins.property - // CHECK: def f32(self) -> _ods_ir.Value: + // CHECK: def f32(self) -> _ods_ir.Value[_ods_ir.FloatType]: // CHECK: return self.operation.operands[1] let arguments = (ins I32, F32:$f32, I64); // CHECK: @builtins.property - // CHECK: def i32(self) -> _ods_ir.OpResult: + // CHECK: def i32(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]: // CHECK: return self.operation.results[0] // // CHECK: @builtins.property - // CHECK: def i64(self) -> _ods_ir.OpResult: + // CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]: // CHECK: return self.operation.results[2] let results = (outs I32:$i32, AnyFloat, I64:$i64); } @@ -590,20 +590,20 @@ def SimpleOp : TestOp<"simple"> { // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip) // CHECK: @builtins.property - // CHECK: def i32(self) -> _ods_ir.Value: + // CHECK: def i32(self) -> _ods_ir.Value[_ods_ir.IntegerType]: // CHECK: return self.operation.operands[0] // // CHECK: @builtins.property - // CHECK: def f32(self) -> _ods_ir.Value: + // CHECK: def f32(self) -> _ods_ir.Value[_ods_ir.FloatType]: // CHECK: return self.operation.operands[1] let arguments = (ins I32:$i32, F32:$f32); // CHECK: @builtins.property - // CHECK: def i64(self) -> _ods_ir.OpResult: + // CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]: // CHECK: return self.operation.results[0] // // CHECK: @builtins.property - // CHECK: def f64(self) -> _ods_ir.OpResult: + // CHECK: def f64(self) -> _ods_ir.OpResult[_ods_ir.FloatType]: // CHECK: return self.operation.results[1] let results = (outs I64:$i64, AnyFloat:$f64); } diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index 1194e32c960c8..f0f74ebc12155 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -554,7 +554,7 @@ def testOptionalOperandOp(): ) assert ( typing.get_type_hints(test.OptionalOperandOp.result.fget)["return"] - is OpResult + == OpResult[IntegerType] ) assert type(op1.result) is OpResult @@ -662,6 +662,13 @@ def testCustomType(): raise +@run +# CHECK-LABEL: TEST: testValue +def testValue(): + # Check that Value is a generic class at runtime. + assert hasattr(Value, "__class_getitem__") + + @run # CHECK-LABEL: TEST: testTensorValue def testTensorValue(): diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 0172b3fa38a6b..f01563fd49d17 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -341,6 +341,22 @@ static std::string attrSizedTraitForKind(const char *kind) { StringRef(kind).drop_front()); } +static StringRef getPythonType(StringRef cppType) { + return llvm::StringSwitch(cppType) + .Case("::mlir::MemRefType", "_ods_ir.MemRefType") + .Case("::mlir::UnrankedMemRefType", "_ods_ir.UnrankedMemRefType") + .Case("::mlir::RankedTensorType", "_ods_ir.RankedTensorType") + .Case("::mlir::UnrankedTensorType", "_ods_ir.UnrankedTensorType") + .Case("::mlir::VectorType", "_ods_ir.VectorType") + .Case("::mlir::IntegerType", "_ods_ir.IntegerType") + .Case("::mlir::FloatType", "_ods_ir.FloatType") + .Case("::mlir::IndexType", "_ods_ir.IndexType") + .Case("::mlir::ComplexType", "_ods_ir.ComplexType") + .Case("::mlir::TupleType", "_ods_ir.TupleType") + .Case("::mlir::NoneType", "_ods_ir.NoneType") + .Default(StringRef()); +} + /// Emits accessors to "elements" of an Op definition. Currently, the supported /// elements are operands and results, indicated by `kind`, which must be either /// `operand` or `result` and is used verbatim in the emitted code. @@ -370,8 +386,11 @@ static void emitElementAccessors( seenVariableLength = true; if (element.name.empty()) continue; - const char *type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value" + std::string type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value" : "_ods_ir.OpResult"; + if (StringRef pythonType = getPythonType(element.constraint.getCppType()); + !pythonType.empty()) + type = llvm::formatv("{0}[{1}]", type, pythonType); if (element.isVariableLength()) { if (element.isOptional()) { os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind, @@ -418,6 +437,11 @@ static void emitElementAccessors( type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value" : "_ods_ir.OpResult"; } + if (std::strcmp(kind, "operand") == 0) { + StringRef pythonType = getPythonType(element.constraint.getCppType()); + if (!pythonType.empty()) + type += "[" + pythonType.str() + "]"; + } os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name), kind, numSimpleLength, numVariadicGroups, numPrecedingSimple, numPrecedingVariadic, type); @@ -449,6 +473,11 @@ static void emitElementAccessors( if (!element.isVariableLength() || element.isOptional()) { type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value" : "_ods_ir.OpResult"; + if (std::strcmp(kind, "operand") == 0) { + StringRef pythonType = getPythonType(element.constraint.getCppType()); + if (!pythonType.empty()) + type += "[" + pythonType.str() + "]"; + } if (!element.isVariableLength()) { trailing = "[0]"; } else if (element.isOptional()) {