Skip to content

Commit

Permalink
[mlir] Prevent SubElementInterface from going into infinite recursion
Browse files Browse the repository at this point in the history
Since only mutable types and attributes can go into infinite recursion
inside SubElementInterface::walkSubElement, and there are only a few of
them (mutable types and attributes), we introduce new traits for Type
and Attribute: TypeTrait::IsMutable and AttributeTrait::IsMutable,
respectively. They indicate whether a type or attribute is mutable.
Such traits are required if the ImplType defines a `mutate` function.

Then, inside SubElementInterface, we use a set to record visited mutable
types and attributes that have been visited before.

Differential Revision: https://reviews.llvm.org/D127537
  • Loading branch information
mshockwave committed Jun 29, 2022
1 parent bc5e7ce commit d410286
Show file tree
Hide file tree
Showing 13 changed files with 153 additions and 11 deletions.
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
Expand Up @@ -264,7 +264,8 @@ class LLVMPointerType : public Type::TypeBase<LLVMPointerType, Type,
/// structs, but does not in uniquing of identified structs.
class LLVMStructType
: public Type::TypeBase<LLVMStructType, Type, detail::LLVMStructTypeStorage,
DataLayoutTypeInterface::Trait> {
DataLayoutTypeInterface::Trait,
TypeTrait::IsMutable> {
public:
/// Inherit base constructors.
using Base::Base;
Expand Down
5 changes: 3 additions & 2 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
Expand Up @@ -275,8 +275,9 @@ class SampledImageType
/// In the above, expressing recursive struct types is accomplished by giving a
/// recursive struct a unique identified and using that identifier in the struct
/// definition for recursive references.
class StructType : public Type::TypeBase<StructType, CompositeType,
detail::StructTypeStorage> {
class StructType
: public Type::TypeBase<StructType, CompositeType,
detail::StructTypeStorage, TypeTrait::IsMutable> {
public:
using Base::Base;

Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/IR/Attributes.h
Expand Up @@ -231,6 +231,18 @@ class AttributeInterface
friend InterfaceBase;
};

//===----------------------------------------------------------------------===//
// Core AttributeTrait
//===----------------------------------------------------------------------===//

/// This trait is used to determine if an attribute is mutable or not. It is
/// attached on an attribute if the corresponding ImplType defines a `mutate`
/// function with proper signature.
namespace AttributeTrait {
template <typename ConcreteType>
using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
} // namespace AttributeTrait

} // namespace mlir.

namespace llvm {
Expand Down
14 changes: 14 additions & 0 deletions mlir/include/mlir/IR/StorageUniquerSupport.h
Expand Up @@ -53,6 +53,16 @@ class StorageUserTraitBase {
}
};

namespace StorageUserTrait {
/// This trait is used to determine if a storage user, like Type, is mutable
/// or not. A storage user is mutable if ImplType of the derived class defines
/// a `mutate` function with a proper signature. Note that this trait is not
/// supposed to be used publicly. Users should use alias names like
/// `TypeTrait::IsMutable` instead.
template <typename ConcreteType>
struct IsMutable : public StorageUserTraitBase<ConcreteType, IsMutable> {};
} // namespace StorageUserTrait

//===----------------------------------------------------------------------===//
// StorageUserBase
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -173,6 +183,10 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
/// Mutate the current storage instance. This will not change the unique key.
/// The arguments are forwarded to 'ConcreteT::mutate'.
template <typename... Args> LogicalResult mutate(Args &&...args) {
static_assert(std::is_base_of<StorageUserTrait::IsMutable<ConcreteT>,
ConcreteT>::value,
"The `mutate` function expects mutable trait "
"(e.g. TypeTrait::IsMutable) to be attached on parent.");
return UniquerT::template mutate<ConcreteT>(this->getContext(), getImpl(),
std::forward<Args>(args)...);
}
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/IR/Types.h
Expand Up @@ -222,6 +222,18 @@ class TypeInterface : public detail::Interface<ConcreteType, Type, Traits, Type,
friend InterfaceBase;
};

//===----------------------------------------------------------------------===//
// Core TypeTrait
//===----------------------------------------------------------------------===//

/// This trait is used to determine if a type is mutable or not. It is attached
/// on a type if the corresponding ImplType defines a `mutate` function with
/// a proper signature.
namespace TypeTrait {
template <typename ConcreteType>
using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
} // namespace TypeTrait

//===----------------------------------------------------------------------===//
// Type Utils
//===----------------------------------------------------------------------===//
Expand Down
36 changes: 31 additions & 5 deletions mlir/lib/IR/SubElementInterfaces.cpp
Expand Up @@ -8,22 +8,34 @@

#include "mlir/IR/SubElementInterfaces.h"

#include "llvm/ADT/DenseSet.h"

using namespace mlir;

template <typename InterfaceT>
static void walkSubElementsImpl(InterfaceT interface,
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) {
function_ref<void(Type)> walkTypesFn,
DenseSet<Attribute> &visitedAttrs,
DenseSet<Type> &visitedTypes) {
interface.walkImmediateSubElements(
[&](Attribute attr) {
// Guard against potentially null inputs. This removes the need for the
// derived attribute/type to do it.
if (!attr)
return;

// Avoid infinite recursion when visiting sub attributes later, if this
// is a mutable attribute.
if (LLVM_UNLIKELY(attr.hasTrait<AttributeTrait::IsMutable>())) {
if (!visitedAttrs.insert(attr).second)
return;
}

// Walk any sub elements first.
if (auto interface = attr.dyn_cast<SubElementAttrInterface>())
walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn);
walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
visitedTypes);

// Walk this attribute.
walkAttrsFn(attr);
Expand All @@ -34,9 +46,17 @@ static void walkSubElementsImpl(InterfaceT interface,
if (!type)
return;

// Avoid infinite recursion when visiting sub types later, if this
// is a mutable type.
if (LLVM_UNLIKELY(type.hasTrait<TypeTrait::IsMutable>())) {
if (!visitedTypes.insert(type).second)
return;
}

// Walk any sub elements first.
if (auto interface = type.dyn_cast<SubElementTypeInterface>())
walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn);
walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs,
visitedTypes);

// Walk this type.
walkTypesFn(type);
Expand All @@ -47,14 +67,20 @@ void SubElementAttrInterface::walkSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) {
assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn);
DenseSet<Attribute> visitedAttrs;
DenseSet<Type> visitedTypes;
walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
visitedTypes);
}

