Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][irdl] Add irdl.base op #76400

Merged
merged 4 commits into from
Jan 18, 2024
Merged

[mlir][irdl] Add irdl.base op #76400

merged 4 commits into from
Jan 18, 2024

Conversation

math-fehr
Copy link
Contributor

The irdl.base op represent an attribute constraint that will check that the
base of a type or attribute is the expected one (e.g. IntegerType) .

Example:

irdl.dialect @cmath {
  irdl.type @complex {
    %0 = irdl.base "!builtin.integer"
    irdl.parameters(%0)
  }

  irdl.type @complex_wrapper {
    %0 = irdl.base @complex
    irdl.parameters(%0)
  }
}

The above program defines a cmath.complex type that expects a single
parameter, which is a type with base name builtin.integer, which is the
name of an IntegerType type.
It also defines a cmath.complex_wrapper type that expects a single
parameter, which is a type of base type cmath.complex.

@math-fehr math-fehr self-assigned this Dec 26, 2023
@llvmbot llvmbot added the mlir label Dec 26, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Dec 26, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-irdl

Author: Fehr Mathieu (math-fehr)

Changes

The irdl.base op represent an attribute constraint that will check that the
base of a type or attribute is the expected one (e.g. IntegerType) .

Example:

irdl.dialect @<!-- -->cmath {
  irdl.type @<!-- -->complex {
    %0 = irdl.base "!builtin.integer"
    irdl.parameters(%0)
  }

  irdl.type @<!-- -->complex_wrapper {
    %0 = irdl.base @<!-- -->complex
    irdl.parameters(%0)
  }
}

The above program defines a cmath.complex type that expects a single
parameter, which is a type with base name builtin.integer, which is the
name of an IntegerType type.
It also defines a cmath.complex_wrapper type that expects a single
parameter, which is a type of base type cmath.complex.


Full diff: https://github.com/llvm/llvm-project/pull/76400.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td (+41)
  • (modified) mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h (+42)
  • (modified) mlir/lib/Dialect/IRDL/IR/IRDL.cpp (+33)
  • (modified) mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp (+54)
  • (modified) mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp (+33)
  • (added) mlir/test/Dialect/IRDL/invalid.irdl.mlir (+43)
  • (modified) mlir/test/Dialect/IRDL/testd.irdl.mlir (+41-7)
  • (modified) mlir/test/Dialect/IRDL/testd.mlir (+53-10)
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
index 681425f8174426..c63a3a70f6703f 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
@@ -451,6 +451,47 @@ def IRDL_IsOp : IRDL_ConstraintOp<"is",
   let assemblyFormat = " $expected ` ` attr-dict ";
 }
 
