Skip to content

Commit

Permalink
[mlir][openacc] Add verifier for dataOperands on compute operations
Browse files Browse the repository at this point in the history
Data operands associated with acc.parallel, acc.serial and
acc.kernels should comes from acc data entry/exit operations
or acc.getdeviceptr.

Reviewed By: razvanlupusoru, jeanPerier

Differential Revision: https://reviews.llvm.org/D149994
  • Loading branch information
clementval committed May 8, 2023
1 parent 78a09cb commit aad53e3
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 0 deletions.
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
)
$region attr-dict-with-keyword
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -562,6 +564,8 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
)
$region attr-dict-with-keyword
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -646,6 +650,8 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
)
$region attr-dict-with-keyword
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
27 changes: 27 additions & 0 deletions mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,21 @@ struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
// ParallelOp
//===----------------------------------------------------------------------===//

/// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
template <typename Op>
static LogicalResult checkDataOperands(Op op,
const mlir::ValueRange &operands) {
for (mlir::Value operand : operands)
if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
operand.getDefiningOp()))
return op.emitError(
"expect data entry/exit operation or acc.getdeviceptr "
"as defining op");
return success();
}

unsigned ParallelOp::getNumDataOperands() {
return getReductionOperands().size() + getCopyOperands().size() +
getCopyinOperands().size() + getCopyinReadonlyOperands().size() +
Expand All @@ -306,6 +321,10 @@ Value ParallelOp::getDataOperand(unsigned i) {
return getOperand(getWaitOperands().size() + numOptional + i);
}

LogicalResult acc::ParallelOp::verify() {
return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
}

//===----------------------------------------------------------------------===//
// SerialOp
//===----------------------------------------------------------------------===//
Expand All @@ -328,6 +347,10 @@ Value SerialOp::getDataOperand(unsigned i) {
return getOperand(getWaitOperands().size() + numOptional + i);
}

LogicalResult acc::SerialOp::verify() {
return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
}

//===----------------------------------------------------------------------===//
// KernelsOp
//===----------------------------------------------------------------------===//
Expand All @@ -348,6 +371,10 @@ Value KernelsOp::getDataOperand(unsigned i) {
return getOperand(getWaitOperands().size() + numOptional + i);
}

LogicalResult acc::KernelsOp::verify() {
return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
}

//===----------------------------------------------------------------------===//
// LoopOp
//===----------------------------------------------------------------------===//
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Dialect/OpenACC/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,27 @@ acc.enter_data dataOperands(%value : memref<10xf32>)
%value = memref.alloc() : memref<10xf32>
// expected-error@+1 {{expect data entry/exit operation or acc.getdeviceptr as defining op}}
acc.update dataOperands(%value : memref<10xf32>)

// -----

%value = memref.alloc() : memref<10xf32>
// expected-error@+1 {{expect data entry/exit operation or acc.getdeviceptr as defining op}}
acc.parallel dataOperands(%value : memref<10xf32>) {
acc.yield
}

// -----

%value = memref.alloc() : memref<10xf32>
// expected-error@+1 {{expect data entry/exit operation or acc.getdeviceptr as defining op}}
acc.serial dataOperands(%value : memref<10xf32>) {
acc.yield
}

// -----

%value = memref.alloc() : memref<10xf32>
// expected-error@+1 {{expect data entry/exit operation or acc.getdeviceptr as defining op}}
acc.kernels dataOperands(%value : memref<10xf32>) {
acc.yield
}

0 comments on commit aad53e3

Please sign in to comment.