Skip to content

Conversation

joker-eph
Copy link
Collaborator

@joker-eph joker-eph commented Sep 11, 2025

The current implementation is overly conservative and disable all possible caching as soon as a context-aware conversion is present. However the context-aware conversion only affects subsequent converters, we can cache the previous ones.

This isn't NFC because if fixed a bug where we use to unconditionally cache when using the convertType(Type t, ... API, while now all APIs are aware of context-aware conversions.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Sep 11, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Mehdi Amini (joker-eph)

Changes

The current implementation is overly conservative and disable all possible caching as soon as a context-aware conversion is present. However the context-aware conversion only affects subsequent converters, we can cache the previous ones.


Full diff: https://github.com/llvm/llvm-project/pull/158072.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+12-5)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+57-38)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 6949f4a14fdba..1b9f7a76fc579 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -433,7 +433,7 @@ class TypeConverter {
                        std::is_same_v<T, Value>,
                    ConversionCallbackFn>
   wrapCallback(FnT &&callback) {
-    hasContextAwareTypeConversions = true;
+    contextAwareTypeConversionsIndex = conversions.size();
     return [callback = std::forward<FnT>(callback)](
                PointerUnion<Type, Value> typeOrValue,
                SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
@@ -555,6 +555,10 @@ class TypeConverter {
     cachedMultiConversions.clear();
   }
 
+  /// Internal implementation of the type conversion.
+  template <typename T>
+  LogicalResult convertTypeImpl(T t, SmallVectorImpl<Type> &results) const;
+
   /// The set of registered conversion functions.
   SmallVector<ConversionCallbackFn, 4> conversions;
 
@@ -575,10 +579,13 @@ class TypeConverter {
   mutable llvm::sys::SmartRWMutex<true> cacheMutex;
   /// Whether the type converter has context-aware type conversions. I.e.,
   /// conversion rules that depend on the SSA value instead of just the type.
-  /// Type conversion caching is deactivated when there are context-aware
-  /// conversions because the type converter may return different results for
-  /// the same input type.
-  bool hasContextAwareTypeConversions = false;
+  /// We store here the index in the `conversions` vector of the last added
+  /// context-aware conversion, if any. This is useful because we can't cache
+  /// the result of type conversion happening after context-aware conversions,
+  /// because the type converter may return different results for the same input
+  /// type. This is why it is recommened to add context-aware conversions first,
+  /// any context-free conversions after will benefit from caching.
+  int contextAwareTypeConversionsIndex = -1;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 36ee87b533b3b..8f36a653e3a17 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3406,22 +3406,37 @@ void TypeConverter::SignatureConversion::remapInput(
       SmallVector<Value, 1>(replacements.begin(), replacements.end())};
 }
 
-LogicalResult TypeConverter::convertType(Type t,
-                                         SmallVectorImpl<Type> &results) const {
+/// Internal implementation of the type conversion.
+/// This is used with either a Type or a Value as the first argument.
+/// When using a value, the caching behavior is different:
+/// - we can cache the context-free conversions until the last registered
+/// context-aware conversion.
+/// - we can't cache the result of type conversion happening after context-aware
+/// conversions,
+///   because the type converter may return different results for the same input
+///   type.
+template <typename T>
+LogicalResult
+TypeConverter::convertTypeImpl(T t, SmallVectorImpl<Type> &results) const {
   assert(t && "expected non-null type");
-
+  auto getType = [&](auto typeOrValue) {
+    if constexpr (std::is_same_v<decltype(typeOrValue), Type>)
+      return typeOrValue;
+    else
+      return typeOrValue.getType();
+  };
   {
     std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
                                                          std::defer_lock);
     if (t.getContext()->isMultithreadingEnabled())
       cacheReadLock.lock();
-    auto existingIt = cachedDirectConversions.find(t);
+    auto existingIt = cachedDirectConversions.find(getType(t));
     if (existingIt != cachedDirectConversions.end()) {
       if (existingIt->second)
         results.push_back(existingIt->second);
       return success(existingIt->second != nullptr);
     }
-    auto multiIt = cachedMultiConversions.find(t);
+    auto multiIt = cachedMultiConversions.find(getType(t));
     if (multiIt != cachedMultiConversions.end()) {
       results.append(multiIt->second.begin(), multiIt->second.end());
       return success();
@@ -3431,52 +3446,56 @@ LogicalResult TypeConverter::convertType(Type t,
   // registered first.
   size_t currentCount = results.size();
 
+  // We can cache the context-free conversions until the last registered
+  // context-aware conversion. But only if we're processing a Value right now.
+  auto isCacheable = [&](int index) {
+    if constexpr (std::is_same_v<T, Type>)
+      return true;
+    int numberOfConversionsUntilContextAware =
+        conversions.size() - 1 - contextAwareTypeConversionsIndex;
+    return index < numberOfConversionsUntilContextAware;
+  };
+
   std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
                                                         std::defer_lock);
 
-  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
-    if (std::optional<LogicalResult> result = converter(t, results)) {
-      if (t.getContext()->isMultithreadingEnabled())
-        cacheWriteLock.lock();
-      if (!succeeded(*result)) {
-        assert(results.size() == currentCount &&
-               "failed type conversion should not change results");
-        cachedDirectConversions.try_emplace(t, nullptr);
-        return failure();
-      }
-      auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
-      if (newTypes.size() == 1)
-        cachedDirectConversions.try_emplace(t, newTypes.front());
-      else
-        cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
+  for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
+    const ConversionCallbackFn &converter = indexedConverter.value();
+    std::optional<LogicalResult> result = converter(t, results);
+    if (!result) {
+      assert(results.size() == currentCount &&
+             "failed type conversion should not change results");
+      continue;
+    }
+    if (!isCacheable(indexedConverter.index()))
       return success();
-    } else {
+    if (t.getContext()->isMultithreadingEnabled())
+      cacheWriteLock.lock();
+    if (!succeeded(*result)) {
       assert(results.size() == currentCount &&
              "failed type conversion should not change results");
+      cachedDirectConversions.try_emplace(getType(t), nullptr);
+      return failure();
     }
+    auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
+    if (newTypes.size() == 1)
+      cachedDirectConversions.try_emplace(getType(t), newTypes.front());
+    else
+      cachedMultiConversions.try_emplace(getType(t),
+                                         llvm::to_vector<2>(newTypes));
+    return success();
   }
   return failure();
 }
 
-LogicalResult TypeConverter::convertType(Value v,
+LogicalResult TypeConverter::convertType(Type t,
                                          SmallVectorImpl<Type> &results) const {
-  assert(v && "expected non-null value");
-
-  // If this type converter does not have context-aware type conversions, call
-  // the type-based overload, which has caching.
-  if (!hasContextAwareTypeConversions)
-    return convertType(v.getType(), results);
+  return convertTypeImpl(t, results);
+}
 
-  // Walk the added converters in reverse order to apply the most recently
-  // registered first.
-  for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
-    if (std::optional<LogicalResult> result = converter(v, results)) {
-      if (!succeeded(*result))
-        return failure();
-      return success();
-    }
-  }
-  return failure();
+LogicalResult TypeConverter::convertType(Value v,
+                                         SmallVectorImpl<Type> &results) const {
+  return convertTypeImpl(v, results);
 }
 
 Type TypeConverter::convertType(Type t) const {

@joker-eph joker-eph force-pushed the caching_conversions branch 2 times, most recently from 8438c7d to c3a5930 Compare September 11, 2025 14:04
…ware conversion

The current implementation is overly conservative and disable all possible
caching as soon as a context-aware conversion is present. However the
context-aware conversion only affects subsequent converters, we can cache
the previous ones.
@joker-eph joker-eph enabled auto-merge (squash) September 11, 2025 14:10
@joker-eph joker-eph merged commit b22f94d into llvm:main Sep 11, 2025
9 checks passed
@joker-eph joker-eph deleted the caching_conversions branch September 11, 2025 15:25
@j2kun
Copy link
Contributor

j2kun commented Sep 11, 2025

Thanks for this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants