Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Transforms] Dialect conversion: Fix missing source materialization #97903

Merged
merged 1 commit into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/docs/DialectConversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,8 @@ class TypeConverter {
/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
/// a signature conversion of a single block argument, to a single SSA value.
/// a signature conversion of a single block argument, to a single SSA value
/// with the old argument type.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
Expand Down
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ def ApplySCFStructuralConversionPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplySCFToControlFlowPatternsOp : Op<Transform_Dialect,
"apply_conversion_patterns.scf.scf_to_control_flow",
[DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface>]> {
let description = [{
Collects patterns that lower structured control flow ops to unstructured
control flow.
}];

let assemblyFormat = "attr-dict";
}

def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;

def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
Expand Down
10 changes: 5 additions & 5 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ class TypeConverter {
/// where `T` is any subclass of `Type`. This function is responsible for
/// creating an operation, using the OpBuilder and Location provided, that
/// "casts" a range of values into a single value of the given type `T`. It
/// must return a Value of the converted type on success, an `std::nullopt` if
/// must return a Value of the type `T` on success, an `std::nullopt` if
/// it failed but other materialization can be attempted, and `nullptr` on
/// unrecoverable failure. It will only be called for (sub)types of `T`.
/// Materialization functions must be provided when a type conversion may
/// persist after the conversion has finished.
/// unrecoverable failure. Materialization functions must be provided when a
/// type conversion may persist after the conversion has finished.

/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
/// a signature conversion of a single block argument, to a single SSA value.
/// a signature conversion of a single block argument, to a single SSA value
/// with the old block argument type.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
Expand Down
28 changes: 21 additions & 7 deletions mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,11 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
type.isVarArg());
});

// Materialization for memrefs creates descriptor structs from individual
// values constituting them, when descriptors are used, i.e. more than one
// value represents a memref.
// Argument materializations convert from the new block argument types
// (multiple SSA values that make up a memref descriptor) back to the
// original block argument type. The dialect conversion framework will then
// insert a target materialization from the original block argument type to
// a legal type.
addArgumentMaterialization(
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
Location loc) -> std::optional<Value> {
Expand All @@ -164,12 +166,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
// memref descriptor cannot be built just from a bare pointer.
return std::nullopt;
}
return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
inputs);
Value desc = UnrankedMemRefDescriptor::pack(builder, loc, *this,
resultType, inputs);
// An argument materialization must return a value of type
// `resultType`, so insert a cast from the memref descriptor type
// (!llvm.struct) to the original memref type.
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
.getResult(0);
});
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Value desc;
if (inputs.size() == 1) {
// This is a bare pointer. We allow bare pointers only for function entry
// blocks.
Expand All @@ -180,10 +188,16 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
if (!block->isEntryBlock() ||
!isa<FunctionOpInterface>(block->getParentOp()))
return std::nullopt;
return MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
inputs[0]);
} else {
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
}
return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
// An argument materialization must return a value of type `resultType`,
// so insert a cast from the memref descriptor type (!llvm.struct) to the
// original memref type.
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
.getResult(0);
});
// Add generic source and target materializations to handle cases where
// non-LLVM types persist after an LLVM conversion.
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransformOps
MLIRIR
MLIRLoopLikeInterface
MLIRSCFDialect
MLIRSCFToControlFlow
MLIRSCFTransforms
MLIRSCFUtils
MLIRTransformDialect
Expand Down
13 changes: 11 additions & 2 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"

#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -49,6 +51,11 @@ void transform::ApplySCFStructuralConversionPatternsOp::
conversionTarget);
}

void transform::ApplySCFToControlFlowPatternsOp::populatePatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
populateSCFToControlFlowConversionPatterns(patterns);
}

//===----------------------------------------------------------------------===//
// ForallToForOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -261,8 +268,10 @@ loopScheduling(scf::ForOp forOp,
return 1;
};

std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound());
std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound());
std::optional<int64_t> ubConstant =
getConstantIntValue(forOp.getUpperBound());
std::optional<int64_t> lbConstant =
getConstantIntValue(forOp.getLowerBound());
DenseMap<Operation *, unsigned> opCycles;
std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
for (Operation &op : forOp.getBody()->getOperations()) {
Expand Down
87 changes: 44 additions & 43 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,10 +707,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl,
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
MaterializationKind kind = MaterializationKind::Target,
Type origOutputType = nullptr)
MaterializationKind kind = MaterializationKind::Target)
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
converterAndKind(converter, kind), origOutputType(origOutputType) {}
converterAndKind(converter, kind) {}

