Skip to content

Commit

Permalink
[mlir] Require folders to produce Values of same type (#75887)
Browse files Browse the repository at this point in the history
This commit adds extra assertions to `OperationFolder` and `OpBuilder`
to ensure that the types of the folded SSA values match with the result
types of the op. There used to be checks that discard the folded results
if the types do not match. This commit makes these checks stricter and
turns them into assertions.

Discarding folded results with the wrong type (without failing
explicitly) can hide bugs in op folders. Two such bugs became apparent
in MLIR (and some more in downstream projects) and are fixed with this
change.

Note: The existing type checks were introduced in
https://reviews.llvm.org/D95991.

Migration guide: If you see failing assertions (`folder produced value
of incorrect type`; make sure to run with assertions enabled!), run with
`-debug` or dump the operation right before the failing assertion. This
will point you to the op that has the broken folder. A common mistake is
a mismatch between static/dynamic dimensions (e.g., input has a static
dimension but folded result has a dynamic dimension).
  • Loading branch information
matthias-springer committed Dec 20, 2023
1 parent 560564f commit f10302e
Show file tree
Hide file tree
Showing 11 changed files with 36 additions and 48 deletions.
6 changes: 4 additions & 2 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -625,11 +625,13 @@ void fir::BoxAddrOp::build(mlir::OpBuilder &builder,
mlir::OpFoldResult fir::BoxAddrOp::fold(FoldAdaptor adaptor) {
if (auto *v = getVal().getDefiningOp()) {
if (auto box = mlir::dyn_cast<fir::EmboxOp>(v)) {
if (!box.getSlice()) // Fold only if not sliced
// Fold only if not sliced
if (!box.getSlice() && box.getMemref().getType() == getType())
return box.getMemref();
}
if (auto box = mlir::dyn_cast<fir::EmboxCharOp>(v))
return box.getMemref();
if (box.getMemref().getType() == getType())
return box.getMemref();
}
return {};
}
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1352,9 +1352,11 @@ OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
setOperand(src);
return getResult();
}

// trunci(zexti(a)) -> a
// trunci(sexti(a)) -> a
return src;
if (srcType == dstType)
return src;
}

// trunci(trunci(a)) -> trunci(a))
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,8 @@ OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
if (!inputTy.hasRank()) \
return {}; \
if (inputTy != getType()) \
return {}; \
if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
return getInput(); \
return {}; \
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1602,9 +1602,10 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
: 0;
};

// If splat or broadcast from a scalar, just return the source scalar.
unsigned broadcastSrcRank = getRank(source.getType());
if (broadcastSrcRank == 0)
if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
return source;

unsigned extractResultRank = getRank(extractOp.getType());
Expand Down
5 changes: 1 addition & 4 deletions mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,14 +486,11 @@ LogicalResult OpBuilder::tryFold(Operation *op,

// Populate the results with the folded results.
Dialect *dialect = op->getDialect();
for (auto it : llvm::zip(foldResults, opResults.getTypes())) {
for (auto it : llvm::zip_equal(foldResults, opResults.getTypes())) {
Type expectedType = std::get<1>(it);

// Normal values get pushed back directly.
if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
if (value.getType() != expectedType)
return cleanupFailure();

results.push_back(value);
continue;
}
Expand Down
26 changes: 24 additions & 2 deletions mlir/lib/IR/Operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -606,13 +606,30 @@ void Operation::setSuccessor(Block *block, unsigned index) {
getBlockOperands()[index].set(block);
}

#ifndef NDEBUG
/// Assert that the folded results (in case of values) have the same type as
/// the results of the given op.
static void checkFoldResultTypes(Operation *op,
SmallVectorImpl<OpFoldResult> &results) {
if (!results.empty())
for (auto [ofr, opResult] : llvm::zip_equal(results, op->getResults()))
if (auto value = ofr.dyn_cast<Value>())
assert(value.getType() == opResult.getType() &&
"folder produced value of incorrect type");
}
#endif // NDEBUG

/// Attempt to fold this operation using the Op's registered foldHook.
LogicalResult Operation::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
// If we have a registered operation definition matching this one, use it to
// try to constant fold the operation.
if (succeeded(name.foldHook(this, operands, results)))
if (succeeded(name.foldHook(this, operands, results))) {
#ifndef NDEBUG
checkFoldResultTypes(this, results);
#endif // NDEBUG
return success();
}

// Otherwise, fall back on the dialect hook to handle it.
Dialect *dialect = getDialect();
Expand All @@ -623,7 +640,12 @@ LogicalResult Operation::fold(ArrayRef<Attribute> operands,
if (!interface)
return failure();

return interface->fold(this, operands, results);
LogicalResult status = interface->fold(this, operands, results);
#ifndef NDEBUG
if (succeeded(status))
checkFoldResultTypes(this, results);
#endif // NDEBUG
return status;
}

LogicalResult Operation::fold(SmallVectorImpl<OpFoldResult> &results) {
Expand Down
4 changes: 0 additions & 4 deletions mlir/lib/Transforms/Utils/FoldUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,6 @@ OperationFolder::processFoldResults(Operation *op,

// Check if the result was an SSA value.
if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
if (repl.getType() != op->getResult(i).getType()) {
results.clear();
return failure();
}
results.emplace_back(repl);
continue;
}
Expand Down
13 changes: 0 additions & 13 deletions mlir/test/Transforms/test-canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,6 @@ func.func @test_commutative_multi_cst(%arg0: i32, %arg1: i32) -> (i32, i32) {
return %y, %z: i32, i32
}

// CHECK-LABEL: func @typemismatch

func.func @typemismatch() -> i32 {
%c42 = arith.constant 42.0 : f32

// The "passthrough_fold" folder will naively return its operand, but we don't
// want to fold here because of the type mismatch.

// CHECK: "test.passthrough_fold"
%0 = "test.passthrough_fold"(%c42) : (f32) -> (i32)
return %0 : i32
}

// CHECK-LABEL: test_dialect_canonicalizer
func.func @test_dialect_canonicalizer() -> (i32) {
%0 = "test.dialect_canonicalizable"() : () -> (i32)
Expand Down
10 changes: 0 additions & 10 deletions mlir/test/Transforms/test-legalizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -310,16 +310,6 @@ builtin.module {

// -----

// The "passthrough_fold" folder will naively return its operand, but we don't
// want to fold here because of the type mismatch.
func.func @typemismatch(%arg: f32) -> i32 {
// expected-remark@+1 {{op 'test.passthrough_fold' is not legalizable}}
%0 = "test.passthrough_fold"(%arg) : (f32) -> (i32)
"test.return"(%0) : (i32) -> ()
}

// -----

// expected-remark @below {{applyPartialConversion failed}}
module {
func.func private @callee(%0 : f32) -> f32
Expand Down
4 changes: 0 additions & 4 deletions mlir/test/lib/Dialect/Test/TestDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,10 +542,6 @@ OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
return {};
}

OpFoldResult TestPassthroughFold::fold(FoldAdaptor adaptor) {
return getOperand();
}

OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
int64_t sum = 0;
if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
Expand Down
7 changes: 0 additions & 7 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1363,13 +1363,6 @@ def TestOpFoldWithFoldAdaptor
let hasFolder = 1;
}

// An op that always fold itself.
def TestPassthroughFold : TEST_Op<"passthrough_fold"> {
let arguments = (ins AnyType:$op);
let results = (outs AnyType);
let hasFolder = 1;
}

def TestDialectCanonicalizerOp : TEST_Op<"dialect_canonicalizable"> {
let arguments = (ins);
let results = (outs I32);
Expand Down

0 comments on commit f10302e

Please sign in to comment.