Skip to content

Commit

Permalink
[mlir] Fix folding for scf.for(tensor.cast).
Browse files Browse the repository at this point in the history
We should only fold tensor.casts that provide some new static information about
shapes, instead of looking for a symmetric pattern cast(for(cast)).

Differential Revision: https://reviews.llvm.org/D144577
  • Loading branch information
pifon2a committed Feb 23, 2023
1 parent 08f0388 commit a5cdcf4
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 28 deletions.
20 changes: 8 additions & 12 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Expand Up @@ -894,8 +894,7 @@ static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
/// %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
/// scf.yield %2 : tensor<?x?xf32>
/// }
/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
/// use_of(%2)
/// use_of(%1)
/// ```
///
/// folds into:
Expand All @@ -908,7 +907,8 @@ static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
/// %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
/// scf.yield %4 : tensor<32x1024xf32>
/// }
/// use_of(%0)
/// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor<?x?xf32>
/// use_of(%1)
/// ```
struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
using OpRewritePattern<ForOp>::OpRewritePattern;
Expand All @@ -920,17 +920,13 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
if (!incomingCast)
continue;
// If the dest type of the cast does not preserve static information in
// the source type.
if (!tensor::preservesStaticInformation(incomingCast.getDest().getType(),
incomingCast.getSource().getType()))
continue;
if (!std::get<1>(it).hasOneUse())
continue;
auto outgoingCastOp =
dyn_cast<tensor::CastOp>(*std::get<1>(it).user_begin());
if (!outgoingCastOp)
continue;

// Must be a tensor.cast op pair with matching types.
if (outgoingCastOp.getResult().getType() !=
incomingCast.getSource().getType())
continue;

// Create a new ForOp with that iter operand replaced.
auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
Expand Down
29 changes: 13 additions & 16 deletions mlir/test/Dialect/SCF/canonicalize.mlir
Expand Up @@ -850,32 +850,29 @@ func.func @fold_away_iter_and_result_with_no_use(%arg0 : i32,

func.func private @do(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>

// CHECK-LABEL: matmul_on_tensors
// CHECK-SAME: %[[T0:[0-9a-z]*]]: tensor<32x1024xf32>
// CHECK-SAME: %[[T1:[0-9a-z]*]]: tensor<1024x1024xf32>
func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c1024 = arith.constant 1024 : index
%0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
%1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor<?x?xf32>) {
%2 = func.call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
scf.yield %2 : tensor<?x?xf32>
} {some_attr}
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: matmul_on_tensors
// CHECK-SAME: %[[T0:[0-9a-z]*]]: tensor<32x1024xf32>

// CHECK-NOT: tensor.cast
// CHECK: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args(%[[ITER_T0:.*]] = %[[T0]]) -> (tensor<32x1024xf32>) {
// CHECK: %[[CAST:.*]] = tensor.cast %[[ITER_T0]] : tensor<32x1024xf32> to tensor<?x?xf32>
// CHECK: %[[DONE:.*]] = func.call @do(%[[CAST]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[UNCAST:.*]] = tensor.cast %[[DONE]] : tensor<?x?xf32> to tensor<32x1024xf32>
// CHECK: scf.yield %[[UNCAST]] : tensor<32x1024xf32>
// CHECK: } {some_attr}
%0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
%1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0) -> (tensor<?x?xf32>) {
%2 = func.call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
scf.yield %2 : tensor<?x?xf32>
} {some_attr}
// CHECK-NOT: tensor.cast
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[FOR_RES]] into %[[T1]][0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32>
// CHECK: return %[[RES]] : tensor<1024x1024xf32>
%2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
%res = tensor.insert_slice %2 into %t1[0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32>
return %res : tensor<1024x1024xf32>
}
// CHECK: %[[RES:.*]] = tensor.cast
// CHECK: return %[[RES]] : tensor<?x?xf32>

// -----

Expand Down

0 comments on commit a5cdcf4

Please sign in to comment.