Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,11 @@ class AsmPrinter {
/// Attempts to start a cyclic printing region for `attrOrType`.
/// A cyclic printing region starts with this call and ends with the
/// destruction of the returned `CyclicPrintReset`. During this time,
/// calling `tryStartCyclicPrint` with the same attribute in any printer
/// will lead to returning failure.
/// calling `tryStartCyclicPrint` with the same attribute or type in any
/// printer will lead to returning failure. Additionally, if the printer
/// knows a complete definition of the attribute or type will be emitted in
/// the future, it'll also return failure to permit abbreviated definitions
/// to be used wherever possible.
///
/// This makes it possible to break infinite recursions when trying to print
/// cyclic attributes or types by printing only immutable parameters if nested
Expand All @@ -278,6 +281,8 @@ class AsmPrinter {
AttrOrTypeT> ||
std::is_base_of_v<TypeTrait::IsMutable<AttrOrTypeT>, AttrOrTypeT>,
"Only mutable attributes or types can be cyclic");
if (hasFutureAlias(attrOrType.getAsOpaquePointer()))
return failure();
if (failed(pushCyclicPrinting(attrOrType.getAsOpaquePointer())))
return failure();
return CyclicPrintReset(this);
Expand All @@ -299,6 +304,12 @@ class AsmPrinter {
/// in reverse order of all successful `pushCyclicPrinting`.
virtual void popCyclicPrinting();

/// Check if the given attribute or type (in the form of a type erased
/// pointer) will be printed as an alias in the future. Returns false if the
/// type has an alias that's currently being printed or has already been
/// printed. This enables cyclic print checking for mutual recursion.
virtual bool hasFutureAlias(const void *opaquePointer) const;

private:
AsmPrinter(const AsmPrinter &) = delete;
void operator=(const AsmPrinter &) = delete;
Expand Down
37 changes: 34 additions & 3 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,8 @@ class AsmPrinter::Impl {

void popCyclicPrinting();

bool hasFutureAlias(const void *opaquePointer) const;

void printDimensionList(ArrayRef<int64_t> shape);

protected:
Expand Down Expand Up @@ -547,8 +549,13 @@ class SymbolAlias {
bool isDeferrable : 1;

public:
/// Used to distinguish aliases that are currently being or have previously
/// been printed from those that will be printed in the future, which can aid
/// printing mutually recursive types.
bool hasStartedPrinting = false;

/// Used to avoid printing incomplete aliases for recursive types.
bool isPrinted = false;
bool hasFinishedPrinting = false;
};

/// This class represents a utility that initializes the set of attribute and
Expand Down Expand Up @@ -974,6 +981,8 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {

void popCyclicPrinting() override { cyclicPrintingStack.pop_back(); }

bool hasFutureAlias(const void *) const override { return false; }

/// Stack of potentially cyclic mutable attributes or type currently being
/// printed.
SetVector<const void *> cyclicPrintingStack;
Expand Down Expand Up @@ -1182,6 +1191,12 @@ class AliasState {
/// Returns success if an alias was printed, failure otherwise.
LogicalResult getAlias(Type ty, raw_ostream &os) const;

/// Check if the given attribute or type (in the form of a type erased
/// pointer) will be printed as an alias in the future. Returns false if the
/// type has an alias that's currently being printed or has already been
/// printed. This enables cyclic print checking for mutual recursion.
bool hasFutureAlias(const void *opaquePointer) const;

/// Print all of the referenced aliases that can not be resolved in a deferred
/// manner.
void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
Expand Down Expand Up @@ -1226,13 +1241,20 @@ LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
const auto *it = attrTypeToAlias.find(ty.getAsOpaquePointer());
if (it == attrTypeToAlias.end())
return failure();
if (!it->second.isPrinted)
if (!it->second.hasFinishedPrinting)
return failure();

it->second.print(os);
return success();
}

bool AliasState::hasFutureAlias(const void *opaquePointer) const {
const auto *it = attrTypeToAlias.find(opaquePointer);
if (it == attrTypeToAlias.end())
return false;
return !it->second.hasStartedPrinting;
}

void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
bool isDeferred) {
auto filterFn = [=](const auto &aliasIt) {
Expand All @@ -1245,8 +1267,9 @@ void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,

if (alias.isTypeAlias()) {
Type type = Type::getFromOpaquePointer(opaqueSymbol);
alias.hasStartedPrinting = true;
p.printTypeImpl(type);
alias.isPrinted = true;
alias.hasFinishedPrinting = true;
} else {
// TODO: Support nested aliases in mutable attributes.
Attribute attr = Attribute::getFromOpaquePointer(opaqueSymbol);
Expand Down Expand Up @@ -2791,6 +2814,10 @@ LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) {

void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); }

bool AsmPrinter::Impl::hasFutureAlias(const void *opaquePointer) const {
return state.getAliasState().hasFutureAlias(opaquePointer);
}

void AsmPrinter::Impl::printDimensionList(ArrayRef<int64_t> shape) {
detail::printDimensionList(os, shape);
}
Expand Down Expand Up @@ -2870,6 +2897,10 @@ LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) {

void AsmPrinter::popCyclicPrinting() { impl->popCyclicPrinting(); }

bool AsmPrinter::hasFutureAlias(const void *opaquePointer) const {
return impl->hasFutureAlias(opaquePointer);
}

//===----------------------------------------------------------------------===//
// Affine expressions and maps
//===----------------------------------------------------------------------===//
Expand Down
17 changes: 13 additions & 4 deletions mlir/test/IR/recursive-type.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>
// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3>>
// CHECK: ![[$NAME7:.*]] = !test.test_rec_alias<name7, !test.test_rec_alias<name6>>
// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5_>
// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4_>
// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, ![[$NAME5]]>
// CHECK: ![[$NAME6:.*]] = !test.test_rec_alias<name6, ![[$NAME7]]>
// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, ![[$NAME4]]>

// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
Expand All @@ -28,13 +30,20 @@ func.func @roundtrip() {
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>

// Mutual recursion.
// Mutual recursion with types fully spelled out.
// CHECK: () -> ![[$NAME3]]
// CHECK: () -> ![[$NAME4]]
// CHECK: () -> ![[$NAME5]]
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5, !test.test_rec_alias<name3>>>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name4, !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4>>>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>

// Mutual recursion with incomplete types.
// CHECK: () -> ![[$NAME6]]
// CHECK: () -> ![[$NAME7]]
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name6, !test.test_rec_alias<name7>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name7, !test.test_rec_alias<name6>>

return
}

Expand Down
4 changes: 4 additions & 0 deletions mlir/test/lib/Dialect/Test/TestTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,10 @@ Type TestRecursiveAliasType::parse(AsmParser &parser) {
return rec;
}

// Allow incomplete definitions that can be completed later.
if (succeeded(parser.parseGreater()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (succeeded(parser.parseGreater()))
if (succeeded(parser.parseOptionalGreater()))

Surprised this doesn't cause any issues I must admit.

return rec;

// Otherwise, parse the body and update the type.
if (failed(parser.parseComma()))
return Type();
Expand Down
Loading