-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR] Enable caching of type conversion in the presence of context-aware conversion #158072
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Mehdi Amini (joker-eph) ChangesThe current implementation is overly conservative and disable all possible caching as soon as a context-aware conversion is present. However the context-aware conversion only affects subsequent converters, we can cache the previous ones. Full diff: https://github.com/llvm/llvm-project/pull/158072.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 6949f4a14fdba..1b9f7a76fc579 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -433,7 +433,7 @@ class TypeConverter {
std::is_same_v<T, Value>,
ConversionCallbackFn>
wrapCallback(FnT &&callback) {
- hasContextAwareTypeConversions = true;
+ contextAwareTypeConversionsIndex = conversions.size();
return [callback = std::forward<FnT>(callback)](
PointerUnion<Type, Value> typeOrValue,
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
@@ -555,6 +555,10 @@ class TypeConverter {
cachedMultiConversions.clear();
}
+ /// Internal implementation of the type conversion.
+ template <typename T>
+ LogicalResult convertTypeImpl(T t, SmallVectorImpl<Type> &results) const;
+
/// The set of registered conversion functions.
SmallVector<ConversionCallbackFn, 4> conversions;
@@ -575,10 +579,13 @@ class TypeConverter {
mutable llvm::sys::SmartRWMutex<true> cacheMutex;
/// Whether the type converter has context-aware type conversions. I.e.,
/// conversion rules that depend on the SSA value instead of just the type.
- /// Type conversion caching is deactivated when there are context-aware
- /// conversions because the type converter may return different results for
- /// the same input type.
- bool hasContextAwareTypeConversions = false;
+ /// We store here the index in the `conversions` vector of the last added
+ /// context-aware conversion, if any. This is useful because we can't cache
+ /// the result of type conversion happening after context-aware conversions,
+ /// because the type converter may return different results for the same input
+ /// type. This is why it is recommened to add context-aware conversions first,
+ /// any context-free conversions after will benefit from caching.
+ int contextAwareTypeConversionsIndex = -1;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 36ee87b533b3b..8f36a653e3a17 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3406,22 +3406,37 @@ void TypeConverter::SignatureConversion::remapInput(
SmallVector<Value, 1>(replacements.begin(), replacements.end())};
}
-LogicalResult TypeConverter::convertType(Type t,
- SmallVectorImpl<Type> &results) const {
+/// Internal implementation of the type conversion.
+/// This is used with either a Type or a Value as the first argument.
+/// When using a value, the caching behavior is different:
+/// - we can cache the context-free conversions until the last registered
+/// context-aware conversion.
+/// - we can't cache the result of type conversion happening after context-aware
+/// conversions,
+/// because the type converter may return different results for the same input
+/// type.
+template <typename T>
+LogicalResult
+TypeConverter::convertTypeImpl(T t, SmallVectorImpl<Type> &results) const {
assert(t && "expected non-null type");
-
+ auto getType = [&](auto typeOrValue) {
+ if constexpr (std::is_same_v<decltype(typeOrValue), Type>)
+ return typeOrValue;
+ else
+ return typeOrValue.getType();
+ };
{
std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
std::defer_lock);
if (t.getContext()->isMultithreadingEnabled())
cacheReadLock.lock();
- auto existingIt = cachedDirectConversions.find(t);
+ auto existingIt = cachedDirectConversions.find(getType(t));
if (existingIt != cachedDirectConversions.end()) {
if (existingIt->second)
results.push_back(existingIt->second);
return success(existingIt->second != nullptr);
}
- auto multiIt = cachedMultiConversions.find(t);
+ auto multiIt = cachedMultiConversions.find(getType(t));
if (multiIt != cachedMultiConversions.end()) {
results.append(multiIt->second.begin(), multiIt->second.end());
return success();
@@ -3431,52 +3446,56 @@ LogicalResult TypeConverter::convertType(Type t,
// registered first.
size_t currentCount = results.size();
+ // We can cache the context-free conversions until the last registered
+ // context-aware conversion. But only if we're processing a Value right now.
+ auto isCacheable = [&](int index) {
+ if constexpr (std::is_same_v<T, Type>)
+ return true;
+ int numberOfConversionsUntilContextAware =
+ conversions.size() - 1 - contextAwareTypeConversionsIndex;
+ return index < numberOfConversionsUntilContextAware;
+ };
+
std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
std::defer_lock);
- for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
- if (std::optional<LogicalResult> result = converter(t, results)) {
- if (t.getContext()->isMultithreadingEnabled())
- cacheWriteLock.lock();
- if (!succeeded(*result)) {
- assert(results.size() == currentCount &&
- "failed type conversion should not change results");
- cachedDirectConversions.try_emplace(t, nullptr);
- return failure();
- }
- auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
- if (newTypes.size() == 1)
- cachedDirectConversions.try_emplace(t, newTypes.front());
- else
- cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
+ for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
+ const ConversionCallbackFn &converter = indexedConverter.value();
+ std::optional<LogicalResult> result = converter(t, results);
+ if (!result) {
+ assert(results.size() == currentCount &&
+ "failed type conversion should not change results");
+ continue;
+ }
+ if (!isCacheable(indexedConverter.index()))
return success();
- } else {
+ if (t.getContext()->isMultithreadingEnabled())
+ cacheWriteLock.lock();
+ if (!succeeded(*result)) {
assert(results.size() == currentCount &&
"failed type conversion should not change results");
+ cachedDirectConversions.try_emplace(getType(t), nullptr);
+ return failure();
}
+ auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
+ if (newTypes.size() == 1)
+ cachedDirectConversions.try_emplace(getType(t), newTypes.front());
+ else
+ cachedMultiConversions.try_emplace(getType(t),
+ llvm::to_vector<2>(newTypes));
+ return success();
}
return failure();
}
-LogicalResult TypeConverter::convertType(Value v,
+LogicalResult TypeConverter::convertType(Type t,
SmallVectorImpl<Type> &results) const {
- assert(v && "expected non-null value");
-
- // If this type converter does not have context-aware type conversions, call
- // the type-based overload, which has caching.
- if (!hasContextAwareTypeConversions)
- return convertType(v.getType(), results);
+ return convertTypeImpl(t, results);
+}
- // Walk the added converters in reverse order to apply the most recently
- // registered first.
- for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
- if (std::optional<LogicalResult> result = converter(v, results)) {
- if (!succeeded(*result))
- return failure();
- return success();
- }
- }
- return failure();
+LogicalResult TypeConverter::convertType(Value v,
+ SmallVectorImpl<Type> &results) const {
+ return convertTypeImpl(v, results);
}
Type TypeConverter::convertType(Type t) const {
|
8438c7d
to
c3a5930
Compare
…ware conversion The current implementation is overly conservative and disable all possible caching as soon as a context-aware conversion is present. However the context-aware conversion only affects subsequent converters, we can cache the previous ones.
c3a5930
to
f8f81d2
Compare
Thanks for this! |
The current implementation is overly conservative and disable all possible caching as soon as a context-aware conversion is present. However the context-aware conversion only affects subsequent converters, we can cache the previous ones.
This isn't NFC because if fixed a bug where we use to unconditionally cache when using the
convertType(Type t, ...
API, while now all APIs are aware of context-aware conversions.