Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions mlir/lib/TableGen/AttrOrTypeDef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,11 @@ bool AttrOrTypeDef::genVerifyDecl() const {
}

bool AttrOrTypeDef::genVerifyInvariantsImpl() const {
return any_of(parameters, [](const AttrOrTypeParameter &p) {
return p.getConstraint() != std::nullopt;
});
return any_of(parameters,
[](const AttrOrTypeParameter &p) {
return p.getConstraint() != std::nullopt;
}) ||
any_of(traits, [](const Trait &t) { return isa<PredTrait>(&t); });
}

std::optional<StringRef> AttrOrTypeDef::getExtraDecls() const {
Expand Down
48 changes: 48 additions & 0 deletions mlir/test/IR/test-verifiers-type.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,51 @@

// expected-error @below{{failed to verify 'elementType': VectorElementTypeInterface instance}}
"test.type_producer"() : () -> vector<memref<2xf32>>

// -----

// Test PredTypeTrait with single parameter - valid case.
// CHECK: "test.type_producer"() : () -> !test.type_pred_trait<5>
"test.type_producer"() : () -> !test.type_pred_trait<5>

// -----

// Test PredTypeTrait with single parameter - invalid case (zero is not positive).
// expected-error @below{{failed to verify that value must be positive}}
"test.type_producer"() : () -> !test.type_pred_trait<0>

// -----

// Test PredTypeTrait with multiple parameters - valid case (5 >= 3).
// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3>
"test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3>

// -----

// Test PredTypeTrait with multiple parameters - edge case (3 >= 3).
// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3>
"test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3>

// -----

// Test PredTypeTrait with multiple parameters - invalid case (2 < 5).
// expected-error @below{{failed to verify that value must be at least min}}
"test.type_producer"() : () -> !test.type_pred_trait_multi<2, 5>

// -----

// Test combined parameter constraint + PredTypeTrait - valid case.
// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32>
"test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32>

// -----

// Test combined - parameter type constraint fails (f16 not in [I16, I32]).
// expected-error @below{{failed to verify 'elementType': 16-bit signless integer or 32-bit signless integer}}
"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2], f16>

// -----

// Test combined - PredTypeTrait fails (count 2 != elements.size() 3).
// expected-error @below{{failed to verify that count must match number of elements}}
"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2, 3], i16>
30 changes: 30 additions & 0 deletions mlir/test/lib/Dialect/Test/TestTypeDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,36 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
let assemblyFormat = "`<` $param `>`";
}

// Test type with PredTypeTrait - single parameter predicate.
def TestTypePredTrait : Test_Type<"TestTypePredTrait",
[PredTypeTrait<"value must be positive", CPred<"$value > 0">>]> {
let parameters = (ins "unsigned":$value);
let mnemonic = "type_pred_trait";
let assemblyFormat = "`<` $value `>`";
}

// Test type with PredTypeTrait - two parameter predicate.
def TestTypePredTraitMultiParam : Test_Type<"TestTypePredTraitMultiParam",
[PredTypeTrait<"value must be at least min",
CPred<"$value >= $minValue">>]> {
let parameters = (ins "unsigned":$value, "unsigned":$minValue);
let mnemonic = "type_pred_trait_multi";
let assemblyFormat = "`<` $value `,` $minValue `>`";
}

// Test type combining parameter type constraints with PredTypeTrait.
def TestTypePredTraitCombined : Test_Type<"TestTypePredTraitCombined",
[PredTypeTrait<"count must match number of elements",
CPred<"$count == $elements.size()">>]> {
let parameters = (ins
"unsigned":$count,
ArrayRefParameter<"int64_t">:$elements,
AnyTypeOf<[I16, I32]>:$elementType
);
let mnemonic = "type_pred_trait_combined";
let assemblyFormat = "`<` $count `,` `[` $elements `]` `,` $elementType `>`";
}

def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
[DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName", "getAlias"]>]> {
let mnemonic = "op_asm_type_interface";
Expand Down
36 changes: 30 additions & 6 deletions mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,16 @@ void DefGen::createParentWithTraits() {
? strfmt("{0}::{1}", def.getStorageNamespace(),
def.getStorageClassName())
: strfmt("::mlir::{0}Storage", valueType));
SmallVector<std::string> traitNames =
llvm::to_vector(llvm::map_range(def.getTraits(), [](auto &trait) {
return isa<NativeTrait>(&trait)
? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
: cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName();
}));
SmallVector<std::string> traitNames;
for (auto &trait : def.getTraits()) {
// Skip PredTrait as it doesn't generate a C++ trait class.
if (isa<PredTrait>(&trait))
continue;
traitNames.push_back(
isa<NativeTrait>(&trait)
? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
: cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName());
}
for (auto &traitName : traitNames)
defParent.addTemplateParam(traitName);

Expand Down Expand Up @@ -385,6 +389,26 @@ void DefGen::emitInvariantsVerifierImpl() {
param.getName(), constraint->getSummary())
<< "\n";
}
{
// Generate verification for PredTraits.
FmtContext traitCtx;
for (auto it : llvm::enumerate(def.getParameters())) {
// Note: Skip over the first method parameter (`emitError`).
traitCtx.addSubst(it.value().getName(),
builderParams[it.index() + 1].getName());
}
for (const Trait &trait : def.getTraits()) {
if (auto *t = dyn_cast<PredTrait>(&trait)) {
verifier->body() << tgfmt(
"if (!($0)) {\n"
" emitError() << \"failed to verify that $1\";\n"
" return ::mlir::failure();\n"
"}\n",
&traitCtx, tgfmt(t->getPredTemplate(), &traitCtx), t->getSummary());
}
}
}

verifier->body() << "return ::mlir::success();";
}

Expand Down