Skip to content

Commit

Permalink
[mlir][Complex] Fix bug in MergeComplexBitcast (#74271)
Browse files Browse the repository at this point in the history
When two `complex.bitcast` ops are folded and the resulting bitcast is a
non-complex -> non-complex bitcast, an `arith.bitcast` should be
generated. Otherwise, the generated `complex.bitcast` op is invalid.

Also remove a pattern that convertes non-complex -> non-complex
`complex.bitcast` ops to `arith.bitcast`. Such `complex.bitcast` ops are
invalid and should not appear in the input.

Note: This bug can only be triggered by running with `-debug` (which
will should intermediate IR that does not verify) or with
`MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS` (#74270).
  • Loading branch information
matthias-springer committed Dec 5, 2023
1 parent c3a9c90 commit 192439d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 20 deletions.
31 changes: 12 additions & 19 deletions mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ LogicalResult BitcastOp::verify() {
}

if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
return emitOpError("requires input or output is a complex type");
return emitOpError(
"requires that either input or output has a complex type");
}

if (isa<ComplexType>(resultType))
Expand All @@ -125,8 +126,15 @@ struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
LogicalResult matchAndRewrite(BitcastOp op,
PatternRewriter &rewriter) const override {
if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
defining.getOperand());
if (isa<ComplexType>(op.getType()) ||
isa<ComplexType>(defining.getOperand().getType())) {
// complex.bitcast requires that input or output is complex.
rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
defining.getOperand());
} else {
rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
defining.getOperand());
}
return success();
}

Expand Down Expand Up @@ -155,24 +163,9 @@ struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
}
};

struct ArithBitcast final : OpRewritePattern<BitcastOp> {
using OpRewritePattern<complex::BitcastOp>::OpRewritePattern;

LogicalResult matchAndRewrite(BitcastOp op,
PatternRewriter &rewriter) const override {
if (isa<ComplexType>(op.getType()) ||
isa<ComplexType>(op.getOperand().getType()))
return failure();

rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
op.getOperand());
return success();
}
};

void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ArithBitcast, MergeComplexBitcast, MergeArithBitcast>(context);
results.add<MergeComplexBitcast, MergeArithBitcast>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Complex/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func.func @complex_constant_two_different_element_types() {
// -----

func.func @complex_bitcast_i64(%arg0 : i64) {
// expected-error @+1 {{op requires input or output is a complex type}}
// expected-error @+1 {{op requires that either input or output has a complex type}}
%0 = complex.bitcast %arg0: i64 to f64
return
}

0 comments on commit 192439d

Please sign in to comment.