-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][tblgen] Add PredTypeTrait/PredAttrTrait support #169153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…bute verification
|
@llvm/pr-subscribers-mlir Author: Tim Noack (timnoack) ChangesThis patch adds support for MotivationPreviously, UsageUse This generates verification code in if (!(value > 0)) {
emitError() << "failed to verify that value must be positive";
return ::mlir::failure();
}Full diff: https://github.com/llvm/llvm-project/pull/169153.diff 4 Files Affected:
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 4659265e24bda..bf835a860cd5b 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -190,9 +190,11 @@ bool AttrOrTypeDef::genVerifyDecl() const {
}
bool AttrOrTypeDef::genVerifyInvariantsImpl() const {
- return any_of(parameters, [](const AttrOrTypeParameter &p) {
- return p.getConstraint() != std::nullopt;
- });
+ return any_of(parameters,
+ [](const AttrOrTypeParameter &p) {
+ return p.getConstraint() != std::nullopt;
+ }) ||
+ any_of(traits, [](const Trait &t) { return isa<PredTrait>(&t); });
}
std::optional<StringRef> AttrOrTypeDef::getExtraDecls() const {
diff --git a/mlir/test/IR/test-verifiers-type.mlir b/mlir/test/IR/test-verifiers-type.mlir
index 6512a1b9c8711..a6a5fa3d4fc9f 100644
--- a/mlir/test/IR/test-verifiers-type.mlir
+++ b/mlir/test/IR/test-verifiers-type.mlir
@@ -22,3 +22,51 @@
// expected-error @below{{failed to verify 'elementType': VectorElementTypeInterface instance}}
"test.type_producer"() : () -> vector<memref<2xf32>>
+
+// -----
+
+// Test PredTypeTrait with single parameter - valid case.
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait<5>
+"test.type_producer"() : () -> !test.type_pred_trait<5>
+
+// -----
+
+// Test PredTypeTrait with single parameter - invalid case (zero is not positive).
+// expected-error @below{{failed to verify that value must be positive}}
+"test.type_producer"() : () -> !test.type_pred_trait<0>
+
+// -----
+
+// Test PredTypeTrait with multiple parameters - valid case (5 >= 3).
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3>
+"test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3>
+
+// -----
+
+// Test PredTypeTrait with multiple parameters - edge case (3 >= 3).
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3>
+"test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3>
+
+// -----
+
+// Test PredTypeTrait with multiple parameters - invalid case (2 < 5).
+// expected-error @below{{failed to verify that value must be at least min}}
+"test.type_producer"() : () -> !test.type_pred_trait_multi<2, 5>
+
+// -----
+
+// Test combined parameter constraint + PredTypeTrait - valid case.
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32>
+"test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32>
+
+// -----
+
+// Test combined - parameter type constraint fails (f16 not in [I16, I32]).
+// expected-error @below{{failed to verify 'elementType': 16-bit signless integer or 32-bit signless integer}}
+"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2], f16>
+
+// -----
+
+// Test combined - PredTypeTrait fails (count 2 != elements.size() 3).
+// expected-error @below{{failed to verify that count must match number of elements}}
+"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2, 3], i16>
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 9859bd06cb526..232d6354d01eb 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -406,6 +406,36 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
let assemblyFormat = "`<` $param `>`";
}
+// Test type with PredTypeTrait - single parameter predicate.
+def TestTypePredTrait : Test_Type<"TestTypePredTrait",
+ [PredTypeTrait<"value must be positive", CPred<"$value > 0">>]> {
+ let parameters = (ins "unsigned":$value);
+ let mnemonic = "type_pred_trait";
+ let assemblyFormat = "`<` $value `>`";
+}
+
+// Test type with PredTypeTrait - two parameter predicate.
+def TestTypePredTraitMultiParam : Test_Type<"TestTypePredTraitMultiParam",
+ [PredTypeTrait<"value must be at least min",
+ CPred<"$value >= $minValue">>]> {
+ let parameters = (ins "unsigned":$value, "unsigned":$minValue);
+ let mnemonic = "type_pred_trait_multi";
+ let assemblyFormat = "`<` $value `,` $minValue `>`";
+}
+
+// Test type combining parameter type constraints with PredTypeTrait.
+def TestTypePredTraitCombined : Test_Type<"TestTypePredTraitCombined",
+ [PredTypeTrait<"count must match number of elements",
+ CPred<"$count == $elements.size()">>]> {
+ let parameters = (ins
+ "unsigned":$count,
+ ArrayRefParameter<"int64_t">:$elements,
+ AnyTypeOf<[I16, I32]>:$elementType
+ );
+ let mnemonic = "type_pred_trait_combined";
+ let assemblyFormat = "`<` $count `,` `[` $elements `]` `,` $elementType `>`";
+}
+
def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
[DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName", "getAlias"]>]> {
let mnemonic = "op_asm_type_interface";
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 2a513c3b8cc9b..6547cb196716c 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -245,12 +245,16 @@ void DefGen::createParentWithTraits() {
? strfmt("{0}::{1}", def.getStorageNamespace(),
def.getStorageClassName())
: strfmt("::mlir::{0}Storage", valueType));
- SmallVector<std::string> traitNames =
- llvm::to_vector(llvm::map_range(def.getTraits(), [](auto &trait) {
- return isa<NativeTrait>(&trait)
- ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
- : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName();
- }));
+ SmallVector<std::string> traitNames;
+ for (auto &trait : def.getTraits()) {
+ // Skip PredTrait as it doesn't generate a C++ trait class.
+ if (isa<PredTrait>(&trait))
+ continue;
+ traitNames.push_back(
+ isa<NativeTrait>(&trait)
+ ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
+ : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName());
+ }
for (auto &traitName : traitNames)
defParent.addTemplateParam(traitName);
@@ -385,6 +389,26 @@ void DefGen::emitInvariantsVerifierImpl() {
param.getName(), constraint->getSummary())
<< "\n";
}
+ {
+ // Generate verification for PredTraits.
+ FmtContext traitCtx;
+ for (auto it : llvm::enumerate(def.getParameters())) {
+ // Note: Skip over the first method parameter (`emitError`).
+ traitCtx.addSubst(it.value().getName(),
+ builderParams[it.index() + 1].getName());
+ }
+ for (const Trait &trait : def.getTraits()) {
+ if (auto *t = dyn_cast<PredTrait>(&trait)) {
+ verifier->body() << tgfmt(
+ "if (!($0)) {\n"
+ " emitError() << \"failed to verify that $1\";\n"
+ " return ::mlir::failure();\n"
+ "}\n",
+ &traitCtx, tgfmt(t->getPredTemplate(), &traitCtx), t->getSummary());
+ }
+ }
+ }
+
verifier->body() << "return ::mlir::success();";
}
|
|
@llvm/pr-subscribers-mlir-core Author: Tim Noack (timnoack) ChangesThis patch adds support for MotivationPreviously, UsageUse This generates verification code in if (!(value > 0)) {
emitError() << "failed to verify that value must be positive";
return ::mlir::failure();
}Full diff: https://github.com/llvm/llvm-project/pull/169153.diff 4 Files Affected:
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 4659265e24bda..bf835a860cd5b 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -190,9 +190,11 @@ bool AttrOrTypeDef::genVerifyDecl() const {
}
bool AttrOrTypeDef::genVerifyInvariantsImpl() const {
- return any_of(parameters, [](const AttrOrTypeParameter &p) {
- return p.getConstraint() != std::nullopt;
- });
+ return any_of(parameters,
+ [](const AttrOrTypeParameter &p) {
+ return p.getConstraint() != std::nullopt;
+ }) ||
+ any_of(traits, [](const Trait &t) { return isa<PredTrait>(&t); });
}
std::optional<StringRef> AttrOrTypeDef::getExtraDecls() const {
diff --git a/mlir/test/IR/test-verifiers-type.mlir b/mlir/test/IR/test-verifiers-type.mlir
index 6512a1b9c8711..a6a5fa3d4fc9f 100644
--- a/mlir/test/IR/test-verifiers-type.mlir
+++ b/mlir/test/IR/test-verifiers-type.mlir
@@ -22,3 +22,51 @@
// expected-error @below{{failed to verify 'elementType': VectorElementTypeInterface instance}}
"test.type_producer"() : () -> vector<memref<2xf32>>
+
+// -----
+
+// Test PredTypeTrait with single parameter - valid case.
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait<5>
+"test.type_producer"() : () -> !test.type_pred_trait<5>
+
+// -----
+
+// Test PredTypeTrait with single parameter - invalid case (zero is not positive).
+// expected-error @below{{failed to verify that value must be positive}}
+"test.type_producer"() : () -> !test.type_pred_trait<0>
+
+// -----
+
+// Test PredTypeTrait with multiple parameters - valid case (5 >= 3).
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3>
+"test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3>
+
+// -----
+
+// Test PredTypeTrait with multiple parameters - edge case (3 >= 3).
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3>
+"test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3>
+
+// -----
+
+// Test PredTypeTrait with multiple parameters - invalid case (2 < 5).
+// expected-error @below{{failed to verify that value must be at least min}}
+"test.type_producer"() : () -> !test.type_pred_trait_multi<2, 5>
+
+// -----
+
+// Test combined parameter constraint + PredTypeTrait - valid case.
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32>
+"test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32>
+
+// -----
+
+// Test combined - parameter type constraint fails (f16 not in [I16, I32]).
+// expected-error @below{{failed to verify 'elementType': 16-bit signless integer or 32-bit signless integer}}
+"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2], f16>
+
+// -----
+
+// Test combined - PredTypeTrait fails (count 2 != elements.size() 3).
+// expected-error @below{{failed to verify that count must match number of elements}}
+"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2, 3], i16>
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 9859bd06cb526..232d6354d01eb 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -406,6 +406,36 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
let assemblyFormat = "`<` $param `>`";
}
+// Test type with PredTypeTrait - single parameter predicate.
+def TestTypePredTrait : Test_Type<"TestTypePredTrait",
+ [PredTypeTrait<"value must be positive", CPred<"$value > 0">>]> {
+ let parameters = (ins "unsigned":$value);
+ let mnemonic = "type_pred_trait";
+ let assemblyFormat = "`<` $value `>`";
+}
+
+// Test type with PredTypeTrait - two parameter predicate.
+def TestTypePredTraitMultiParam : Test_Type<"TestTypePredTraitMultiParam",
+ [PredTypeTrait<"value must be at least min",
+ CPred<"$value >= $minValue">>]> {
+ let parameters = (ins "unsigned":$value, "unsigned":$minValue);
+ let mnemonic = "type_pred_trait_multi";
+ let assemblyFormat = "`<` $value `,` $minValue `>`";
+}
+
+// Test type combining parameter type constraints with PredTypeTrait.
+def TestTypePredTraitCombined : Test_Type<"TestTypePredTraitCombined",
+ [PredTypeTrait<"count must match number of elements",
+ CPred<"$count == $elements.size()">>]> {
+ let parameters = (ins
+ "unsigned":$count,
+ ArrayRefParameter<"int64_t">:$elements,
+ AnyTypeOf<[I16, I32]>:$elementType
+ );
+ let mnemonic = "type_pred_trait_combined";
+ let assemblyFormat = "`<` $count `,` `[` $elements `]` `,` $elementType `>`";
+}
+
def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
[DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName", "getAlias"]>]> {
let mnemonic = "op_asm_type_interface";
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 2a513c3b8cc9b..6547cb196716c 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -245,12 +245,16 @@ void DefGen::createParentWithTraits() {
? strfmt("{0}::{1}", def.getStorageNamespace(),
def.getStorageClassName())
: strfmt("::mlir::{0}Storage", valueType));
- SmallVector<std::string> traitNames =
- llvm::to_vector(llvm::map_range(def.getTraits(), [](auto &trait) {
- return isa<NativeTrait>(&trait)
- ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
- : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName();
- }));
+ SmallVector<std::string> traitNames;
+ for (auto &trait : def.getTraits()) {
+ // Skip PredTrait as it doesn't generate a C++ trait class.
+ if (isa<PredTrait>(&trait))
+ continue;
+ traitNames.push_back(
+ isa<NativeTrait>(&trait)
+ ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
+ : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName());
+ }
for (auto &traitName : traitNames)
defParent.addTemplateParam(traitName);
@@ -385,6 +389,26 @@ void DefGen::emitInvariantsVerifierImpl() {
param.getName(), constraint->getSummary())
<< "\n";
}
+ {
+ // Generate verification for PredTraits.
+ FmtContext traitCtx;
+ for (auto it : llvm::enumerate(def.getParameters())) {
+ // Note: Skip over the first method parameter (`emitError`).
+ traitCtx.addSubst(it.value().getName(),
+ builderParams[it.index() + 1].getName());
+ }
+ for (const Trait &trait : def.getTraits()) {
+ if (auto *t = dyn_cast<PredTrait>(&trait)) {
+ verifier->body() << tgfmt(
+ "if (!($0)) {\n"
+ " emitError() << \"failed to verify that $1\";\n"
+ " return ::mlir::failure();\n"
+ "}\n",
+ &traitCtx, tgfmt(t->getPredTemplate(), &traitCtx), t->getSummary());
+ }
+ }
+ }
+
verifier->body() << "return ::mlir::success();";
}
|
|
@joker-eph @ftynse Not sure who to ping here. I saw that @matthias-springer implemented verification of type constraints on types / attributes. |
This patch adds support for
PredTypeTraitandPredAttrTraitin type and attribute definitions, enabling declarative predicate-based verification similar to howPredOpTraitworks for operations.Motivation
Previously,
PredTypeTrait/PredAttrTraitwere defined in TableGen but not implemented in the code generator. Using them would cause mlir-tblgen to crash with an assertion failure when trying to castPredTraittoInterfaceTrait. This patch fixes the crash and implements the actual verification code generation.Usage
Use
$paramNamesyntax in predicates to reference type/attribute parameters:This generates verification code in
verifyInvariantsImpl():