diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt index 4b46939aa88db7..b9980c3795b3aa 100644 --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -44,6 +44,8 @@ if ("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD) else() set(MLIR_CUDA_CONVERSIONS_ENABLED 0) endif() +# TODO: we should use a config.h file like LLVM does +add_definitions(-DMLIR_CUDA_CONVERSIONS_ENABLED=${MLIR_CUDA_CONVERSIONS_ENABLED}) set(MLIR_CUDA_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir CUDA runner") diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h index b867e0624916e6..6b28041bf98072 100644 --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -90,8 +90,10 @@ inline void registerAllPasses() { // CUDA createConvertGpuLaunchFuncToCudaCallsPass(); +#if MLIR_CUDA_CONVERSIONS_ENABLED createConvertGPUKernelToCubinPass( [](const std::string &, Location, StringRef) { return nullptr; }); +#endif createLowerGpuOpsToNVVMOpsPass(); // Linalg diff --git a/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt b/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt index a758f7b935efd2..484bc9dbd89c32 100644 --- a/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt @@ -1,7 +1,16 @@ -add_llvm_library(MLIRGPUtoCUDATransforms +set(LLVM_OPTIONAL_SOURCES ConvertKernelFuncToCubin.cpp +) + +set(SOURCES ConvertLaunchFuncToCudaCalls.cpp ) + +if (MLIR_CUDA_CONVERSIONS_ENABLED) + list(APPEND SOURCES "ConvertKernelFuncToCubin.cpp") +endif() + +add_llvm_library(MLIRGPUtoCUDATransforms ${SOURCES}) target_link_libraries(MLIRGPUtoCUDATransforms MLIRGPU MLIRLLVMIR diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp index fe571cf31548b4..140026eaf64344 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp @@ -57,9 +57,10 @@ class GpuKernelToCubinPass gpu::GPUModuleOp module = getOperation(); // Make sure the NVPTX target is initialized. - llvm::InitializeAllTargets(); - llvm::InitializeAllTargetMCs(); - llvm::InitializeAllAsmPrinters(); + LLVMInitializeNVPTXTarget(); + LLVMInitializeNVPTXTargetInfo(); + LLVMInitializeNVPTXTargetMC(); + LLVMInitializeNVPTXAsmPrinter(); auto llvmModule = translateModuleToNVVMIR(module); if (!llvmModule)