Skip to content

Commit

Permalink
Fuse iota ops with consumers always. (#14070)
Browse files Browse the repository at this point in the history
Current fusion heuristics always fuse copy-like ops with its consumers. Iota ops are also copy-like ops (indeed if the deprecated linalg.indexed_generic were around it would still be a copy-like op).

Fixes #13745
  • Loading branch information
MaheshRavishankar committed Jun 13, 2023
1 parent 389152b commit e0d36f9
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ static bool areFusableOps(MLIRContext *context, OpOperand *fusedOperand) {
// If the generic op is "just" copy, then fuse always.
Block &body = producerOp->getRegion(0).front();
if (std::begin(body)->hasTrait<OpTrait::IsTerminator>()) return true;
if (llvm::all_of(body.getArguments(),
[](BlockArgument arg) { return arg.use_empty(); })) {
// THe operands arent used, its just an `linalg.index` op.
return true;
}

// If producer does not have a single user, dont fuse.
if (!producerOp->hasOneUse()) return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,59 @@ module {
// CHECK-SAME: ins(%[[GENERIC1]], %[[GENERIC0]] :
// CHECK-SAME: outs(%[[FILL]] :
// CHECK: return %[[GENERIC2]]

// -----

func.func @fuse_iota_ops(%arg0: tensor<10x20xi32>) -> (tensor<10x20xi32>, tensor<10x20xi32>) {
%c20 = arith.constant 20 : index
%0 = tensor.empty() : tensor<10x20xi32>
%1 = tensor.empty() : tensor<10x20xindex>
%2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
outs(%1 : tensor<10x20xindex>) {
^bb0(%b0 : index):
%3 = linalg.index 0 : index
%4 = linalg.index 1 : index
%5 = arith.muli %4, %c20 : index
%6 = arith.addi %3, %5 : index
linalg.yield %6 : index
} -> tensor<10x20xindex>
%7 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0, %2: tensor<10x20xi32>, tensor<10x20xindex>) outs(%0 : tensor<10x20xi32>) {
^bb0(%b0 : i32, %b1 : index, %b2 : i32):
%8 = arith.index_cast %b1 : index to i32
%9 = arith.addi %8, %b0 : i32
linalg.yield %9 : i32
} -> tensor<10x20xi32>
%8 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0, %2: tensor<10x20xi32>, tensor<10x20xindex>) outs(%0 : tensor<10x20xi32>) {
^bb0(%b0 : i32, %b1 : index, %b2 : i32):
%8 = arith.index_cast %b1 : index to i32
%9 = arith.muli %8, %b0 : i32
linalg.yield %9 : i32
} -> tensor<10x20xi32>
return %7, %8 : tensor<10x20xi32>, tensor<10x20xi32>
}
// CHECK-LABEL: func @fuse_iota_ops(
// CHECK-SAME: %[[ARG0:.+]]: tensor<10x20xi32>)
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<10x20xi32>
// CHECK: %[[GENERIC1:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] : tensor<10x20xi32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<10x20xi32>)
// CHECK: linalg.index
// CHECK: linalg.index
// CHECK: arith.addi
// CHECK: linalg.yield
// CHECK: %[[GENERIC2:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] : tensor<10x20xi32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<10x20xi32>)
// CHECK: linalg.index
// CHECK: linalg.index
// CHECK: arith.muli
// CHECK: linalg.yield
// CHECK: return %[[GENERIC1]], %[[GENERIC2]]

0 comments on commit e0d36f9

Please sign in to comment.