From 268f067d322fd72d067da68aaab36b0bec300921 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Mon, 5 Aug 2024 23:40:08 -0700 Subject: [PATCH 1/9] add DecomposeAggregatedOps --- include/gc/Transforms/Passes.td | 6 +++ lib/gc/Transforms/CMakeLists.txt | 3 +- lib/gc/Transforms/DecomposeAggregatedOps.cpp | 49 ++++++++++++++++++++ lib/gc/Transforms/Pipeline.cpp | 2 + 4 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 lib/gc/Transforms/DecomposeAggregatedOps.cpp diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 6ccf86fc6..4dcd9c677 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -92,6 +92,12 @@ def VerifyTargetDescription : Pass<"verify-target-description", "ModuleOp"> { /*default=*/"\"CPU\"", "The device to verify. Supported device: CPU, ">, ]; + +def DecomposeAggregatedOps : Pass<"decompose-aggregated-ops", "func::FuncOp"> { + let summary = "Decompose aggregated operations."; + let description = [{ + Decompose operations that implement the `AggregatedOpInterface`. + }]; } #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index d240f28c1..a8f203aee 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -16,7 +16,8 @@ gc_add_mlir_library(GcPasses IterativeTilingAndFusion.cpp TilingUsingInterfaceX.cpp VerifyTargetDescription.cpp - + DecomposeAggregatedOps.cpp + DEPENDS GraphCompilerPassIncGen diff --git a/lib/gc/Transforms/DecomposeAggregatedOps.cpp b/lib/gc/Transforms/DecomposeAggregatedOps.cpp new file mode 100644 index 000000000..cc5db66da --- /dev/null +++ b/lib/gc/Transforms/DecomposeAggregatedOps.cpp @@ -0,0 +1,49 @@ +//===- DecomposeAggregatedOps.cpp --------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_DECOMPOSEAGGREGATEDOPS +#include "gc/Transforms/Passes.h.inc" +} // namespace gc +} // namespace mlir + +namespace { + +struct DecomposeAggregateOpsImpl : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::SoftmaxOp softmaxOp, + PatternRewriter &rewriter) const override { + auto decomposableOp = + cast(softmaxOp.getOperation()); + FailureOr> maybeNewResult = + decomposableOp.decomposeOperation(rewriter); + if (failed(maybeNewResult)) + return failure(); + rewriter.replaceOp(softmaxOp, *maybeNewResult); + return success(); + } +}; + +struct DecomposeAggregatedOps + : public gc::impl::DecomposeAggregatedOpsBase { + void runOnOperation() override { + RewritePatternSet patterns(getOperation().getContext()); + patterns.add(patterns.getContext()); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace \ No newline at end of file diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 8ab630026..d2f689c42 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -54,6 +54,8 @@ void populateTensorPasses(mlir::OpPassManager &pm) { // REMOVE this pass after the above passes are added. Currently we add this // pass to make the pipeline work properly pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); + // copied from tpp project + pm.addNestedPass(createDecomposeAggregatedOps()); } // scf + arith + math + vector + tensor + linalg.brgemm From 3b19f4e25130b2db3aa8f9542bdf0c30c722295d Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Tue, 6 Aug 2024 00:18:46 -0700 Subject: [PATCH 2/9] add test case --- .../test/gc/Transforms/DecomposeAggregatedOps.mlir | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 test/mlir/test/gc/Transforms/DecomposeAggregatedOps.mlir diff --git a/test/mlir/test/gc/Transforms/DecomposeAggregatedOps.mlir b/test/mlir/test/gc/Transforms/DecomposeAggregatedOps.mlir new file mode 100644 index 000000000..9b178b3e7 --- /dev/null +++ b/test/mlir/test/gc/Transforms/DecomposeAggregatedOps.mlir @@ -0,0 +1,11 @@ +// RUN: gc-opt %s -decompose-aggregated-ops | FileCheck %s + +// CHECK-LABEL: softmax +func.func @softmax(%arg0: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { + %0 = tensor.empty() : tensor<2x2x2x2xf32> + // CHECK-NOT: linalg.softmax + // CHECK-COUNT-4: linalg.generic + %1 = linalg.softmax dimension(3) + ins(%arg0 : tensor<2x2x2x2xf32>) outs(%0 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> + return %1 : tensor<2x2x2x2xf32> +} From 49fcd5bb67f45a9835e8ee674cd2bbf79940ce7d Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Tue, 6 Aug 2024 00:19:06 -0700 Subject: [PATCH 3/9] fix --- include/gc/Transforms/Passes.td | 1 + 1 file changed, 1 insertion(+) diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 4dcd9c677..312920211 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -92,6 +92,7 @@ def VerifyTargetDescription : Pass<"verify-target-description", "ModuleOp"> { /*default=*/"\"CPU\"", "The device to verify. Supported device: CPU, ">, ]; +} def DecomposeAggregatedOps : Pass<"decompose-aggregated-ops", "func::FuncOp"> { let summary = "Decompose aggregated operations."; From 3c72835992b591206c0d1f3a4940b52935da886e Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Tue, 6 Aug 2024 00:37:12 -0700 Subject: [PATCH 4/9] fix license --- lib/gc/Transforms/DecomposeAggregatedOps.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/gc/Transforms/DecomposeAggregatedOps.cpp b/lib/gc/Transforms/DecomposeAggregatedOps.cpp index cc5db66da..a9cf889a9 100644 --- a/lib/gc/Transforms/DecomposeAggregatedOps.cpp +++ b/lib/gc/Transforms/DecomposeAggregatedOps.cpp @@ -1,6 +1,6 @@ -//===- DecomposeAggregatedOps.cpp --------------------------------*- C++-*-===// +//===-- DecomposeAggregatedOps.cpp - Decompose Aggregated Ops ---*- C++ -*-===// // -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // From 54666105b97b9d4a0ceea46671368048838bb7ad Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Mon, 5 Aug 2024 23:40:08 -0700 Subject: [PATCH 5/9] add DecomposeAggregatedOps --- include/gc/Transforms/Passes.td | 6 +++ lib/gc/Transforms/CMakeLists.txt | 3 +- lib/gc/Transforms/DecomposeAggregatedOps.cpp | 49 ++++++++++++++++++++ lib/gc/Transforms/Pipeline.cpp | 2 + 4 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 lib/gc/Transforms/DecomposeAggregatedOps.cpp diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 6ccf86fc6..4dcd9c677 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -92,6 +92,12 @@ def VerifyTargetDescription : Pass<"verify-target-description", "ModuleOp"> { /*default=*/"\"CPU\"", "The device to verify. Supported device: CPU, ">, ]; + +def DecomposeAggregatedOps : Pass<"decompose-aggregated-ops", "func::FuncOp"> { + let summary = "Decompose aggregated operations."; + let description = [{ + Decompose operations that implement the `AggregatedOpInterface`. + }]; } #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index d240f28c1..a8f203aee 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -16,7 +16,8 @@ gc_add_mlir_library(GcPasses IterativeTilingAndFusion.cpp TilingUsingInterfaceX.cpp VerifyTargetDescription.cpp - + DecomposeAggregatedOps.cpp + DEPENDS GraphCompilerPassIncGen diff --git a/lib/gc/Transforms/DecomposeAggregatedOps.cpp b/lib/gc/Transforms/DecomposeAggregatedOps.cpp new file mode 100644 index 000000000..cc5db66da --- /dev/null +++ b/lib/gc/Transforms/DecomposeAggregatedOps.cpp @@ -0,0 +1,49 @@ +//===- DecomposeAggregatedOps.cpp --------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_DECOMPOSEAGGREGATEDOPS +#include "gc/Transforms/Passes.h.inc" +} // namespace gc +} // namespace mlir + +namespace { + +struct DecomposeAggregateOpsImpl : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::SoftmaxOp softmaxOp, + PatternRewriter &rewriter) const override { + auto decomposableOp = + cast(softmaxOp.getOperation()); + FailureOr> maybeNewResult = + decomposableOp.decomposeOperation(rewriter); + if (failed(maybeNewResult)) + return failure(); + rewriter.replaceOp(softmaxOp, *maybeNewResult); + return success(); + } +}; + +struct DecomposeAggregatedOps + : public gc::impl::DecomposeAggregatedOpsBase { + void runOnOperation() override { + RewritePatternSet patterns(getOperation().getContext()); + patterns.add(patterns.getContext()); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace \ No newline at end of file diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 13f73857c..ed8eee9cf 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -54,6 +54,8 @@ void populateTensorPasses(mlir::OpPassManager &pm) { // REMOVE this pass after the above passes are added. Currently we add this // pass to make the pipeline work properly pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); + // copied from tpp project + pm.addNestedPass(createDecomposeAggregatedOps()); } // scf + arith + math + vector + tensor + linalg.brgemm From a85fb9dd2e5ce242241f01f45500f80c273ca43e Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Tue, 6 Aug 2024 00:18:46 -0700 Subject: [PATCH 6/9] add test case --- .../test/gc/Transforms/DecomposeAggregatedOps.mlir | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 test/mlir/test/gc/Transforms/DecomposeAggregatedOps.mlir diff --git a/test/mlir/test/gc/Transforms/DecomposeAggregatedOps.mlir b/test/mlir/test/gc/Transforms/DecomposeAggregatedOps.mlir new file mode 100644 index 000000000..9b178b3e7 --- /dev/null +++ b/test/mlir/test/gc/Transforms/DecomposeAggregatedOps.mlir @@ -0,0 +1,11 @@ +// RUN: gc-opt %s -decompose-aggregated-ops | FileCheck %s + +// CHECK-LABEL: softmax +func.func @softmax(%arg0: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { + %0 = tensor.empty() : tensor<2x2x2x2xf32> + // CHECK-NOT: linalg.softmax + // CHECK-COUNT-4: linalg.generic + %1 = linalg.softmax dimension(3) + ins(%arg0 : tensor<2x2x2x2xf32>) outs(%0 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> + return %1 : tensor<2x2x2x2xf32> +} From 3fb5965364b504fbf28739634ca77385017944a9 Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Tue, 6 Aug 2024 00:19:06 -0700 Subject: [PATCH 7/9] fix --- include/gc/Transforms/Passes.td | 1 + 1 file changed, 1 insertion(+) diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 4dcd9c677..312920211 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -92,6 +92,7 @@ def VerifyTargetDescription : Pass<"verify-target-description", "ModuleOp"> { /*default=*/"\"CPU\"", "The device to verify. Supported device: CPU, ">, ]; +} def DecomposeAggregatedOps : Pass<"decompose-aggregated-ops", "func::FuncOp"> { let summary = "Decompose aggregated operations."; From c639459824e15caa280bc5b71c6abfca8afc836a Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Tue, 6 Aug 2024 00:37:12 -0700 Subject: [PATCH 8/9] fix license --- lib/gc/Transforms/DecomposeAggregatedOps.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/gc/Transforms/DecomposeAggregatedOps.cpp b/lib/gc/Transforms/DecomposeAggregatedOps.cpp index cc5db66da..a9cf889a9 100644 --- a/lib/gc/Transforms/DecomposeAggregatedOps.cpp +++ b/lib/gc/Transforms/DecomposeAggregatedOps.cpp @@ -1,6 +1,6 @@ -//===- DecomposeAggregatedOps.cpp --------------------------------*- C++-*-===// +//===-- DecomposeAggregatedOps.cpp - Decompose Aggregated Ops ---*- C++ -*-===// // -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // From 26ae9f4ff231999381ae088071ae7f6dd495160d Mon Sep 17 00:00:00 2001 From: "Xu, Rui" Date: Fri, 23 Aug 2024 00:43:27 -0700 Subject: [PATCH 9/9] fix --- include/gc/Transforms/Passes.td | 1 + 1 file changed, 1 insertion(+) diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index b5f15693d..d7bf6db1b 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -110,6 +110,7 @@ def DecomposeAggregatedOps : Pass<"decompose-aggregated-ops", "func::FuncOp"> { let description = [{ Decompose operations that implement the `AggregatedOpInterface`. }]; +} def SinkOpIntoInnerLoop : Pass<"sink-op-into-inner-loop"> { let summary = "Sink operations into inner loops";