36 changes: 31 additions & 5 deletions mlir/lib/IR/SubElementInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,34 @@

#include "mlir/IR/SubElementInterfaces.h"

#include "llvm/ADT/DenseSet.h"

using namespace mlir;

template <typename InterfaceT>
static void walkSubElementsImpl(InterfaceT interface,
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) {
function_ref<void(Type)> walkTypesFn,
DenseSet<Attribute> &visitedAttrs,
DenseSet<Type> &visitedTypes) {
interface.walkImmediateSubElements(
[&](Attribute attr) {
// Guard against potentially null inputs. This removes the need for the
// derived attribute/type to do it.
if (!attr)
return;

// Avoid infinite recursion when visiting sub attributes later, if this
// is a mutable attribute.
if (LLVM_UNLIKELY(attr.hasTrait<AttributeTrait::IsMutable>())) {
if (!visitedAttrs.insert(attr).second)
return;
}

// Walk any sub elements first.
if (auto interface = attr.dyn_cast<SubElementAttrInterface>())
walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn);
walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
visitedTypes);

// Walk this attribute.
walkAttrsFn(attr);
Expand All @@ -34,9 +46,17 @@ static void walkSubElementsImpl(InterfaceT interface,
if (!type)
return;

// Avoid infinite recursion when visiting sub types later, if this
// is a mutable type.
if (LLVM_UNLIKELY(type.hasTrait<TypeTrait::IsMutable>())) {
if (!visitedTypes.insert(type).second)
return;
}

// Walk any sub elements first.
if (auto interface = type.dyn_cast<SubElementTypeInterface>())
walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn);
walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
visitedTypes);

// Walk this type.
walkTypesFn(type);
Expand All @@ -47,14 +67,20 @@ void SubElementAttrInterface::walkSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) {
assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn);
DenseSet<Attribute> visitedAttrs;
DenseSet<Type> visitedTypes;
walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
visitedTypes);
}

void SubElementTypeInterface::walkSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) {
assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn);
DenseSet<Attribute> visitedAttrs;
DenseSet<Type> visitedTypes;
walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
visitedTypes);
}

//===----------------------------------------------------------------------===//
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/IR/recursive-type.mlir
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
// RUN: mlir-opt %s -test-recursive-types | FileCheck %s

// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>

// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
// CHECK: !test.test_rec<a, test_rec<b, test_type>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<a, test_rec<b, test_type>>
// CHECK: !test.test_rec<c, test_rec<c>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<c, test_rec<c>>
// Make sure walkSubElementType, which is used to generate aliases, doesn't go
// into inifinite recursion.
// CHECK: !testrec
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<type_to_alias, test_rec<type_to_alias>>
return
}

Expand Down
7 changes: 7 additions & 0 deletions mlir/test/lib/Dialect/Test/TestDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
return AliasResult::FinalAlias;
}
}
if (auto recType = type.dyn_cast<TestRecursiveType>()) {
if (recType.getName() == "type_to_alias") {
// We only make alias for a specific recursive type.
os << "testrec";
return AliasResult::FinalAlias;
}
}
return AliasResult::NoAlias;
}

Expand Down
13 changes: 11 additions & 2 deletions mlir/test/lib/Dialect/Test/TestTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SubElementInterfaces.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"

Expand Down Expand Up @@ -130,7 +131,9 @@ struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
/// from type creation.
class TestRecursiveType
: public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
TestRecursiveTypeStorage> {
TestRecursiveTypeStorage,
::mlir::SubElementTypeInterface::Trait,
::mlir::TypeTrait::IsMutable> {
public:
using Base::Base;

Expand All @@ -141,10 +144,16 @@ class TestRecursiveType

/// Body getter and setter.
::mlir::LogicalResult setBody(Type body) { return Base::mutate(body); }
::mlir::Type getBody() { return getImpl()->body; }
::mlir::Type getBody() const { return getImpl()->body; }

/// Name/key getter.
::llvm::StringRef getName() { return getImpl()->name; }

void walkImmediateSubElements(
::llvm::function_ref<void(::mlir::Attribute)> walkAttrsFn,
::llvm::function_ref<void(::mlir::Type)> walkTypesFn) const {
walkTypesFn(getBody());
}
};

} // namespace test
Expand Down
2 changes: 1 addition & 1 deletion mlir/unittests/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ target_link_libraries(MLIRDialectTests
MLIRDialect)

