diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md index 7070351755e7a..5ae35155b4ed5 100644 --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -285,9 +285,13 @@ conversions. A context-unaware conversion function converts a `Type` into a `Type`. A context-aware conversion function converts a `Value` into a type. The latter allows users to customize type conversion rules based on the IR. -Note: When there is at least one context-aware type conversion function, the -result of type conversions can no longer be cached, which can increase -compilation time. Use this feature with caution! +Note: context-aware type conversion functions impact the ability of the +framework to cache the conversion result. In the absence of a context-aware +conversion, all context-free type conversions can be cached. Otherwise only the +context-free conversions added after a context-aware type conversion can be +cached (conversions are applied in reverse order). +As such it is advised to add context-aware conversions as early as possible in +the sequence of `addConversion` calls (so that they apply last). A `materialization` describes how a list of values should be converted to a list of values with specific types. An important distinction from a diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 6949f4a14fdba..a096f82a4cfd8 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -433,7 +433,7 @@ class TypeConverter { std::is_same_v, ConversionCallbackFn> wrapCallback(FnT &&callback) { - hasContextAwareTypeConversions = true; + contextAwareTypeConversionsIndex = conversions.size(); return [callback = std::forward(callback)]( PointerUnion typeOrValue, SmallVectorImpl &results) -> std::optional { @@ -555,6 +555,10 @@ class TypeConverter { cachedMultiConversions.clear(); } + /// Internal implementation of the type conversion. + LogicalResult convertTypeImpl(PointerUnion t, + SmallVectorImpl &results) const; + /// The set of registered conversion functions. SmallVector conversions; @@ -575,10 +579,13 @@ class TypeConverter { mutable llvm::sys::SmartRWMutex cacheMutex; /// Whether the type converter has context-aware type conversions. I.e., /// conversion rules that depend on the SSA value instead of just the type. - /// Type conversion caching is deactivated when there are context-aware - /// conversions because the type converter may return different results for - /// the same input type. - bool hasContextAwareTypeConversions = false; + /// We store here the index in the `conversions` vector of the last added + /// context-aware conversion, if any. This is useful because we can't cache + /// the result of type conversion happening after context-aware conversions, + /// because the type converter may return different results for the same input + /// type. This is why it is recommened to add context-aware conversions first, + /// any context-free conversions after will benefit from caching. + int contextAwareTypeConversionsIndex = -1; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 36ee87b533b3b..df9700f11200f 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -3406,10 +3406,19 @@ void TypeConverter::SignatureConversion::remapInput( SmallVector(replacements.begin(), replacements.end())}; } -LogicalResult TypeConverter::convertType(Type t, - SmallVectorImpl &results) const { - assert(t && "expected non-null type"); - +/// Internal implementation of the type conversion. +/// This is used with either a Type or a Value as the first argument. +/// - we can cache the context-free conversions until the last registered +/// context-aware conversion. +/// - we can't cache the result of type conversion happening after context-aware +/// conversions, because the type converter may return different results for the +/// same input type. +LogicalResult +TypeConverter::convertTypeImpl(PointerUnion typeOrValue, + SmallVectorImpl &results) const { + assert(typeOrValue && "expected non-null type"); + Type t = (isa(typeOrValue)) ? cast(typeOrValue).getType() + : cast(typeOrValue); { std::shared_lock cacheReadLock(cacheMutex, std::defer_lock); @@ -3431,52 +3440,53 @@ LogicalResult TypeConverter::convertType(Type t, // registered first. size_t currentCount = results.size(); + // We can cache the context-free conversions until the last registered + // context-aware conversion. But only if we're processing a Value right now. + auto isCacheable = [&](int index) { + int numberOfConversionsUntilContextAware = + conversions.size() - 1 - contextAwareTypeConversionsIndex; + return index < numberOfConversionsUntilContextAware; + }; + std::unique_lock cacheWriteLock(cacheMutex, std::defer_lock); - for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) { - if (std::optional result = converter(t, results)) { - if (t.getContext()->isMultithreadingEnabled()) - cacheWriteLock.lock(); - if (!succeeded(*result)) { - assert(results.size() == currentCount && - "failed type conversion should not change results"); - cachedDirectConversions.try_emplace(t, nullptr); - return failure(); - } - auto newTypes = ArrayRef(results).drop_front(currentCount); - if (newTypes.size() == 1) - cachedDirectConversions.try_emplace(t, newTypes.front()); - else - cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes)); + for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) { + const ConversionCallbackFn &converter = indexedConverter.value(); + std::optional result = converter(typeOrValue, results); + if (!result) { + assert(results.size() == currentCount && + "failed type conversion should not change results"); + continue; + } + if (!isCacheable(indexedConverter.index())) return success(); - } else { + if (t.getContext()->isMultithreadingEnabled()) + cacheWriteLock.lock(); + if (!succeeded(*result)) { assert(results.size() == currentCount && "failed type conversion should not change results"); + cachedDirectConversions.try_emplace(t, nullptr); + return failure(); } + auto newTypes = ArrayRef(results).drop_front(currentCount); + if (newTypes.size() == 1) + cachedDirectConversions.try_emplace(t, newTypes.front()); + else + cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes)); + return success(); } return failure(); } -LogicalResult TypeConverter::convertType(Value v, +LogicalResult TypeConverter::convertType(Type t, SmallVectorImpl &results) const { - assert(v && "expected non-null value"); - - // If this type converter does not have context-aware type conversions, call - // the type-based overload, which has caching. - if (!hasContextAwareTypeConversions) - return convertType(v.getType(), results); + return convertTypeImpl(t, results); +} - // Walk the added converters in reverse order to apply the most recently - // registered first. - for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) { - if (std::optional result = converter(v, results)) { - if (!succeeded(*result)) - return failure(); - return success(); - } - } - return failure(); +LogicalResult TypeConverter::convertType(Value v, + SmallVectorImpl &results) const { + return convertTypeImpl(v, results); } Type TypeConverter::convertType(Type t) const {