diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 57ac7893a..d7bf6db1b 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -105,6 +105,13 @@ def VerifyTargetDescription : Pass<"verify-target-description", "ModuleOp"> { ]; } +def DecomposeAggregatedOps : Pass<"decompose-aggregated-ops", "func::FuncOp"> { + let summary = "Decompose aggregated operations."; + let description = [{ + Decompose operations that implement the `AggregatedOpInterface`. + }]; +} + def SinkOpIntoInnerLoop : Pass<"sink-op-into-inner-loop"> { let summary = "Sink operations into inner loops"; let description = [{The pass tries to sink operations into inner loops as deep as possible to maximize the chance for outer loop optimization. diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index e53493bd7..6bbdb251f 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -16,6 +16,7 @@ gc_add_mlir_library(GcPasses IterativeTilingAndFusion.cpp TilingUsingInterfaceX.cpp VerifyTargetDescription.cpp + DecomposeAggregatedOps.cpp DeepTileContractionOp.cpp TilingUtil.cpp SinkOpIntoInnerLoop.cpp diff --git a/lib/gc/Transforms/DecomposeAggregatedOps.cpp b/lib/gc/Transforms/DecomposeAggregatedOps.cpp new file mode 100644 index 000000000..a9cf889a9 --- /dev/null +++ b/lib/gc/Transforms/DecomposeAggregatedOps.cpp @@ -0,0 +1,49 @@ +//===-- DecomposeAggregatedOps.cpp - Decompose Aggregated Ops ---*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_DECOMPOSEAGGREGATEDOPS +#include "gc/Transforms/Passes.h.inc" +} // namespace gc +} // namespace mlir + +namespace { + +struct DecomposeAggregateOpsImpl : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::SoftmaxOp softmaxOp, + PatternRewriter &rewriter) const override { + auto decomposableOp = + cast(softmaxOp.getOperation()); + FailureOr> maybeNewResult = + decomposableOp.decomposeOperation(rewriter); + if (failed(maybeNewResult)) + return failure(); + rewriter.replaceOp(softmaxOp, *maybeNewResult); + return success(); + } +}; + +struct DecomposeAggregatedOps + : public gc::impl::DecomposeAggregatedOpsBase { + void runOnOperation() override { + RewritePatternSet patterns(getOperation().getContext()); + patterns.add(patterns.getContext()); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace \ No newline at end of file diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 4dfc85f0b..a39cbb0d2 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -62,6 +62,8 @@ void populateTensorPasses(mlir::OpPassManager &pm) { // REMOVE this pass after the above passes are added. Currently we add this // pass to make the pipeline work properly pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); + // copied from tpp project + pm.addNestedPass(createDecomposeAggregatedOps()); // fold useless tensor operation pass pm.addPass(createFoldTensorOperation()); pm.addPass(createLoopInvariantCodeMotionPass()); diff --git a/test/mlir/test/gc/Transforms/DecomposeAggregatedOps.mlir b/test/mlir/test/gc/Transforms/DecomposeAggregatedOps.mlir new file mode 100644 index 000000000..9b178b3e7 --- /dev/null +++ b/test/mlir/test/gc/Transforms/DecomposeAggregatedOps.mlir @@ -0,0 +1,11 @@ +// RUN: gc-opt %s -decompose-aggregated-ops | FileCheck %s + +// CHECK-LABEL: softmax +func.func @softmax(%arg0: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { + %0 = tensor.empty() : tensor<2x2x2x2xf32> + // CHECK-NOT: linalg.softmax + // CHECK-COUNT-4: linalg.generic + %1 = linalg.softmax dimension(3) + ins(%arg0 : tensor<2x2x2x2xf32>) outs(%0 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> + return %1 : tensor<2x2x2x2xf32> +}