add_subdirectory(Affine)
add_subdirectory(LLVMIR)
add_subdirectory(MemRef)

add_subdirectory(Quant)
add_subdirectory(SparseTensor)
add_subdirectory(SPIRV)
Expand Down
7 changes: 7 additions & 0 deletions mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
add_mlir_unittest(MLIRLLVMIRTests
LLVMTypeTest.cpp
)
target_link_libraries(MLIRLLVMIRTests
PRIVATE
MLIRLLVMDialect
)
27 changes: 27 additions & 0 deletions mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===- LLVMTestBase.h - Test fixure for LLVM dialect tests ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Test fixure for LLVM dialect tests.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H
#define MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/MLIRContext.h"
#include "gtest/gtest.h"

class LLVMIRTest : public ::testing::Test {
protected:
LLVMIRTest() { context.getOrLoadDialect<mlir::LLVM::LLVMDialect>(); }

mlir::MLIRContext context;
};

#endif
63 changes: 63 additions & 0 deletions mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
//===- LLVMTypeTest.cpp - Tests for LLVM types ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "LLVMTestBase.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/SubElementInterfaces.h"

using namespace mlir;
using namespace mlir::LLVM;

TEST_F(LLVMIRTest, IsStructTypeMutable) {
auto structTy = LLVMStructType::getIdentified(&context, "foo");
ASSERT_TRUE(bool(structTy));
ASSERT_TRUE(structTy.hasTrait<TypeTrait::IsMutable>());
}

TEST_F(LLVMIRTest, MutualReferencedSubElementTypes) {
auto fooStructTy = LLVMStructType::getIdentified(&context, "foo");
ASSERT_TRUE(bool(fooStructTy));
auto barStructTy = LLVMStructType::getIdentified(&context, "bar");
ASSERT_TRUE(bool(barStructTy));

// Created two structs that are referencing each other.
Type fooBody[] = {LLVMPointerType::get(barStructTy)};
ASSERT_TRUE(succeeded(fooStructTy.setBody(fooBody, /*packed=*/false)));
Type barBody[] = {LLVMPointerType::get(fooStructTy)};
ASSERT_TRUE(succeeded(barStructTy.setBody(barBody, /*packed=*/false)));

auto subElementInterface = fooStructTy.dyn_cast<SubElementTypeInterface>();
ASSERT_TRUE(bool(subElementInterface));
// Test if walkSubElements goes into infinite loops.
SmallVector<Type, 4> subElementTypes;
subElementInterface.walkSubElements(
[](Attribute attr) {},
[&](Type type) { subElementTypes.push_back(type); });
// We don't record LLVMPointerType (because it's immutable), thus
// !llvm.ptr<struct<"bar",...>> will be visited twice.
ASSERT_EQ(subElementTypes.size(), 5U);

// !llvm.ptr<struct<"bar",...>>
ASSERT_TRUE(subElementTypes[0].isa<LLVMPointerType>());

// !llvm.struct<"foo",...>
auto structType = subElementTypes[1].dyn_cast<LLVMStructType>();
ASSERT_TRUE(bool(structType));
ASSERT_TRUE(structType.getName().equals("foo"));

// !llvm.ptr<struct<"foo",...>>
ASSERT_TRUE(subElementTypes[2].isa<LLVMPointerType>());

// !llvm.struct<"bar",...>
structType = subElementTypes[3].dyn_cast<LLVMStructType>();
ASSERT_TRUE(bool(structType));
ASSERT_TRUE(structType.getName().equals("bar"));

// !llvm.ptr<struct<"bar",...>>
ASSERT_TRUE(subElementTypes[4].isa<LLVMPointerType>());
}