+def IRDL_BaseOp : IRDL_ConstraintOp<"base",
+    [ParentOneOf<["TypeOp", "AttributeOp", "OperationOp"]>,
+     DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let summary = "Constraints an attribute/type base";
+  let description = [{
+    `irdl.base` defines a constraint that only accepts a single type
+    or attribute base, e.g. an `IntegerType`. The attribute base is defined
+    either by a symbolic reference to the corresponding IRDL definition,
+    or by the name of the base. Named bases are prefixed with `!` or `#`
+    respectively for types and attributes.
+
+    Example:
+
+    ```mlir
+    irdl.dialect @cmath {
+      irdl.type @complex {
+        %0 = irdl.base "!builtin.integer"
+        irdl.parameters(%0)
+      }
+
+      irdl.type @complex_wrapper {
+        %0 = irdl.base @complex
+        irdl.parameters(%0)
+      }
+    }
+    ```
+
+    The above program defines a `cmath.complex` type that expects a single
+    parameter, which is a type with base name `builtin.integer`, which is the
+    name of an `IntegerType` type.
+    It also defines a `cmath.complex_wrapper` type that expects a single
+    parameter, which is a type of base type `cmath.complex`.
+  }];
+
+  let arguments = (ins OptionalAttr<SymbolRefAttr>:$base_ref,
+                       OptionalAttr<StrAttr>:$base_name);
+  let results = (outs IRDL_AttributeType:$output);
+  let assemblyFormat = " ($base_ref^)? ($base_name^)? ` ` attr-dict";
+  let hasVerifier = 1;
+}
+
 def IRDL_ParametricOp : IRDL_ConstraintOp<"parametric",
     [ParentOneOf<["TypeOp", "AttributeOp", "OperationOp"]>, Pure]> {
   let summary = "Constraints an attribute/type base and its parameters";
diff --git a/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
index f8ce77cbc50e9e..9ecb7c0107d7f8 100644
--- a/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
+++ b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
@@ -99,6 +99,48 @@ class IsConstraint : public Constraint {
   Attribute expectedAttribute;
 };
 
+/// A constraint that checks that an attribute is of a given attribute base
+/// (e.g. IntegerAttr).
+class BaseAttrConstraint : public Constraint {
+public:
+  BaseAttrConstraint(TypeID baseTypeID, StringRef baseName)
+      : baseTypeID(baseTypeID), baseName(baseName) {}
+
+  virtual ~BaseAttrConstraint() = default;
+
+  LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                       Attribute attr,
+                       ConstraintVerifier &context) const override;
+
+private:
+  /// The expected base attribute typeID.
+  TypeID baseTypeID;
+
+  /// The base attribute name, only used for error reporting.
+  StringRef baseName;
+};
+
+/// A constraint that checks that a type is of a given type base (e.g.
+/// IntegerType).
+class BaseTypeConstraint : public Constraint {
+public:
+  BaseTypeConstraint(TypeID baseTypeID, StringRef baseName)
+      : baseTypeID(baseTypeID), baseName(baseName) {}
+
+  virtual ~BaseTypeConstraint() = default;
+
+  LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
+                       Attribute attr,
+                       ConstraintVerifier &context) const override;
+
+private:
+  /// The expected base type typeID.
+  TypeID baseTypeID;
+
+  /// The base type name, only used for error reporting.
+  StringRef baseName;
+};
+
 /// A constraint that checks that an attribute is of a
 /// specific dynamic attribute definition, and that all of its parameters
 /// satisfy the given constraints.
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
index 33c6bb869a643f..4eae2b03024c24 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
@@ -117,6 +117,39 @@ LogicalResult AttributesOp::verify() {
   return success();
 }
 
+LogicalResult BaseOp::verify() {
+  std::optional<StringRef> baseName = getBaseName();
+  std::optional<SymbolRefAttr> baseRef = getBaseRef();
+  if (baseName.has_value() == baseRef.has_value())
+    return emitOpError() << "the base type or attribute should be specified by "
+                            "either a name or a reference";
+
+  if (baseName &&
+      (baseName->empty() || ((*baseName)[0] != '!' && (*baseName)[0] != '#')))
+    return emitOpError() << "the base type or attribute name should start with "
+                            "'!' or '#'";
+
+  return success();
+}
+
+LogicalResult BaseOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  std::optional<SymbolRefAttr> baseRef = getBaseRef();
+  if (!baseRef)
+    return success();
+
+  TypeOp typeOp = symbolTable.lookupNearestSymbolFrom<TypeOp>(*this, *baseRef);
+  if (typeOp)
+    return success();
+
+  AttributeOp attrOp =
+      symbolTable.lookupNearestSymbolFrom<AttributeOp>(*this, *baseRef);
+  if (attrOp)
+    return success();
+
+  return emitOpError() << "'" << *baseRef
+                       << "' does not refer to a type or attribute definition";
+}
+
 /// Parse a value with its variadicity first. By default, the variadicity is
 /// single.
 ///
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
index e172039712f24c..0895306b8bce1a 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
@@ -37,6 +37,60 @@ std::unique_ptr<Constraint> IsOp::getVerifier(
   return std::make_unique<IsConstraint>(getExpectedAttr());
 }
 
