Skip to content

Commit

Permalink
Remove the conversionCallStack from the MLIR TypeConverter
Browse files Browse the repository at this point in the history
This vector keeps tracks of recursive types through the recursive invocations
of `convertType()`. However this is something only useful for some specific
cases, in which the dedicated conversion callbacks can handle this stack
privately.

This allows removing a mutable member of the type converter.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D158351
  • Loading branch information
joker-eph committed Aug 27, 2023
1 parent 3823395 commit dc3dc97
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 54 deletions.
3 changes: 1 addition & 2 deletions flang/include/flang/Optimizer/CodeGen/TypeConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
// fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
std::optional<mlir::LogicalResult>
convertRecordType(fir::RecordType derived,
llvm::SmallVectorImpl<mlir::Type> &results,
llvm::ArrayRef<mlir::Type> callStack) const;
llvm::SmallVectorImpl<mlir::Type> &results);

// Is an extended descriptor needed given the element type of a fir.box type ?
// Extended descriptors are required for derived types.
Expand Down
21 changes: 13 additions & 8 deletions flang/lib/Optimizer/CodeGen/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"

namespace fir {
Expand Down Expand Up @@ -81,11 +82,10 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA)
});
addConversion(
[&](fir::PointerType pointer) { return convertPointerLike(pointer); });
addConversion([&](fir::RecordType derived,
llvm::SmallVectorImpl<mlir::Type> &results,
llvm::ArrayRef<mlir::Type> callStack) {
return convertRecordType(derived, results, callStack);
});
addConversion(
[&](fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results) {
return convertRecordType(derived, results);
});
addConversion(
[&](fir::RealType real) { return convertRealType(real.getFKind()); });
addConversion(
Expand Down Expand Up @@ -167,14 +167,19 @@ mlir::Type LLVMTypeConverter::indexType() const {

// fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
std::optional<mlir::LogicalResult> LLVMTypeConverter::convertRecordType(
fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results,
llvm::ArrayRef<mlir::Type> callStack) const {
fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results) {
auto name = derived.getName();
auto st = mlir::LLVM::LLVMStructType::getIdentified(&getContext(), name);
if (llvm::count(callStack, derived) > 1) {

auto &callStack = getCurrentThreadRecursiveStack();
if (llvm::count(callStack, derived)) {
results.push_back(st);
return mlir::success();
}
callStack.push_back(derived);
auto popConversionCallStack =
llvm::make_scope_exit([&callStack]() { callStack.pop_back(); });

llvm::SmallVector<mlir::Type> members;
for (auto mem : derived.getTypeList()) {
// Prevent fir.box from degenerating to a pointer to a descriptor in the
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ class LLVMTypeConverter : public TypeConverter {
/// Pointer to the LLVM dialect.
LLVM::LLVMDialect *llvmDialect;

// Recursive structure detection.
// We store one entry per thread here, and rely on locking.
DenseMap<uint64_t, std::unique_ptr<SmallVector<Type>>> conversionCallStack;
llvm::sys::SmartRWMutex<true> callStackMutex;
SmallVector<Type> &getCurrentThreadRecursiveStack();

private:
/// Convert a function type. The arguments and results are converted one by
/// one. Additionally, if the function returns more than one value, pack the
Expand Down
48 changes: 15 additions & 33 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ class TypeConverter {
/// types is empty, the type is removed and any usages of the existing value
/// are expected to be removed during conversion.
using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
Type, SmallVectorImpl<Type> &, ArrayRef<Type>)>;
Type, SmallVectorImpl<Type> &)>;

/// The signature of the callback used to materialize a conversion.
using MaterializationCallbackFn = std::function<std::optional<Value>(
Expand All @@ -330,44 +330,30 @@ class TypeConverter {
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
wrapCallback(FnT &&callback) const {
return wrapCallback<T>(
[callback = std::forward<FnT>(callback)](
T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
if (std::optional<Type> resultOpt = callback(type)) {
bool wasSuccess = static_cast<bool>(*resultOpt);
if (wasSuccess)
results.push_back(*resultOpt);
return std::optional<LogicalResult>(success(wasSuccess));
}
return std::optional<LogicalResult>();
});
return wrapCallback<T>([callback = std::forward<FnT>(callback)](
T type, SmallVectorImpl<Type> &results) {
if (std::optional<Type> resultOpt = callback(type)) {
bool wasSuccess = static_cast<bool>(*resultOpt);
if (wasSuccess)
results.push_back(*resultOpt);
return std::optional<LogicalResult>(success(wasSuccess));
}
return std::optional<LogicalResult>();
});
}
/// With callback of form: `std::optional<LogicalResult>(
/// T, SmallVectorImpl<Type> &)`.
/// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
ConversionCallbackFn>
wrapCallback(FnT &&callback) const {
return wrapCallback<T>(
[callback = std::forward<FnT>(callback)](
T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
return callback(type, results);
});
}
/// With callback of form: `std::optional<LogicalResult>(
/// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
template <typename T, typename FnT>
std::enable_if_t<
std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &, ArrayRef<Type>>,
ConversionCallbackFn>
wrapCallback(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
Type type, SmallVectorImpl<Type> &results,
ArrayRef<Type> callStack) -> std::optional<LogicalResult> {
Type type,
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
T derivedType = dyn_cast<T>(type);
if (!derivedType)
return std::nullopt;
return callback(derivedType, results, callStack);
return callback(derivedType, results);
};
}

Expand Down Expand Up @@ -435,10 +421,6 @@ class TypeConverter {
mutable DenseMap<Type, Type> cachedDirectConversions;
/// This cache stores the successful 1->N conversions, where N != 1.
mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;

/// Stores the types that are being converted in the case when convertType
/// is being called recursively to convert nested types.
mutable SmallVector<Type, 2> conversionCallStack;
};

//===----------------------------------------------------------------------===//
Expand Down
36 changes: 33 additions & 3 deletions mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,34 @@
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Threading.h"
#include <memory>
#include <mutex>
#include <optional>

using namespace mlir;

SmallVector<Type> &LLVMTypeConverter::getCurrentThreadRecursiveStack() {
{
// Most of the time, the entry already exists in the map.
std::shared_lock<decltype(callStackMutex)> lock(callStackMutex,
std::defer_lock);
if (getContext().isMultithreadingEnabled())
lock.lock();
auto recursiveStack = conversionCallStack.find(llvm::get_threadid());
if (recursiveStack != conversionCallStack.end())
return *recursiveStack->second;
}

// First time this thread gets here, we have to get an exclusive access to
// inset in the map
std::unique_lock<decltype(callStackMutex)> lock(callStackMutex);
auto recursiveStackInserted = conversionCallStack.insert(std::make_pair(
llvm::get_threadid(), std::make_unique<SmallVector<Type>>()));
return *recursiveStackInserted.first->second.get();
}

/// Create an LLVMTypeConverter using default LowerToLLVMOptions.
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
const DataLayoutAnalysis *analysis)
Expand Down Expand Up @@ -56,8 +80,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
return LLVM::LLVMPointerType::get(pointee, type.getAddressSpace());
return std::nullopt;
});
addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results,
ArrayRef<Type> callStack) -> std::optional<LogicalResult> {

addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results)
-> std::optional<LogicalResult> {
// Fastpath for types that won't be converted by this callback anyway.
if (LLVM::isCompatibleType(type)) {
results.push_back(type);
Expand All @@ -75,10 +100,15 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
type.getContext(),
("_Converted_" + std::to_string(counter) + type.getName()).str());
}
if (llvm::count(callStack, type) > 1) {

SmallVectorImpl<Type> &recursiveStack = getCurrentThreadRecursiveStack();
if (llvm::count(recursiveStack, type)) {
results.push_back(convertedType);
return success();
}
recursiveStack.push_back(type);
auto popConversionCallStack = llvm::make_scope_exit(
[&recursiveStack]() { recursiveStack.pop_back(); });

SmallVector<Type> convertedElemTypes;
convertedElemTypes.reserve(type.getBody().size());
Expand Down
7 changes: 2 additions & 5 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2935,12 +2935,9 @@ LogicalResult TypeConverter::convertType(Type t,
// Walk the added converters in reverse order to apply the most recently
// registered first.
size_t currentCount = results.size();
conversionCallStack.push_back(t);
auto popConversionCallStack =
llvm::make_scope_exit([this]() { conversionCallStack.pop_back(); });

for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
if (std::optional<LogicalResult> result =
converter(t, results, conversionCallStack)) {
if (std::optional<LogicalResult> result = converter(t, results)) {
if (!succeeded(*result)) {
cachedDirectConversions.try_emplace(t, nullptr);
return failure();
Expand Down
13 changes: 10 additions & 3 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ScopeExit.h"

using namespace mlir;
using namespace test;
Expand Down Expand Up @@ -1374,6 +1375,7 @@ struct TestTypeConversionDriver

void runOnOperation() override {
// Initialize the type converter.
SmallVector<Type, 2> conversionCallStack;
TypeConverter converter;

/// Add the legal set of type conversions.
Expand All @@ -1394,20 +1396,25 @@ struct TestTypeConversionDriver
converter.addConversion(
// Convert a recursive self-referring type into a non-self-referring
// type named "outer_converted_type" that contains a SimpleAType.
[&](test::TestRecursiveType type, SmallVectorImpl<Type> &results,
ArrayRef<Type> callStack) -> std::optional<LogicalResult> {
[&](test::TestRecursiveType type,
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
// If the type is already converted, return it to indicate that it is
// legal.
if (type.getName() == "outer_converted_type") {
results.push_back(type);
return success();
}

conversionCallStack.push_back(type);
auto popConversionCallStack = llvm::make_scope_exit(
[&conversionCallStack]() { conversionCallStack.pop_back(); });

// If the type is on the call stack more than once (it is there at
// least once because of the _current_ call, which is always the last
// element on the stack), we've hit the recursive case. Just return
// SimpleAType here to create a non-recursive type as a result.
if (llvm::is_contained(callStack.drop_back(), type)) {
if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(),
type)) {
results.push_back(test::SimpleAType::get(type.getContext()));
return success();
}
Expand Down

0 comments on commit dc3dc97

Please sign in to comment.