diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 325f986f97694..af41532670890 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1043,6 +1043,12 @@ std::pair AliasInitializer::visitImpl( void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) { auto it = std::next(aliases.begin(), aliasIndex); + + // If already marked non-deferrable stop the recursion. + // All children should already be marked non-deferrable as well. + if (!it->second.canBeDeferred) + return; + it->second.canBeDeferred = false; // Propagate the non-deferrable flag to any child aliases. diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir index bc9b2cdbea6b6..121ba095573ba 100644 --- a/mlir/test/IR/recursive-type.mlir +++ b/mlir/test/IR/recursive-type.mlir @@ -1,6 +1,8 @@ // RUN: mlir-opt %s -test-recursive-types | FileCheck %s // CHECK: !testrec = !test.test_rec> +// CHECK: ![[$NAME:.*]] = !test.test_rec_alias> +// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias, i32>> // CHECK-LABEL: @roundtrip func.func @roundtrip() { @@ -12,6 +14,16 @@ func.func @roundtrip() { // into inifinite recursion. // CHECK: !testrec "test.dummy_op_for_roundtrip"() : () -> !test.test_rec> + + // CHECK: () -> ![[$NAME]] + // CHECK: () -> ![[$NAME]] + "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias> + "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias> + + // CHECK: () -> ![[$NAME2]] + // CHECK: () -> ![[$NAME2]] + "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias, i32>> + "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias, i32>> return } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 072f6ff4b84d3..debe733f59be4 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -312,6 +312,10 @@ struct TestOpAsmInterface : public OpAsmDialectInterface { return AliasResult::FinalAlias; } } + if (auto recAliasType = dyn_cast(type)) { + os << recAliasType.getName(); + return AliasResult::FinalAlias; + } return AliasResult::NoAlias; } diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index 15dbd74aec118..2a8bdad8fb25d 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -369,4 +369,26 @@ def TestTypeElseAnchorStruct : Test_Type<"TestTypeElseAnchorStruct"> { let assemblyFormat = "`<` (`?`) : (struct($a, $b)^)? `>`"; } +def TestI32 : Test_Type<"TestI32"> { + let mnemonic = "i32"; +} + +def TestRecursiveAlias + : Test_Type<"TestRecursiveAlias", [NativeTypeTrait<"IsMutable">]> { + let mnemonic = "test_rec_alias"; + let storageClass = "TestRecursiveTypeStorage"; + let storageNamespace = "test"; + let genStorageClass = 0; + + let parameters = (ins "llvm::StringRef":$name); + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + Type getBody() const; + + void setBody(Type type); + }]; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 0633752067a14..20dc03a765269 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -482,3 +482,54 @@ void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { SetVector stack; printTestType(type, printer, stack); } + +Type TestRecursiveAliasType::getBody() const { return getImpl()->body; } + +void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); } + +StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; } + +Type TestRecursiveAliasType::parse(AsmParser &parser) { + thread_local static SetVector stack; + + StringRef name; + if (parser.parseLess() || parser.parseKeyword(&name)) + return Type(); + auto rec = TestRecursiveAliasType::get(parser.getContext(), name); + + // If this type already has been parsed above in the stack, expect just the + // name. + if (stack.contains(rec)) { + if (failed(parser.parseGreater())) + return Type(); + return rec; + } + + // Otherwise, parse the body and update the type. + if (failed(parser.parseComma())) + return Type(); + stack.insert(rec); + Type subtype; + if (parser.parseType(subtype)) + return nullptr; + stack.pop_back(); + if (!subtype || failed(parser.parseGreater())) + return Type(); + + rec.setBody(subtype); + + return rec; +} + +void TestRecursiveAliasType::print(AsmPrinter &printer) const { + thread_local static SetVector stack; + + printer << "<" << getName(); + if (!stack.contains(*this)) { + printer << ", "; + stack.insert(*this); + printer << getBody(); + stack.pop_back(); + } + printer << ">"; +} diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h index c7d169d020d56..0ce86dd70ab90 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -91,9 +91,6 @@ struct FieldParser> { #include "TestTypeInterfaces.h.inc" -#define GET_TYPEDEF_CLASSES -#include "TestTypeDefs.h.inc" - namespace test { /// Storage for simple named recursive types, where the type is identified by @@ -150,4 +147,7 @@ class TestRecursiveType } // namespace test +#define GET_TYPEDEF_CLASSES +#include "TestTypeDefs.h.inc" + #endif // MLIR_TESTTYPES_H