+std::unique_ptr<Constraint> BaseOp::getVerifier(
+    ArrayRef<Value> valueToConstr,
+    DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
+    DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const
+        &attrs) {
+  MLIRContext *ctx = getContext();
+
+  // Case where the input is a symbol reference.
+  // This corresponds to the case where the base is an IRDL type or attribute.
+  if (auto baseRef = getBaseRef()) {
+    Operation *defOp =
+        SymbolTable::lookupNearestSymbolFrom(getOperation(), baseRef.value());
+
+    // Type case.
+    if (auto typeOp = dyn_cast<TypeOp>(defOp)) {
+      DynamicTypeDefinition *typeDef = types.at(typeOp).get();
+      auto name = StringAttr::get(ctx, typeDef->getDialect()->getNamespace() +
+                                           "." + typeDef->getName().str());
+      return std::make_unique<BaseTypeConstraint>(typeDef->getTypeID(), name);
+    }
+
+    // Attribute case.
+    auto attrOp = cast<AttributeOp>(defOp);
+    DynamicAttrDefinition *attrDef = attrs.at(attrOp).get();
+    auto name = StringAttr::get(ctx, attrDef->getDialect()->getNamespace() +
+                                         "." + attrDef->getName().str());
+    return std::make_unique<BaseAttrConstraint>(attrDef->getTypeID(), name);
+  }
+
+  // Case where the input is string literal.
+  // This corresponds to the case where the base is a registered type or
+  // attribute.
+  StringRef baseName = getBaseName().value();
+
+  // Type case.
+  if (baseName[0] == '!') {
+    auto abstractType = AbstractType::lookup(baseName.drop_front(1), ctx);
+    if (!abstractType) {
+      emitError() << "no registered type with name " << baseName;
+      return nullptr;
+    }
+    return std::make_unique<BaseTypeConstraint>(abstractType->get().getTypeID(),
+                                                abstractType->get().getName());
+  }
+
+  auto abstractAttr = AbstractAttribute::lookup(baseName.drop_front(1), ctx);
+  if (!abstractAttr) {
+    emitError() << "no registered attribute with name " << baseName;
+    return nullptr;
+  }
+  return std::make_unique<BaseAttrConstraint>(abstractAttr->get().getTypeID(),
+                                              abstractAttr->get().getName());
+}
+
 std::unique_ptr<Constraint> ParametricOp::getVerifier(
     ArrayRef<Value> valueToConstr,
     DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
diff --git a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
index 90b068ba35831b..2310c11ea0e8ed 100644
--- a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
@@ -69,6 +69,39 @@ LogicalResult IsConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
   return failure();
 }
 
