Skip to content

Commit

Permalink
Merge 5c9dac6 into 7b58c71
Browse files Browse the repository at this point in the history
  • Loading branch information
MaheshRavishankar committed Jun 22, 2024
2 parents 7b58c71 + 5c9dac6 commit b23c024
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,36 @@ class SinkReshapesPass : public impl::SinkReshapesPassBase<SinkReshapesPass> {
/// fusion analysis actually happens, but that requires a direct producer ->
/// consumer relationship and indexing maps for the right analysis. Here
/// we just approximate it (and try to be optimistic)
static bool isFusableUsingTileAndFuse(Operation *producer,
Operation *consumer) {
return llvm::isa_and_nonnull<linalg::LinalgOp, tensor::UnPackOp,
Encoding::UnsetEncodingOp>(producer);
static bool isFusableUsingTileAndFuse(Operation *producer, Operation *consumer,
unsigned consumerOperandNumber) {
if (llvm::isa_and_nonnull<tensor::UnPackOp, Encoding::UnsetEncodingOp>(
producer)) {
return true;
}

// If the producer is a linalg op.
auto producerLinalgOp = dyn_cast_or_null<linalg::LinalgOp>(producer);
if (!producerLinalgOp) {
return false;
}
// Ignore elementwise linalg op producers.
if (producerLinalgOp.getNumLoops() ==
producerLinalgOp.getNumParallelLoops()) {
return false;
}

// For now just check that the consumer iteration space rank is same as the
// producers parallel iteration space rank. This is done by checking that the
// consumer has a permutation index for the corresponding operand.
auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer);
if (!consumerLinalgOp) {
return false;
}
AffineMap indexingMap =
cast<AffineMapAttr>(
consumerLinalgOp.getIndexingMaps()[consumerOperandNumber])
.getValue();
return indexingMap.isPermutation();
}

/// Control function to check if an `tensor.expand_shape` (which is producer of
Expand All @@ -65,7 +91,7 @@ static bool shouldSinkExpandShapeOp(OpOperand *opOperand) {
return false;
}

// Do not sink reshapes across dequantize operations since they are
// Do not sink reshapes across dequantize operations since tey are
// cloned into their producers.
if (isDequantizationLikeOp(consumer)) {
return false;
Expand All @@ -76,7 +102,8 @@ static bool shouldSinkExpandShapeOp(OpOperand *opOperand) {
if (llvm::any_of(consumer->getOpOperands(), [](OpOperand &opOperand) {
Operation *currProducer = opOperand.get().getDefiningOp();
Operation *currConsumer = opOperand.getOwner();
return isFusableUsingTileAndFuse(currProducer, currConsumer) &&
return isFusableUsingTileAndFuse(currProducer, currConsumer,
opOperand.getOperandNumber()) &&
// The check for the producer having a single use is not fully
// worked out. Ideally we can fuse with a producer irrespective
// of number of uses, but is a good thumb rule in practice.
Expand All @@ -91,8 +118,8 @@ static bool shouldSinkExpandShapeOp(OpOperand *opOperand) {
return false;
}

return isFusableUsingTileAndFuse(reshapeOp.getSrc().getDefiningOp(),
consumer);
return isFusableUsingTileAndFuse(reshapeOp.getSrc().getDefiningOp(), consumer,
opOperand->getOperandNumber());
}

void SinkReshapesPass::runOnOperation() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,46 @@ func.func @do_not_sink_across_dequantize_ops(%arg0: tensor<?x?xf32>) -> tensor<2
// CHECK: %[[DEQUANT:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXPAND]] :
// CHECK: return %[[DEQUANT]]

// -----

// Check that reshape sinks based with better estimate of what producers
// -> consumer are fusable.
func.func @better_producer_estimate(%lhs : tensor<2x4096x640xi32>, %rhs : tensor<2x640x640xi32>,
%fill0 : tensor<2x4096x640xi32>, %fill1 : tensor<2x4096xi32>) -> tensor<2x4096x640x1xf16> {
%bmm = linalg.batch_matmul_transpose_b ins(%lhs, %rhs : tensor<2x4096x640xi32>, tensor<2x640x640xi32>)
outs(%fill0 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>
%reduction = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%lhs : tensor<2x4096x640xi32>) outs(%fill1 : tensor<2x4096xi32>) {
^bb0(%in: i32, %out: i32):
%12 = arith.addi %in, %out : i32
linalg.yield %12 : i32
} -> tensor<2x4096xi32>
%expanded = tensor.expand_shape %bmm [[0], [1], [2, 3]] output_shape [2, 4096, 640, 1]
: tensor<2x4096x640xi32> into tensor<2x4096x640x1xi32>
%empty = tensor.empty() : tensor<2x4096x640x1xf16>
%quant = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%expanded, %reduction : tensor<2x4096x640x1xi32>, tensor<2x4096xi32>)
outs(%empty : tensor<2x4096x640x1xf16>) {
^bb0(%in: i32, %in_3: i32, %out: f16):
%14 = arith.subi %in, %in_3 : i32
%16 = arith.sitofp %14 : i32 to f32
%18 = arith.truncf %16 : f32 to f16
linalg.yield %18 : f16
} -> tensor<2x4096x640x1xf16>
return %quant : tensor<2x4096x640x1xf16>
}
// CHECK-LABEL: func @better_producer_estimate(
// CHECK: %[[BMM:.+]] = linalg.batch_matmul_transpose_b
// CHECK: %[[REDUCTION:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[BMM]], %[[REDUCTION]] :
// CHECK: %[[COLLAPSE:.+]] = tensor.expand_shape %[[GENERIC]]
// CHECK: return %[[COLLAPSE]]

0 comments on commit b23c024

Please sign in to comment.