12 changes: 12 additions & 0 deletions mlir/test/IR/recursive-type.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// RUN: mlir-opt %s -test-recursive-types | FileCheck %s

// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>

// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
Expand All @@ -12,6 +14,16 @@ func.func @roundtrip() {
// into inifinite recursion.
// CHECK: !testrec
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<type_to_alias, test_rec<type_to_alias>>

// CHECK: () -> ![[$NAME]]
// CHECK: () -> ![[$NAME]]
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name, !test.test_rec_alias<name>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name, !test.test_rec_alias<name>>

// CHECK: () -> ![[$NAME2]]
// CHECK: () -> ![[$NAME2]]
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
return
}

Expand Down
4 changes: 4 additions & 0 deletions mlir/test/lib/Dialect/Test/TestDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,10 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
return AliasResult::FinalAlias;
}
}
if (auto recAliasType = dyn_cast<TestRecursiveAliasType>(type)) {
os << recAliasType.getName();
return AliasResult::FinalAlias;
}
return AliasResult::NoAlias;
}

Expand Down
22 changes: 22 additions & 0 deletions mlir/test/lib/Dialect/Test/TestTypeDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -369,4 +369,26 @@ def TestTypeElseAnchorStruct : Test_Type<"TestTypeElseAnchorStruct"> {
let assemblyFormat = "`<` (`?`) : (struct($a, $b)^)? `>`";
}

def TestI32 : Test_Type<"TestI32"> {
let mnemonic = "i32";
}

def TestRecursiveAlias
: Test_Type<"TestRecursiveAlias", [NativeTypeTrait<"IsMutable">]> {
let mnemonic = "test_rec_alias";
let storageClass = "TestRecursiveTypeStorage";
let storageNamespace = "test";
let genStorageClass = 0;

let parameters = (ins "llvm::StringRef":$name);

let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
Type getBody() const;

void setBody(Type type);
}];
}

#endif // TEST_TYPEDEFS
51 changes: 51 additions & 0 deletions mlir/test/lib/Dialect/Test/TestTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,54 @@ void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
SetVector<Type> stack;
printTestType(type, printer, stack);
}

Type TestRecursiveAliasType::getBody() const { return getImpl()->body; }

void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); }

StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; }

Type TestRecursiveAliasType::parse(AsmParser &parser) {
thread_local static SetVector<Type> stack;

StringRef name;
if (parser.parseLess() || parser.parseKeyword(&name))
return Type();
auto rec = TestRecursiveAliasType::get(parser.getContext(), name);

// If this type already has been parsed above in the stack, expect just the
// name.
if (stack.contains(rec)) {
if (failed(parser.parseGreater()))
return Type();
return rec;
}

// Otherwise, parse the body and update the type.
if (failed(parser.parseComma()))
return Type();
stack.insert(rec);
Type subtype;
if (parser.parseType(subtype))
return nullptr;
stack.pop_back();
if (!subtype || failed(parser.parseGreater()))
return Type();

rec.setBody(subtype);

return rec;
}

void TestRecursiveAliasType::print(AsmPrinter &printer) const {
thread_local static SetVector<Type> stack;

printer << "<" << getName();
if (!stack.contains(*this)) {
printer << ", ";
stack.insert(*this);
printer << getBody();
stack.pop_back();
}
printer << ">";
}
6 changes: 3 additions & 3 deletions mlir/test/lib/Dialect/Test/TestTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ struct FieldParser<std::optional<int>> {

#include "TestTypeInterfaces.h.inc"

#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.h.inc"

namespace test {

/// Storage for simple named recursive types, where the type is identified by
Expand Down Expand Up @@ -150,4 +147,7 @@ class TestRecursiveType

} // namespace test

#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.h.inc"

#endif // MLIR_TESTTYPES_H
11 changes: 2 additions & 9 deletions mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func.func @powf() {
%a_p = arith.constant 2.0 : f64
call @func_powff64(%a, %a_p) : (f64, f64) -> ()

// CHECK-NEXT: nan
// CHECK-NEXT: -27
%b = arith.constant -3.0 : f64
%b_p = arith.constant 3.0 : f64
call @func_powff64(%b, %b_p) : (f64, f64) -> ()
Expand All @@ -220,16 +220,9 @@ func.func @powf() {
%f_p = arith.constant 1.2 : f64
call @func_powff64(%f, %f_p) : (f64, f64) -> ()

// CHECK-NEXT: nan
%g = arith.constant 0xff80000000000000 : f64
call @func_powff64(%g, %g) : (f64, f64) -> ()

// CHECK-NEXT: nan
%h = arith.constant 0x7fffffffffffffff : f64
call @func_powff64(%h, %h) : (f64, f64) -> ()

// CHECK-NEXT: nan
%i = arith.constant 1.0 : f64
%h = arith.constant 0x7fffffffffffffff : f64
call @func_powff64(%i, %h) : (f64, f64) -> ()

// CHECK-NEXT: inf
Expand Down
46 changes: 46 additions & 0 deletions mlir/unittests/IR/DialectTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,50 @@ TEST(Dialect, RepeatedDelayedRegistration) {
EXPECT_TRUE(testDialectInterface != nullptr);
}

namespace {
/// A dummy extension that increases a counter when being applied and
/// recursively adds additional extensions.
struct DummyExtension : DialectExtension<DummyExtension, TestDialect> {
DummyExtension(int *counter, int numRecursive)
: DialectExtension(), counter(counter), numRecursive(numRecursive) {}

void apply(MLIRContext *ctx, TestDialect *dialect) const final {
++(*counter);
DialectRegistry nestedRegistry;
for (int i = 0; i < numRecursive; ++i)
nestedRegistry.addExtension(
std::make_unique<DummyExtension>(counter, /*numRecursive=*/0));
// Adding additional extensions may trigger a reallocation of the
// `extensions` vector in the dialect registry.
ctx->appendDialectRegistry(nestedRegistry);
}

private:
int *counter;
int numRecursive;
};
} // namespace

TEST(Dialect, NestedDialectExtension) {
DialectRegistry registry;
registry.insert<TestDialect>();

// Add an extension that adds 100 more extensions.
int counter1 = 0;
registry.addExtension(std::make_unique<DummyExtension>(&counter1, 100));
// Add one more extension. This should not crash.
int counter2 = 0;
registry.addExtension(std::make_unique<DummyExtension>(&counter2, 0));

// Load dialect and apply extensions.
MLIRContext context(registry);
Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
ASSERT_TRUE(testDialect != nullptr);

// Extensions may be applied multiple times. Make sure that each expected
// extension was applied at least once.
EXPECT_GE(counter1, 101);
EXPECT_GE(counter2, 1);
}

} // namespace