diff --git a/clang/test/OpenMP/cancel_codegen.cpp b/clang/test/OpenMP/cancel_codegen.cpp index 16e7542a8e826..150cdb9b2cc14 100644 --- a/clang/test/OpenMP/cancel_codegen.cpp +++ b/clang/test/OpenMP/cancel_codegen.cpp @@ -811,16 +811,16 @@ for (int i = 0; i < argc; ++i) { // CHECK3-NEXT: br label [[OMP_SECTION_LOOP_BODY_CASE23_SECTION_AFTER:%.*]] // CHECK3: omp_section_loop.body.case23.section.after: // CHECK3-NEXT: br label [[OMP_SECTION_LOOP_BODY16_SECTIONS_AFTER]] -// CHECK3: omp_section_loop.body.case25: +// CHECK3: omp_section_loop.body.case26: // CHECK3-NEXT: [[OMP_GLOBAL_THREAD_NUM27:%.*]] = call i32 @__kmpc_global_thread_num(ptr @[[GLOB1]]) // CHECK3-NEXT: [[TMP18:%.*]] = call i32 @__kmpc_cancel(ptr @[[GLOB1]], i32 [[OMP_GLOBAL_THREAD_NUM27]], i32 3) // CHECK3-NEXT: [[TMP19:%.*]] = icmp eq i32 [[TMP18]], 0 // CHECK3-NEXT: br i1 [[TMP19]], label [[OMP_SECTION_LOOP_BODY_CASE25_SPLIT:%.*]], label [[OMP_SECTION_LOOP_BODY_CASE25_CNCL:%.*]] -// CHECK3: omp_section_loop.body.case25.split: +// CHECK3: omp_section_loop.body.case26.split: // CHECK3-NEXT: br label [[OMP_SECTION_LOOP_BODY_CASE25_SECTION_AFTER26:%.*]] -// CHECK3: omp_section_loop.body.case25.section.after26: +// CHECK3: omp_section_loop.body.case26.section.after27: // CHECK3-NEXT: br label [[OMP_SECTION_LOOP_BODY_CASE25_SECTION_AFTER:%.*]] -// CHECK3: omp_section_loop.body.case25.section.after: +// CHECK3: omp_section_loop.body.case26.section.after: // CHECK3-NEXT: br label [[OMP_SECTION_LOOP_BODY16_SECTIONS_AFTER]] // CHECK3: omp_section_loop.body16.sections.after: // CHECK3-NEXT: br label [[OMP_SECTION_LOOP_INC17]] @@ -891,10 +891,12 @@ for (int i = 0; i < argc; ++i) { // CHECK3: .cancel.exit: // CHECK3-NEXT: br label [[CANCEL_EXIT:%.*]] // CHECK3: omp_section_loop.body.case.cncl: -// CHECK3-NEXT: br label [[OMP_SECTION_LOOP_EXIT]] -// CHECK3: omp_section_loop.body.case23.cncl: +// CHECK3-NEXT: br label [[FINI10:.*]] +// CHECK3: .fini25: // CHECK3-NEXT: br label [[OMP_SECTION_LOOP_EXIT18]] -// CHECK3: omp_section_loop.body.case25.cncl: +// CHECK3: omp_section_loop.body.case26.cncl: +// CHECK3-NEXT: br label [[FINI29:.*]] +// CHECK3: .fini29: // CHECK3-NEXT: br label [[OMP_SECTION_LOOP_EXIT18]] // CHECK3: .cancel.continue: // CHECK3-NEXT: br label [[OMP_IF_END:%.*]] @@ -967,8 +969,10 @@ for (int i = 0; i < argc; ++i) { // CHECK3-NEXT: [[TMP8:%.*]] = call i32 @__kmpc_cancel_barrier(ptr @[[GLOB3:[0-9]+]], i32 [[OMP_GLOBAL_THREAD_NUM4]]) // CHECK3-NEXT: [[TMP9:%.*]] = icmp eq i32 [[TMP8]], 0 // CHECK3-NEXT: br i1 [[TMP9]], label [[DOTCONT:%.*]], label [[DOTCNCL5:%.*]] -// CHECK3: .cncl5: -// CHECK3-NEXT: br label [[OMP_PAR_OUTLINED_EXIT_EXITSTUB:%.*]] +// CHECK3: .cncl4: +// CHECK3-NEXT: br label [[FINI:%.*]] +// CHECK3: .fini +// CHECK3-NEXT: br label %[[EXIT_STUB:omp.par.exit.exitStub]] // CHECK3: .cont: // CHECK3-NEXT: [[TMP10:%.*]] = load i32, ptr [[LOADGEP_ARGC_ADDR]], align 4 // CHECK3-NEXT: [[TMP11:%.*]] = load ptr, ptr [[LOADGEP_ARGV_ADDR]], align 8 @@ -984,16 +988,14 @@ for (int i = 0; i < argc; ++i) { // CHECK3: omp.par.region.parallel.after: // CHECK3-NEXT: br label [[OMP_PAR_PRE_FINALIZE:%.*]] // CHECK3: omp.par.pre_finalize: -// CHECK3-NEXT: br label [[OMP_PAR_OUTLINED_EXIT_EXITSTUB]] +// CHECK3-NEXT: br label [[FINI]] // CHECK3: 14: // CHECK3-NEXT: [[OMP_GLOBAL_THREAD_NUM1:%.*]] = call i32 @__kmpc_global_thread_num(ptr @[[GLOB1]]) // CHECK3-NEXT: [[TMP15:%.*]] = call i32 @__kmpc_cancel(ptr @[[GLOB1]], i32 [[OMP_GLOBAL_THREAD_NUM1]], i32 1) // CHECK3-NEXT: [[TMP16:%.*]] = icmp eq i32 [[TMP15]], 0 // CHECK3-NEXT: br i1 [[TMP16]], label [[DOTSPLIT:%.*]], label [[DOTCNCL:%.*]] // CHECK3: .cncl: -// CHECK3-NEXT: [[OMP_GLOBAL_THREAD_NUM2:%.*]] = call i32 @__kmpc_global_thread_num(ptr @[[GLOB1]]) -// CHECK3-NEXT: [[TMP17:%.*]] = call i32 @__kmpc_cancel_barrier(ptr @[[GLOB2]], i32 [[OMP_GLOBAL_THREAD_NUM2]]) -// CHECK3-NEXT: br label [[OMP_PAR_OUTLINED_EXIT_EXITSTUB]] +// CHECK3-NEXT: br label [[FINI]] // CHECK3: .split: // CHECK3-NEXT: br label [[TMP4]] // CHECK3: omp.par.exit.exitStub: @@ -1089,7 +1091,7 @@ for (int i = 0; i < argc; ++i) { // CHECK3: .omp.sections.case.split: // CHECK3-NEXT: br label [[DOTOMP_SECTIONS_EXIT]] // CHECK3: .omp.sections.case.cncl: -// CHECK3-NEXT: br label [[CANCEL_CONT:%.*]] +// CHECK3-NEXT: br label [[FINI:%.*]] // CHECK3: .omp.sections.exit: // CHECK3-NEXT: br label [[OMP_INNER_FOR_INC:%.*]] // CHECK3: omp.inner.for.inc: @@ -1100,7 +1102,7 @@ for (int i = 0; i < argc; ++i) { // CHECK3: omp.inner.for.end: // CHECK3-NEXT: [[OMP_GLOBAL_THREAD_NUM3:%.*]] = call i32 @__kmpc_global_thread_num(ptr @[[GLOB19:[0-9]+]]) // CHECK3-NEXT: call void @__kmpc_for_static_fini(ptr @[[GLOB15]], i32 [[OMP_GLOBAL_THREAD_NUM3]]) -// CHECK3-NEXT: br label [[CANCEL_CONT]] +// CHECK3-NEXT: br label [[CANCEL_CONT:.*]] // CHECK3: cancel.cont: // CHECK3-NEXT: ret void // CHECK3: cancel.exit: @@ -1153,6 +1155,8 @@ for (int i = 0; i < argc; ++i) { // CHECK3: .omp.sections.case.split: // CHECK3-NEXT: br label [[DOTOMP_SECTIONS_EXIT]] // CHECK3: .omp.sections.case.cncl: +// CHECK3-NEXT: br label [[DOTFINI:.%*]] +// CHECK3: .fini: // CHECK3-NEXT: br label [[CANCEL_CONT:%.*]] // CHECK3: .omp.sections.case2: // CHECK3-NEXT: [[OMP_GLOBAL_THREAD_NUM3:%.*]] = call i32 @__kmpc_global_thread_num(ptr @[[GLOB1]]) @@ -1164,7 +1168,7 @@ for (int i = 0; i < argc; ++i) { // CHECK3: .omp.sections.case2.section.after: // CHECK3-NEXT: br label [[DOTOMP_SECTIONS_EXIT]] // CHECK3: .omp.sections.case2.cncl: -// CHECK3-NEXT: br label [[OMP_INNER_FOR_END]] +// CHECK3-NEXT: br label [[FINI:.*]] // CHECK3: .omp.sections.exit: // CHECK3-NEXT: br label [[OMP_INNER_FOR_INC:%.*]] // CHECK3: omp.inner.for.inc: diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 5331cb5abdc6f..3658dfe263424 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -580,6 +580,11 @@ class OpenMPIRBuilder { /// Flag to indicate if the directive is cancellable. bool IsCancellable; + + /// The basic block to which control should be transferred to + /// implement the FiniCB. Memoized to avoid generating finalization + /// multiple times. + llvm::BasicBlock *FiniBB = nullptr; }; /// Push a finalization callback on the finalization stack. @@ -2181,8 +2186,7 @@ class OpenMPIRBuilder { /// /// \return an error, if any were triggered during execution. LLVM_ABI Error emitCancelationCheckImpl(Value *CancelFlag, - omp::Directive CanceledDirective, - FinalizeCallbackTy ExitCB = {}); + omp::Directive CanceledDirective); /// Generate a target region entry call. /// diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 286ed039b1214..340b0604af528 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1108,21 +1108,9 @@ OpenMPIRBuilder::createCancel(const LocationDescription &Loc, Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind}; Value *Result = Builder.CreateCall( getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_cancel), Args); - auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) -> Error { - if (CanceledDirective == OMPD_parallel) { - IRBuilder<>::InsertPointGuard IPG(Builder); - Builder.restoreIP(IP); - return createBarrier(LocationDescription(Builder.saveIP(), Loc.DL), - omp::Directive::OMPD_unknown, - /* ForceSimpleCall */ false, - /* CheckCancelFlag */ false) - .takeError(); - } - return Error::success(); - }; // The actual cancel logic is shared with others, e.g., cancel_barriers. - if (Error Err = emitCancelationCheckImpl(Result, CanceledDirective, ExitCB)) + if (Error Err = emitCancelationCheckImpl(Result, CanceledDirective)) return Err; // Update the insertion point and remove the terminator we introduced. @@ -1159,21 +1147,9 @@ OpenMPIRBuilder::createCancellationPoint(const LocationDescription &Loc, Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind}; Value *Result = Builder.CreateCall( getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_cancellationpoint), Args); - auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) -> Error { - if (CanceledDirective == OMPD_parallel) { - IRBuilder<>::InsertPointGuard IPG(Builder); - Builder.restoreIP(IP); - return createBarrier(LocationDescription(Builder.saveIP(), Loc.DL), - omp::Directive::OMPD_unknown, - /* ForceSimpleCall */ false, - /* CheckCancelFlag */ false) - .takeError(); - } - return Error::success(); - }; // The actual cancel logic is shared with others, e.g., cancel_barriers. - if (Error Err = emitCancelationCheckImpl(Result, CanceledDirective, ExitCB)) + if (Error Err = emitCancelationCheckImpl(Result, CanceledDirective)) return Err; // Update the insertion point and remove the terminator we introduced. @@ -1277,8 +1253,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitKernelLaunch( } Error OpenMPIRBuilder::emitCancelationCheckImpl( - Value *CancelFlag, omp::Directive CanceledDirective, - FinalizeCallbackTy ExitCB) { + Value *CancelFlag, omp::Directive CanceledDirective) { assert(isLastFinalizationInfoCancellable(CanceledDirective) && "Unexpected cancellation!"); @@ -1305,13 +1280,17 @@ Error OpenMPIRBuilder::emitCancelationCheckImpl( // From the cancellation block we finalize all variables and go to the // post finalization block that is known to the FiniCB callback. - Builder.SetInsertPoint(CancellationBlock); - if (ExitCB) - if (Error Err = ExitCB(Builder.saveIP())) - return Err; auto &FI = FinalizationStack.back(); - if (Error Err = FI.FiniCB(Builder.saveIP())) - return Err; + if (!FI.FiniBB) { + llvm::IRBuilderBase::InsertPointGuard Guard(Builder); + FI.FiniBB = BasicBlock::Create(BB->getContext(), ".fini", BB->getParent()); + Builder.SetInsertPoint(FI.FiniBB); + // FiniCB adds the branch to the exit stub. + if (Error Err = FI.FiniCB(Builder.saveIP())) + return Err; + } + Builder.SetInsertPoint(CancellationBlock); + Builder.CreateBr(FI.FiniBB); // The continuation block is where code generation continues. Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin()); @@ -1800,8 +1779,18 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel( Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator(); InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator()); - if (Error Err = FiniCB(PreFiniIP)) - return Err; + if (!FiniInfo.FiniBB) { + if (Error Err = FiniCB(PreFiniIP)) + return Err; + } else { + IRBuilderBase::InsertPointGuard Guard{Builder}; + Builder.restoreIP(PreFiniIP); + Builder.CreateBr(FiniInfo.FiniBB); + // There's currently a branch to omp.par.exit. Delete it. We will get there + // via the fini block + if (llvm::Instruction *Term = Builder.GetInsertBlock()->getTerminator()) + Term->eraseFromParent(); + } // Register the outlined info. addOutlineInfo(std::move(OI)); @@ -6556,13 +6545,14 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitCommonDirectiveExit( FinalizationInfo Fi = FinalizationStack.pop_back_val(); assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!"); - if (Error Err = Fi.FiniCB(FinIP)) - return Err; - - BasicBlock *FiniBB = FinIP.getBlock(); - Instruction *FiniBBTI = FiniBB->getTerminator(); + if (!Fi.FiniBB) { + if (Error Err = Fi.FiniCB(FinIP)) + return Err; + Fi.FiniBB = FinIP.getBlock(); + } // set Builder IP for call creation + Instruction *FiniBBTI = Fi.FiniBB->getTerminator(); Builder.SetInsertPoint(FiniBBTI); } diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index e56872320b4ac..da1760a56d952 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -428,8 +428,8 @@ TEST_F(OpenMPIRBuilderTest, CreateCancel) { OMPBuilder.createCancel(Loc, nullptr, OMPD_parallel)); Builder.restoreIP(NewIP); EXPECT_FALSE(M->global_empty()); - EXPECT_EQ(M->size(), 4U); - EXPECT_EQ(F->size(), 4U); + EXPECT_EQ(M->size(), 3U); + EXPECT_EQ(F->size(), 5U); EXPECT_EQ(BB->size(), 4U); CallInst *GTID = dyn_cast(&BB->front()); @@ -449,23 +449,16 @@ TEST_F(OpenMPIRBuilderTest, CreateCancel) { Instruction *CancelBBTI = Cancel->getParent()->getTerminator(); EXPECT_EQ(CancelBBTI->getNumSuccessors(), 2U); EXPECT_EQ(CancelBBTI->getSuccessor(0), NewIP.getBlock()); - EXPECT_EQ(CancelBBTI->getSuccessor(1)->size(), 3U); - CallInst *GTID1 = dyn_cast(&CancelBBTI->getSuccessor(1)->front()); - EXPECT_NE(GTID1, nullptr); - EXPECT_EQ(GTID1->arg_size(), 1U); - EXPECT_EQ(GTID1->getCalledFunction()->getName(), "__kmpc_global_thread_num"); - EXPECT_FALSE(GTID1->getCalledFunction()->doesNotAccessMemory()); - EXPECT_FALSE(GTID1->getCalledFunction()->doesNotFreeMemory()); - CallInst *Barrier = dyn_cast(GTID1->getNextNode()); - EXPECT_NE(Barrier, nullptr); - EXPECT_EQ(Barrier->arg_size(), 2U); - EXPECT_EQ(Barrier->getCalledFunction()->getName(), "__kmpc_cancel_barrier"); - EXPECT_FALSE(Barrier->getCalledFunction()->doesNotAccessMemory()); - EXPECT_FALSE(Barrier->getCalledFunction()->doesNotFreeMemory()); - EXPECT_TRUE(Barrier->use_empty()); + EXPECT_EQ(CancelBBTI->getSuccessor(1)->size(), 1U); EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getNumSuccessors(), 1U); - EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getSuccessor(0), CBB); + // cancel branch instruction (1) -> .cncl -> .fini -> CBB + EXPECT_EQ(CancelBBTI->getSuccessor(1) + ->getTerminator() + ->getSuccessor(0) + ->getTerminator() + ->getSuccessor(0), + CBB); EXPECT_EQ(cast(Cancel)->getArgOperand(1), GTID); @@ -497,8 +490,8 @@ TEST_F(OpenMPIRBuilderTest, CreateCancelIfCond) { OMPBuilder.createCancel(Loc, Builder.getTrue(), OMPD_parallel)); Builder.restoreIP(NewIP); EXPECT_FALSE(M->global_empty()); - EXPECT_EQ(M->size(), 4U); - EXPECT_EQ(F->size(), 7U); + EXPECT_EQ(M->size(), 3U); + EXPECT_EQ(F->size(), 8U); EXPECT_EQ(BB->size(), 1U); ASSERT_TRUE(isa(BB->getTerminator())); ASSERT_EQ(BB->getTerminator()->getNumSuccessors(), 2U); @@ -524,23 +517,15 @@ TEST_F(OpenMPIRBuilderTest, CreateCancelIfCond) { EXPECT_EQ(CancelBBTI->getSuccessor(0)->size(), 1U); EXPECT_EQ(CancelBBTI->getSuccessor(0)->getUniqueSuccessor(), NewIP.getBlock()); - EXPECT_EQ(CancelBBTI->getSuccessor(1)->size(), 3U); - CallInst *GTID1 = dyn_cast(&CancelBBTI->getSuccessor(1)->front()); - EXPECT_NE(GTID1, nullptr); - EXPECT_EQ(GTID1->arg_size(), 1U); - EXPECT_EQ(GTID1->getCalledFunction()->getName(), "__kmpc_global_thread_num"); - EXPECT_FALSE(GTID1->getCalledFunction()->doesNotAccessMemory()); - EXPECT_FALSE(GTID1->getCalledFunction()->doesNotFreeMemory()); - CallInst *Barrier = dyn_cast(GTID1->getNextNode()); - EXPECT_NE(Barrier, nullptr); - EXPECT_EQ(Barrier->arg_size(), 2U); - EXPECT_EQ(Barrier->getCalledFunction()->getName(), "__kmpc_cancel_barrier"); - EXPECT_FALSE(Barrier->getCalledFunction()->doesNotAccessMemory()); - EXPECT_FALSE(Barrier->getCalledFunction()->doesNotFreeMemory()); - EXPECT_TRUE(Barrier->use_empty()); + EXPECT_EQ(CancelBBTI->getSuccessor(1)->size(), 1U); EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getNumSuccessors(), 1U); - EXPECT_EQ(CancelBBTI->getSuccessor(1)->getTerminator()->getSuccessor(0), CBB); + EXPECT_EQ(CancelBBTI->getSuccessor(1) + ->getTerminator() + ->getSuccessor(0) + ->getTerminator() + ->getSuccessor(0), + CBB); EXPECT_EQ(cast(Cancel)->getArgOperand(1), GTID); @@ -572,7 +557,7 @@ TEST_F(OpenMPIRBuilderTest, CreateCancelBarrier) { Builder.restoreIP(NewIP); EXPECT_FALSE(M->global_empty()); EXPECT_EQ(M->size(), 3U); - EXPECT_EQ(F->size(), 4U); + EXPECT_EQ(F->size(), 5U); EXPECT_EQ(BB->size(), 4U); CallInst *GTID = dyn_cast(&BB->front()); @@ -595,7 +580,11 @@ TEST_F(OpenMPIRBuilderTest, CreateCancelBarrier) { EXPECT_EQ(BarrierBBTI->getSuccessor(1)->size(), 1U); EXPECT_EQ(BarrierBBTI->getSuccessor(1)->getTerminator()->getNumSuccessors(), 1U); - EXPECT_EQ(BarrierBBTI->getSuccessor(1)->getTerminator()->getSuccessor(0), + EXPECT_EQ(BarrierBBTI->getSuccessor(1) + ->getTerminator() + ->getSuccessor(0) + ->getTerminator() + ->getSuccessor(0), CBB); EXPECT_EQ(cast(Barrier)->getArgOperand(1), GTID); @@ -1291,8 +1280,8 @@ TEST_F(OpenMPIRBuilderTest, ParallelCancelBarrier) { EXPECT_EQ(NumBodiesGenerated, 1U); EXPECT_EQ(NumPrivatizedVars, 0U); - EXPECT_EQ(NumFinalizationPoints, 2U); - EXPECT_TRUE(FakeDestructor->hasNUses(2)); + EXPECT_EQ(NumFinalizationPoints, 1U); + EXPECT_TRUE(FakeDestructor->hasNUses(1)); Builder.restoreIP(AfterIP); Builder.CreateRetVoid(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 8de49dd397d26..1629be2ba46c4 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -2662,6 +2662,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, ArrayRef isByRef = getIsByRef(opInst.getReductionByref()); assert(isByRef.size() == opInst.getNumReductionVars()); llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + bool isCancellable = constructIsCancellable(opInst); if (failed(checkImplementationStatus(*opInst))) return failure(); @@ -2797,6 +2798,18 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, privateVarsInfo.privatizers))) return llvm::make_error(); + // If we could be performing cancellation, add the cancellation barrier on + // the way out of the outlined region. + if (isCancellable) { + auto IPOrErr = ompBuilder->createBarrier( + llvm::OpenMPIRBuilder::LocationDescription(builder), + llvm::omp::Directive::OMPD_unknown, + /* ForceSimpleCall */ false, + /* CheckCancelFlag */ false); + if (!IPOrErr) + return IPOrErr.takeError(); + } + builder.restoreIP(oldIP); return llvm::Error::success(); }; @@ -2810,7 +2823,6 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, auto pbKind = llvm::omp::OMP_PROC_BIND_default; if (auto bind = opInst.getProcBindKind()) pbKind = getProcBindKind(*bind); - bool isCancellable = constructIsCancellable(opInst); llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); diff --git a/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir b/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir index c4b245667a1f3..6585549de7f96 100644 --- a/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir +++ b/mlir/test/Target/LLVMIR/openmp-barrier-cancel.mlir @@ -29,22 +29,24 @@ llvm.func @test() { // CHECK: %[[VAL_14:.*]] = icmp eq i32 %[[VAL_13]], 0 // CHECK: br i1 %[[VAL_14]], label %[[VAL_15:.*]], label %[[VAL_16:.*]] // CHECK: omp.par.region1.cncl: ; preds = %[[VAL_11]] -// CHECK: %[[VAL_17:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) -// CHECK: %[[VAL_18:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_17]]) -// CHECK: br label %[[VAL_19:.*]] +// CHECK: br label %[[FINI:.*]] +// CHECK: .fini: +// CHECK: %[[TID:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: %[[CNCL_BARRIER:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[TID]]) +// CHECK: br label %[[EXIT_STUB:.*]] // CHECK: omp.par.region1.split: ; preds = %[[VAL_11]] // CHECK: %[[VAL_20:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: %[[VAL_21:.*]] = call i32 @__kmpc_cancel_barrier(ptr @3, i32 %[[VAL_20]]) // CHECK: %[[VAL_22:.*]] = icmp eq i32 %[[VAL_21]], 0 // CHECK: br i1 %[[VAL_22]], label %[[VAL_23:.*]], label %[[VAL_24:.*]] // CHECK: omp.par.region1.split.cncl: ; preds = %[[VAL_15]] -// CHECK: br label %[[VAL_19]] +// CHECK: br label %[[FINI]] // CHECK: omp.par.region1.split.cont: ; preds = %[[VAL_15]] // CHECK: br label %[[VAL_25:.*]] // CHECK: omp.region.cont: ; preds = %[[VAL_23]] // CHECK: br label %[[VAL_26:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_25]] -// CHECK: br label %[[VAL_19]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_26]], %[[VAL_24]], %[[VAL_16]] +// CHECK: br label %[[FINI]] +// CHECK: omp.par.exit.exitStub: // CHECK: ret void diff --git a/mlir/test/Target/LLVMIR/openmp-cancel.mlir b/mlir/test/Target/LLVMIR/openmp-cancel.mlir index 21241702ad569..035e1f546bafa 100644 --- a/mlir/test/Target/LLVMIR/openmp-cancel.mlir +++ b/mlir/test/Target/LLVMIR/openmp-cancel.mlir @@ -24,16 +24,18 @@ llvm.func @cancel_parallel() { // CHECK: %[[VAL_15:.*]] = icmp eq i32 %[[VAL_14]], 0 // CHECK: br i1 %[[VAL_15]], label %[[VAL_16:.*]], label %[[VAL_17:.*]] // CHECK: omp.par.region1.cncl: ; preds = %[[VAL_12]] +// CHECK: br label %[[VAL_20:.*]] +// CHECK: .fini: // CHECK: %[[VAL_18:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: %[[VAL_19:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_18]]) -// CHECK: br label %[[VAL_20:.*]] +// CHECK: br label %[[EXIT_STUB:.*]] // CHECK: omp.par.region1.split: ; preds = %[[VAL_12]] // CHECK: br label %[[VAL_21:.*]] // CHECK: omp.region.cont: ; preds = %[[VAL_16]] // CHECK: br label %[[VAL_22:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_21]] // CHECK: br label %[[VAL_20]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_22]], %[[VAL_17]] +// CHECK: omp.par.exit.exitStub: // CHECK: ret void llvm.func @cancel_parallel_if(%arg0 : i1) { @@ -67,18 +69,20 @@ llvm.func @cancel_parallel_if(%arg0 : i1) { // CHECK: br label %[[VAL_26:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_25]] // CHECK: br label %[[VAL_27:.*]] -// CHECK: 5: ; preds = %[[VAL_20]] +// CHECK: .fini: +// CHECK: %[[VAL_32:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) +// CHECK: %[[VAL_33:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_32]]) +// CHECK: br label %[[EXIT_STUB:.*]] +// CHECK: 6: ; preds = %[[VAL_20]] // CHECK: %[[VAL_28:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: %[[VAL_29:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_28]], i32 1) // CHECK: %[[VAL_30:.*]] = icmp eq i32 %[[VAL_29]], 0 // CHECK: br i1 %[[VAL_30]], label %[[VAL_24]], label %[[VAL_31:.*]] // CHECK: .cncl: ; preds = %[[VAL_21]] -// CHECK: %[[VAL_32:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) -// CHECK: %[[VAL_33:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_32]]) // CHECK: br label %[[VAL_27]] // CHECK: .split: ; preds = %[[VAL_21]] // CHECK: br label %[[VAL_23]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_31]], %[[VAL_26]] +// CHECK: omp.par.exit.exitStub: // CHECK: ret void llvm.func @cancel_sections_if(%cond : i1) { @@ -145,7 +149,7 @@ llvm.func @cancel_sections_if(%cond : i1) { // CHECK: omp_section_loop.inc: ; preds = %[[VAL_23]] // CHECK: %[[VAL_15]] = add nuw i32 %[[VAL_14]], 1 // CHECK: br label %[[VAL_12]] -// CHECK: omp_section_loop.exit: ; preds = %[[VAL_33]], %[[VAL_16]] +// CHECK: omp_section_loop.exit: // CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_7]]) // CHECK: %[[VAL_36:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_36]]) @@ -232,7 +236,7 @@ llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) { // CHECK: omp_loop.inc: ; preds = %[[VAL_52]] // CHECK: %[[VAL_34]] = add nuw i32 %[[VAL_33]], 1 // CHECK: br label %[[VAL_31]] -// CHECK: omp_loop.exit: ; preds = %[[VAL_50]], %[[VAL_35]] +// CHECK: omp_loop.exit: // CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_26]]) // CHECK: %[[VAL_53:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_53]]) diff --git a/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir b/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir index 5e0d3f9f7e293..eeac1344cd7fb 100644 --- a/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir +++ b/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir @@ -24,16 +24,18 @@ llvm.func @cancellation_point_parallel() { // CHECK: %[[VAL_15:.*]] = icmp eq i32 %[[VAL_14]], 0 // CHECK: br i1 %[[VAL_15]], label %[[VAL_16:.*]], label %[[VAL_17:.*]] // CHECK: omp.par.region1.cncl: ; preds = %[[VAL_12]] +// CHECK: br label %[[FINI:.*]] +// CHECK: .fini: // CHECK: %[[VAL_18:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: %[[VAL_19:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_18]]) -// CHECK: br label %[[VAL_20:.*]] +// CHECK: br label %[[EXIT_STUB:.*]] // CHECK: omp.par.region1.split: ; preds = %[[VAL_12]] // CHECK: br label %[[VAL_21:.*]] // CHECK: omp.region.cont: ; preds = %[[VAL_16]] // CHECK: br label %[[VAL_22:.*]] // CHECK: omp.par.pre_finalize: ; preds = %[[VAL_21]] -// CHECK: br label %[[VAL_20]] -// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_22]], %[[VAL_17]] +// CHECK: br label %[[FINI]] +// CHECK: omp.par.exit.exitStub: // CHECK: ret void llvm.func @cancellation_point_sections() { @@ -94,7 +96,7 @@ llvm.func @cancellation_point_sections() { // CHECK: omp_section_loop.inc: ; preds = %[[VAL_46]] // CHECK: %[[VAL_38]] = add nuw i32 %[[VAL_37]], 1 // CHECK: br label %[[VAL_35]] -// CHECK: omp_section_loop.exit: ; preds = %[[VAL_53]], %[[VAL_39]] +// CHECK: omp_section_loop.exit: // CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_30]]) // CHECK: %[[VAL_55:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_55]]) @@ -175,7 +177,7 @@ llvm.func @cancellation_point_wsloop(%lb : i32, %ub : i32, %step : i32) { // CHECK: omp_loop.inc: ; preds = %[[VAL_106]] // CHECK: %[[VAL_92]] = add nuw i32 %[[VAL_91]], 1 // CHECK: br label %[[VAL_89]] -// CHECK: omp_loop.exit: ; preds = %[[VAL_105]], %[[VAL_93]] +// CHECK: omp_loop.exit: // CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_84]]) // CHECK: %[[VAL_107:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) // CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_107]])