void SubElementTypeInterface::walkSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) {
assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn);
DenseSet<Attribute> visitedAttrs;
DenseSet<Type> visitedTypes;
walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs,
visitedTypes);
}

//===----------------------------------------------------------------------===//
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/IR/recursive-type.mlir
@@ -1,11 +1,17 @@
// RUN: mlir-opt %s -test-recursive-types | FileCheck %s

// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>

// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
// CHECK: !test.test_rec<a, test_rec<b, test_type>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<a, test_rec<b, test_type>>
// CHECK: !test.test_rec<c, test_rec<c>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<c, test_rec<c>>
// Make sure walkSubElementType, which is used to generate aliases, doesn't go
// into inifinite recursion.
// CHECK: !testrec
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<type_to_alias, test_rec<type_to_alias>>
return
}

Expand Down
7 changes: 7 additions & 0 deletions mlir/test/lib/Dialect/Test/TestDialect.cpp
Expand Up @@ -160,6 +160,13 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
return AliasResult::FinalAlias;
}
}
if (auto recType = type.dyn_cast<TestRecursiveType>()) {
if (recType.getName() == "type_to_alias") {
// We only make alias for a specific recursive type.
os << "testrec";
return AliasResult::FinalAlias;
}
}
return AliasResult::NoAlias;
}

Expand Down
13 changes: 11 additions & 2 deletions mlir/test/lib/Dialect/Test/TestTypes.h
Expand Up @@ -21,6 +21,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SubElementInterfaces.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"

Expand Down Expand Up @@ -130,7 +131,9 @@ struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
/// from type creation.
class TestRecursiveType
: public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
TestRecursiveTypeStorage> {
TestRecursiveTypeStorage,
::mlir::SubElementTypeInterface::Trait,
::mlir::TypeTrait::IsMutable> {
public:
using Base::Base;

Expand All @@ -141,10 +144,16 @@ class TestRecursiveType

/// Body getter and setter.
::mlir::LogicalResult setBody(Type body) { return Base::mutate(body); }
::mlir::Type getBody() { return getImpl()->body; }
::mlir::Type getBody() const { return getImpl()->body; }

/// Name/key getter.
::llvm::StringRef getName() { return getImpl()->name; }

void walkImmediateSubElements(
::llvm::function_ref<void(::mlir::Attribute)> walkAttrsFn,
::llvm::function_ref<void(::mlir::Type)> walkTypesFn) const {
walkTypesFn(getBody());
}
};

} // namespace test
Expand Down
2 changes: 1 addition & 1 deletion mlir/unittests/Dialect/CMakeLists.txt
Expand Up @@ -7,8 +7,8 @@ target_link_libraries(MLIRDialectTests
MLIRDialect)

add_subdirectory(Affine)
add_subdirectory(LLVMIR)
add_subdirectory(MemRef)

add_subdirectory(Quant)
add_subdirectory(SparseTensor)
add_subdirectory(SPIRV)
Expand Down
7 changes: 7 additions & 0 deletions mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
@@ -0,0 +1,7 @@
add_mlir_unittest(MLIRLLVMIRTests
LLVMTypeTest.cpp
)
target_link_libraries(MLIRLLVMIRTests
PRIVATE
MLIRLLVMDialect
)
27 changes: 27 additions & 0 deletions mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h
@@ -0,0 +1,27 @@
//===- LLVMTestBase.h - Test fixure for LLVM dialect tests ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Test fixure for LLVM dialect tests.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H
#define MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/MLIRContext.h"
#include "gtest/gtest.h"

class LLVMIRTest : public ::testing::Test {
protected:
LLVMIRTest() { context.getOrLoadDialect<mlir::LLVM::LLVMDialect>(); }

mlir::MLIRContext context;
};

#endif
20 changes: 20 additions & 0 deletions mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
@@ -0,0 +1,20 @@
//===- LLVMTypeTest.cpp - Tests for LLVM types ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "LLVMTestBase.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/SubElementInterfaces.h"

using namespace mlir;
using namespace mlir::LLVM;

TEST_F(LLVMIRTest, IsStructTypeMutable) {
auto structTy = LLVMStructType::getIdentified(&context, "foo");
ASSERT_TRUE(bool(structTy));
ASSERT_TRUE(structTy.hasTrait<TypeTrait::IsMutable>());
}

0 comments on commit d410286

Please sign in to comment.