Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions mlir/docs/DialectConversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,13 @@ conversions. A context-unaware conversion function converts a `Type` into a
`Type`. A context-aware conversion function converts a `Value` into a type. The
latter allows users to customize type conversion rules based on the IR.

Note: When there is at least one context-aware type conversion function, the
result of type conversions can no longer be cached, which can increase
compilation time. Use this feature with caution!
Note: context-aware type conversion functions impact the ability of the
framework to cache the conversion result. In the absence of a context-aware
conversion, all context-free type conversions can be cached. Otherwise only the
context-free conversions added after a context-aware type conversion can be
cached (conversions are applied in reverse order).
As such it is advised to add context-aware conversions as early as possible in
the sequence of `addConversion` calls (so that they apply last).

A `materialization` describes how a list of values should be converted to a
list of values with specific types. An important distinction from a
Expand Down
17 changes: 12 additions & 5 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ class TypeConverter {
std::is_same_v<T, Value>,
ConversionCallbackFn>
wrapCallback(FnT &&callback) {
hasContextAwareTypeConversions = true;
contextAwareTypeConversionsIndex = conversions.size();
return [callback = std::forward<FnT>(callback)](
PointerUnion<Type, Value> typeOrValue,
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
Expand Down Expand Up @@ -555,6 +555,10 @@ class TypeConverter {
cachedMultiConversions.clear();
}

/// Internal implementation of the type conversion.
LogicalResult convertTypeImpl(PointerUnion<Type, Value> t,
SmallVectorImpl<Type> &results) const;

/// The set of registered conversion functions.
SmallVector<ConversionCallbackFn, 4> conversions;

Expand All @@ -575,10 +579,13 @@ class TypeConverter {
mutable llvm::sys::SmartRWMutex<true> cacheMutex;
/// Whether the type converter has context-aware type conversions. I.e.,
/// conversion rules that depend on the SSA value instead of just the type.
/// Type conversion caching is deactivated when there are context-aware
/// conversions because the type converter may return different results for
/// the same input type.
bool hasContextAwareTypeConversions = false;
/// We store here the index in the `conversions` vector of the last added
/// context-aware conversion, if any. This is useful because we can't cache
/// the result of type conversion happening after context-aware conversions,
/// because the type converter may return different results for the same input
/// type. This is why it is recommened to add context-aware conversions first,
/// any context-free conversions after will benefit from caching.
int contextAwareTypeConversionsIndex = -1;
};

//===----------------------------------------------------------------------===//
Expand Down
84 changes: 47 additions & 37 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3406,10 +3406,19 @@ void TypeConverter::SignatureConversion::remapInput(
SmallVector<Value, 1>(replacements.begin(), replacements.end())};
}

LogicalResult TypeConverter::convertType(Type t,
SmallVectorImpl<Type> &results) const {
assert(t && "expected non-null type");

/// Internal implementation of the type conversion.
/// This is used with either a Type or a Value as the first argument.
/// - we can cache the context-free conversions until the last registered
/// context-aware conversion.
/// - we can't cache the result of type conversion happening after context-aware
/// conversions, because the type converter may return different results for the
/// same input type.
LogicalResult
TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
SmallVectorImpl<Type> &results) const {
assert(typeOrValue && "expected non-null type");
Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
: cast<Type>(typeOrValue);
{
std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
std::defer_lock);
Expand All @@ -3431,52 +3440,53 @@ LogicalResult TypeConverter::convertType(Type t,
// registered first.
size_t currentCount = results.size();

// We can cache the context-free conversions until the last registered
// context-aware conversion. But only if we're processing a Value right now.
auto isCacheable = [&](int index) {
int numberOfConversionsUntilContextAware =
conversions.size() - 1 - contextAwareTypeConversionsIndex;
return index < numberOfConversionsUntilContextAware;
};

std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
std::defer_lock);

for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
if (std::optional<LogicalResult> result = converter(t, results)) {
if (t.getContext()->isMultithreadingEnabled())
cacheWriteLock.lock();
if (!succeeded(*result)) {
assert(results.size() == currentCount &&
"failed type conversion should not change results");
cachedDirectConversions.try_emplace(t, nullptr);
return failure();
}
auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
if (newTypes.size() == 1)
cachedDirectConversions.try_emplace(t, newTypes.front());
else
cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
const ConversionCallbackFn &converter = indexedConverter.value();
std::optional<LogicalResult> result = converter(typeOrValue, results);
if (!result) {
assert(results.size() == currentCount &&
"failed type conversion should not change results");
continue;
}
if (!isCacheable(indexedConverter.index()))
return success();
} else {
if (t.getContext()->isMultithreadingEnabled())
cacheWriteLock.lock();
if (!succeeded(*result)) {
assert(results.size() == currentCount &&
"failed type conversion should not change results");
cachedDirectConversions.try_emplace(t, nullptr);
return failure();
}
auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
if (newTypes.size() == 1)
cachedDirectConversions.try_emplace(t, newTypes.front());
else
cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
return success();
}
return failure();
}

LogicalResult TypeConverter::convertType(Value v,
LogicalResult TypeConverter::convertType(Type t,
SmallVectorImpl<Type> &results) const {
assert(v && "expected non-null value");

// If this type converter does not have context-aware type conversions, call
// the type-based overload, which has caching.
if (!hasContextAwareTypeConversions)
return convertType(v.getType(), results);
return convertTypeImpl(t, results);
}

// Walk the added converters in reverse order to apply the most recently
// registered first.
for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
if (std::optional<LogicalResult> result = converter(v, results)) {
if (!succeeded(*result))
return failure();
return success();
}
}
return failure();
LogicalResult TypeConverter::convertType(Value v,
SmallVectorImpl<Type> &results) const {
return convertTypeImpl(v, results);
}

Type TypeConverter::convertType(Type t) const {
Expand Down