From fce09402983f69c052d75779b38122c97f500b59 Mon Sep 17 00:00:00 2001 From: tn Date: Sat, 22 Nov 2025 09:11:12 +0100 Subject: [PATCH] [mlir][tblgen] Add PredTypeTrait/PredAttrTrait support for type/attribute verification --- mlir/lib/TableGen/AttrOrTypeDef.cpp | 8 ++-- mlir/test/IR/test-verifiers-type.mlir | 48 +++++++++++++++++++++ mlir/test/lib/Dialect/Test/TestTypeDefs.td | 30 +++++++++++++ mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 36 +++++++++++++--- 4 files changed, 113 insertions(+), 9 deletions(-) diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp index 4659265e24bda..bf835a860cd5b 100644 --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -190,9 +190,11 @@ bool AttrOrTypeDef::genVerifyDecl() const { } bool AttrOrTypeDef::genVerifyInvariantsImpl() const { - return any_of(parameters, [](const AttrOrTypeParameter &p) { - return p.getConstraint() != std::nullopt; - }); + return any_of(parameters, + [](const AttrOrTypeParameter &p) { + return p.getConstraint() != std::nullopt; + }) || + any_of(traits, [](const Trait &t) { return isa(&t); }); } std::optional AttrOrTypeDef::getExtraDecls() const { diff --git a/mlir/test/IR/test-verifiers-type.mlir b/mlir/test/IR/test-verifiers-type.mlir index 6512a1b9c8711..a6a5fa3d4fc9f 100644 --- a/mlir/test/IR/test-verifiers-type.mlir +++ b/mlir/test/IR/test-verifiers-type.mlir @@ -22,3 +22,51 @@ // expected-error @below{{failed to verify 'elementType': VectorElementTypeInterface instance}} "test.type_producer"() : () -> vector> + +// ----- + +// Test PredTypeTrait with single parameter - valid case. +// CHECK: "test.type_producer"() : () -> !test.type_pred_trait<5> +"test.type_producer"() : () -> !test.type_pred_trait<5> + +// ----- + +// Test PredTypeTrait with single parameter - invalid case (zero is not positive). +// expected-error @below{{failed to verify that value must be positive}} +"test.type_producer"() : () -> !test.type_pred_trait<0> + +// ----- + +// Test PredTypeTrait with multiple parameters - valid case (5 >= 3). +// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3> +"test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3> + +// ----- + +// Test PredTypeTrait with multiple parameters - edge case (3 >= 3). +// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3> +"test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3> + +// ----- + +// Test PredTypeTrait with multiple parameters - invalid case (2 < 5). +// expected-error @below{{failed to verify that value must be at least min}} +"test.type_producer"() : () -> !test.type_pred_trait_multi<2, 5> + +// ----- + +// Test combined parameter constraint + PredTypeTrait - valid case. +// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32> +"test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32> + +// ----- + +// Test combined - parameter type constraint fails (f16 not in [I16, I32]). +// expected-error @below{{failed to verify 'elementType': 16-bit signless integer or 32-bit signless integer}} +"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2], f16> + +// ----- + +// Test combined - PredTypeTrait fails (count 2 != elements.size() 3). +// expected-error @below{{failed to verify that count must match number of elements}} +"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2, 3], i16> diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index 9859bd06cb526..232d6354d01eb 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -406,6 +406,36 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> { let assemblyFormat = "`<` $param `>`"; } +// Test type with PredTypeTrait - single parameter predicate. +def TestTypePredTrait : Test_Type<"TestTypePredTrait", + [PredTypeTrait<"value must be positive", CPred<"$value > 0">>]> { + let parameters = (ins "unsigned":$value); + let mnemonic = "type_pred_trait"; + let assemblyFormat = "`<` $value `>`"; +} + +// Test type with PredTypeTrait - two parameter predicate. +def TestTypePredTraitMultiParam : Test_Type<"TestTypePredTraitMultiParam", + [PredTypeTrait<"value must be at least min", + CPred<"$value >= $minValue">>]> { + let parameters = (ins "unsigned":$value, "unsigned":$minValue); + let mnemonic = "type_pred_trait_multi"; + let assemblyFormat = "`<` $value `,` $minValue `>`"; +} + +// Test type combining parameter type constraints with PredTypeTrait. +def TestTypePredTraitCombined : Test_Type<"TestTypePredTraitCombined", + [PredTypeTrait<"count must match number of elements", + CPred<"$count == $elements.size()">>]> { + let parameters = (ins + "unsigned":$count, + ArrayRefParameter<"int64_t">:$elements, + AnyTypeOf<[I16, I32]>:$elementType + ); + let mnemonic = "type_pred_trait_combined"; + let assemblyFormat = "`<` $count `,` `[` $elements `]` `,` $elementType `>`"; +} + def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface", [DeclareTypeInterfaceMethods]> { let mnemonic = "op_asm_type_interface"; diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index 2a513c3b8cc9b..6547cb196716c 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -245,12 +245,16 @@ void DefGen::createParentWithTraits() { ? strfmt("{0}::{1}", def.getStorageNamespace(), def.getStorageClassName()) : strfmt("::mlir::{0}Storage", valueType)); - SmallVector traitNames = - llvm::to_vector(llvm::map_range(def.getTraits(), [](auto &trait) { - return isa(&trait) - ? cast(&trait)->getFullyQualifiedTraitName() - : cast(&trait)->getFullyQualifiedTraitName(); - })); + SmallVector traitNames; + for (auto &trait : def.getTraits()) { + // Skip PredTrait as it doesn't generate a C++ trait class. + if (isa(&trait)) + continue; + traitNames.push_back( + isa(&trait) + ? cast(&trait)->getFullyQualifiedTraitName() + : cast(&trait)->getFullyQualifiedTraitName()); + } for (auto &traitName : traitNames) defParent.addTemplateParam(traitName); @@ -385,6 +389,26 @@ void DefGen::emitInvariantsVerifierImpl() { param.getName(), constraint->getSummary()) << "\n"; } + { + // Generate verification for PredTraits. + FmtContext traitCtx; + for (auto it : llvm::enumerate(def.getParameters())) { + // Note: Skip over the first method parameter (`emitError`). + traitCtx.addSubst(it.value().getName(), + builderParams[it.index() + 1].getName()); + } + for (const Trait &trait : def.getTraits()) { + if (auto *t = dyn_cast(&trait)) { + verifier->body() << tgfmt( + "if (!($0)) {\n" + " emitError() << \"failed to verify that $1\";\n" + " return ::mlir::failure();\n" + "}\n", + &traitCtx, tgfmt(t->getPredTemplate(), &traitCtx), t->getSummary()); + } + } + } + verifier->body() << "return ::mlir::success();"; }