diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index eae79d2a74867..ba5946415b6f9 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -1215,20 +1215,23 @@ struct EmboxCharOpConversion : public FIROpConversion { } // namespace /// Return the LLVMFuncOp corresponding to the standard malloc call. -static mlir::LLVM::LLVMFuncOp +static mlir::SymbolRefAttr getMalloc(fir::AllocMemOp op, mlir::ConversionPatternRewriter &rewriter) { + static constexpr char mallocName[] = "malloc"; auto module = op->getParentOfType(); - if (mlir::LLVM::LLVMFuncOp mallocFunc = - module.lookupSymbol("malloc")) - return mallocFunc; + if (auto mallocFunc = module.lookupSymbol(mallocName)) + return mlir::SymbolRefAttr::get(mallocFunc); + if (auto userMalloc = module.lookupSymbol(mallocName)) + return mlir::SymbolRefAttr::get(userMalloc); mlir::OpBuilder moduleBuilder( op->getParentOfType().getBodyRegion()); auto indexType = mlir::IntegerType::get(op.getContext(), 64); - return moduleBuilder.create( - rewriter.getUnknownLoc(), "malloc", + auto mallocDecl = moduleBuilder.create( + op.getLoc(), mallocName, mlir::LLVM::LLVMFunctionType::get(getLlvmPtrType(op.getContext()), indexType, /*isVarArg=*/false)); + return mlir::SymbolRefAttr::get(mallocDecl); } /// Helper function for generating the LLVM IR that computes the distance @@ -1276,7 +1279,6 @@ struct AllocMemOpConversion : public FIROpConversion { matchAndRewrite(fir::AllocMemOp heap, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Type heapTy = heap.getType(); - mlir::LLVM::LLVMFuncOp mallocFunc = getMalloc(heap, rewriter); mlir::Location loc = heap.getLoc(); auto ity = lowerTy().indexType(); mlir::Type dataTy = fir::unwrapRefType(heapTy); @@ -1289,7 +1291,7 @@ struct AllocMemOpConversion : public FIROpConversion { for (mlir::Value opnd : adaptor.getOperands()) size = rewriter.create( loc, ity, size, integerCast(loc, rewriter, ity, opnd)); - heap->setAttr("callee", mlir::SymbolRefAttr::get(mallocFunc)); + heap->setAttr("callee", getMalloc(heap, rewriter)); rewriter.replaceOpWithNewOp( heap, ::getLlvmPtrType(heap.getContext()), size, heap->getAttrs()); return mlir::success(); @@ -1307,19 +1309,25 @@ struct AllocMemOpConversion : public FIROpConversion { } // namespace /// Return the LLVMFuncOp corresponding to the standard free call. -static mlir::LLVM::LLVMFuncOp -getFree(fir::FreeMemOp op, mlir::ConversionPatternRewriter &rewriter) { +static mlir::SymbolRefAttr getFree(fir::FreeMemOp op, + mlir::ConversionPatternRewriter &rewriter) { + static constexpr char freeName[] = "free"; auto module = op->getParentOfType(); - if (mlir::LLVM::LLVMFuncOp freeFunc = - module.lookupSymbol("free")) - return freeFunc; + // Check if free already defined in the module. + if (auto freeFunc = module.lookupSymbol(freeName)) + return mlir::SymbolRefAttr::get(freeFunc); + if (auto freeDefinedByUser = + module.lookupSymbol(freeName)) + return mlir::SymbolRefAttr::get(freeDefinedByUser); + // Create llvm declaration for free. mlir::OpBuilder moduleBuilder(module.getBodyRegion()); auto voidType = mlir::LLVM::LLVMVoidType::get(op.getContext()); - return moduleBuilder.create( - rewriter.getUnknownLoc(), "free", + auto freeDecl = moduleBuilder.create( + rewriter.getUnknownLoc(), freeName, mlir::LLVM::LLVMFunctionType::get(voidType, getLlvmPtrType(op.getContext()), /*isVarArg=*/false)); + return mlir::SymbolRefAttr::get(freeDecl); } static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) { @@ -1339,9 +1347,8 @@ struct FreeMemOpConversion : public FIROpConversion { mlir::LogicalResult matchAndRewrite(fir::FreeMemOp freemem, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - mlir::LLVM::LLVMFuncOp freeFunc = getFree(freemem, rewriter); mlir::Location loc = freemem.getLoc(); - freemem->setAttr("callee", mlir::SymbolRefAttr::get(freeFunc)); + freemem->setAttr("callee", getFree(freemem, rewriter)); rewriter.create(loc, mlir::TypeRange{}, mlir::ValueRange{adaptor.getHeapref()}, freemem->getAttrs()); diff --git a/flang/test/Fir/already-defined-free.fir b/flang/test/Fir/already-defined-free.fir new file mode 100644 index 0000000000000..1a6aa7b5b5404 --- /dev/null +++ b/flang/test/Fir/already-defined-free.fir @@ -0,0 +1,22 @@ +// Test that FIR codegen handles cases when free and malloc have +// already been defined in FIR (either by the user in Fortran via +// BIND(C) or by some FIR pass in between). +// RUN: fir-opt --fir-to-llvm-ir %s | FileCheck %s + + +func.func @already_declared_free_malloc() { + %c4 = arith.constant 4 : index + %0 = fir.call @malloc(%c4) : (index) -> !fir.heap + fir.call @free(%0) : (!fir.heap) -> () + %1 = fir.allocmem i32 + fir.freemem %1 : !fir.heap + return +} + +// CHECK: llvm.call @malloc(%{{.*}}) +// CHECK: llvm.call @free(%{{.*}}) +// CHECK: llvm.call @malloc(%{{.*}}) +// CHECK: llvm.call @free(%{{.*}}) + +func.func private @free(!fir.heap) +func.func private @malloc(index) -> !fir.heap