Skip to content

Conversation

@timnoack
Copy link
Contributor

This patch adds support for PredTypeTrait and PredAttrTrait in type and attribute definitions, enabling declarative predicate-based verification similar to how PredOpTrait works for operations.

Motivation

Previously, PredTypeTrait/PredAttrTrait were defined in TableGen but not implemented in the code generator. Using them would cause mlir-tblgen to crash with an assertion failure when trying to cast PredTrait to InterfaceTrait. This patch fixes the crash and implements the actual verification code generation.

Usage

Use $paramName syntax in predicates to reference type/attribute parameters:

def MyType : MyDialect_Type<"MyType",
    [PredTypeTrait<"value must be positive", CPred<"$value > 0">>]> {
  let parameters = (ins "unsigned":$value);
  let mnemonic = "my_type";
  let assemblyFormat = "`<` $value `>`";
}

This generates verification code in verifyInvariantsImpl():

  if (!(value > 0)) {
    emitError() << "failed to verify that value must be positive";
    return ::mlir::failure();
  }

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Nov 22, 2025
@llvmbot
Copy link
Member

llvmbot commented Nov 22, 2025

@llvm/pr-subscribers-mlir

Author: Tim Noack (timnoack)

Changes

This patch adds support for PredTypeTrait and PredAttrTrait in type and attribute definitions, enabling declarative predicate-based verification similar to how PredOpTrait works for operations.

Motivation

Previously, PredTypeTrait/PredAttrTrait were defined in TableGen but not implemented in the code generator. Using them would cause mlir-tblgen to crash with an assertion failure when trying to cast PredTrait to InterfaceTrait. This patch fixes the crash and implements the actual verification code generation.

Usage

Use $paramName syntax in predicates to reference type/attribute parameters:

def MyType : MyDialect_Type&lt;"MyType",
    [PredTypeTrait&lt;"value must be positive", CPred&lt;"$value &gt; 0"&gt;&gt;]&gt; {
  let parameters = (ins "unsigned":$value);
  let mnemonic = "my_type";
  let assemblyFormat = "`&lt;` $value `&gt;`";
}

This generates verification code in verifyInvariantsImpl():

  if (!(value &gt; 0)) {
    emitError() &lt;&lt; "failed to verify that value must be positive";
    return ::mlir::failure();
  }

Full diff: https://github.com/llvm/llvm-project/pull/169153.diff

4 Files Affected:

  • (modified) mlir/lib/TableGen/AttrOrTypeDef.cpp (+5-3)
  • (modified) mlir/test/IR/test-verifiers-type.mlir (+48)
  • (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+30)
  • (modified) mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp (+30-6)
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 4659265e24bda..bf835a860cd5b 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -190,9 +190,11 @@ bool AttrOrTypeDef::genVerifyDecl() const {
 }
 
 bool AttrOrTypeDef::genVerifyInvariantsImpl() const {
-  return any_of(parameters, [](const AttrOrTypeParameter &p) {
-    return p.getConstraint() != std::nullopt;
-  });
+  return any_of(parameters,
+                [](const AttrOrTypeParameter &p) {
+                  return p.getConstraint() != std::nullopt;
+                }) ||
+         any_of(traits, [](const Trait &t) { return isa<PredTrait>(&t); });
 }
 
 std::optional<StringRef> AttrOrTypeDef::getExtraDecls() const {
diff --git a/mlir/test/IR/test-verifiers-type.mlir b/mlir/test/IR/test-verifiers-type.mlir
index 6512a1b9c8711..a6a5fa3d4fc9f 100644
--- a/mlir/test/IR/test-verifiers-type.mlir
+++ b/mlir/test/IR/test-verifiers-type.mlir
@@ -22,3 +22,51 @@
 
 // expected-error @below{{failed to verify 'elementType': VectorElementTypeInterface instance}}
 "test.type_producer"() : () -> vector<memref<2xf32>>
+
+// -----
+
+// Test PredTypeTrait with single parameter - valid case.
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait<5>
+"test.type_producer"() : () -> !test.type_pred_trait<5>
+
+// -----
+
+// Test PredTypeTrait with single parameter - invalid case (zero is not positive).
+// expected-error @below{{failed to verify that value must be positive}}
+"test.type_producer"() : () -> !test.type_pred_trait<0>
+
+// -----
+
+// Test PredTypeTrait with multiple parameters - valid case (5 >= 3).
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3>
+"test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3>
+
+// -----
+
+// Test PredTypeTrait with multiple parameters - edge case (3 >= 3).
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3>
+"test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3>
+
+// -----
+
+// Test PredTypeTrait with multiple parameters - invalid case (2 < 5).
+// expected-error @below{{failed to verify that value must be at least min}}
+"test.type_producer"() : () -> !test.type_pred_trait_multi<2, 5>
+
+// -----
+
+// Test combined parameter constraint + PredTypeTrait - valid case.
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32>
+"test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32>
+
+// -----
+
+// Test combined - parameter type constraint fails (f16 not in [I16, I32]).
+// expected-error @below{{failed to verify 'elementType': 16-bit signless integer or 32-bit signless integer}}
+"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2], f16>
+
+// -----
+
+// Test combined - PredTypeTrait fails (count 2 != elements.size() 3).
+// expected-error @below{{failed to verify that count must match number of elements}}
+"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2, 3], i16>
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 9859bd06cb526..232d6354d01eb 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -406,6 +406,36 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
   let assemblyFormat = "`<` $param `>`";
 }
 
+// Test type with PredTypeTrait - single parameter predicate.
+def TestTypePredTrait : Test_Type<"TestTypePredTrait",
+    [PredTypeTrait<"value must be positive", CPred<"$value > 0">>]> {
+  let parameters = (ins "unsigned":$value);
+  let mnemonic = "type_pred_trait";
+  let assemblyFormat = "`<` $value `>`";
+}
+
+// Test type with PredTypeTrait - two parameter predicate.
+def TestTypePredTraitMultiParam : Test_Type<"TestTypePredTraitMultiParam",
+    [PredTypeTrait<"value must be at least min",
+                   CPred<"$value >= $minValue">>]> {
+  let parameters = (ins "unsigned":$value, "unsigned":$minValue);
+  let mnemonic = "type_pred_trait_multi";
+  let assemblyFormat = "`<` $value `,` $minValue `>`";
+}
+
+// Test type combining parameter type constraints with PredTypeTrait.
+def TestTypePredTraitCombined : Test_Type<"TestTypePredTraitCombined",
+    [PredTypeTrait<"count must match number of elements",
+                   CPred<"$count == $elements.size()">>]> {
+  let parameters = (ins
+    "unsigned":$count,
+    ArrayRefParameter<"int64_t">:$elements,
+    AnyTypeOf<[I16, I32]>:$elementType
+  );
+  let mnemonic = "type_pred_trait_combined";
+  let assemblyFormat = "`<` $count `,` `[` $elements `]` `,` $elementType `>`";
+}
+
 def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
     [DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName", "getAlias"]>]> {
   let mnemonic = "op_asm_type_interface";
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 2a513c3b8cc9b..6547cb196716c 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -245,12 +245,16 @@ void DefGen::createParentWithTraits() {
                                  ? strfmt("{0}::{1}", def.getStorageNamespace(),
                                           def.getStorageClassName())
                                  : strfmt("::mlir::{0}Storage", valueType));
-  SmallVector<std::string> traitNames =
-      llvm::to_vector(llvm::map_range(def.getTraits(), [](auto &trait) {
-        return isa<NativeTrait>(&trait)
-                   ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
-                   : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName();
-      }));
+  SmallVector<std::string> traitNames;
+  for (auto &trait : def.getTraits()) {
+    // Skip PredTrait as it doesn't generate a C++ trait class.
+    if (isa<PredTrait>(&trait))
+      continue;
+    traitNames.push_back(
+        isa<NativeTrait>(&trait)
+            ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
+            : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName());
+  }
   for (auto &traitName : traitNames)
     defParent.addTemplateParam(traitName);
 
@@ -385,6 +389,26 @@ void DefGen::emitInvariantsVerifierImpl() {
                                 param.getName(), constraint->getSummary())
                      << "\n";
   }
+  {
+    // Generate verification for PredTraits.
+    FmtContext traitCtx;
+    for (auto it : llvm::enumerate(def.getParameters())) {
+      // Note: Skip over the first method parameter (`emitError`).
+      traitCtx.addSubst(it.value().getName(),
+                        builderParams[it.index() + 1].getName());
+    }
+    for (const Trait &trait : def.getTraits()) {
+      if (auto *t = dyn_cast<PredTrait>(&trait)) {
+        verifier->body() << tgfmt(
+            "if (!($0)) {\n"
+            "  emitError() << \"failed to verify that $1\";\n"
+            "  return ::mlir::failure();\n"
+            "}\n",
+            &traitCtx, tgfmt(t->getPredTemplate(), &traitCtx), t->getSummary());
+      }
+    }
+  }
+
   verifier->body() << "return ::mlir::success();";
 }
 

@llvmbot
Copy link
Member

llvmbot commented Nov 22, 2025

@llvm/pr-subscribers-mlir-core

Author: Tim Noack (timnoack)

Changes

This patch adds support for PredTypeTrait and PredAttrTrait in type and attribute definitions, enabling declarative predicate-based verification similar to how PredOpTrait works for operations.

Motivation

Previously, PredTypeTrait/PredAttrTrait were defined in TableGen but not implemented in the code generator. Using them would cause mlir-tblgen to crash with an assertion failure when trying to cast PredTrait to InterfaceTrait. This patch fixes the crash and implements the actual verification code generation.

Usage

Use $paramName syntax in predicates to reference type/attribute parameters:

def MyType : MyDialect_Type&lt;"MyType",
    [PredTypeTrait&lt;"value must be positive", CPred&lt;"$value &gt; 0"&gt;&gt;]&gt; {
  let parameters = (ins "unsigned":$value);
  let mnemonic = "my_type";
  let assemblyFormat = "`&lt;` $value `&gt;`";
}

This generates verification code in verifyInvariantsImpl():

  if (!(value &gt; 0)) {
    emitError() &lt;&lt; "failed to verify that value must be positive";
    return ::mlir::failure();
  }

Full diff: https://github.com/llvm/llvm-project/pull/169153.diff

4 Files Affected:

  • (modified) mlir/lib/TableGen/AttrOrTypeDef.cpp (+5-3)
  • (modified) mlir/test/IR/test-verifiers-type.mlir (+48)
  • (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+30)
  • (modified) mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp (+30-6)
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 4659265e24bda..bf835a860cd5b 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -190,9 +190,11 @@ bool AttrOrTypeDef::genVerifyDecl() const {
 }
 
 bool AttrOrTypeDef::genVerifyInvariantsImpl() const {
-  return any_of(parameters, [](const AttrOrTypeParameter &p) {
-    return p.getConstraint() != std::nullopt;
-  });
+  return any_of(parameters,
+                [](const AttrOrTypeParameter &p) {
+                  return p.getConstraint() != std::nullopt;
+                }) ||
+         any_of(traits, [](const Trait &t) { return isa<PredTrait>(&t); });
 }
 
 std::optional<StringRef> AttrOrTypeDef::getExtraDecls() const {
diff --git a/mlir/test/IR/test-verifiers-type.mlir b/mlir/test/IR/test-verifiers-type.mlir
index 6512a1b9c8711..a6a5fa3d4fc9f 100644
--- a/mlir/test/IR/test-verifiers-type.mlir
+++ b/mlir/test/IR/test-verifiers-type.mlir
@@ -22,3 +22,51 @@
 
 // expected-error @below{{failed to verify 'elementType': VectorElementTypeInterface instance}}
 "test.type_producer"() : () -> vector<memref<2xf32>>
+
+// -----
+
+// Test PredTypeTrait with single parameter - valid case.
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait<5>
+"test.type_producer"() : () -> !test.type_pred_trait<5>
+
+// -----
+
+// Test PredTypeTrait with single parameter - invalid case (zero is not positive).
+// expected-error @below{{failed to verify that value must be positive}}
+"test.type_producer"() : () -> !test.type_pred_trait<0>
+
+// -----
+
+// Test PredTypeTrait with multiple parameters - valid case (5 >= 3).
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3>
+"test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3>
+
+// -----
+
+// Test PredTypeTrait with multiple parameters - edge case (3 >= 3).
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3>
+"test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3>
+
+// -----
+
+// Test PredTypeTrait with multiple parameters - invalid case (2 < 5).
+// expected-error @below{{failed to verify that value must be at least min}}
+"test.type_producer"() : () -> !test.type_pred_trait_multi<2, 5>
+
+// -----
+
+// Test combined parameter constraint + PredTypeTrait - valid case.
+// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32>
+"test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32>
+
+// -----
+
+// Test combined - parameter type constraint fails (f16 not in [I16, I32]).
+// expected-error @below{{failed to verify 'elementType': 16-bit signless integer or 32-bit signless integer}}
+"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2], f16>
+
+// -----
+
+// Test combined - PredTypeTrait fails (count 2 != elements.size() 3).
+// expected-error @below{{failed to verify that count must match number of elements}}
+"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2, 3], i16>
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 9859bd06cb526..232d6354d01eb 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -406,6 +406,36 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
   let assemblyFormat = "`<` $param `>`";
 }
 
+// Test type with PredTypeTrait - single parameter predicate.
+def TestTypePredTrait : Test_Type<"TestTypePredTrait",
+    [PredTypeTrait<"value must be positive", CPred<"$value > 0">>]> {
+  let parameters = (ins "unsigned":$value);
+  let mnemonic = "type_pred_trait";
+  let assemblyFormat = "`<` $value `>`";
+}
+
+// Test type with PredTypeTrait - two parameter predicate.
+def TestTypePredTraitMultiParam : Test_Type<"TestTypePredTraitMultiParam",
+    [PredTypeTrait<"value must be at least min",
+                   CPred<"$value >= $minValue">>]> {
+  let parameters = (ins "unsigned":$value, "unsigned":$minValue);
+  let mnemonic = "type_pred_trait_multi";
+  let assemblyFormat = "`<` $value `,` $minValue `>`";
+}
+
+// Test type combining parameter type constraints with PredTypeTrait.
+def TestTypePredTraitCombined : Test_Type<"TestTypePredTraitCombined",
+    [PredTypeTrait<"count must match number of elements",
+                   CPred<"$count == $elements.size()">>]> {
+  let parameters = (ins
+    "unsigned":$count,
+    ArrayRefParameter<"int64_t">:$elements,
+    AnyTypeOf<[I16, I32]>:$elementType
+  );
+  let mnemonic = "type_pred_trait_combined";
+  let assemblyFormat = "`<` $count `,` `[` $elements `]` `,` $elementType `>`";
+}
+
 def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
     [DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName", "getAlias"]>]> {
   let mnemonic = "op_asm_type_interface";
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 2a513c3b8cc9b..6547cb196716c 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -245,12 +245,16 @@ void DefGen::createParentWithTraits() {
                                  ? strfmt("{0}::{1}", def.getStorageNamespace(),
                                           def.getStorageClassName())
                                  : strfmt("::mlir::{0}Storage", valueType));
-  SmallVector<std::string> traitNames =
-      llvm::to_vector(llvm::map_range(def.getTraits(), [](auto &trait) {
-        return isa<NativeTrait>(&trait)
-                   ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
-                   : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName();
-      }));
+  SmallVector<std::string> traitNames;
+  for (auto &trait : def.getTraits()) {
+    // Skip PredTrait as it doesn't generate a C++ trait class.
+    if (isa<PredTrait>(&trait))
+      continue;
+    traitNames.push_back(
+        isa<NativeTrait>(&trait)
+            ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
+            : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName());
+  }
   for (auto &traitName : traitNames)
     defParent.addTemplateParam(traitName);
 
@@ -385,6 +389,26 @@ void DefGen::emitInvariantsVerifierImpl() {
                                 param.getName(), constraint->getSummary())
                      << "\n";
   }
+  {
+    // Generate verification for PredTraits.
+    FmtContext traitCtx;
+    for (auto it : llvm::enumerate(def.getParameters())) {
+      // Note: Skip over the first method parameter (`emitError`).
+      traitCtx.addSubst(it.value().getName(),
+                        builderParams[it.index() + 1].getName());
+    }
+    for (const Trait &trait : def.getTraits()) {
+      if (auto *t = dyn_cast<PredTrait>(&trait)) {
+        verifier->body() << tgfmt(
+            "if (!($0)) {\n"
+            "  emitError() << \"failed to verify that $1\";\n"
+            "  return ::mlir::failure();\n"
+            "}\n",
+            &traitCtx, tgfmt(t->getPredTemplate(), &traitCtx), t->getSummary());
+      }
+    }
+  }
+
   verifier->body() << "return ::mlir::success();";
 }
 

@timnoack
Copy link
Contributor Author

@joker-eph @ftynse Not sure who to ping here. I saw that @matthias-springer implemented verification of type constraints on types / attributes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants