Skip to content

Commit

Permalink
Merge a62dd5a into 4294a5b
Browse files Browse the repository at this point in the history
  • Loading branch information
MaheshRavishankar committed Jun 28, 2024
2 parents 4294a5b + a62dd5a commit b8a2701
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -636,10 +636,12 @@ isFusableWithConsumer(OpOperand &fusedOperand,
// Check if the iteration spaces of the producer and consumer are same.
// TODO(#12664): This is unnecessary requirement, but we need a better config
// to tile the consumer with a larger iteration space.
auto producerIterationSpace = producerFusionOp.getStaticLoopRanges();
auto consumerIterationSpace = consumerFusionOp.getStaticLoopRanges();
if (producerIterationSpace.size() < consumerIterationSpace.size()) {
return false;
if (!options.aggressiveFusion) {
auto producerIterationSpace = producerFusionOp.getStaticLoopRanges();
auto consumerIterationSpace = consumerFusionOp.getStaticLoopRanges();
if (producerIterationSpace.size() < consumerIterationSpace.size()) {
return false;
}
}

// Under aggressive fusion assume that the dispatches are vectorized. In which
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions))" --split-input-file %s | FileCheck %s
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}))" --split-input-file %s | FileCheck %s

util.func public @pack_elementwise_fusion(%arg0 : tensor<?xf32>,
%arg1 : tensor<?x?xf32>) -> tensor<?x?x8x32xf32> {
Expand Down Expand Up @@ -640,3 +640,109 @@ util.func public @broadcasting_dequant_op(%arg0 : tensor<?x?xi8>,
// CHECK-SAME: ins(%[[GENERIC]],
// CHECK: flow.return %[[MATMUL]]
// CHECK: return %[[RETURN]]

// -----

util.func @softmax_like_fusion(%arg0: tensor<2x4096x640xf16>,
%arg1: tensor<640xf16>, %arg2: tensor<640xf16>) -> tensor<2x4096x640x1xf16> {
%expanded = tensor.expand_shape %arg0 [[0], [1], [2, 3]]
output_shape [2, 4096, 640, 1] : tensor<2x4096x640xf16> into tensor<2x4096x640x1xf16>
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 1.100000e+01 : f32
%cst_1 = arith.constant 4.000000e+00 : f32
%0 = tensor.empty() : tensor<2x4096x640xf32>
%1 = tensor.empty() : tensor<2x4096x640x1xf16>
%2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg0 : tensor<2x4096x640xf16>) outs(%0 : tensor<2x4096x640xf32>) {
^bb0(%in: f16, %out: f32):
%9 = arith.extf %in : f16 to f32
linalg.yield %9 : f32
} -> tensor<2x4096x640xf32>
%3 = tensor.empty() : tensor<2x4096xf32>
%4 = linalg.fill ins(%cst : f32)
outs(%3 : tensor<2x4096xf32>) -> tensor<2x4096xf32>
%5 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%2 : tensor<2x4096x640xf32>) outs(%4 : tensor<2x4096xf32>) {
^bb0(%in: f32, %out: f32):
%9 = arith.addf %in, %out : f32
linalg.yield %9 : f32
} -> tensor<2x4096xf32>
%6 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%5 : tensor<2x4096xf32>) outs(%3 : tensor<2x4096xf32>) {
^bb0(%in: f32, %out: f32):
%9 = arith.divf %in, %cst_0 : f32
linalg.yield %9 : f32
} -> tensor<2x4096xf32>
%7 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%2, %6 : tensor<2x4096x640xf32>, tensor<2x4096xf32>)
outs(%4 : tensor<2x4096xf32>) {
^bb0(%in: f32, %in_4: f32, %out: f32):
%9 = arith.subf %in, %in_4 : f32
%10 = arith.mulf %9, %9 : f32
%11 = arith.addf %10, %out : f32
linalg.yield %11 : f32
} -> tensor<2x4096xf32>
%expanded_2 = tensor.expand_shape %arg1 [[0, 1]] output_shape [640, 1]
: tensor<640xf16> into tensor<640x1xf16>
%expanded_3 = tensor.expand_shape %arg2 [[0, 1]] output_shape [640, 1]
: tensor<640xf16> into tensor<640x1xf16>
%8 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1)>,
affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%expanded, %6, %7, %expanded_2, %expanded_3
: tensor<2x4096x640x1xf16>, tensor<2x4096xf32>, tensor<2x4096xf32>,
tensor<640x1xf16>, tensor<640x1xf16>)
outs(%1 : tensor<2x4096x640x1xf16>) {
^bb0(%in: f16, %in_4: f32, %in_5: f32, %in_6: f16, %in_7: f16, %out: f16):
%9 = arith.divf %in_5, %cst_0 : f32
%10 = arith.addf %9, %cst_1 : f32
%11 = math.rsqrt %10 : f32
%12 = arith.extf %in : f16 to f32
%13 = arith.subf %12, %in_4 : f32
%14 = arith.mulf %13, %11 : f32
%15 = arith.extf %in_6 : f16 to f32
%16 = arith.mulf %14, %15 : f32
%17 = arith.extf %in_7 : f16 to f32
%18 = arith.addf %16, %17 : f32
%19 = arith.truncf %18 : f32 to f16
linalg.yield %19 : f16
} -> tensor<2x4096x640x1xf16>
util.return %8 : tensor<2x4096x640x1xf16>
}
// CHECK-LABEL: func public @softmax_like_fusion(
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x4096x640xf16>
// CHECK: %[[BITEXTEND:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] :
// CHECK: %[[RESULT:.+]] = flow.dispatch.region
// CHECK: %[[GENERIC1:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[BITEXTEND]] :
// CHECK: %[[GENERIC2:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[GENERIC1]] :
// CHECK: %[[GENERIC3:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[BITEXTEND]], %[[GENERIC2]] :
// CHECK: %[[GENERIC4:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%{{.+}}, %[[GENERIC2]], %[[GENERIC3]]
// CHECK: flow.return %[[GENERIC4]]
// CHECK: util.return %[[RESULT]]

0 comments on commit b8a2701

Please sign in to comment.