Skip to content

Commit

Permalink
[mlir][LLVMIR] Memorize compatible LLVM types
Browse files Browse the repository at this point in the history
This patch memorize compatible LLVM types in `LLVM::isCompatibleType` in
order to avoid redundant works.

This is especially useful when the size of program is big and there are
multiple occurrences of some deeply nested LLVM struct types, in which
case we can gain quite some speedups with this patch.

Differential Revision: https://reviews.llvm.org/D127918
  • Loading branch information
mshockwave committed Jun 27, 2022
1 parent 856056d commit fc7f726
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 50 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
Expand Up @@ -25,6 +25,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/ThreadLocalCache.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
Expand Down
9 changes: 9 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
Expand Up @@ -61,6 +61,15 @@ def LLVM_Dialect : Dialect {
static StringRef getEmitCWrapperAttrName() {
return "llvm.emit_c_interface";
}

/// Returns `true` if the given type is compatible with the LLVM dialect.
static bool isCompatibleType(Type);

private:
/// A cache storing compatible LLVM types that have been verified. This
/// can save us lots of verification time if there are many occurrences
/// of some deeply-nested aggregate types in the program.
ThreadLocalCache<DenseSet<Type>> compatibleTypes;
}];

let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
Expand Up @@ -451,7 +451,8 @@ void printType(Type type, AsmPrinter &printer);
// Utility functions.
//===----------------------------------------------------------------------===//

/// Returns `true` if the given type is compatible with the LLVM dialect.
/// Returns `true` if the given type is compatible with the LLVM dialect. This
/// is an alias to `LLVMDialect::isCompatibleType`.
bool isCompatibleType(Type type);

/// Returns `true` if the given outer type is compatible with the LLVM dialect
Expand Down
110 changes: 61 additions & 49 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
Expand Up @@ -721,63 +721,75 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) {
return false;
}

static bool isCompatibleImpl(Type type, SetVector<Type> &callstack) {
if (callstack.contains(type))
static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
if (!compatibleTypes.insert(type).second)
return true;

callstack.insert(type);
auto stackPopper = llvm::make_scope_exit([&] { callstack.pop_back(); });

auto isCompatible = [&](Type type) {
return isCompatibleImpl(type, callstack);
return isCompatibleImpl(type, compatibleTypes);
};

return llvm::TypeSwitch<Type, bool>(type)
.Case<LLVMStructType>([&](auto structType) {
return llvm::all_of(structType.getBody(), isCompatible);
})
.Case<LLVMFunctionType>([&](auto funcType) {
return isCompatible(funcType.getReturnType()) &&
llvm::all_of(funcType.getParams(), isCompatible);
})
.Case<IntegerType>([](auto intType) { return intType.isSignless(); })
.Case<VectorType>([&](auto vecType) {
return vecType.getRank() == 1 && isCompatible(vecType.getElementType());
})
.Case<LLVMPointerType>([&](auto pointerType) {
if (pointerType.isOpaque())
return true;
return isCompatible(pointerType.getElementType());
})
// clang-format off
.Case<
LLVMFixedVectorType,
LLVMScalableVectorType,
LLVMArrayType
>([&](auto containerType) {
return isCompatible(containerType.getElementType());
})
.Case<
BFloat16Type,
Float16Type,
Float32Type,
Float64Type,
Float80Type,
Float128Type,
LLVMLabelType,
LLVMMetadataType,
LLVMPPCFP128Type,
LLVMTokenType,
LLVMVoidType,
LLVMX86MMXType
>([](Type) { return true; })
// clang-format on
.Default([](Type) { return false; });
bool result =
llvm::TypeSwitch<Type, bool>(type)
.Case<LLVMStructType>([&](auto structType) {
return llvm::all_of(structType.getBody(), isCompatible);
})
.Case<LLVMFunctionType>([&](auto funcType) {
return isCompatible(funcType.getReturnType()) &&
llvm::all_of(funcType.getParams(), isCompatible);
})
.Case<IntegerType>([](auto intType) { return intType.isSignless(); })
.Case<VectorType>([&](auto vecType) {
return vecType.getRank() == 1 &&
isCompatible(vecType.getElementType());
})
.Case<LLVMPointerType>([&](auto pointerType) {
if (pointerType.isOpaque())
return true;
return isCompatible(pointerType.getElementType());
})
// clang-format off
.Case<
LLVMFixedVectorType,
LLVMScalableVectorType,
LLVMArrayType
>([&](auto containerType) {
return isCompatible(containerType.getElementType());
})
.Case<
BFloat16Type,
Float16Type,
Float32Type,
Float64Type,
Float80Type,
Float128Type,
LLVMLabelType,
LLVMMetadataType,
LLVMPPCFP128Type,
LLVMTokenType,
LLVMVoidType,
LLVMX86MMXType
>([](Type) { return true; })
// clang-format on
.Default([](Type) { return false; });

if (!result)
compatibleTypes.erase(type);

return result;
}

bool LLVMDialect::isCompatibleType(Type type) {
if (auto *llvmDialect =
type.getContext()->getLoadedDialect<LLVM::LLVMDialect>())
return isCompatibleImpl(type, llvmDialect->compatibleTypes.get());

DenseSet<Type> localCompatibleTypes;
return isCompatibleImpl(type, localCompatibleTypes);
}

bool mlir::LLVM::isCompatibleType(Type type) {
SetVector<Type> callstack;
return isCompatibleImpl(type, callstack);
return LLVMDialect::isCompatibleType(type);
}

bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {
Expand Down

0 comments on commit fc7f726

Please sign in to comment.