-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][openacc] Add legalize data pass for compute operation #80351
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-mlir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesThis patch adds a simple pass to replace the uses inside compute operation. It replaces the private and reductions variables are not included in this pass since they will normally be replace when they are materialized. Full diff: https://github.com/llvm/llvm-project/pull/80351.diff 12 Files Affected:
diff --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h
index 8c47ad3d9f445..b5c41699205f4 100644
--- a/flang/include/flang/Optimizer/Support/InitFIR.h
+++ b/flang/include/flang/Optimizer/Support/InitFIR.h
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
@@ -74,6 +75,7 @@ inline void loadDialects(mlir::MLIRContext &context) {
/// Register the standard passes we use. This comes from registerAllPasses(),
/// but is a smaller set since we aren't using many of the passes found there.
inline void registerMLIRPassesForFortranTools() {
+ mlir::acc::registerOpenACCPasses();
mlir::registerCanonicalizerPass();
mlir::registerCSEPass();
mlir::affine::registerAffineLoopFusionPass();
diff --git a/flang/test/Fir/OpenACC/legalize-data.fir b/flang/test/Fir/OpenACC/legalize-data.fir
new file mode 100644
index 0000000000000..3b8695434e6e4
--- /dev/null
+++ b/flang/test/Fir/OpenACC/legalize-data.fir
@@ -0,0 +1,24 @@
+// RUN: fir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s
+
+func.func @_QPsub1(%arg0: !fir.ref<i32> {fir.bindc_name = "i"}) {
+ %0:2 = hlfir.declare %arg0 {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %1 = acc.copyin varPtr(%0#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"}
+ acc.parallel dataOperands(%1 : !fir.ref<i32>) {
+ %c0_i32 = arith.constant 0 : i32
+ hlfir.assign %c0_i32 to %0#0 : i32, !fir.ref<i32>
+ acc.yield
+ }
+ acc.copyout accPtr(%1 : !fir.ref<i32>) to varPtr(%0#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}
+ return
+}
+
+// CHECK-LABEL: func.func @_QPsub1
+// CHECK-SAME: (%[[ARG0:.*]]: !fir.ref<i32> {fir.bindc_name = "i"})
+// CHECK: %[[I:.*]]:2 = hlfir.declare %[[ARG0]] {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%[[I]]#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"}
+// CHECK: acc.parallel dataOperands(%[[COPYIN]] : !fir.ref<i32>) {
+// CHECK: %c0_i32 = arith.constant 0 : i32
+// CHECK: hlfir.assign %c0{{.*}} to %[[COPYIN]] : i32, !fir.ref<i32>
+// CHECK: acc.yield
+// CHECK: }
+// CHECK: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<i32>) to varPtr(%[[I]]#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}
diff --git a/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt
index 56ba2976ee5d4..8a4b1c7b196ea 100644
--- a/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt
@@ -1,3 +1,5 @@
+add_subdirectory(Transforms)
+
set(LLVM_TARGET_DEFINITIONS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend/OpenACC/ACC.td)
mlir_tablegen(AccCommon.td --gen-directive-decl --directives-dialect=OpenACC)
add_public_tablegen_target(acc_common_td)
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..ddbd5839576fc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name OpenACC)
+add_public_tablegen_target(MLIROpenACCPassIncGen)
+
+add_mlir_doc(Passes OpenACCPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
new file mode 100644
index 0000000000000..5a11056cda609
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
@@ -0,0 +1,40 @@
+//===- Passes.h - OpenACC Passes Construction and Registration ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES_H
+
+#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
+#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
+#include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"
+#include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
+#include "mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h"
+#include "mlir/Pass/Pass.h"
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+
+namespace mlir {
+
+namespace func {
+class FuncOp;
+} // namespace func
+
+namespace acc {
+
+/// Create a pass to replace ssa values in region with device/host values.
+std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeDataInRegion();
+
+/// Generate the code for registering conversion passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+
+} // namespace acc
+} // namespace mlir
+
+#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
new file mode 100644
index 0000000000000..abbc27765e342
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
@@ -0,0 +1,28 @@
+//===-- Passes.td - OpenACC pass definition file -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
+#define MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def LegalizeDataInRegion : Pass<"openacc-legalize-data", "mlir::func::FuncOp"> {
+ let summary = "Legalize the data in the compute region";
+ let description = [{
+ This pass replace uses of varPtr in the compute region with their accPtr
+ gathered from the data clause operands.
+ }];
+ let options = [
+ Option<"hostToDevice", "host-to-device", "bool", "true",
+ "Replace varPtr uses with accPtr if true. Replace accPtr uses with "
+ "varPtr if false">
+ ];
+ let constructor = "::mlir::acc::createLegalizeDataInRegion()";
+}
+
+#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 28dc3cc23daf2..e28921619fe58 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -34,6 +34,7 @@
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Mesh/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
@@ -64,6 +65,7 @@ inline void registerAllPasses() {
registerConversionPasses();
// Dialect passes
+ acc::registerOpenACCPasses();
affine::registerAffinePasses();
amdgpu::registerAMDGPUPasses();
registerAsyncPasses();
diff --git a/mlir/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/CMakeLists.txt
index 27285246ef997..9f57627c321fb 100644
--- a/mlir/lib/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/CMakeLists.txt
@@ -1,20 +1,2 @@
-add_mlir_dialect_library(MLIROpenACCDialect
- IR/OpenACC.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
-
- DEPENDS
- MLIROpenACCOpsIncGen
- MLIROpenACCEnumsIncGen
- MLIROpenACCAttributesIncGen
- MLIROpenACCOpsInterfacesIncGen
- MLIROpenACCTypeInterfacesIncGen
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRLLVMDialect
- MLIRMemRefDialect
- MLIROpenACCMPCommon
- )
-
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..b802de165b8f3
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt
@@ -0,0 +1,20 @@
+add_mlir_dialect_library(MLIROpenACCDialect
+ OpenACC.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
+
+ DEPENDS
+ MLIROpenACCOpsIncGen
+ MLIROpenACCEnumsIncGen
+ MLIROpenACCAttributesIncGen
+ MLIROpenACCOpsInterfacesIncGen
+ MLIROpenACCTypeInterfacesIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMDialect
+ MLIRMemRefDialect
+ MLIROpenACCMPCommon
+ )
+
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..78f676362ecfd
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIROpenACCTransforms
+ LegalizeData.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
+
+ DEPENDS
+ MLIROpenACCPassIncGen
+ MLIROpenACCOpsIncGen
+ MLIROpenACCEnumsIncGen
+ MLIROpenACCAttributesIncGen
+ MLIROpenACCOpsInterfacesIncGen
+ MLIROpenACCTypeInterfacesIncGen
+
+ LINK_LIBS PUBLIC
+ MLIROpenACCDialect
+)
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
new file mode 100644
index 0000000000000..ef44a0ec68d9c
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
@@ -0,0 +1,72 @@
+//===- LegalizeData.cpp - -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_LEGALIZEDATAINREGION
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+template <typename Op>
+static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
+ llvm::SmallVector<std::pair<Value, Value>> values;
+ for (auto operand : op.getDataClauseOperands()) {
+ Value varPtr = acc::getVarPtr(operand.getDefiningOp());
+ Value accPtr = acc::getAccPtr(operand.getDefiningOp());
+ if (varPtr && accPtr) {
+ if (hostToDevice)
+ values.push_back({varPtr, accPtr});
+ else
+ values.push_back({accPtr, varPtr});
+ }
+ }
+
+ for (auto p : values)
+ replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
+}
+
+struct LegalizeDataInRegion
+ : public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> {
+
+ void runOnOperation() override {
+ func::FuncOp funcOp = getOperation();
+ bool replaceHostVsDevice = this->hostToDevice.getValue();
+
+ funcOp.walk([&](Operation *op) {
+ if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op))
+ return;
+
+ if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
+ collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
+ } else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
+ collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
+ } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
+ collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
+ }
+ });
+ }
+};
+
+} // end anonymous namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+mlir::acc::createLegalizeDataInRegion() {
+ return std::make_unique<LegalizeDataInRegion>();
+}
diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir
new file mode 100644
index 0000000000000..f9857411306c0
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s --check-prefixes=CHECK,DEVICE
+// RUN: mlir-opt -split-input-file --openacc-legalize-data=host-to-device=false %s | FileCheck %s --check-prefixes=CHECK,HOST
+
+func.func @test(%a: memref<10xf32>, %i : index) {
+ %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
+ acc.parallel dataOperands(%create : memref<10xf32>) {
+ %ci = memref.load %a[%i] : memref<10xf32>
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
+// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel dataOperands(%[[CREATE]] : memref<10xf32>) {
+// DEVICE: %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
+// HOST: %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
+// CHECK: acc.yield
+// CHECK: }
+
+// -----
+
+func.func @test(%a: memref<10xf32>, %i : index) {
+ %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
+ acc.serial dataOperands(%create : memref<10xf32>) {
+ %ci = memref.load %a[%i] : memref<10xf32>
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
+// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.serial dataOperands(%[[CREATE]] : memref<10xf32>) {
+// DEVICE: %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
+// HOST: %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
+// CHECK: acc.yield
+// CHECK: }
+
+// -----
+
+func.func @test(%a: memref<10xf32>, %i : index) {
+ %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
+ acc.kernels dataOperands(%create : memref<10xf32>) {
+ %ci = memref.load %a[%i] : memref<10xf32>
+ acc.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
+// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.kernels dataOperands(%[[CREATE]] : memref<10xf32>) {
+// DEVICE: %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
+// HOST: %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
+// CHECK: acc.terminator
+// CHECK: }
+
+// -----
+
+func.func @test(%a: memref<10xf32>) {
+ %lb = arith.constant 0 : index
+ %st = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
+ acc.parallel dataOperands(%create : memref<10xf32>) {
+ acc.loop (%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+ %ci = memref.load %a[%i] : memref<10xf32>
+ acc.yield
+ }
+ acc.yield
+ }
+ return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel dataOperands(%[[CREATE]] : memref<10xf32>) {
+// CHECK: acc.loop (%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
+// DEVICE: %{{.*}} = memref.load %[[CREATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK: acc.yield
+// CHECK: }
+// CHECK: acc.yield
+// CHECK: }
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Implementation looks much cleaner than I imagined.
I do think that the ssa uses for private/reduction/firstprivate operations should also be replaced. This means that the logic would also need to apply to acc.loop in addition to the compute constructs. This should be added also under an option just like you did for hostToDevice.
I would say this is OK to submit as-is - and we can make the improvement when it is needed. Nice work! Thank you for getting this done :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
Co-authored-by: Slava Zakharin <szakharin@nvidia.com>
I'm gonna make a follow up PR for private/reduction |
This patch adds a simple pass to replace the uses inside compute operation. It replaces the `varPtr` values with their corresponding `accPtr` values gathered through the dataClauseOperands. private and reductions variables are not included in this pass since they will normally be replace when they are materialized.
This patch adds a simple pass to replace the uses inside compute operation. It replaces the `varPtr` values with their corresponding `accPtr` values gathered through the dataClauseOperands. private and reductions variables are not included in this pass since they will normally be replace when they are materialized. Reland with fix for dependencies
) This patch adds a simple pass to replace the uses inside compute operation. It replaces the `varPtr` values with their corresponding `accPtr` values gathered through the dataClauseOperands. private and reductions variables are not included in this pass since they will normally be replace when they are materialized. --------- Co-authored-by: Slava Zakharin <szakharin@nvidia.com>
…llvm#80710) Reverts llvm#80351 Breaks some buildbot
This is a follow up to #80351 and adds private and reduction operands from acc.loop, acc.parallel and acc.serial operations.
This patch adds a simple pass to replace the uses inside compute operation. It replaces the
varPtr
values with their correspondingaccPtr
values gathered through the dataClauseOperands.private and reductions variables are not included in this pass since they will normally be replace when they are materialized.