diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp index 8d10aacb53ec9..54c7d17a97b50 100644 --- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp +++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp @@ -533,6 +533,41 @@ static bool getBases(Operation *op, SmallPtrSet ¶mIds, return false; } + if (auto base = dyn_cast(op)) { + if (base.getBaseName()) { + StringRef baseName = *base.getBaseName(); + if (baseName[0] == '!') { + auto abstractType = + AbstractType::lookup(baseName.drop_front(1), op->getContext()); + assert(abstractType && "type name should refer to an existing type"); + paramIds.insert(abstractType->get().getTypeID()); + } else if (baseName[0] == '#') { + auto abstractAttr = + AbstractAttribute::lookup(baseName.drop_front(1), op->getContext()); + assert(abstractAttr && "attribute name should refer to an existing " + "attribute"); + paramIds.insert(abstractAttr->get().getTypeID()); + } else { + llvm_unreachable( + "invalid `irdl.base` operation: base name should start " + "with '!' for types or '#' for attributes"); + } + return false; + } + + if (base.getBaseRef()) { + SymbolRefAttr symRef = *base.getBaseRef(); + Operation *defOp = irdl::lookupSymbolNearDialect(op, symRef); + assert(defOp && "symbol reference should refer to an existing operation"); + paramIrdlOps.insert(defOp); + return false; + } + + llvm_unreachable( + "invalid `irdl.base` operation: expected either a base name " + "or a base symbol reference"); + } + // For `irdl.any`, we return `false` since we can match any type or attribute // base. if (auto isA = dyn_cast(op)) diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py index 1900b8c162456..dfcd7f2d641d0 100644 --- a/mlir/python/mlir/dialects/ext.py +++ b/mlir/python/mlir/dialects/ext.py @@ -85,29 +85,22 @@ def _lower(self, type_) -> ir.Value: return irdl.any() elif isinstance(type_, TypeVar): return self.lower(type_) + elif origin and issubclass(origin, Type | Attribute): + return irdl.parametric( + base_type=[origin._dialect_name, origin._name], + args=[self.lower(arg) for arg in get_args(type_)], + ) elif origin and issubclass(origin, ir.Type): - if issubclass(origin, Type): - return irdl.parametric( - base_type=[origin._dialect_name, origin._name], - args=[self.lower(arg) for arg in get_args(type_)], - ) t = construct_instance(origin, get_args(type_)) return irdl.is_(ir.TypeAttr.get(t)) elif origin and issubclass(origin, ir.Attribute): - if issubclass(origin, Attribute): - return irdl.parametric( - base_type=[origin._dialect_name, origin._name], - args=[self.lower(arg) for arg in get_args(type_)], - ) attr = construct_instance(origin, get_args(type_)) return irdl.is_(attr) + elif issubclass(type_, Type | Attribute): + return irdl.base(base_ref=[type_._dialect_name, type_._name]) elif issubclass(type_, ir.Type): - if issubclass(type_, Type): - return irdl.base(base_ref=[type_._dialect_name, type_._name]) return irdl.base(base_name=f"!{type_.type_name}") elif issubclass(type_, ir.Attribute): - if issubclass(type_, Attribute): - return irdl.base(base_ref=[type_._dialect_name, type_._name]) return irdl.base(base_name=f"#{type_.attr_name}") raise TypeError(f"unsupported type in constraints: {type_}") @@ -197,13 +190,9 @@ def from_type_hint(name, type_, specifier) -> "FieldDef": get_args(type_)[0], kw_only=specifier.kw_only(), ) - elif issubclass(origin or type_, ir.Attribute): - return AttributeDef(name, variadicity, type_) elif type_ is ir.Region: return RegionDef(name, variadicity, Any) - raise TypeError( - f"unsupported type for field '{name}' in operation definition: {type_}" - ) + return AttributeDef(name, variadicity, type_) @dataclass diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py index a1593c35855ea..78c74684cef77 100644 --- a/mlir/test/python/dialects/ext.py +++ b/mlir/test/python/dialects/ext.py @@ -736,3 +736,44 @@ class AssignNoneOnNonOptionalOp( except ValueError as e: # CHECK: only optional operand can be a keyword parameter print(e) + + +# CHECK: TEST: testExtDialectWithAttrInOp +@run +def testExtDialectWithAttrInOp(): + class TestAttrInOp(Dialect, name="ext_attr_in_op"): + pass + + class OpWithAttr(TestAttrInOp.Operation, name="op_with_attr"): + a: IntegerAttr | StringAttr + b: IntegerType[32] | IntegerType[64] + + with Context(), Location.unknown(): + TestAttrInOp.load() + # CHECK: irdl.dialect @ext_attr_in_op { + # CHECK: irdl.operation @op_with_attr { + # CHECK: %0 = irdl.base "#builtin.integer" + # CHECK: %1 = irdl.base "#builtin.string" + # CHECK: %2 = irdl.any_of(%0, %1) + # CHECK: %3 = irdl.is i32 + # CHECK: %4 = irdl.is i64 + # CHECK: %5 = irdl.any_of(%3, %4) + # CHECK: irdl.attributes {"a" = %2, "b" = %5} + # CHECK: } + # CHECK: } + print(TestAttrInOp._mlir_module) + + i32 = IntegerType.get_signless(32) + i64 = IntegerType.get_signless(64) + iattr = IntegerAttr.get(i32, 42) + sattr = StringAttr.get("hello") + + module = Module.create() + with InsertionPoint(module.body): + OpWithAttr(iattr, TypeAttr.get(i32)) + OpWithAttr(sattr, TypeAttr.get(i64)) + + assert module.operation.verify() + # CHECK: "ext_attr_in_op.op_with_attr"() {a = 42 : i32, b = i32} : () -> () + # CHECK: "ext_attr_in_op.op_with_attr"() {a = "hello", b = i64} : () -> () + print(module)