+LogicalResult
+BaseAttrConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
+                           Attribute attr, ConstraintVerifier &context) const {
+  if (attr.getTypeID() == baseTypeID)
+    return success();
+
+  if (emitError)
+    return emitError() << "expected base attribute '" << baseName
+                       << "' but got '" << attr.getAbstractAttribute().getName()
+                       << "'";
+  return failure();
+}
+
+LogicalResult
+BaseTypeConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
+                           Attribute attr, ConstraintVerifier &context) const {
+  auto typeAttr = dyn_cast<TypeAttr>(attr);
+  if (!typeAttr) {
+    if (emitError)
+      return emitError() << "expected type, got attribute '" << attr;
+    return failure();
+  }
+
+  Type type = typeAttr.getValue();
+  if (type.getTypeID() == baseTypeID)
+    return success();
+
+  if (emitError)
+    return emitError() << "expected base type '" << baseName << "' but got '"
+                       << type.getAbstractType().getName() << "'";
+  return failure();
+}
+
 LogicalResult DynParametricAttrConstraint::verify(
     function_ref<InFlightDiagnostic()> emitError, Attribute attr,
     ConstraintVerifier &context) const {
diff --git a/mlir/test/Dialect/IRDL/invalid.irdl.mlir b/mlir/test/Dialect/IRDL/invalid.irdl.mlir
new file mode 100644
index 00000000000000..d62bb498a7ad98
--- /dev/null
+++ b/mlir/test/Dialect/IRDL/invalid.irdl.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt %s -verify-diagnostics -split-input-file
+
+// Testing invalid IRDL IRs
+
+func.func private @foo()
+
+irdl.dialect @testd {
+  irdl.type @type {
+    // expected-error@+1 {{'@foo' does not refer to a type or attribute definition}}
+    %0 = irdl.base @foo
+    irdl.parameters(%0)
+  }
+}
+
+// -----
+
+irdl.dialect @testd {
+  irdl.type @type {
+    // expected-error@+1 {{the base type or attribute name should start with '!' or '#'}}
+    %0 = irdl.base "builtin.integer"
+    irdl.parameters(%0)
+  }
+}
+
+// -----
+
+irdl.dialect @testd {
+  irdl.type @type {
+    // expected-error@+1 {{the base type or attribute name should start with '!' or '#'}}
+    %0 = irdl.base ""
+    irdl.parameters(%0)
+  }
+}
+
+// -----
+
+irdl.dialect @testd {
+  irdl.type @type {
+    // expected-error@+1 {{the base type or attribute should be specified by either a name}}
+    %0 = irdl.base
+    irdl.parameters(%0)
+  }
+}
diff --git a/mlir/test/Dialect/IRDL/testd.irdl.mlir b/mlir/test/Dialect/IRDL/testd.irdl.mlir
index 684286e4afeb0f..f828d95bdb81d5 100644
--- a/mlir/test/Dialect/IRDL/testd.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/testd.irdl.mlir
@@ -11,6 +11,15 @@ irdl.dialect @testd {
     irdl.parameters(%0)
   }
 
+  // CHECK: irdl.attribute @parametric_attr {
+  // CHECK:  %[[v0:[^ ]*]] = irdl.any
+  // CHECK:  irdl.parameters(%[[v0]])
+  // CHECK: }
+  irdl.attribute @parametric_attr {
+    %0 = irdl.any
+    irdl.parameters(%0)
+  }
+
   // CHECK: irdl.type @attr_in_type_out {
   // CHECK:   %[[v0:[^ ]*]] = irdl.any
   // CHECK:   irdl.parameters(%[[v0]])
@@ -66,15 +75,40 @@ irdl.dialect @testd {
     irdl.results(%0)
   }
 
-  // CHECK: irdl.operation @dynbase {
-  // CHECK:   %[[v0:[^ ]*]] = irdl.any
-  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @parametric<%[[v0]]>
+  // CHECK: irdl.operation @dyn_type_base {
+  // CHECK:   %[[v1:[^ ]*]] = irdl.base @parametric
   // CHECK:   irdl.results(%[[v1]])
   // CHECK: }
-  irdl.operation @dynbase {
-    %0 = irdl.any
-    %1 = irdl.parametric @parametric<%0>
-    irdl.results(%1)
+  irdl.operation @dyn_type_base {
+    %0 = irdl.base @parametric
+    irdl.results(%0)
+  }
+
+  // CHECK: irdl.operation @dyn_attr_base {
+  // CHECK:   %[[v1:[^ ]*]] = irdl.base @parametric_attr
+  // CHECK:   irdl.attributes {"attr1" = %[[v1]]}
+  // CHECK: }
+  irdl.operation @dyn_attr_base {
+    %0 = irdl.base @parametric_attr
+    irdl.attributes {"attr1" = %0}
+  }
+
+  // CHECK: irdl.operation @named_type_base {
+  // CHECK:   %[[v1:[^ ]*]] = irdl.base "!builtin.integer"
+  // CHECK:   irdl.results(%[[v1]])
+  // CHECK: }
+  irdl.operation @named_type_base {
+    %0 = irdl.base "!builtin.integer"
+    irdl.results(%0)
+  }
+
+  // CHECK: irdl.operation @named_attr_base {
+  // CHECK:   %[[v1:[^ ]*]] = irdl.base "#builtin.integer"
+  // CHECK:   irdl.attributes {"attr1" = %[[v1]]}
+  // CHECK: }
+  irdl.operation @named_attr_base {
+    %0 = irdl.base "#builtin.integer"
+    irdl.attributes {"attr1" = %0}
   }
 
   // CHECK: irdl.operation @dynparams {
diff --git a/mlir/test/Dialect/IRDL/testd.mlir b/mlir/test/Dialect/IRDL/testd.mlir
index bb1e9f46356411..333bb96eb2e60f 100644
--- a/mlir/test/Dialect/IRDL/testd.mlir
+++ b/mlir/test/Dialect/IRDL/testd.mlir
@@ -120,24 +120,67 @@ func.func @succeededAnyConstraint() {
 // -----
 
 //===----------------------------------------------------------------------===//
-// Dynamic base constraint
+// Base constraints
 //===----------------------------------------------------------------------===//
 
 func.func @succeededDynBaseConstraint() {
-  // CHECK: "testd.dynbase"() : () -> !testd.parametric<i32>
-  "testd.dynbase"() : () -> !testd.parametric<i32>
-  // CHECK: "testd.dynbase"() : () -> !testd.parametric<i64>
-  "testd.dynbase"() : () -> !testd.parametric<i64>
-  // CHECK: "testd.dynbase"() : () -> !testd.parametric<!testd.parametric<i64>>
-  "testd.dynbase"() : () -> !testd.parametric<!testd.parametric<i64>>
+  // CHECK: "testd.dyn_type_base"() : () -> !testd.parametric<i32>
+  "testd.dyn_type_base"() : () -> !testd.parametric<i32>
+  // CHECK: "testd.dyn_type_base"() : () -> !testd.parametric<i64>
+  "testd.dyn_type_base"() : () -> !testd.parametric<i64>
+  // CHECK: "testd.dyn_type_base"() : () -> !testd.parametric<!testd.parametric<i64>>
+  "testd.dyn_type_base"() : () -> !testd.parametric<!testd.parametric<i64>>
+  // CHECK: "testd.dyn_attr_base"() {attr1 = #testd.parametric_attr<i32>} : () -> ()
+  "testd.dyn_attr_base"() {attr1 = #testd.parametric_attr<i32>} : () -> ()
+  // CHECK: "testd.dyn_attr_base"() {attr1 = #testd.parametric_attr<i64>} : () -> ()
+  "testd.dyn_attr_base"() {attr1 = #testd.parametric_attr<i64>} : () -> ()
   return
 }
 
 // -----
 
-func.func @failedDynBaseConstraint() {
-  // expected-error@+1 {{expected base type 'testd.parametric' but got 'i32'}}
-  "testd.dynbase"() : () -> i32
+func.func @failedDynTypeBaseConstraint() {
+  // expected-error@+1 {{expected base type 'testd.parametric' but got 'builtin.integer'}}
+  "testd.dyn_type_base"() : () -> i32
+  return
+}
+
+// -----
+
+func.func @failedDynAttrBaseConstraintNotType() {
+  // expected-error@+1 {{expected base attribute 'testd.parametric_attr' but got 'builtin.type'}}
+  "testd.dyn_attr_base"() {attr1 = i32}: () -> ()
+  return
+}
+
+// -----
+
+
+func.func @succeededNamedBaseConstraint() {
+  // CHECK: "testd.named_type_base"() : () -> i32
+  "testd.named_type_base"() : () -> i32
+  // CHECK: "testd.named_type_base"() : () -> i64
+  "testd.named_type_base"() : () -> i64
+  // CHECK: "testd.named_attr_base"() {attr1 = 0 : i32} : () -> ()
+  "testd.named_attr_base"() {attr1 = 0 : i32} : () -> ()
+  // CHECK: "testd.named_attr_base"() {attr1 = 0 : i64} : () -> ()
+  "testd.named_attr_base"() {attr1 = 0 : i64} : () -> ()
+  return
+}
+
+// -----
+
+func.func @failedNamedTypeBaseConstraint() {
+  // expected-error@+1 {{expected base type 'builtin.integer' but got 'builtin.vector'}}
+  "testd.named_type_base"() : () -> vector<i32>
+  return
+}
+
+// -----
+
+func.func @failedDynAttrBaseConstraintNotType() {
+  // expected-error@+1 {{expected base attribute 'builtin.integer' but got 'builtin.type'}}
+  "testd.named_attr_base"() {attr1 = i32}: () -> ()
   return
 }
 

if (attr.getTypeID() == baseTypeID)
return success();

if (emitError)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this always non-null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's null when you are inside the check of an irdl.any_of.
If you have a constraint like AnyOf<i32, i64> for instance, you will verify both sides, so you don't want an error to be emitted until you checked both.

@math-fehr
Copy link
Contributor Author

gentle ping @Mogball @joker-eph

@math-fehr math-fehr merged commit 914cfa4 into llvm:main Jan 18, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants