diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 5b89f741e296d..8c33be7ff1747 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -2999,6 +2999,12 @@ def AtomicCaptureOp : OpenACC_Op<"atomic.capture", acc.atomic.write ... acc.terminator } + + acc.atomic.capture { + acc.atomic.update ... + acc.atomic.write ... + acc.terminator + } ``` }]; diff --git a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td index 223bee9ab1c27..9df6b907eb326 100644 --- a/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.td @@ -239,6 +239,7 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> { implement one of the atomic interfaces. It can be found in one of these forms: `{ atomic.update, atomic.read }` + `{ atomic.update, atomic.write }` `{ atomic.read, atomic.update }` `{ atomic.read, atomic.write }` }]; @@ -291,12 +292,15 @@ def AtomicCaptureOpInterface : OpInterface<"AtomicCaptureOpInterface"> { auto secondWriteStmt = dyn_cast(secondOp); if (!((firstUpdateStmt && secondReadStmt) || + (firstUpdateStmt && secondWriteStmt) || (firstReadStmt && secondUpdateStmt) || (firstReadStmt && secondWriteStmt))) return ops.front().emitError() << "invalid sequence of operations in the capture region"; - if (firstUpdateStmt && secondReadStmt && - firstUpdateStmt.getX() != secondReadStmt.getX()) + if ((firstUpdateStmt && secondReadStmt && + firstUpdateStmt.getX() != secondReadStmt.getX()) || + (firstUpdateStmt && secondWriteStmt && + firstUpdateStmt.getX() != secondWriteStmt.getX())) return firstUpdateStmt.emitError() << "updated variable in atomic.update must be captured in " "second operation"; diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir index 0e75894eaeceb..d56bb66ab7efe 100644 --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -690,13 +690,14 @@ func.func @acc_atomic_capture(%x: memref, %v: memref, %expr: i32) { func.func @acc_atomic_capture(%x: memref, %v: memref, %expr: i32) { acc.atomic.capture { - // expected-error @below {{invalid sequence of operations in the capture region}} + // expected-error @below {{updated variable in atomic.update must be captured in second operation}} acc.atomic.update %x : memref { ^bb0(%xval: i32): %newval = llvm.add %xval, %expr : i32 acc.yield %newval : i32 } - acc.atomic.write %x = %expr : memref, i32 + acc.atomic.write %v = %expr : memref, i32 + acc.terminator } return diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir index fc11bae60d9e0..7bd115f83d3d2 100644 --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -1938,6 +1938,24 @@ func.func @acc_atomic_capture(%v: memref, %x: memref, %expr: i32) { acc.atomic.write %x = %expr : memref, i32 } + // CHECK: acc.atomic.capture { + // CHECK-NEXT: acc.atomic.update %[[x]] : memref + // CHECK-NEXT: (%[[xval:.*]]: i32): + // CHECK-NEXT: %[[newval:.*]] = llvm.add %[[xval]], %[[expr]] : i32 + // CHECK-NEXT: acc.yield %[[newval]] : i32 + // CHECK-NEXT: } + // CHECK-NEXT: acc.atomic.write %[[x]] = %[[expr]] : memref, i32 + // CHECK-NEXT: } + acc.atomic.capture { + acc.atomic.update %x : memref { + ^bb0(%xval: i32): + %newval = llvm.add %xval, %expr : i32 + acc.yield %newval : i32 + } + acc.atomic.write %x = %expr : memref, i32 + acc.terminator + } + return } diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index af24d969064ab..b59f1aeb6eab4 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1263,13 +1263,13 @@ func.func @omp_atomic_capture(%x: memref, %v: memref, %expr: i32) { func.func @omp_atomic_capture(%x: memref, %v: memref, %expr: i32) { omp.atomic.capture { - // expected-error @below {{invalid sequence of operations in the capture region}} + // expected-error @below {{updated variable in atomic.update must be captured in second operation}} omp.atomic.update %x : memref { ^bb0(%xval: i32): %newval = llvm.add %xval, %expr : i32 omp.yield (%newval : i32) } - omp.atomic.write %x = %expr : memref, i32 + omp.atomic.write %v = %expr : memref, i32 omp.terminator } return diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index ac29e20907b55..c2fa808856118 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1926,6 +1926,23 @@ func.func @omp_atomic_capture(%v: memref, %x: memref, %expr: i32) { omp.atomic.read %v = %x : memref, memref, i32 } + // CHECK: omp.atomic.capture { + // CHECK-NEXT: omp.atomic.update %[[x]] : memref + // CHECK-NEXT: (%[[xval:.*]]: i32): + // CHECK-NEXT: %[[newval:.*]] = llvm.add %[[xval]], %[[expr]] : i32 + // CHECK-NEXT: omp.yield(%[[newval]] : i32) + // CHECK-NEXT: } + // CHECK-NEXT: omp.atomic.write %[[x]] = %[[expr]] : memref, i32 + // CHECK-NEXT: } + omp.atomic.capture { + omp.atomic.update %x : memref { + ^bb0(%xval: i32): + %newval = llvm.add %xval, %expr : i32 + omp.yield (%newval : i32) + } + omp.atomic.write %x = %expr : memref, i32 + } + return }