Skip to content

Commit

Permalink
[MLIR][OpenMP] Add Conversion for Atomic Update Op
Browse files Browse the repository at this point in the history
Reviewed By: TIFitis

Differential Revision: https://reviews.llvm.org/D143964
  • Loading branch information
kiranchandramohan committed Feb 16, 2023
1 parent b374423 commit 22cdeb5
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 7 deletions.
32 changes: 32 additions & 0 deletions flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,35 @@ func.func private @_QPwork()
// CHECK: }
// CHECK: llvm.func @_QPwork() attributes {sym_visibility = "private"}
// CHECK: }

// -----

func.func @_QPs() {
%0 = fir.address_of(@_QFsEc) : !fir.ref<i32>
omp.atomic.update %0 : !fir.ref<i32> {
^bb0(%arg0: i32):
%c1_i32 = arith.constant 1 : i32
%1 = arith.addi %arg0, %c1_i32 : i32
omp.yield(%1 : i32)
}
return
}
fir.global internal @_QFsEc : i32 {
%c10_i32 = arith.constant 10 : i32
fir.has_value %c10_i32 : i32
}

// CHECK-LABEL: llvm.func @_QPs() {
// CHECK: %[[GLOBAL_VAR:.*]] = llvm.mlir.addressof @[[GLOBAL:.*]] : !llvm.ptr<i32>
// CHECK: omp.atomic.update %[[GLOBAL_VAR]] : !llvm.ptr<i32> {
// CHECK: ^bb0(%[[IN_VAL:.*]]: i32):
// CHECK: %[[CONST_1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[OUT_VAL:.*]] = llvm.add %[[IN_VAL]], %[[CONST_1]] : i32
// CHECK: omp.yield(%[[OUT_VAL]] : i32)
// CHECK: }
// CHECK: llvm.return
// CHECK: }
// CHECK: llvm.mlir.global internal @[[GLOBAL]]() {{.*}} : i32 {
// CHECK: %[[INIT_10:.*]] = llvm.mlir.constant(10 : i32) : i32
// CHECK: llvm.return %[[INIT_10]] : i32
// CHECK: }
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1359,6 +1359,18 @@ def AtomicUpdateOp : OpenMP_Op<"atomic.update",
/// Returns the new value if the operation is equivalent to just a write
/// operation. Otherwise, returns nullptr.
Value getWriteOpVal();

/// The number of variable operands.
unsigned getNumVariableOperands() {
assert(getX() && "expected 'x' operand");
return 1;
}

/// The i-th variable operand passed.
Value getVariableOperand(unsigned i) {
assert(i == 0 && "invalid index position for an operand");
return getX();
}
}];
}

Expand Down
54 changes: 47 additions & 7 deletions mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,44 @@ struct RegionLessOpWithVarOperandsConversion
}
};

template <typename T>
struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern<T> {
using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(T curOp, typename T::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
SmallVector<Type> resTypes;
if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
return failure();
SmallVector<Value> convertedOperands;
assert(curOp.getNumVariableOperands() ==
curOp.getOperation()->getNumOperands() &&
"unexpected non-variable operands");
for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
Value originalVariableOperand = curOp.getVariableOperand(idx);
if (!originalVariableOperand)
return failure();
if (originalVariableOperand.getType().isa<MemRefType>()) {
// TODO: Support memref type in variable operands
return rewriter.notifyMatchFailure(curOp,
"memref is not supported yet");
}
convertedOperands.emplace_back(adaptor.getOperands()[idx]);
}
auto newOp = rewriter.create<T>(curOp.getLoc(), resTypes, convertedOperands,
curOp->getAttrs());
rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
newOp.getRegion().end());
if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
*this->getTypeConverter())))
return failure();

rewriter.eraseOp(curOp);
return success();
}
};

struct ReductionOpConversion : public ConvertOpToLLVMPattern<omp::ReductionOp> {
using ConvertOpToLLVMPattern<omp::ReductionOp>::ConvertOpToLLVMPattern;
LogicalResult
Expand Down Expand Up @@ -114,13 +152,14 @@ struct LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> {
void mlir::configureOpenMPToLLVMConversionLegality(
ConversionTarget &target, LLVMTypeConverter &typeConverter) {
target.addDynamicallyLegalOp<
mlir::omp::CriticalOp, mlir::omp::ParallelOp, mlir::omp::WsLoopOp,
mlir::omp::SimdLoopOp, mlir::omp::MasterOp, mlir::omp::SectionsOp,
mlir::omp::SingleOp, mlir::omp::TaskOp>([&](Operation *op) {
return typeConverter.isLegal(&op->getRegion(0)) &&
typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes());
});
mlir::omp::AtomicUpdateOp, mlir::omp::CriticalOp, mlir::omp::ParallelOp,
mlir::omp::WsLoopOp, mlir::omp::SimdLoopOp, mlir::omp::MasterOp,
mlir::omp::SectionsOp, mlir::omp::SingleOp, mlir::omp::TaskOp>(
[&](Operation *op) {
return typeConverter.isLegal(&op->getRegion(0)) &&
typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes());
});
target.addDynamicallyLegalOp<mlir::omp::AtomicReadOp,
mlir::omp::AtomicWriteOp, mlir::omp::FlushOp,
mlir::omp::ThreadprivateOp, mlir::omp::DataOp,
Expand All @@ -145,6 +184,7 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
RegionOpConversion<omp::TaskOp>,
RegionLessOpWithVarOperandsConversion<omp::AtomicReadOp>,
RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
RegionOpWithVarOperandsConversion<omp::AtomicUpdateOp>,
RegionLessOpWithVarOperandsConversion<omp::FlushOp>,
RegionLessOpWithVarOperandsConversion<omp::ThreadprivateOp>,
LegalizeDataOpForLLVMTranslation<omp::DataOp>,
Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,32 @@ func.func @atomic_read(%a: !llvm.ptr<i32>, %b: !llvm.ptr<i32>) -> () {

// -----

func.func @atomic_update() {
%0 = llvm.mlir.addressof @_QFsEc : !llvm.ptr<i32>
omp.atomic.update %0 : !llvm.ptr<i32> {
^bb0(%arg0: i32):
%1 = arith.constant 1 : i32
%2 = arith.addi %arg0, %1 : i32
omp.yield(%2 : i32)
}
return
}
llvm.mlir.global internal @_QFsEc() : i32 {
%0 = arith.constant 10 : i32
llvm.return %0 : i32
}

// CHECK-LABEL: @atomic_update
// CHECK: %[[GLOBAL_VAR:.*]] = llvm.mlir.addressof @_QFsEc : !llvm.ptr<i32>
// CHECK: omp.atomic.update %[[GLOBAL_VAR]] : !llvm.ptr<i32> {
// CHECK: ^bb0(%[[IN_VAL:.*]]: i32):
// CHECK: %[[CONST_1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[OUT_VAL:.*]] = llvm.add %[[IN_VAL]], %[[CONST_1]] : i32
// CHECK: omp.yield(%[[OUT_VAL]] : i32)
// CHECK: }

// -----

// CHECK-LABEL: @threadprivate
// CHECK: (%[[ARG0:.*]]: !llvm.ptr<i32>)
// CHECK: %[[VAL0:.*]] = omp.threadprivate %[[ARG0]] : !llvm.ptr<i32> -> !llvm.ptr<i32>
Expand Down

0 comments on commit 22cdeb5

Please sign in to comment.