diff --git a/flang/test/Lower/OpenMP/flush.f90 b/flang/test/Lower/OpenMP/flush.f90 index 86f6c68c166f2..e6970f47a7aa9 100644 --- a/flang/test/Lower/OpenMP/flush.f90 +++ b/flang/test/Lower/OpenMP/flush.f90 @@ -1,14 +1,16 @@ ! This test checks lowering of OpenMP Flush Directive. !RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefixes="FIRDialect,OMPDialect" -!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --cfg-conversion | fir-opt --fir-to-llvm-ir | FileCheck %s --check-prefixes="OMPDialect" +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --cfg-conversion | fir-opt --fir-to-llvm-ir | FileCheck %s --check-prefixes="LLVMIRDialect,OMPDialect" subroutine flush_standalone(a, b, c) integer, intent(inout) :: a, b, c !$omp flush(a,b,c) !$omp flush -!OMPDialect: omp.flush(%{{.*}}, %{{.*}}, %{{.*}} : !fir.ref, !fir.ref, !fir.ref) +!OMPDialect: omp.flush(%{{.*}}, %{{.*}}, %{{.*}} : +!FIRDialect: !fir.ref, !fir.ref, !fir.ref) +!LLVMIRDialect: !llvm.ptr, !llvm.ptr, !llvm.ptr) !OMPDialect: omp.flush end subroutine flush_standalone @@ -19,7 +21,9 @@ subroutine flush_parallel(a, b, c) !$omp parallel !OMPDialect: omp.parallel { -!OMPDialect: omp.flush(%{{.*}}, %{{.*}}, %{{.*}} : !fir.ref, !fir.ref, !fir.ref) +!OMPDialect: omp.flush(%{{.*}}, %{{.*}}, %{{.*}} : +!FIRDialect: !fir.ref, !fir.ref, !fir.ref) +!LLVMIRDialect: !llvm.ptr, !llvm.ptr, !llvm.ptr) !OMPDialect: omp.flush !$omp flush(a,b,c) !$omp flush diff --git a/flang/test/Lower/OpenMP/parallel-sections.f90 b/flang/test/Lower/OpenMP/parallel-sections.f90 index e9759072c5234..0b04bfadfb849 100644 --- a/flang/test/Lower/OpenMP/parallel-sections.f90 +++ b/flang/test/Lower/OpenMP/parallel-sections.f90 @@ -41,7 +41,9 @@ subroutine omp_parallel_sections_allocate(x, y) !FIRDialect: %[[allocator:.*]] = arith.constant 1 : i32 !LLVMDialect: %[[allocator:.*]] = llvm.mlir.constant(1 : i32) : i32 !OMPDialect: omp.parallel { - !OMPDialect: omp.sections allocate(%[[allocator]] : i32 -> %{{.*}} : !fir.ref) { + !OMPDialect: omp.sections allocate( + !FIRDialect: %[[allocator]] : i32 -> %{{.*}} : !fir.ref) { + !LLVMDialect: %[[allocator]] : i32 -> %{{.*}} : !llvm.ptr) { !$omp parallel sections allocate(omp_high_bw_mem_alloc: x) !OMPDialect: omp.section { !$omp section diff --git a/flang/test/Lower/OpenMP/parallel.f90 b/flang/test/Lower/OpenMP/parallel.f90 index 70dcee4ddd255..0f20115cf66e3 100644 --- a/flang/test/Lower/OpenMP/parallel.f90 +++ b/flang/test/Lower/OpenMP/parallel.f90 @@ -1,5 +1,5 @@ !RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefixes="FIRDialect,OMPDialect" -!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --fir-to-llvm-ir | FileCheck %s --check-prefixes="OMPDialect" +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --fir-to-llvm-ir | FileCheck %s --check-prefixes="LLVMDialect,OMPDialect" !FIRDialect-LABEL: func @_QPparallel_simple subroutine parallel_simple() @@ -152,7 +152,10 @@ end subroutine parallel_proc_bind subroutine parallel_allocate() use omp_lib integer :: x - !OMPDialect: omp.parallel allocate(%{{.+}} : i32 -> %{{.+}} : !fir.ref) { + !OMPDialect: omp.parallel allocate( + !FIRDialect: %{{.+}} : i32 -> %{{.+}} : !fir.ref + !LLVMDialect: %{{.+}} : i32 -> %{{.+}} : !llvm.ptr + !OMPDialect: ) { !$omp parallel allocate(omp_high_bw_mem_alloc: x) private(x) !FIRDialect: arith.addi x = x + 12 @@ -191,7 +194,10 @@ subroutine parallel_multiple_clauses(alpha, num_threads) !OMPDialect: omp.terminator !$omp end parallel - !OMPDialect: omp.parallel if({{.*}} : i1) num_threads({{.*}} : i32) allocate(%{{.+}} : i32 -> %{{.+}} : !fir.ref) { + !OMPDialect: omp.parallel if({{.*}} : i1) num_threads({{.*}} : i32) allocate( + !FIRDialect: %{{.+}} : i32 -> %{{.+}} : !fir.ref + !LLVMDialect: %{{.+}} : i32 -> %{{.+}} : !llvm.ptr + !OMPDialect: ) { !$omp parallel num_threads(num_threads) if(alpha .le. 0) allocate(omp_high_bw_mem_alloc: alpha) private(alpha) !FIRDialect: fir.call call f3() diff --git a/flang/test/Lower/OpenMP/single.f90 b/flang/test/Lower/OpenMP/single.f90 index e159dcf73d9e2..b655d78075bdd 100644 --- a/flang/test/Lower/OpenMP/single.f90 +++ b/flang/test/Lower/OpenMP/single.f90 @@ -1,5 +1,5 @@ !RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefixes="FIRDialect,OMPDialect" -!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --cfg-conversion | fir-opt --fir-to-llvm-ir | FileCheck %s --check-prefixes="OMPDialect" +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --cfg-conversion | fir-opt --fir-to-llvm-ir | FileCheck %s --check-prefixes="LLVMDialect,OMPDialect" !=============================================================================== ! Single construct @@ -55,7 +55,10 @@ subroutine single_allocate() integer :: x !OMPDialect: omp.parallel { !$omp parallel - !OMPDialect: omp.single allocate(%{{.+}} : i32 -> %{{.+}} : !fir.ref) { + !OMPDialect: omp.single allocate( + !FIRDialect: %{{.+}} : i32 -> %{{.+}} : !fir.ref + !LLVMDialect: %{{.+}} : i32 -> %{{.+}} : !llvm.ptr + !OMPDialect: ) { !$omp single allocate(omp_high_bw_mem_alloc: x) private(x) !FIRDialect: arith.addi x = x + 12 diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 8d3654ce37ba4..7c183a9788ed6 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -574,9 +574,19 @@ def FlushOp : OpenMP_Op<"flush"> { specified or implied. }]; - let arguments = (ins Variadic:$varList); + let arguments = (ins Variadic:$varList); let assemblyFormat = [{ ( `(` $varList^ `:` type($varList) `)` )? attr-dict}]; + let extraClassDeclaration = [{ + /// The number of variable operands. + unsigned getNumVariableOperands() { + return getOperation()->getNumOperands(); + } + /// The i-th variable operand passed. + Value getVariableOperand(unsigned i) { + return getOperand(i); + } + }]; } //===----------------------------------------------------------------------===// // 2.14.5 target construct diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index 3c8416b75511f..bb8f523a1374f 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -47,7 +47,8 @@ struct RegionOpConversion : public ConvertOpToLLVMPattern { }; template -struct RegionLessOpConversion : public ConvertOpToLLVMPattern { +struct RegionLessOpWithVarOperandsConversion + : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(T curOp, typename T::Adaptor adaptor, @@ -57,6 +58,9 @@ struct RegionLessOpConversion : public ConvertOpToLLVMPattern { if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes))) return failure(); SmallVector convertedOperands; + assert(curOp.getNumVariableOperands() == + curOp.getOperation()->getNumOperands() && + "unexpected non-variable operands"); for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) { Value originalVariableOperand = curOp.getVariableOperand(idx); if (!originalVariableOperand) @@ -78,23 +82,31 @@ struct RegionLessOpConversion : public ConvertOpToLLVMPattern { void mlir::configureOpenMPToLLVMConversionLegality( ConversionTarget &target, LLVMTypeConverter &typeConverter) { target.addDynamicallyLegalOp( - [&](Operation *op) { return typeConverter.isLegal(&op->getRegion(0)); }); + mlir::omp::MasterOp, mlir::omp::SectionsOp, + mlir::omp::SingleOp>([&](Operation *op) { + return typeConverter.isLegal(&op->getRegion(0)) && + typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); target .addDynamicallyLegalOp([&](Operation *op) { - return typeConverter.isLegal(op->getOperandTypes()); - }); + mlir::omp::FlushOp, mlir::omp::ThreadprivateOp>( + [&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); } void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add, - RegionOpConversion, - RegionOpConversion, - RegionLessOpConversion, - RegionLessOpConversion, - RegionLessOpConversion>(converter); + patterns.add< + RegionOpConversion, RegionOpConversion, + RegionOpConversion, RegionOpConversion, + RegionOpConversion, + RegionLessOpWithVarOperandsConversion, + RegionLessOpWithVarOperandsConversion, + RegionLessOpWithVarOperandsConversion, + RegionLessOpWithVarOperandsConversion>(converter); } namespace { diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index fe85130a8a101..3a2ac4b905081 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -29,19 +29,19 @@ func.func @omp_taskyield() -> () { } // CHECK-LABEL: func @omp_flush -// CHECK-SAME: ([[ARG0:%.*]]: i32) { -func.func @omp_flush(%arg0 : i32) -> () { +// CHECK-SAME: ([[ARG0:%.*]]: memref) { +func.func @omp_flush(%arg0 : memref) -> () { // Test without data var // CHECK: omp.flush omp.flush // Test with one data var - // CHECK: omp.flush([[ARG0]] : i32) - omp.flush(%arg0 : i32) + // CHECK: omp.flush([[ARG0]] : memref) + omp.flush(%arg0 : memref) // Test with two data var - // CHECK: omp.flush([[ARG0]], [[ARG0]] : i32, i32) - omp.flush(%arg0, %arg0: i32, i32) + // CHECK: omp.flush([[ARG0]], [[ARG0]] : memref, memref) + omp.flush(%arg0, %arg0: memref, memref) return } diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir index 7b484812d22e3..bcf4ddd40d4aa 100644 --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -18,16 +18,16 @@ llvm.func @test_stand_alone_directives() { llvm.return } -// CHECK-LABEL: define void @test_flush_construct(i32 %0) -llvm.func @test_flush_construct(%arg0: i32) { +// CHECK-LABEL: define void @test_flush_construct(ptr %{{[0-9]+}}) +llvm.func @test_flush_construct(%arg0: !llvm.ptr) { // CHECK: call void @__kmpc_flush(ptr @{{[0-9]+}} omp.flush // CHECK: call void @__kmpc_flush(ptr @{{[0-9]+}} - omp.flush (%arg0 : i32) + omp.flush (%arg0 : !llvm.ptr) // CHECK: call void @__kmpc_flush(ptr @{{[0-9]+}} - omp.flush (%arg0, %arg0 : i32, i32) + omp.flush (%arg0, %arg0 : !llvm.ptr, !llvm.ptr) %0 = llvm.mlir.constant(1 : i64) : i64 // CHECK: alloca {{.*}} align 4