diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir index d98d12675a7a7..3764e42939c1c 100644 --- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir +++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir @@ -321,3 +321,35 @@ func.func private @_QPwork() // CHECK: } // CHECK: llvm.func @_QPwork() attributes {sym_visibility = "private"} // CHECK: } + +// ----- + +func.func @_QPs() { + %0 = fir.address_of(@_QFsEc) : !fir.ref + omp.atomic.update %0 : !fir.ref { + ^bb0(%arg0: i32): + %c1_i32 = arith.constant 1 : i32 + %1 = arith.addi %arg0, %c1_i32 : i32 + omp.yield(%1 : i32) + } + return +} +fir.global internal @_QFsEc : i32 { + %c10_i32 = arith.constant 10 : i32 + fir.has_value %c10_i32 : i32 +} + +// CHECK-LABEL: llvm.func @_QPs() { +// CHECK: %[[GLOBAL_VAR:.*]] = llvm.mlir.addressof @[[GLOBAL:.*]] : !llvm.ptr +// CHECK: omp.atomic.update %[[GLOBAL_VAR]] : !llvm.ptr { +// CHECK: ^bb0(%[[IN_VAL:.*]]: i32): +// CHECK: %[[CONST_1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[OUT_VAL:.*]] = llvm.add %[[IN_VAL]], %[[CONST_1]] : i32 +// CHECK: omp.yield(%[[OUT_VAL]] : i32) +// CHECK: } +// CHECK: llvm.return +// CHECK: } +// CHECK: llvm.mlir.global internal @[[GLOBAL]]() {{.*}} : i32 { +// CHECK: %[[INIT_10:.*]] = llvm.mlir.constant(10 : i32) : i32 +// CHECK: llvm.return %[[INIT_10]] : i32 +// CHECK: } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index d494e89b1274a..7dec1c8126b3c 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1359,6 +1359,18 @@ def AtomicUpdateOp : OpenMP_Op<"atomic.update", /// Returns the new value if the operation is equivalent to just a write /// operation. Otherwise, returns nullptr. Value getWriteOpVal(); + + /// The number of variable operands. + unsigned getNumVariableOperands() { + assert(getX() && "expected 'x' operand"); + return 1; + } + + /// The i-th variable operand passed. + Value getVariableOperand(unsigned i) { + assert(i == 0 && "invalid index position for an operand"); + return getX(); + } }]; } diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index 822a1abd0b282..621600b268c0f 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -83,6 +83,44 @@ struct RegionLessOpWithVarOperandsConversion } }; +template +struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(T curOp, typename T::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); + SmallVector resTypes; + if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes))) + return failure(); + SmallVector convertedOperands; + assert(curOp.getNumVariableOperands() == + curOp.getOperation()->getNumOperands() && + "unexpected non-variable operands"); + for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) { + Value originalVariableOperand = curOp.getVariableOperand(idx); + if (!originalVariableOperand) + return failure(); + if (originalVariableOperand.getType().isa()) { + // TODO: Support memref type in variable operands + return rewriter.notifyMatchFailure(curOp, + "memref is not supported yet"); + } + convertedOperands.emplace_back(adaptor.getOperands()[idx]); + } + auto newOp = rewriter.create(curOp.getLoc(), resTypes, convertedOperands, + curOp->getAttrs()); + rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), + *this->getTypeConverter()))) + return failure(); + + rewriter.eraseOp(curOp); + return success(); + } +}; + struct ReductionOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult @@ -114,13 +152,14 @@ struct LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern { void mlir::configureOpenMPToLLVMConversionLegality( ConversionTarget &target, LLVMTypeConverter &typeConverter) { target.addDynamicallyLegalOp< - mlir::omp::CriticalOp, mlir::omp::ParallelOp, mlir::omp::WsLoopOp, - mlir::omp::SimdLoopOp, mlir::omp::MasterOp, mlir::omp::SectionsOp, - mlir::omp::SingleOp, mlir::omp::TaskOp>([&](Operation *op) { - return typeConverter.isLegal(&op->getRegion(0)) && - typeConverter.isLegal(op->getOperandTypes()) && - typeConverter.isLegal(op->getResultTypes()); - }); + mlir::omp::AtomicUpdateOp, mlir::omp::CriticalOp, mlir::omp::ParallelOp, + mlir::omp::WsLoopOp, mlir::omp::SimdLoopOp, mlir::omp::MasterOp, + mlir::omp::SectionsOp, mlir::omp::SingleOp, mlir::omp::TaskOp>( + [&](Operation *op) { + return typeConverter.isLegal(&op->getRegion(0)) && + typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); target.addDynamicallyLegalOp, RegionLessOpWithVarOperandsConversion, RegionLessOpWithVarOperandsConversion, + RegionOpWithVarOperandsConversion, RegionLessOpWithVarOperandsConversion, RegionLessOpWithVarOperandsConversion, LegalizeDataOpForLLVMTranslation, diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir index 354c67912377b..74c1b19ea5102 100644 --- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -109,6 +109,32 @@ func.func @atomic_read(%a: !llvm.ptr, %b: !llvm.ptr) -> () { // ----- +func.func @atomic_update() { + %0 = llvm.mlir.addressof @_QFsEc : !llvm.ptr + omp.atomic.update %0 : !llvm.ptr { + ^bb0(%arg0: i32): + %1 = arith.constant 1 : i32 + %2 = arith.addi %arg0, %1 : i32 + omp.yield(%2 : i32) + } + return +} +llvm.mlir.global internal @_QFsEc() : i32 { + %0 = arith.constant 10 : i32 + llvm.return %0 : i32 +} + +// CHECK-LABEL: @atomic_update +// CHECK: %[[GLOBAL_VAR:.*]] = llvm.mlir.addressof @_QFsEc : !llvm.ptr +// CHECK: omp.atomic.update %[[GLOBAL_VAR]] : !llvm.ptr { +// CHECK: ^bb0(%[[IN_VAL:.*]]: i32): +// CHECK: %[[CONST_1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[OUT_VAL:.*]] = llvm.add %[[IN_VAL]], %[[CONST_1]] : i32 +// CHECK: omp.yield(%[[OUT_VAL]] : i32) +// CHECK: } + +// ----- + // CHECK-LABEL: @threadprivate // CHECK: (%[[ARG0:.*]]: !llvm.ptr) // CHECK: %[[VAL0:.*]] = omp.threadprivate %[[ARG0]] : !llvm.ptr -> !llvm.ptr