Skip to content

Commit

Permalink
Fuse Generic Ops Generated by gather Lowering (#17341)
Browse files Browse the repository at this point in the history
  • Loading branch information
IanWood1 committed May 15, 2024
1 parent 428adf2 commit 748db31
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,22 @@
//===----------------------------------------------------------------------===//

#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler::IREE::Flow {
Expand Down Expand Up @@ -131,14 +142,76 @@ struct FoldSuccessiveTensorInsertSliceOps
}
};

//===----------------------------------------------------------------------===//
// GatherFusionPattern
//===----------------------------------------------------------------------===//

// Specific case. The linalg generic implementation of "gather"
// cannot be fused because it there is no producer-consumer
// relationship between the two generics. This is because the indexing
// is not affine (index values come from a tensor).
struct GatherFusionPattern : public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
// Check if extractOp is inside a generic op
auto consumerOp =
dyn_cast_or_null<linalg::GenericOp>(extractOp->getParentOp());
if (!consumerOp) {
return rewriter.notifyMatchFailure(
extractOp, "expected extract op to be inside a generic op");
}

auto producerOp = extractOp.getTensor().getDefiningOp<linalg::GenericOp>();
if (!producerOp) {
return rewriter.notifyMatchFailure(
consumerOp, "expected extract operand to be a generic op");
}

// Check if the producerOp is fusible
if (producerOp.getNumDpsInputs() != 1 || producerOp.getNumResults() != 1 ||
!isElementwise(producerOp) || !isDequantizationLikeOp(producerOp)) {
return rewriter.notifyMatchFailure(producerOp,
"producer op is not fusible");
}

OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(extractOp);

// Create a new extract op that extracts from the original tensor
// (after the original extract). Clone the producerOp's body into the
// consumerOp, inline the cloned block (erases the block) after the new
// extract, and clean up.
auto newExtractOp = rewriter.create<tensor::ExtractOp>(
extractOp.getLoc(), producerOp.getDpsInputOperand(0)->get(),
extractOp.getIndices());
rewriter.cloneRegionBefore(producerOp.getRegion(), consumerOp.getRegion(),
consumerOp.getRegion().begin());
Block &clonedBlock = consumerOp.getRegion().front();
auto producerTermOp = clonedBlock.getTerminator();

rewriter.inlineBlockBefore(
&clonedBlock, extractOp->getNextNode(),
{newExtractOp.getResult(), newExtractOp.getResult()});

// Replace the the all references to the original extract result with the
// result from the inlined producerOp.
extractOp.getResult().replaceAllUsesWith(producerTermOp->getOperand(0));
rewriter.eraseOp(producerTermOp);
rewriter.eraseOp(extractOp);

return success();
}
};

struct FusionPreprocessingPass
: public IREE::Flow::impl::FusionPreprocessingPassBase<
FusionPreprocessingPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns
.add<FoldSuccessiveTensorInsertSliceOps, GenericOpInterchangePattern>(
&getContext());
patterns.add<FoldSuccessiveTensorInsertSliceOps,
GenericOpInterchangePattern, GatherFusionPattern>(
&getContext());

// Fold away `tensor.dim` operations that can be resolved in terms of its
// operand shapes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,87 @@ util.func public @fold_insert_slices(%source : tensor<?x?xf32>,
// CHECK: %[[RETURN:.+]] = tensor.insert_slice %[[SOURCE]] into %[[FILL]]
// CHECK-SAME: [%[[NEW_OFFSET0]], %[[NEW_OFFSET1]]] [%[[SIZE0]], %[[SIZE1]]]
// CHECK: util.return %[[RETURN]]


// -----

util.func public @fuse_generic_gather(
%11 :tensor<128256x4096xf16>, %12 : tensor<4x?xi64>,
%13 : tensor<4x?x4096xf32>, %14 : tensor<128256x4096xf32>)
-> tensor<4x?x4096xf32>{

%15 = linalg.generic {
indexing_maps = [ affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%11 : tensor<128256x4096xf16>)
outs(%14 : tensor<128256x4096xf32>) {
^bb0(%in: f16, %out: f32):
%17 = arith.extf %in : f16 to f32
linalg.yield %17 : f32
} -> tensor<128256x4096xf32>
%16 = linalg.generic {
indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%12 : tensor<4x?xi64>)
outs(%13 : tensor<4x?x4096xf32>) {
^bb0(%in: i64, %out: f32):
%17 = arith.index_cast %in : i64 to index
%18 = linalg.index 2 : index
%extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32>
linalg.yield %extracted : f32
} -> tensor<4x?x4096xf32>
util.return %16 : tensor<4x?x4096xf32>
}

// CHECK: %[[INDEX0:[a-zA-Z0-9]+]] = arith.index_cast %in : i64 to index
// CHECK: %[[INDEX1:[a-zA-Z0-9]+]] = linalg.index 2 : index
// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract %[[TENSOR0:.+]][%[[INDEX0]], %[[INDEX1]]] : tensor<128256x4096xf16>
// CHECK-NEXT: %[[RES:[a-zA-Z0-9]+]] = arith.extf %[[EXTRACTED]] : f16 to f32
// CHECK-NEXT: linalg.yield %[[RES]] : f32


// -----

util.func public @fuse_generic_gather2(
%11 :tensor<128256x4096xf16>, %12 : tensor<4x?xi64>,
%13 : tensor<4x?x4096xf32>, %14 : tensor<128256x4096xf32>)
-> tensor<4x?x4096xf32>{

%15 = linalg.generic {
indexing_maps = [ affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%11 : tensor<128256x4096xf16>)
outs(%14 : tensor<128256x4096xf32>) {
^bb0(%in: f16, %out: f32):
%17 = arith.extf %in : f16 to f32
linalg.yield %17 : f32
} -> tensor<128256x4096xf32>
%16 = linalg.generic {
indexing_maps = [ affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%12 : tensor<4x?xi64>)
outs(%13 : tensor<4x?x4096xf32>) {
^bb0(%in: i64, %out: f32):
%17 = arith.index_cast %in : i64 to index
%18 = linalg.index 2 : index
%extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32>
%result = arith.addf %extracted, %extracted : f32
%result2 = arith.mulf %extracted, %extracted : f32
%final = arith.addf %result, %result2 : f32
linalg.yield %final: f32
} -> tensor<4x?x4096xf32>
util.return %16 : tensor<4x?x4096xf32>
}

// CHECK: %[[INDEX0:[a-zA-Z0-9]+]] = arith.index_cast %in : i64 to index
// CHECK: %[[INDEX1:[a-zA-Z0-9]+]] = linalg.index 2 : index
// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract %[[TENSOR0:.+]][%[[INDEX0]], %[[INDEX1]]] : tensor<128256x4096xf16>
// CHECK-NEXT: %[[RES:[a-zA-Z0-9]+]] = arith.extf %[[EXTRACTED]] : f16 to f32
// CHECK-NEXT: %[[RES2:[a-zA-Z0-9]+]] = arith.addf %[[RES]], %[[RES]] : f32
// CHECK-NEXT: %[[RES3:[a-zA-Z0-9]+]] = arith.mulf %[[RES]], %[[RES]] : f32
// CHECK-NEXT: %[[RES4:[a-zA-Z0-9]+]] = arith.addf %[[RES2]], %[[RES3]] : f32
// CHECK-NEXT: linalg.yield %[[RES4]] : f32

0 comments on commit 748db31

Please sign in to comment.