From aad53e3b26103bf2525264264e07b0427403fb49 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Mon, 8 May 2023 09:30:01 -0700 Subject: [PATCH] [mlir][openacc] Add verifier for dataOperands on compute operations Data operands associated with acc.parallel, acc.serial and acc.kernels should comes from acc data entry/exit operations or acc.getdeviceptr. Reviewed By: razvanlupusoru, jeanPerier Differential Revision: https://reviews.llvm.org/D149994 --- .../mlir/Dialect/OpenACC/OpenACCOps.td | 6 +++++ mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 27 +++++++++++++++++++ mlir/test/Dialect/OpenACC/invalid.mlir | 24 +++++++++++++++++ 3 files changed, 57 insertions(+) diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 4af056cbba1e5..b028b5eeab88e 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -477,6 +477,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", ) $region attr-dict-with-keyword }]; + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -562,6 +564,8 @@ def OpenACC_SerialOp : OpenACC_Op<"serial", ) $region attr-dict-with-keyword }]; + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -646,6 +650,8 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", ) $region attr-dict-with-keyword }]; + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index efcc809383da8..6f53c4e75b7e7 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -285,6 +285,21 @@ struct RemoveConstantIfCondition : public OpRewritePattern { // ParallelOp //===----------------------------------------------------------------------===// +/// Check dataOperands for acc.parallel, acc.serial and acc.kernels. +template +static LogicalResult checkDataOperands(Op op, + const mlir::ValueRange &operands) { + for (mlir::Value operand : operands) + if (!mlir::isa( + operand.getDefiningOp())) + return op.emitError( + "expect data entry/exit operation or acc.getdeviceptr " + "as defining op"); + return success(); +} + unsigned ParallelOp::getNumDataOperands() { return getReductionOperands().size() + getCopyOperands().size() + getCopyinOperands().size() + getCopyinReadonlyOperands().size() + @@ -306,6 +321,10 @@ Value ParallelOp::getDataOperand(unsigned i) { return getOperand(getWaitOperands().size() + numOptional + i); } +LogicalResult acc::ParallelOp::verify() { + return checkDataOperands(*this, getDataClauseOperands()); +} + //===----------------------------------------------------------------------===// // SerialOp //===----------------------------------------------------------------------===// @@ -328,6 +347,10 @@ Value SerialOp::getDataOperand(unsigned i) { return getOperand(getWaitOperands().size() + numOptional + i); } +LogicalResult acc::SerialOp::verify() { + return checkDataOperands(*this, getDataClauseOperands()); +} + //===----------------------------------------------------------------------===// // KernelsOp //===----------------------------------------------------------------------===// @@ -348,6 +371,10 @@ Value KernelsOp::getDataOperand(unsigned i) { return getOperand(getWaitOperands().size() + numOptional + i); } +LogicalResult acc::KernelsOp::verify() { + return checkDataOperands(*this, getDataClauseOperands()); +} + //===----------------------------------------------------------------------===// // LoopOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir index 9c6ec08ca5a23..66644a7f31125 100644 --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -224,3 +224,27 @@ acc.enter_data dataOperands(%value : memref<10xf32>) %value = memref.alloc() : memref<10xf32> // expected-error@+1 {{expect data entry/exit operation or acc.getdeviceptr as defining op}} acc.update dataOperands(%value : memref<10xf32>) + +// ----- + +%value = memref.alloc() : memref<10xf32> +// expected-error@+1 {{expect data entry/exit operation or acc.getdeviceptr as defining op}} +acc.parallel dataOperands(%value : memref<10xf32>) { + acc.yield +} + +// ----- + +%value = memref.alloc() : memref<10xf32> +// expected-error@+1 {{expect data entry/exit operation or acc.getdeviceptr as defining op}} +acc.serial dataOperands(%value : memref<10xf32>) { + acc.yield +} + +// ----- + +%value = memref.alloc() : memref<10xf32> +// expected-error@+1 {{expect data entry/exit operation or acc.getdeviceptr as defining op}} +acc.kernels dataOperands(%value : memref<10xf32>) { + acc.yield +}