Skip to content

Commit

Permalink
[mlir][linalg] Fix FoldTensorCastConsumerOp invalid folding
Browse files Browse the repository at this point in the history
CastOp can be in conditionally reachable region, in which case this folding will be invalid.
Only conservatively fold ops in same block for now.

Fixes #56557

Differential Revision: https://reviews.llvm.org/D130314
  • Loading branch information
Hardcode84 committed Jul 22, 2022
1 parent 1ac12a5 commit f46744b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
7 changes: 7 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Expand Up @@ -1712,10 +1712,17 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
PatternRewriter &rewriter) const override {
if (!tensor::canFoldIntoProducerOp(castOp))
return failure();

auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
if (!linalgOp)
return failure();

// Cast can be in conditionally reachable region, if which case folding will
// generate invalid code. Only conservatively fold ops in same block for
// now.
if (castOp->getBlock() != linalgOp->getBlock())
return failure();

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(linalgOp);

Expand Down
27 changes: 27 additions & 0 deletions mlir/test/Dialect/Linalg/canonicalize.mlir
Expand Up @@ -846,6 +846,33 @@ func.func @fold_linalgop_with_cast_consumer(%arg0 : tensor<?x?xf32>, %arg1 : ten
// CHECK: %[[RESULT_CAST:.+]] = tensor.cast %[[MATMUL]]
// CHECK: return %[[MATMUL]], %[[RESULT_CAST]]

// -----

func.func private @some_use(%0 : tensor<4x8xf32>)

func.func @linalgop_with_cond_cast_consumer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>, %arg3 : i1) -> tensor<?x?xf32> {
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
scf.if %arg3 {
%1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x8xf32>
func.call @some_use(%1) : (tensor<4x8xf32>) -> ()
}
return %0 : tensor<?x?xf32>
}

// Check conditionally reachable cast is not folded into producer.
// CHECK-LABEL: func @linalgop_with_cond_cast_consumer
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>, %[[ARG2:.*]]: tensor<?x?xf32>, %[[ARG3:.*]]: i1)
// CHECK: %[[RES:.*]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
// CHECK-SAME: outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: scf.if %[[ARG3]] {
// CHECK: %[[CAST:.*]] = tensor.cast %[[RES]] : tensor<?x?xf32> to tensor<4x8xf32>
// CHECK: func.call @some_use(%[[CAST]]) : (tensor<4x8xf32>) -> ()
// CHECK: }
// CHECK: return %[[RES]] : tensor<?x?xf32>


// -----

func.func @fold_conv_op_with_cast_consumer(%arg0 : tensor<?x?x?x?xf32>,
Expand Down

0 comments on commit f46744b

Please sign in to comment.