diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 45b11c818245e..7d2fe869322f3 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -325,6 +325,26 @@ static void collectIteratorIVs( // ClauseProcessor unique clauses //===----------------------------------------------------------------------===// +bool ClauseProcessor::processAlign(mlir::omp::AlignClauseOps &result) const { + if (auto *clause = findUniqueClause()) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + const std::optional align = evaluate::ToInt64(clause->v); + result.align = firOpBuilder.getI64IntegerAttr(*align); + return true; + } + return false; +} + +bool ClauseProcessor::processAllocator( + lower::StatementContext &stmtCtx, + mlir::omp::AllocatorClauseOps &result) const { + if (auto *clause = findUniqueClause()) { + result.allocator = fir::getBase(converter.genExprValue(clause->v, stmtCtx)); + return true; + } + return false; +} + bool ClauseProcessor::processBare(mlir::omp::BareClauseOps &result) const { return markClauseOccurrence(result.bare); } diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index f343ee8ff4332..29b5c29b8e33a 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -57,6 +57,9 @@ class ClauseProcessor { : converter(converter), semaCtx(semaCtx), clauses(clauses) {} // 'Unique' clauses: They can appear at most once in the clause list. + bool processAlign(mlir::omp::AlignClauseOps &result) const; + bool processAllocator(lower::StatementContext &stmtCtx, + mlir::omp::AllocatorClauseOps &result) const; bool processBare(mlir::omp::BareClauseOps &result) const; bool processBind(mlir::omp::BindClauseOps &result) const; bool processCancelDirectiveName( diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index a24f137386235..5cdb54a07fecd 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -1510,6 +1510,21 @@ static OpTy genWrapperOp(lower::AbstractConverter &converter, // Code generation functions for clauses //===----------------------------------------------------------------------===// +static void genAllocateClauses(lower::AbstractConverter &converter, + semantics::SemanticsContext &semaCtx, + lower::StatementContext &stmtCtx, + const ObjectList &objects, + const List &clauses, mlir::Location loc, + llvm::SmallVectorImpl &operandRange, + mlir::omp::AllocateDirOperands &clauseOps) { + if (!objects.empty()) + genObjectList(objects, converter, operandRange); + + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processAlign(clauseOps); + cp.processAllocator(stmtCtx, clauseOps); +} + static void genCancelClauses(lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, const List &clauses, mlir::Location loc, @@ -1930,6 +1945,30 @@ static void genWsloopClauses( //===----------------------------------------------------------------------===// // Code generation functions for leaf constructs //===----------------------------------------------------------------------===// +static mlir::omp::AllocateDirOp genAllocateDirOp( + lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, + lower::StatementContext &stmtCtx, lower::pft::Evaluation &eval, + mlir::Location loc, const ObjectList &objects, const ConstructQueue &queue, + ConstructQueue::const_iterator item) { + llvm::SmallVector operandRange; + mlir::omp::AllocateDirOperands clauseOps; + genAllocateClauses(converter, semaCtx, stmtCtx, objects, item->clauses, loc, + operandRange, clauseOps); + + auto allocDirOp = mlir::omp::AllocateDirOp::create( + converter.getFirOpBuilder(), loc, operandRange, clauseOps.align, + clauseOps.allocator); + + // Register a cleanup at the Fortran scope exit. + fir::FirOpBuilder *builder = &converter.getFirOpBuilder(); + mlir::Value allocator = clauseOps.allocator; + converter.getFctCtx().attachCleanup([builder, loc, operandRange, + allocator]() { + mlir::omp::AllocateFreeOp::create(*builder, loc, operandRange, allocator); + }); + + return allocDirOp; +} static mlir::omp::BarrierOp genBarrierOp(lower::AbstractConverter &converter, lower::SymMap &symTable, @@ -3844,8 +3883,18 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, const parser::OmpAllocateDirective &allocate) { - if (!semaCtx.langOptions().OpenMPSimd) - TODO(converter.getCurrentLocation(), "OmpAllocateDirective"); + lower::StatementContext stmtCtx; + ObjectList objects = makeObjects((allocate.BeginDir().Arguments()), semaCtx); + const auto &clauseList = (allocate.BeginDir().Clauses()); + List clauses = makeClauses(clauseList, semaCtx); + mlir::Location loc = converter.genLocation(allocate.source); + + ConstructQueue queue{buildConstructQueue( + converter.getFirOpBuilder().getModule(), semaCtx, eval, allocate.source, + llvm::omp::Directive::OMPD_allocate, clauses)}; + + genAllocateDirOp(converter, semaCtx, stmtCtx, eval, loc, objects, queue, + queue.begin()); } static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, diff --git a/flang/test/Lower/OpenMP/Todo/omp-declarative-allocate-align.f90 b/flang/test/Lower/OpenMP/Todo/omp-declarative-allocate-align.f90 deleted file mode 100644 index fec146ac70313..0000000000000 --- a/flang/test/Lower/OpenMP/Todo/omp-declarative-allocate-align.f90 +++ /dev/null @@ -1,10 +0,0 @@ -! This test checks lowering of OpenMP allocate Directive with align clause. - -! RUN: not %flang_fc1 -emit-fir -fopenmp -fopenmp-version=51 %s 2>&1 | FileCheck %s - -program main - integer :: x - - ! CHECK: not yet implemented: OmpAllocateDirective - !$omp allocate(x) align(32) -end diff --git a/flang/test/Lower/OpenMP/Todo/omp-declarative-allocate.f90 b/flang/test/Lower/OpenMP/Todo/omp-declarative-allocate.f90 deleted file mode 100644 index 7cae8051fda77..0000000000000 --- a/flang/test/Lower/OpenMP/Todo/omp-declarative-allocate.f90 +++ /dev/null @@ -1,10 +0,0 @@ -! This test checks lowering of OpenMP allocate Directive. - -! RUN: not %flang_fc1 -emit-fir -fopenmp %s 2>&1 | FileCheck %s - -program main - integer :: x, y - - ! CHECK: not yet implemented: OmpAllocateDirective - !$omp allocate(x, y) -end diff --git a/flang/test/Lower/OpenMP/omp-declarative-allocate-align.f90 b/flang/test/Lower/OpenMP/omp-declarative-allocate-align.f90 new file mode 100644 index 0000000000000..fdcc4ac1fef20 --- /dev/null +++ b/flang/test/Lower/OpenMP/omp-declarative-allocate-align.f90 @@ -0,0 +1,51 @@ +! This test checks lowering of OpenMP allocate Directive with align and allocator +! clauses to HLFIR. Verifies code generation for: +! - align(16) only (null allocator) +! - allocator(1) only (no align) +! - align(64) allocator(6) (both clauses, array variable) +! - align(32) allocator(3) (both clauses, multiple variables) +! Each omp.allocate_dir must be paired with a matching omp.allocate_free + +! RUN: %flang_fc1 -emit-hlfir %openmp_flags -fopenmp-version=51 %s -o - 2>&1 | FileCheck %s + +program main + integer :: x, y + integer :: z(10) + character c + real :: r + complex :: cmplx + !$omp allocate(x) align(16) + !$omp allocate(y) allocator(1) + !$omp allocate(z) align(64) allocator(6) + !$omp allocate(c, r, cmplx) align(32) allocator(3) + x = 1 + y = 2 + z = x + y + print *, "z : ", z +end program + +! CHECK: %[[C1_IDX:.*]] = arith.constant 1 : index +! CHECK: %[[C_ALLOC:.*]] = fir.alloca !fir.char<1> {bindc_name = "c", uniq_name = "_QFEc"} +! CHECK: %[[C_DECL:.*]]:2 = hlfir.declare %[[C_ALLOC]] typeparams %[[C1_IDX]] {uniq_name = "_QFEc"} : (!fir.ref>, index) -> (!fir.ref>, !fir.ref>) +! CHECK: %[[CMPLX_ALLOC:.*]] = fir.alloca complex {bindc_name = "cmplx", uniq_name = "_QFEcmplx"} +! CHECK: %[[CMPLX_DECL:.*]]:2 = hlfir.declare %[[CMPLX_ALLOC]] {uniq_name = "_QFEcmplx"} : (!fir.ref>) -> (!fir.ref>, !fir.ref>) +! CHECK: %[[R_ALLOC:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFEr"} +! CHECK: %[[R_DECL:.*]]:2 = hlfir.declare %[[R_ALLOC]] {uniq_name = "_QFEr"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: %[[X_ALLOC:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"} +! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X_ALLOC]] {uniq_name = "_QFEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: %[[Y_ALLOC:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"} +! CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_ALLOC]] {uniq_name = "_QFEy"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: %[[Z_REF:.*]] = fir.address_of(@_QFEz) : !fir.ref> +! CHECK: %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z_REF]]({{.*}}) {uniq_name = "_QFEz"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) +! CHECK: omp.allocate_dir(%[[X_DECL]]#0 : !fir.ref) align(16) +! CHECK: %[[ALLOC1:.*]] = arith.constant 1 : i32 +! CHECK: omp.allocate_dir(%[[Y_DECL]]#0 : !fir.ref) allocator(%[[ALLOC1]] : i32) +! CHECK: %[[ALLOC6:.*]] = arith.constant 6 : i32 +! CHECK: omp.allocate_dir(%[[Z_DECL]]#0 : !fir.ref>) align(64) allocator(%[[ALLOC6]] : i32) +! CHECK: %[[ALLOC3:.*]] = arith.constant 3 : i32 +! CHECK: omp.allocate_dir(%[[C_DECL]]#0, %[[R_DECL]]#0, %[[CMPLX_DECL]]#0 : !fir.ref>, !fir.ref, !fir.ref>) align(32) allocator(%[[ALLOC3]] : i32) +! CHECK: omp.allocate_free(%[[C_DECL]]#0, %[[R_DECL]]#0, %[[CMPLX_DECL]]#0 : !fir.ref>, !fir.ref, !fir.ref>) allocator(%[[ALLOC3]] : i32) +! CHECK: omp.allocate_free(%[[Z_DECL]]#0 : !fir.ref>) allocator(%[[ALLOC6]] : i32) +! CHECK: omp.allocate_free(%[[Y_DECL]]#0 : !fir.ref) allocator(%[[ALLOC1]] : i32) +! CHECK: omp.allocate_free(%[[X_DECL]]#0 : !fir.ref) +! CHECK: return diff --git a/flang/test/Lower/OpenMP/omp-declarative-allocate.f90 b/flang/test/Lower/OpenMP/omp-declarative-allocate.f90 new file mode 100644 index 0000000000000..77f211ccf0aeb --- /dev/null +++ b/flang/test/Lower/OpenMP/omp-declarative-allocate.f90 @@ -0,0 +1,18 @@ +! This test checks lowering of OpenMP allocate Directive to HLFIR. +! Verifies code generation for default (no align, null allocator) case. +! omp.allocate_free must be emitted at the exit (before return). + +! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s + +program main + integer :: x, y + !$omp allocate(x, y) +end program + +! CHECK: %[[X_ALLOC:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"} +! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X_ALLOC]] {uniq_name = "_QFEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: %[[Y_ALLOC:.*]] = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"} +! CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_ALLOC]] {uniq_name = "_QFEy"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: omp.allocate_dir(%[[X_DECL]]#0, %[[Y_DECL]]#0 : !fir.ref, !fir.ref) +! CHECK: omp.allocate_free(%[[X_DECL]]#0, %[[Y_DECL]]#0 : !fir.ref, !fir.ref) +! CHECK: return diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 19a8a53556a73..593e7497f4442 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -3201,7 +3201,7 @@ class OpenMPIRBuilder { llvm::IntegerType *IntPtrTy, bool BranchtoEnd = true); - /// Create a runtime call for kmpc_Alloc + /// Create a runtime call for kmpc_alloc /// /// \param Loc The insert and source location description. /// \param Size Size of allocated memory space @@ -3212,6 +3212,20 @@ class OpenMPIRBuilder { LLVM_ABI CallInst *createOMPAlloc(const LocationDescription &Loc, Value *Size, Value *Allocator, std::string Name = ""); + /// Create a runtime call for kmpc_align_alloc + /// + /// \param Loc The insert and source location description. + /// \param Align Align value + /// \param Size Size of allocated memory space + /// \param Allocator Allocator information instruction + /// \param Name Name of call Instruction for OMP_Align_Alloc + /// + /// \returns CallInst to the OMP_Align_Alloc call + LLVM_ABI CallInst *createOMPAlignedAlloc(const LocationDescription &Loc, + Value *Align, Value *Size, + Value *Allocator, + std::string Name = ""); + /// Create a runtime call for kmpc_free /// /// \param Loc The insert and source location description. diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 04431e22483d9..6a4ad54f353ba 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -7671,7 +7671,8 @@ CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc, Value *Size, Value *Allocator, std::string Name) { IRBuilder<>::InsertPointGuard IPG(Builder); - updateToLocation(Loc); + if (!updateToLocation(Loc)) + return nullptr; uint32_t SrcLocStrSize; Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize); @@ -7684,11 +7685,31 @@ CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc, return createRuntimeFunctionCall(Fn, Args, Name); } +CallInst *OpenMPIRBuilder::createOMPAlignedAlloc(const LocationDescription &Loc, + Value *Align, Value *Size, + Value *Allocator, + std::string Name) { + IRBuilder<>::InsertPointGuard IPG(Builder); + if (!updateToLocation(Loc)) + return nullptr; + + uint32_t SrcLocStrSize; + Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize); + Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize); + Value *ThreadId = getOrCreateThreadID(Ident); + Value *Args[] = {ThreadId, Align, Size, Allocator}; + + Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_aligned_alloc); + + return Builder.CreateCall(Fn, Args, Name); +} + CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc, Value *Addr, Value *Allocator, std::string Name) { IRBuilder<>::InsertPointGuard IPG(Builder); - updateToLocation(Loc); + if (!updateToLocation(Loc)) + return nullptr; uint32_t SrcLocStrSize; Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize); diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index f24efd0d4fc42..13a1fc3bd08bc 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -146,11 +146,11 @@ class OpenMP_AllocatorClauseSkip< extraClassDeclaration> { let arguments = (ins - Optional:$allocator + Optional:$allocator ); let optAssemblyFormat = [{ - `allocator` `(` $allocator `)` + `allocator` `(` $allocator `:` type($allocator) `)` }]; let description = [{ diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 669dd3cd1544a..1931c91080644 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -2219,6 +2219,26 @@ def AllocateDirOp : OpenMP_Op<"allocate_dir", [AttrSizedOperandSegments], clause let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// AllocateFreeOp +//===----------------------------------------------------------------------===// + +def AllocateFreeOp : OpenMP_Op<"allocate_free", [AttrSizedOperandSegments], + clauses = [OpenMP_AllocatorClause]> { + let summary = "free-op paired with allocate directive"; + let description = [{ + At the end of the scope each list item allocated using allocate directive + should be deallocated(using this free operation). + }] # clausesDescription; + + let arguments = !con((ins Variadic:$varList), + clausesArgs); + + let assemblyFormat = " `(` $varList `:` type($varList) `)` oilist(" # + clausesOptAssemblyFormat # + ") attr-dict "; +} + //===----------------------------------------------------------------------===// // TargetAllocMemOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationDialectInterface.td b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationDialectInterface.td index 6d8c7174bd2e3..01c1b3a3cfaa3 100644 --- a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationDialectInterface.td +++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationDialectInterface.td @@ -55,7 +55,7 @@ def LLVMTranslationDialectInterface : DialectInterface<"LLVMTranslationDialectIn [{ return ::llvm::success(); }] - > + >, ]; } diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 8806daee06000..227e6d205ace6 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -4814,6 +4814,26 @@ static Operation *getGlobalOpFromValue(Value value) { return nullptr; } +static Value getBaseValueForTypeLookup(Value value) { + while (Operation *op = value.getDefiningOp()) { + if (auto addrCast = dyn_cast_if_present(op)) + value = addrCast.getOperand(); + // Traces through hlfir.declare, fir.declare to reach the base address and + // use for type lookup. + else if (op->getName().getIdentifier() && + (op->getName().getIdentifier().str() == "hlfir.declare" || + op->getName().getIdentifier().str() == "fir.declare")) { + if (op->getNumOperands() > 0) + value = op->getOperand(0); + else + break; + } else { + break; + } + } + return value; +} + static llvm::SmallString<64> getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, llvm::OpenMPIRBuilder &ompBuilder) { @@ -7486,6 +7506,26 @@ class OpenMPDialectLLVMIRTranslationInterface amendOperation(Operation *op, ArrayRef instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const final; + + /// Records the LLVM alloc pointer produced for an OMP ALLOCATE variable so + /// that the paired omp.allocate_free op can generate the matching + /// __kmpc_free call. + void registerAllocatedPtr(Value var, llvm::Value *ptr) const { + ompAllocatedPtrs[var] = ptr; + } + + /// Returns the LLVM alloc pointer previously registered for var, or + /// nullptr if no allocation was recorded. + llvm::Value *lookupAllocatedPtr(Value var) const { + auto it = ompAllocatedPtrs.find(var); + return it != ompAllocatedPtrs.end() ? it->second : nullptr; + } + +private: + /// Maps each MLIR variable value that appeared in an omp.allocate_dir op to + /// the LLVM pointer returned by the corresponding __kmpc_alloc call. The + /// paired omp.allocate_free op looks up these pointers to emit __kmpc_free. + mutable DenseMap ompAllocatedPtrs; }; } // namespace @@ -7663,6 +7703,121 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, return success(); } +static LogicalResult +convertAllocateDirOp(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + const OpenMPDialectLLVMIRTranslationInterface &ompIface) { + auto allocateDirOp = cast(opInst); + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); + llvm::DataLayout dataLayout = llvmModule->getDataLayout(); + SmallVector vars = allocateDirOp.getVarList(); + std::optional alignAttr = allocateDirOp.getAlign(); + + llvm::Value *allocator; + if (auto allocatorVar = allocateDirOp.getAllocator()) { + allocator = moduleTranslation.lookupValue(allocatorVar); + if (allocator->getType()->isIntegerTy()) + allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy()); + else if (allocator->getType()->isPointerTy()) + allocator = builder.CreatePointerBitCastOrAddrSpaceCast( + allocator, builder.getPtrTy()); + } else { + allocator = llvm::ConstantPointerNull::get(builder.getPtrTy()); + } + + for (Value var : vars) { + llvm::Type *llvmVarTy = moduleTranslation.convertType(var.getType()); + + // Opaque pointers lose element type. Trace to GlobalOp for type + // Falls back to llvmVarTy when not from a global. + llvm::Type *typeToInspect = llvmVarTy; + if (llvmVarTy->isPointerTy()) { + Value baseVar = getBaseValueForTypeLookup(var); + if (Operation *globalOp = getGlobalOpFromValue(baseVar)) { + if (auto gop = dyn_cast(globalOp)) + typeToInspect = moduleTranslation.convertType(gop.getGlobalType()); + } + } + + llvm::Value *size; + if (auto arrTy = llvm::dyn_cast(typeToInspect)) { + llvm::Value *elementCount = builder.getInt64(1); + llvm::Type *currentType = arrTy; + while (auto nestedArrTy = llvm::dyn_cast(currentType)) { + elementCount = builder.CreateMul( + elementCount, builder.getInt64(nestedArrTy->getNumElements())); + currentType = nestedArrTy->getElementType(); + } + uint64_t elemSizeInBits = dataLayout.getTypeSizeInBits(currentType); + size = + builder.CreateMul(elementCount, builder.getInt64(elemSizeInBits / 8)); + } else { + size = builder.getInt64( + dataLayout.getTypeStoreSize(typeToInspect).getFixedValue()); + } + + uint64_t alignValue = + alignAttr ? alignAttr.value() + : dataLayout.getABITypeAlign(typeToInspect).value(); + llvm::Value *alignConst = builder.getInt64(alignValue); + // Align the size: ((size + align - 1) / align) * align + size = builder.CreateAdd(size, builder.getInt64(alignValue - 1), "", true); + size = builder.CreateUDiv(size, alignConst); + size = builder.CreateMul(size, alignConst, "", true); + + std::string allocName = + ompBuilder->createPlatformSpecificName({".void.addr"}); + llvm::CallInst *allocCall; + if (alignAttr.has_value()) { + allocCall = ompBuilder->createOMPAlignedAlloc( + ompLoc, builder.getInt64(alignAttr.value()), size, allocator, + allocName); + } else { + allocCall = + ompBuilder->createOMPAlloc(ompLoc, size, allocator, allocName); + } + // Record the alloc pointer keyed by the MLIR variable value. + ompIface.registerAllocatedPtr(var, allocCall); + } + + return success(); +} + +static LogicalResult +convertAllocateFreeOp(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + const OpenMPDialectLLVMIRTranslationInterface &ompIface) { + auto freeOp = cast(opInst); + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + + llvm::Value *allocator; + if (auto allocatorVar = freeOp.getAllocator()) { + allocator = moduleTranslation.lookupValue(allocatorVar); + if (allocator->getType()->isIntegerTy()) + allocator = builder.CreateIntToPtr(allocator, builder.getPtrTy()); + else if (allocator->getType()->isPointerTy()) + allocator = builder.CreatePointerBitCastOrAddrSpaceCast( + allocator, builder.getPtrTy()); + } else { + allocator = llvm::ConstantPointerNull::get(builder.getPtrTy()); + } + + // Emit __kmpc_free for each variable in reverse allocation order. + SmallVector vars = freeOp.getVarList(); + for (Value var : llvm::reverse(vars)) { + llvm::Value *allocPtr = ompIface.lookupAllocatedPtr(var); + if (!allocPtr) + return opInst.emitError("omp.allocate_free: no allocation recorded"); + ompBuilder->createOMPFree(ompLoc, allocPtr, allocator, ""); + } + + return success(); +} + static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule) { llvm::Type *ptrTy = builder.getPtrTy(0); @@ -7908,6 +8063,13 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation( .Case([&](omp::TargetFreeMemOp) { return convertTargetFreeMemOp(*op, builder, moduleTranslation); }) + .Case([&](omp::AllocateDirOp) { + return convertAllocateDirOp(*op, builder, moduleTranslation, *this); + }) + .Case([&](omp::AllocateFreeOp) { + return convertAllocateFreeOp(*op, builder, moduleTranslation, + *this); + }) .Default([&](Operation *inst) { return inst->emitError() << "not yet implemented: " << inst->getName(); diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 90db9187a56bf..7b3d2c9a0732e 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -3485,27 +3485,51 @@ func.func @omp_allocate_dir(%arg0 : memref, %arg1 : memref) -> () { // Test with one data var and allocator clause // CHECK: %[[VAL_1:.*]] = arith.constant 1 : i64 %omp_default_mem_alloc = arith.constant 1 : i64 - // CHECK: omp.allocate_dir(%[[ARG0]] : memref) allocator(%[[VAL_1:.*]]) - omp.allocate_dir (%arg0 : memref) allocator(%omp_default_mem_alloc) + // CHECK: omp.allocate_dir(%[[ARG0]] : memref) allocator(%[[VAL_1:.*]] : i64) + omp.allocate_dir (%arg0 : memref) allocator(%omp_default_mem_alloc : i64) // Test with one data var, align clause and allocator clause // CHECK: %[[VAL_2:.*]] = arith.constant 7 : i64 %omp_pteam_mem_alloc = arith.constant 7 : i64 - // CHECK: omp.allocate_dir(%[[ARG0]] : memref) align(4) allocator(%[[VAL_2:.*]]) - omp.allocate_dir (%arg0 : memref) align(4) allocator(%omp_pteam_mem_alloc) + // CHECK: omp.allocate_dir(%[[ARG0]] : memref) align(4) allocator(%[[VAL_2:.*]] : i64) + omp.allocate_dir (%arg0 : memref) align(4) allocator(%omp_pteam_mem_alloc : i64) // Test with two data vars, align clause and allocator clause // CHECK: %[[VAL_3:.*]] = arith.constant 6 : i64 %omp_cgroup_mem_alloc = arith.constant 6 : i64 - // CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref, memref) align(8) allocator(%[[VAL_3:.*]]) - omp.allocate_dir (%arg0, %arg1 : memref, memref) align(8) allocator(%omp_cgroup_mem_alloc) + // CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref, memref) align(8) allocator(%[[VAL_3:.*]] : i64) + omp.allocate_dir (%arg0, %arg1 : memref, memref) align(8) allocator(%omp_cgroup_mem_alloc : i64) // Test with one data var and user defined allocator clause // CHECK: %[[VAL_4:.*]] = arith.constant 9 : i64 %custom_allocator = arith.constant 9 : i64 %custom_mem_alloc = func.call @omp_init_allocator(%custom_allocator) : (i64) -> (i64) - // CHECK: omp.allocate_dir(%[[ARG0]] : memref) allocator(%[[VAL_5:.*]]) - omp.allocate_dir (%arg0 : memref) allocator(%custom_mem_alloc) + // CHECK: omp.allocate_dir(%[[ARG0]] : memref) allocator(%[[VAL_5:.*]] : i64) + omp.allocate_dir (%arg0 : memref) allocator(%custom_mem_alloc : i64) + + return +} + +// CHECK-LABEL: func.func @omp_allocate_free( +// CHECK-SAME: %[[ARG0:.*]]: memref, +// CHECK-SAME: %[[ARG1:.*]]: memref) { +func.func @omp_allocate_free(%arg0 : memref, %arg1 : memref) -> () { + + // Test free with no allocator + // CHECK: omp.allocate_free(%[[ARG0]] : memref) + omp.allocate_free (%arg0 : memref) + + // Test free with allocator clause + // CHECK: %[[VAL_1:.*]] = arith.constant 1 : i64 + %omp_default_mem_alloc = arith.constant 1 : i64 + // CHECK: omp.allocate_free(%[[ARG0]] : memref) allocator(%[[VAL_1:.*]] : i64) + omp.allocate_free (%arg0 : memref) allocator(%omp_default_mem_alloc : i64) + + // Test free with two variables and allocator clause + // CHECK: %[[VAL_3:.*]] = arith.constant 6 : i64 + %omp_cgroup_mem_alloc = arith.constant 6 : i64 + // CHECK: omp.allocate_free(%[[ARG0]], %[[ARG1]] : memref, memref) allocator(%[[VAL_3:.*]] : i64) + omp.allocate_free (%arg0, %arg1 : memref, memref) allocator(%omp_cgroup_mem_alloc : i64) return } diff --git a/mlir/test/Target/LLVMIR/openmp-allocate-directive.mlir b/mlir/test/Target/LLVMIR/openmp-allocate-directive.mlir new file mode 100644 index 0000000000000..d8975eb512abe --- /dev/null +++ b/mlir/test/Target/LLVMIR/openmp-allocate-directive.mlir @@ -0,0 +1,117 @@ +// Tests for translation of omp.allocate_dir / omp.allocate_free pairs to +// LLVM IR, covering all combinations of align and allocator clauses. +// The frontend is responsible for placing omp.allocate_free at the correct +// Fortran scope exit; here each function pairs the ops manually. + +// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s + +// ----- + +// CHECK-LABEL: define void @test_allocate_default +// CHECK-SAME: (ptr %[[ARG0:.*]]) { +// CHECK: %[[TID:.*]] = call i32 @__kmpc_global_thread_num( +// CHECK: %[[ALLOC:.*]] = call ptr @__kmpc_alloc(i32 %[[TID]], i64 8, ptr null) +// CHECK: %[[TID_FREE:.*]] = call i32 @__kmpc_global_thread_num( +// CHECK: call void @__kmpc_free(i32 %[[TID_FREE]], ptr %[[ALLOC]], ptr null) +// CHECK: ret void +// CHECK: } +// CHECK: declare noalias ptr @__kmpc_alloc(i32, i64, ptr) +// CHECK: declare void @__kmpc_free(i32, ptr, ptr) +llvm.func @test_allocate_default(%arg0: !llvm.ptr) { + omp.allocate_dir (%arg0 : !llvm.ptr) + omp.allocate_free (%arg0 : !llvm.ptr) + llvm.return +} + +// ----- + +// CHECK-LABEL: define void @test_allocate_align_only +// CHECK: %[[TID:.*]] = call i32 @__kmpc_global_thread_num( +// CHECK: %[[ALLOC:.*]] = call ptr @__kmpc_aligned_alloc(i32 %[[TID]], i64 16, i64 16, ptr null) +// CHECK: %[[TID_FREE:.*]] = call i32 @__kmpc_global_thread_num( +// CHECK: call void @__kmpc_free(i32 %[[TID_FREE]], ptr %[[ALLOC]], ptr null) +// CHECK: ret void +// CHECK: declare noalias ptr @__kmpc_aligned_alloc(i32, i64, i64, ptr) +llvm.func @test_allocate_align_only(%arg0: !llvm.ptr) { + omp.allocate_dir (%arg0 : !llvm.ptr) align(16) + omp.allocate_free (%arg0 : !llvm.ptr) + llvm.return +} + +// ----- + +// CHECK-LABEL: define void @test_allocate_allocator_only +// CHECK: %[[TID:.*]] = call i32 @__kmpc_global_thread_num( +// CHECK: %[[ALLOC:.*]] = call ptr @__kmpc_alloc(i32 %[[TID]], i64 8, ptr inttoptr (i32 1 to ptr)) +// CHECK: %[[TID_FREE:.*]] = call i32 @__kmpc_global_thread_num( +// CHECK: call void @__kmpc_free(i32 %[[TID_FREE]], ptr %[[ALLOC]], ptr inttoptr (i32 1 to ptr)) +// CHECK: ret void +llvm.func @test_allocate_allocator_only(%arg0: !llvm.ptr) { + %alloc1 = llvm.mlir.constant(1 : i32) : i32 + omp.allocate_dir (%arg0 : !llvm.ptr) allocator(%alloc1 : i32) + omp.allocate_free (%arg0 : !llvm.ptr) allocator(%alloc1 : i32) + llvm.return +} + +// ----- + +// CHECK-LABEL: define void @test_allocate_align_and_allocator +// CHECK: %[[TID:.*]] = call i32 @__kmpc_global_thread_num( +// CHECK: %[[ALLOC:.*]] = call ptr @__kmpc_aligned_alloc(i32 %[[TID]], i64 64, i64 64, ptr inttoptr (i32 6 to ptr)) +// CHECK: %[[TID_FREE:.*]] = call i32 @__kmpc_global_thread_num( +// CHECK: call void @__kmpc_free(i32 %[[TID_FREE]], ptr %[[ALLOC]], ptr inttoptr (i32 6 to ptr)) +// CHECK: ret void +llvm.func @test_allocate_align_and_allocator(%arg0: !llvm.ptr) { + %alloc6 = llvm.mlir.constant(6 : i32) : i32 + omp.allocate_dir (%arg0 : !llvm.ptr) align(64) allocator(%alloc6 : i32) + omp.allocate_free (%arg0 : !llvm.ptr) allocator(%alloc6 : i32) + llvm.return +} + +// ----- + +// Verifies that multiple variables each get their own __kmpc_aligned_alloc +// call, and that __kmpc_free calls are emitted in reverse allocation order. +// +// CHECK-LABEL: define void @test_allocate_multiple_vars +// CHECK: %[[TID0:.*]] = call i32 @__kmpc_global_thread_num( +// CHECK: %[[ALLOC0:.*]] = call ptr @__kmpc_aligned_alloc(i32 %[[TID0]], i64 32, i64 32, ptr inttoptr (i32 3 to ptr)) +// CHECK: %[[TID1:.*]] = call i32 @__kmpc_global_thread_num( +// CHECK: %[[ALLOC1:.*]] = call ptr @__kmpc_aligned_alloc(i32 %[[TID1]], i64 32, i64 32, ptr inttoptr (i32 3 to ptr)) +// CHECK: %[[TID2:.*]] = call i32 @__kmpc_global_thread_num( +// CHECK: %[[ALLOC2:.*]] = call ptr @__kmpc_aligned_alloc(i32 %[[TID2]], i64 32, i64 32, ptr inttoptr (i32 3 to ptr)) +// Free order is reversed relative to allocation order. +// CHECK: call void @__kmpc_free({{.*}}, ptr %[[ALLOC2]], ptr inttoptr (i32 3 to ptr)) +// CHECK: call void @__kmpc_free({{.*}}, ptr %[[ALLOC1]], ptr inttoptr (i32 3 to ptr)) +// CHECK: call void @__kmpc_free({{.*}}, ptr %[[ALLOC0]], ptr inttoptr (i32 3 to ptr)) +// CHECK: ret void +llvm.func @test_allocate_multiple_vars(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { + %alloc3 = llvm.mlir.constant(3 : i32) : i32 + omp.allocate_dir (%arg0, %arg1, %arg2 : !llvm.ptr, !llvm.ptr, !llvm.ptr) align(32) allocator(%alloc3 : i32) + omp.allocate_free (%arg0, %arg1, %arg2 : !llvm.ptr, !llvm.ptr, !llvm.ptr) allocator(%alloc3 : i32) + llvm.return +} + +// ----- + +// Verifies that array size is correctly calculated from the global's element +// type: [10 x i32] = 40 bytes, rounded up to alignment 64 => 64 bytes. +// +// CHECK-LABEL: define void @test_allocate_array_global +// CHECK: %[[TID:.*]] = call i32 @__kmpc_global_thread_num( +// CHECK: %[[ALLOC:.*]] = call ptr @__kmpc_aligned_alloc(i32 %[[TID]], i64 64, i64 64, ptr inttoptr (i32 6 to ptr)) +// CHECK: %[[TID_FREE:.*]] = call i32 @__kmpc_global_thread_num( +// CHECK: call void @__kmpc_free(i32 %[[TID_FREE]], ptr %[[ALLOC]], ptr inttoptr (i32 6 to ptr)) +// CHECK: ret void +llvm.mlir.global internal @arr_global() : !llvm.array<10 x i32> { + %0 = llvm.mlir.zero : !llvm.array<10 x i32> + llvm.return %0 : !llvm.array<10 x i32> +} + +llvm.func @test_allocate_array_global() { + %z = llvm.mlir.addressof @arr_global : !llvm.ptr + %alloc6 = llvm.mlir.constant(6 : i32) : i32 + omp.allocate_dir (%z : !llvm.ptr) align(64) allocator(%alloc6 : i32) + omp.allocate_free (%z : !llvm.ptr) allocator(%alloc6 : i32) + llvm.return +}