diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td index bf75123e85377..af6bd41cbb71d 100644 --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -439,7 +439,7 @@ def CufImplicitDeviceGlobal : def CUFAddConstructor : Pass<"cuf-add-constructor", "mlir::ModuleOp"> { let summary = "Add constructor to register CUDA Fortran allocators"; let dependentDialects = [ - "mlir::func::FuncDialect" + "cuf::CUFDialect", "mlir::func::FuncDialect" ]; } diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt index 5e1a0293e63c9..352fe4cbe09e9 100644 --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -49,6 +49,7 @@ add_flang_library(FIRTransforms HLFIRDialect MLIRAffineUtils MLIRFuncDialect + MLIRGPUDialect MLIRLLVMDialect MLIRLLVMCommonConversion MLIRMathTransforms diff --git a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp b/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp index 48620fbc58586..3db24226e7504 100644 --- a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp +++ b/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp @@ -12,6 +12,7 @@ #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "flang/Runtime/entry-names.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/SmallVector.h" @@ -23,6 +24,8 @@ namespace fir { namespace { +static constexpr llvm::StringRef cudaModName{"cuda_device_mod"}; + static constexpr llvm::StringRef cudaFortranCtorName{ "__cudaFortranConstructor"}; @@ -31,6 +34,7 @@ struct CUFAddConstructor void runOnOperation() override { mlir::ModuleOp mod = getOperation(); + mlir::SymbolTable symTab(mod); mlir::OpBuilder builder{mod.getBodyRegion()}; builder.setInsertionPointToEnd(mod.getBody()); mlir::Location loc = mod.getLoc(); @@ -48,13 +52,25 @@ struct CUFAddConstructor mod.getContext(), RTNAME_STRING(CUFRegisterAllocator)); builder.setInsertionPointToEnd(mod.getBody()); - // Create the constructor function that cal CUFRegisterAllocator. - builder.setInsertionPointToEnd(mod.getBody()); + // Create the constructor function that call CUFRegisterAllocator. auto func = builder.create(loc, cudaFortranCtorName, funcTy); func.setLinkage(mlir::LLVM::Linkage::Internal); builder.setInsertionPointToStart(func.addEntryBlock(builder)); builder.create(loc, funcTy, cufRegisterAllocatorRef); + + // Register kernels + auto gpuMod = symTab.lookup(cudaModName); + if (gpuMod) { + for (auto func : gpuMod.getOps()) { + if (func.isKernel()) { + auto kernelName = mlir::SymbolRefAttr::get( + builder.getStringAttr(cudaModName), + {mlir::SymbolRefAttr::get(builder.getContext(), func.getName())}); + builder.create(loc, kernelName); + } + } + } builder.create(loc, mlir::ValueRange{}); // Create the llvm.global_ctor with the function. diff --git a/flang/test/Fir/CUDA/cuda-register-func.fir b/flang/test/Fir/CUDA/cuda-register-func.fir index a428f68eb3bf4..277475f0883dc 100644 --- a/flang/test/Fir/CUDA/cuda-register-func.fir +++ b/flang/test/Fir/CUDA/cuda-register-func.fir @@ -1,4 +1,4 @@ -// RUN: fir-opt %s | FileCheck %s +// RUN: fir-opt --cuf-add-constructor %s | FileCheck %s module attributes {gpu.container_module} { gpu.module @cuda_device_mod { @@ -9,12 +9,8 @@ module attributes {gpu.container_module} { gpu.return } } - llvm.func internal @__cudaFortranConstructor() { - cuf.register_kernel @cuda_device_mod::@_QPsub_device1 - cuf.register_kernel @cuda_device_mod::@_QPsub_device2 - llvm.return - } } +// CHECK-LABEL: llvm.func internal @__cudaFortranConstructor() // CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1 // CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2