static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
Expand All @@ -734,17 +733,11 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
return converterAndKind.getInt();
}

/// Return the original illegal output type of the input values.
Type getOrigOutputType() const { return origOutputType; }

private:
/// The corresponding type converter to use when resolving this
/// materialization, and the kind of this materialization.
llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
converterAndKind;

/// The original output type. This is only used for argument conversions.
Type origOutputType;
};
} // namespace

Expand Down Expand Up @@ -860,12 +853,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Block *insertBlock,
Block::iterator insertPt, Location loc,
ValueRange inputs, Type outputType,
Type origOutputType,
const TypeConverter *converter);

Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
ValueRange inputs,
Type origOutputType,
Type outputType,
const TypeConverter *converter);

Expand Down Expand Up @@ -1388,20 +1379,28 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
if (replArgs.size() == 1 &&
(!converter || replArgs[0].getType() == origArg.getType())) {
newArg = replArgs.front();
mapping.map(origArg, newArg);
} else {
Type origOutputType = origArg.getType();

// Legalize the argument output type.
Type outputType = origOutputType;
if (Type legalOutputType = converter->convertType(outputType))
outputType = legalOutputType;

newArg = buildUnresolvedArgumentMaterialization(
newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
converter);
// Build argument materialization: new block arguments -> old block
// argument type.
Value argMat = buildUnresolvedArgumentMaterialization(
newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
mapping.map(origArg, argMat);

// Build target materialization: old block argument type -> legal type.
matthias-springer marked this conversation as resolved.
Show resolved Hide resolved
// Note: This function returns an "empty" type if no valid conversion to
// a legal type exists. In that case, we continue the conversion with the
// original block argument type.
Type legalOutputType = converter->convertType(origArg.getType());
if (legalOutputType && legalOutputType != origArg.getType()) {
newArg = buildUnresolvedTargetMaterialization(
origArg.getLoc(), argMat, legalOutputType, converter);
mapping.map(argMat, newArg);
} else {
newArg = argMat;
}
}

mapping.map(origArg, newArg);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}
Expand All @@ -1424,7 +1423,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// of input operands.
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
Location loc, ValueRange inputs, Type outputType, Type origOutputType,
Location loc, ValueRange inputs, Type outputType,
const TypeConverter *converter) {
// Avoid materializing an unnecessary cast.
if (inputs.size() == 1 && inputs.front().getType() == outputType)
Expand All @@ -1435,16 +1434,15 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
OpBuilder builder(insertBlock, insertPt);
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
origOutputType);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
return convertOp.getResult(0);
}
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
Block *block, Location loc, ValueRange inputs, Type origOutputType,
Type outputType, const TypeConverter *converter) {
Block *block, Location loc, ValueRange inputs, Type outputType,
const TypeConverter *converter) {
return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
block->begin(), loc, inputs, outputType,
origOutputType, converter);
converter);
}
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
Location loc, Value input, Type outputType,
Expand All @@ -1456,7 +1454,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(

return buildUnresolvedMaterialization(MaterializationKind::Target,
insertBlock, insertPt, loc, input,
outputType, outputType, converter);
outputType, converter);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2672,19 +2670,28 @@ static void computeNecessaryMaterializations(
ConversionPatternRewriterImpl &rewriterImpl,
DenseMap<Value, SmallVector<Value>> &inverseMapping,
SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
// Helper function to check if the given value or a not yet materialized
// replacement of the given value is live.
// Note: `inverseMapping` maps from replaced values to original values.
auto isLive = [&](Value value) {
auto findFn = [&](Operation *user) {
auto matIt = materializationOps.find(user);
if (matIt != materializationOps.end())
return !necessaryMaterializations.count(matIt->second);
return rewriterImpl.isOpIgnored(user);
};
// This value may be replacing another value that has a live user.
for (Value inv : inverseMapping.lookup(value))
if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
// A worklist is needed because a value may have gone through a chain of
// replacements and each of the replaced values may have live users.
SmallVector<Value> worklist;
worklist.push_back(value);
while (!worklist.empty()) {
Value next = worklist.pop_back_val();
if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end())
return true;
// Or have live users itself.
return llvm::find_if_not(value.getUsers(), findFn) != value.user_end();
// This value may be replacing another value that has a live user.
llvm::append_range(worklist, inverseMapping.lookup(next));
}
return false;
};

llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
Expand Down Expand Up @@ -2844,18 +2851,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
switch (mat.getMaterializationKind()) {
case MaterializationKind::Argument:
Copy link
Member

@zero9178 zero9178 Jul 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am slightly confused by the comment at line 2804 which states that this code only deals with target materializations (I am interpreting this as a materializations to the target type system, not specifically target conversions).
Doesn't the argument materialization now returning values from the source type system somewhat contradict this? Same with the fallback to target materialization which is guaranteed to return a different type.

It seems me either the comment needs to be updated or the fallback path can be removed or changed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, that comment makes no sense to me. This part of the code base deals exclusively with materializations; there are no type conversions anymore.

When this comment was added, the implementation already handled both argument and target materializations. Source materializations are not handled here because they never show up as "unresolved materializations". There isn't even a MaterializationKind::Source enum value. So I think this should be rephrased as We currently handle only argument and target materializations here. I believe if we were to handle source materializations here, a few unrealized_conversion_cast ops (the ones that cancel out with target materializations) would not have to be materialized. So the comment could be a kind of TODO to support source materializations.

// We currently only handle target materializations here.
OpResult opResult = op->getOpResult(0);

Interestingly this comment is right before the getOpResult(0). Another limitation of this part of the code base is that 1:N materializations are not supported. (But neither does the type converter API support it when adding materialization functions.)

// Try to materialize an argument conversion.
// FIXME: The current argument materialization hook expects the original
// output type, even though it doesn't use that as the actual output type
// of the generated IR. The output type is just used as an indicator of
// the type of materialization to do. This behavior is really awkward in
// that it diverges from the behavior of the other hooks, and can be
// easily misunderstood. We should clean up the argument hooks to better
// represent the desired invariants we actually care about.
newMaterialization = converter->materializeArgumentConversion(
rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
rewriter, op->getLoc(), outputType, inputOperands);
if (newMaterialization)
break;

// If an argument materialization failed, fallback to trying a target
// materialization.
[[fallthrough]];
Expand All @@ -2865,6 +2864,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
break;
}
if (newMaterialization) {
assert(newMaterialization.getType() == outputType &&
"materialization callback produced value of incorrect type");
replaceMaterialization(rewriterImpl, opResult, newMaterialization,
inverseMapping);
return success();
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: mlir-opt -convert-func-to-llvm -reconcile-unrealized-casts %s | FileCheck %s

// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1' %s | FileCheck %s --check-prefix=BAREPTR
// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1' -reconcile-unrealized-casts %s | FileCheck %s --check-prefix=BAREPTR

// RUN: mlir-opt -transform-interpreter %s | FileCheck %s --check-prefix=BAREPTR
// RUN: mlir-opt -transform-interpreter -reconcile-unrealized-casts %s | FileCheck %s --check-prefix=BAREPTR

// These tests were separated from func-memref.mlir because applying
// -reconcile-unrealized-casts resulted in `llvm.extractvalue` ops getting
Expand Down
44 changes: 44 additions & 0 deletions mlir/test/Transforms/test-block-legalization.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s

// CHECK-LABEL: func @complex_block_signature_conversion(
// CHECK: %[[cst:.*]] = complex.constant
// CHECK: %[[complex_llvm:.*]] = builtin.unrealized_conversion_cast %[[cst]] : complex<f64> to !llvm.struct<(f64, f64)>
// Note: Some blocks are omitted.
// CHECK: llvm.br ^[[block1:.*]](%[[complex_llvm]]
// CHECK: ^[[block1]](%[[arg:.*]]: !llvm.struct<(f64, f64)>):
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : !llvm.struct<(f64, f64)> to complex<f64>
// CHECK: llvm.br ^[[block2:.*]]
// CHECK: ^[[block2]]:
// CHECK: "test.consumer_of_complex"(%[[cast]]) : (complex<f64>) -> ()
func.func @complex_block_signature_conversion() {
%cst = complex.constant [0.000000e+00, 0.000000e+00] : complex<f64>
%true = arith.constant true
%0 = scf.if %true -> complex<f64> {
scf.yield %cst : complex<f64>
} else {
scf.yield %cst : complex<f64>
}

// Regression test to ensure that the a source materialization is inserted.
// The operand of "test.consumer_of_complex" must not change.
"test.consumer_of_complex"(%0) : (complex<f64>) -> ()
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %toplevel_module
: (!transform.any_op) -> !transform.any_op
transform.apply_conversion_patterns to %func {
transform.apply_conversion_patterns.dialect_to_llvm "cf"
transform.apply_conversion_patterns.func.func_to_llvm
transform.apply_conversion_patterns.scf.scf_to_control_flow
} with type_converter {
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
} {
legal_dialects = ["llvm"],
partial_conversion
} : !transform.any_op
transform.yield
}
}
Loading