Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/gc/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ def VerifyTargetDescription : Pass<"verify-target-description", "ModuleOp"> {
];
}

def DecomposeAggregatedOps : Pass<"decompose-aggregated-ops", "func::FuncOp"> {
let summary = "Decompose aggregated operations.";
let description = [{
Decompose operations that implement the `AggregatedOpInterface`.
}];
}

def SinkOpIntoInnerLoop : Pass<"sink-op-into-inner-loop"> {
let summary = "Sink operations into inner loops";
let description = [{The pass tries to sink operations into inner loops as deep as possible to maximize the chance for outer loop optimization.
Expand Down
1 change: 1 addition & 0 deletions lib/gc/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ gc_add_mlir_library(GcPasses
IterativeTilingAndFusion.cpp
TilingUsingInterfaceX.cpp
VerifyTargetDescription.cpp
DecomposeAggregatedOps.cpp
DeepTileContractionOp.cpp
TilingUtil.cpp
SinkOpIntoInnerLoop.cpp
Expand Down
49 changes: 49 additions & 0 deletions lib/gc/Transforms/DecomposeAggregatedOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//===-- DecomposeAggregatedOps.cpp - Decompose Aggregated Ops ---*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "gc/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
namespace mlir {
namespace gc {
#define GEN_PASS_DEF_DECOMPOSEAGGREGATEDOPS
#include "gc/Transforms/Passes.h.inc"
} // namespace gc
} // namespace mlir

namespace {

struct DecomposeAggregateOpsImpl : public OpRewritePattern<linalg::SoftmaxOp> {
using OpRewritePattern<linalg::SoftmaxOp>::OpRewritePattern;

LogicalResult matchAndRewrite(linalg::SoftmaxOp softmaxOp,
PatternRewriter &rewriter) const override {
auto decomposableOp =
cast<linalg::AggregatedOpInterface>(softmaxOp.getOperation());
FailureOr<SmallVector<Value>> maybeNewResult =
decomposableOp.decomposeOperation(rewriter);
if (failed(maybeNewResult))
return failure();
rewriter.replaceOp(softmaxOp, *maybeNewResult);
return success();
}
};

struct DecomposeAggregatedOps
: public gc::impl::DecomposeAggregatedOpsBase<DecomposeAggregatedOps> {
void runOnOperation() override {
RewritePatternSet patterns(getOperation().getContext());
patterns.add<DecomposeAggregateOpsImpl>(patterns.getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace
2 changes: 2 additions & 0 deletions lib/gc/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ void populateTensorPasses(mlir::OpPassManager &pm) {
// REMOVE this pass after the above passes are added. Currently we add this
// pass to make the pipeline work properly
pm.addNestedPass<func::FuncOp>(createLinalgGeneralizeNamedOpsPass());
// copied from tpp project
pm.addNestedPass<func::FuncOp>(createDecomposeAggregatedOps());
// fold useless tensor operation pass
pm.addPass(createFoldTensorOperation());
pm.addPass(createLoopInvariantCodeMotionPass());
Expand Down
11 changes: 11 additions & 0 deletions test/mlir/test/gc/Transforms/DecomposeAggregatedOps.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: gc-opt %s -decompose-aggregated-ops | FileCheck %s

// CHECK-LABEL: softmax
func.func @softmax(%arg0: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> {
%0 = tensor.empty() : tensor<2x2x2x2xf32>
// CHECK-NOT: linalg.softmax
// CHECK-COUNT-4: linalg.generic
%1 = linalg.softmax dimension(3)
ins(%arg0 : tensor<2x2x2x2xf32>) outs(%0 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
return %1 : tensor<2x2x2x2xf32>
}