diff --git a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp index 381b2a29c517a..c1a6b06d6a52b 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp @@ -222,35 +222,47 @@ static mlir::Type convertObjectType(const fir::LLVMTypeConverter &converter, return converter.convertType(firType); } -// FIR Op specific conversion for TargetAllocMemOp -struct TargetAllocMemOpConversion - : public OpenMPFIROpConversion { - using OpenMPFIROpConversion::OpenMPFIROpConversion; +// FIR Op specific conversion for allocation operations +template +struct AllocMemOpConversion : public OpenMPFIROpConversion { + using OpenMPFIROpConversion::OpenMPFIROpConversion; llvm::LogicalResult - matchAndRewrite(mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor, + matchAndRewrite(T allocmemOp, + typename OpenMPFIROpConversion::OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Type heapTy = allocmemOp.getAllocatedType(); mlir::Location loc = allocmemOp.getLoc(); - auto ity = lowerTy().indexType(); + auto ity = OpenMPFIROpConversion::lowerTy().indexType(); mlir::Type dataTy = fir::unwrapRefType(heapTy); - mlir::Type llvmObjectTy = convertObjectType(lowerTy(), dataTy); + mlir::Type llvmObjectTy = + convertObjectType(OpenMPFIROpConversion::lowerTy(), dataTy); if (fir::isRecordWithTypeParameters(fir::unwrapSequenceType(dataTy))) - TODO(loc, "omp.target_allocmem codegen of derived type with length " - "parameters"); + TODO(loc, allocmemOp->getName().getStringRef() + + " codegen of derived type with length parameters"); mlir::Value size = fir::computeElementDistance( - loc, llvmObjectTy, ity, rewriter, lowerTy().getDataLayout()); + loc, llvmObjectTy, ity, rewriter, + OpenMPFIROpConversion::lowerTy().getDataLayout()); if (auto scaleSize = fir::genAllocationScaleSize( loc, allocmemOp.getInType(), ity, rewriter)) size = rewriter.create(loc, ity, size, scaleSize); - for (mlir::Value opnd : adaptor.getOperands().drop_front()) + for (mlir::Value opnd : adaptor.getTypeparams()) + size = rewriter.create( + loc, ity, size, + integerCast(OpenMPFIROpConversion::lowerTy(), loc, rewriter, ity, + opnd)); + for (mlir::Value opnd : adaptor.getShape()) size = rewriter.create( - loc, ity, size, integerCast(lowerTy(), loc, rewriter, ity, opnd)); - auto mallocTyWidth = lowerTy().getIndexTypeBitwidth(); + loc, ity, size, + integerCast(OpenMPFIROpConversion::lowerTy(), loc, rewriter, ity, + opnd)); + auto mallocTyWidth = + OpenMPFIROpConversion::lowerTy().getIndexTypeBitwidth(); auto mallocTy = mlir::IntegerType::get(rewriter.getContext(), mallocTyWidth); if (mallocTyWidth != ity.getIntOrFloatBitWidth()) - size = integerCast(lowerTy(), loc, rewriter, mallocTy, size); + size = integerCast(OpenMPFIROpConversion::lowerTy(), loc, rewriter, + mallocTy, size); rewriter.modifyOpInPlace(allocmemOp, [&]() { allocmemOp.setInType(rewriter.getI8Type()); allocmemOp.getTypeparamsMutable().clear(); @@ -265,5 +277,6 @@ void fir::populateOpenMPFIRToLLVMConversionPatterns( const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) { patterns.add(converter); patterns.add(converter); - patterns.add(converter); + patterns.add, + AllocMemOpConversion>(converter); } diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 02d61c1a3626a..d8e5f8cf5a45e 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -2950,6 +2950,17 @@ class OpenMPIRBuilder { LLVM_ABI CallInst *createOMPFree(const LocationDescription &Loc, Value *Addr, Value *Allocator, std::string Name = ""); + /// Create a runtime call for kmpc_alloc_shared. + /// + /// \param Loc The insert and source location description. + /// \param Size Size of allocated memory space. + /// \param Name Name of call Instruction. + /// + /// \returns CallInst to the kmpc_alloc_shared call. + LLVM_ABI CallInst *createOMPAllocShared(const LocationDescription &Loc, + Value *Size, + const Twine &Name = Twine("")); + /// Create a runtime call for kmpc_alloc_shared. /// /// \param Loc The insert and source location description. @@ -2961,6 +2972,18 @@ class OpenMPIRBuilder { Type *VarType, const Twine &Name = Twine("")); + /// Create a runtime call for kmpc_free_shared. + /// + /// \param Loc The insert and source location description. + /// \param Addr Value obtained from the corresponding kmpc_alloc_shared call. + /// \param Size Size of allocated memory space. + /// \param Name Name of call Instruction. + /// + /// \returns CallInst to the kmpc_free_shared call. + LLVM_ABI CallInst *createOMPFreeShared(const LocationDescription &Loc, + Value *Addr, Value *Size, + const Twine &Name = Twine("")); + /// Create a runtime call for kmpc_free_shared. /// /// \param Loc The insert and source location description. diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index bd483aa2c5e02..a18db939b5876 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -6855,32 +6855,45 @@ CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc, } CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc, - Type *VarType, + Value *Size, const Twine &Name) { IRBuilder<>::InsertPointGuard IPG(Builder); updateToLocation(Loc); - const DataLayout &DL = M.getDataLayout(); - Value *Args[] = {Builder.getInt64(DL.getTypeStoreSize(VarType))}; + Value *Args[] = {Size}; Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc_shared); CallInst *Call = Builder.CreateCall(Fn, Args, Name); - Call->addRetAttr( - Attribute::getWithAlignment(M.getContext(), DL.getPrefTypeAlign(Int64))); + Call->addRetAttr(Attribute::getWithAlignment( + M.getContext(), M.getDataLayout().getPrefTypeAlign(Int64))); return Call; } +CallInst *OpenMPIRBuilder::createOMPAllocShared(const LocationDescription &Loc, + Type *VarType, + const Twine &Name) { + return createOMPAllocShared( + Loc, Builder.getInt64(M.getDataLayout().getTypeStoreSize(VarType)), Name); +} + CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc, - Value *Addr, Type *VarType, + Value *Addr, Value *Size, const Twine &Name) { IRBuilder<>::InsertPointGuard IPG(Builder); updateToLocation(Loc); - Value *Args[] = { - Addr, Builder.getInt64(M.getDataLayout().getTypeStoreSize(VarType))}; + Value *Args[] = {Addr, Size}; Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free_shared); return Builder.CreateCall(Fn, Args, Name); } +CallInst *OpenMPIRBuilder::createOMPFreeShared(const LocationDescription &Loc, + Value *Addr, Type *VarType, + const Twine &Name) { + return createOMPFreeShared( + Loc, Addr, Builder.getInt64(M.getDataLayout().getTypeStoreSize(VarType)), + Name); +} + CallInst *OpenMPIRBuilder::createOMPInteropInit( const LocationDescription &Loc, Value *InteropVar, omp::OMPInteropType InteropType, Value *Device, Value *NumDependences, diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 8b206f58c7733..fa037c2ff9496 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -2202,6 +2202,68 @@ def TargetFreeMemOp : OpenMP_Op<"target_freemem", Arg:$heapref ); let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))"; + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// AllocSharedMemOp +//===----------------------------------------------------------------------===// + +def AllocSharedMemOp : OpenMP_Op<"alloc_shared_mem", traits = [ + AttrSizedOperandSegments + ], clauses = [ + OpenMP_HeapAllocClause + ]> { + let summary = "allocate storage on shared memory for an object of a given type"; + + let description = [{ + Allocates memory shared across threads of a team for an object of the given + type. Returns a pointer representing the allocated memory. The memory is + uninitialized after allocation. Operations must be paired with + `omp.free_shared` to avoid memory leaks. + + ```mlir + // Allocate a static 3x3 integer vector. + %ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr + // ... + omp.free_shared_mem %ptr_shared : !llvm.ptr + ``` + }] # clausesDescription; + + let results = (outs OpenMP_PointerLikeType); + let assemblyFormat = clausesAssemblyFormat # " attr-dict `:` type(results)"; +} + +//===----------------------------------------------------------------------===// +// FreeSharedMemOp +//===----------------------------------------------------------------------===// + +def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", [MemoryEffects<[MemFree]>]> { + let summary = "free shared memory"; + + let description = [{ + Deallocates shared memory that was previously allocated by an + `omp.alloc_shared_mem` operation. After this operation, the deallocated + memory is in an undefined state and should not be accessed. + It is crucial to ensure that all accesses to the memory region are completed + before `omp.alloc_shared_mem` is called to avoid undefined behavior. + + ```mlir + // Example of allocating and freeing shared memory. + %ptr_shared = omp.alloc_shared_mem vector<3x3xi32> : !llvm.ptr + // ... + omp.free_shared_mem %ptr_shared : !llvm.ptr + ``` + + The `heapref` operand represents the pointer to shared memory to be + deallocated, previously returned by `omp.alloc_shared_mem`. + }]; + + let arguments = (ins + Arg:$heapref + ); + let assemblyFormat = "$heapref attr-dict `:` type($heapref)"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index fabb1b8c173a2..3b48dce4b7989 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -4161,6 +4161,28 @@ LogicalResult AllocateDirOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// TargetFreeMemOp +//===----------------------------------------------------------------------===// + +LogicalResult TargetFreeMemOp::verify() { + return getHeapref().getDefiningOp() + ? success() + : emitOpError() << "'heapref' operand must be defined by an " + "'omp.target_allocmem' op"; +} + +//===----------------------------------------------------------------------===// +// FreeSharedMemOp +//===----------------------------------------------------------------------===// + +LogicalResult FreeSharedMemOp::verify() { + return getHeapref().getDefiningOp() + ? success() + : emitOpError() << "'heapref' operand must be defined by an " + "'omp.alloc_shared_memory' op"; +} + //===----------------------------------------------------------------------===// // WorkdistributeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 80e052105dc4c..3accca891ba9c 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -6104,11 +6104,9 @@ static bool isTargetDeviceOp(Operation *op) { // by taking it in as an operand, so we must always lower these in // some manner or result in an ICE (whether they end up in a no-op // or otherwise). - if (mlir::isa(op)) - return true; - - if (mlir::isa(op) || - mlir::isa(op)) + if (mlir::isa(op)) return true; if (auto parentFn = op->getParentOfType()) @@ -6135,6 +6133,21 @@ static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder, return func; } +static llvm::Value * +getAllocationSize(llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, Type allocatedTy, + OperandRange typeparams, OperandRange shape) { + llvm::DataLayout dataLayout = + moduleTranslation.getLLVMModule()->getDataLayout(); + llvm::Type *llvmHeapTy = moduleTranslation.convertType(allocatedTy); + llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy); + llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue()); + for (auto typeParam : typeparams) + allocSize = + builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam)); + return allocSize; +} + static LogicalResult convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { @@ -6149,14 +6162,9 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, mlir::Value deviceNum = allocMemOp.getDevice(); llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum); // Get the allocation size. - llvm::DataLayout dataLayout = llvmModule->getDataLayout(); - mlir::Type heapTy = allocMemOp.getAllocatedType(); - llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy); - llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy); - llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue()); - for (auto typeParam : allocMemOp.getTypeparams()) - allocSize = - builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam)); + llvm::Value *allocSize = getAllocationSize( + builder, moduleTranslation, allocMemOp.getAllocatedType(), + allocMemOp.getTypeparams(), allocMemOp.getShape()); // Create call to "omp_target_alloc" with the args as translated llvm values. llvm::CallInst *call = builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum}); @@ -6167,6 +6175,19 @@ convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, return success(); } +static LogicalResult +convertAllocSharedMemOp(omp::AllocSharedMemOp allocMemOp, + llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + llvm::Value *size = getAllocationSize( + builder, moduleTranslation, allocMemOp.getAllocatedType(), + allocMemOp.getTypeparams(), allocMemOp.getShape()); + moduleTranslation.mapValue(allocMemOp.getResult(), + ompBuilder->createOMPAllocShared(builder, size)); + return success(); +} + static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder, llvm::Module *llvmModule) { llvm::Type *ptrTy = builder.getPtrTy(0); @@ -6202,6 +6223,21 @@ convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, return success(); } +static LogicalResult +convertFreeSharedMemOp(omp::FreeSharedMemOp freeMemOp, + llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + auto allocMemOp = + freeMemOp.getHeapref().getDefiningOp(); + llvm::Value *size = getAllocationSize( + builder, moduleTranslation, allocMemOp.getAllocatedType(), + allocMemOp.getTypeparams(), allocMemOp.getShape()); + ompBuilder->createOMPFreeShared( + builder, moduleTranslation.lookupValue(freeMemOp.getHeapref()), size); + return success(); +} + /// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including /// OpenMP runtime calls). static LogicalResult @@ -6382,6 +6418,12 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, .Case([&](omp::TargetFreeMemOp) { return convertTargetFreeMemOp(*op, builder, moduleTranslation); }) + .Case([&](omp::AllocSharedMemOp op) { + return convertAllocSharedMemOp(op, builder, moduleTranslation); + }) + .Case([&](omp::FreeSharedMemOp op) { + return convertFreeSharedMemOp(op, builder, moduleTranslation); + }) .Default([&](Operation *inst) { return inst->emitError() << "not yet implemented: " << inst->getName(); diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 0cc4b522db466..9f28172161fa8 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -3153,3 +3153,31 @@ func.func @target_allocmem_invalid_bindc_name(%device : i32) -> () { %0 = omp.target_allocmem %device : i32, i64 {bindc_name=2} return } + +// ----- +func.func @target_freemem_invalid_ptr(%device : i32, %ptr : i64) -> () { + // expected-error @below {{op 'heapref' operand must be defined by an 'omp.target_allocmem' op}} + omp.target_freemem %device, %ptr : i32, i64 + return +} + +// ----- +func.func @alloc_shared_mem_invalid_uniq_name() -> () { + // expected-error @below {{op attribute 'uniq_name' failed to satisfy constraint: string attribute}} + %0 = omp.alloc_shared_mem i64 {uniq_name=2} + return +} + +// ----- +func.func @alloc_shared_mem_invalid_bindc_name() -> () { + // expected-error @below {{op attribute 'bindc_name' failed to satisfy constraint: string attribute}} + %0 = omp.alloc_shared_mem i64 {bindc_name=2} + return +} + +// ----- +func.func @free_shared_mem_invalid_ptr(%ptr : !llvm.ptr) -> () { + // expected-error @below {{op 'heapref' operand must be defined by an 'omp.alloc_shared_memory' op}} + omp.free_shared_mem %ptr : !llvm.ptr + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 9e7287178ff66..55e6d77857972 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -3339,9 +3339,36 @@ func.func @omp_target_allocmem(%device: i32, %x: index, %y: index, %z: i32) { } // CHECK-LABEL: func.func @omp_target_freemem( -// CHECK-SAME: %[[DEVICE:.*]]: i32, %[[PTR:.*]]: i64) { -func.func @omp_target_freemem(%device : i32, %ptr : i64) { +// CHECK-SAME: %[[DEVICE:.*]]: i32) { +func.func @omp_target_freemem(%device : i32) { + // CHECK: %[[PTR:.*]] = omp.target_allocmem + %ptr = omp.target_allocmem %device : i32, i64 // CHECK: omp.target_freemem %[[DEVICE]], %[[PTR]] : i32, i64 omp.target_freemem %device, %ptr : i32, i64 return } + +// CHECK-LABEL: func.func @omp_alloc_shared_mem( +// CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index, %[[Z:.*]]: i32) { +func.func @omp_alloc_shared_mem(%x: index, %y: index, %z: i32) { + // CHECK: %{{.*}} = omp.alloc_shared_mem i64 : !llvm.ptr + %0 = omp.alloc_shared_mem i64 : !llvm.ptr + // CHECK: %{{.*}} = omp.alloc_shared_mem vector<16x16xf32> {bindc_name = "bindc", uniq_name = "uniq"} : !llvm.ptr + %1 = omp.alloc_shared_mem vector<16x16xf32> {uniq_name="uniq", bindc_name="bindc"} : !llvm.ptr + // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32) : !llvm.ptr + %2 = omp.alloc_shared_mem !llvm.ptr(%x, %y, %z : index, index, i32) : !llvm.ptr + // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr, %[[X]], %[[Y]] : !llvm.ptr + %3 = omp.alloc_shared_mem !llvm.ptr, %x, %y : !llvm.ptr + // CHECK: %{{.*}} = omp.alloc_shared_mem !llvm.ptr(%[[X]], %[[Y]], %[[Z]] : index, index, i32), %[[X]], %[[Y]] : !llvm.ptr + %4 = omp.alloc_shared_mem !llvm.ptr(%x, %y, %z : index, index, i32), %x, %y : !llvm.ptr + return +} + +// CHECK-LABEL: func.func @omp_free_shared_mem() { +func.func @omp_free_shared_mem() { + // CHECK: %[[PTR:.*]] = omp.alloc_shared_mem + %0 = omp.alloc_shared_mem i64 : !llvm.ptr + // CHECK: omp.free_shared_mem %[[PTR]] : !llvm.ptr + omp.free_shared_mem %0 : !llvm.ptr + return +}