diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h index 139360f8bd3fc..e5eb043dc36e1 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h +++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h @@ -46,6 +46,11 @@ class OffloadingTranslationAttrTrait /// ensure type safeness. Targets are free to ignore these options. class TargetOptions { public: + using DiagnosticCallback = function_ref; + using LLVMIRCallback = + function_ref; + using ISACallback = + function_ref; /// Constructor initializing the toolkit path, the list of files to link to, /// extra command line options, the compilation target and a callback for /// obtaining the parent symbol table. The default compilation target is @@ -55,10 +60,10 @@ class TargetOptions { StringRef cmdOptions = {}, StringRef elfSection = {}, CompilationTarget compilationTarget = getDefaultCompilationTarget(), function_ref getSymbolTableCallback = {}, - function_ref initialLlvmIRCallback = {}, - function_ref linkedLlvmIRCallback = {}, - function_ref optimizedLlvmIRCallback = {}, - function_ref isaCallback = {}); + LLVMIRCallback initialLlvmIRCallback = {}, + LLVMIRCallback linkedLlvmIRCallback = {}, + LLVMIRCallback optimizedLlvmIRCallback = {}, + ISACallback isaCallback = {}); /// Returns the typeID. TypeID getTypeID() const; @@ -97,19 +102,19 @@ class TargetOptions { /// Returns the callback invoked with the initial LLVM IR for the device /// module. - function_ref getInitialLlvmIRCallback() const; + LLVMIRCallback getInitialLlvmIRCallback() const; /// Returns the callback invoked with LLVM IR for the device module /// after linking the device libraries. - function_ref getLinkedLlvmIRCallback() const; + LLVMIRCallback getLinkedLlvmIRCallback() const; /// Returns the callback invoked with LLVM IR for the device module after /// LLVM optimizations but before codegen. - function_ref getOptimizedLlvmIRCallback() const; + LLVMIRCallback getOptimizedLlvmIRCallback() const; /// Returns the callback invoked with the target ISA for the device, /// for example PTX assembly. - function_ref getISACallback() const; + ISACallback getISACallback() const; /// Returns the default compilation target: `CompilationTarget::Fatbin`. static CompilationTarget getDefaultCompilationTarget(); @@ -127,10 +132,10 @@ class TargetOptions { StringRef elfSection = {}, CompilationTarget compilationTarget = getDefaultCompilationTarget(), function_ref getSymbolTableCallback = {}, - function_ref initialLlvmIRCallback = {}, - function_ref linkedLlvmIRCallback = {}, - function_ref optimizedLlvmIRCallback = {}, - function_ref isaCallback = {}); + LLVMIRCallback initialLlvmIRCallback = {}, + LLVMIRCallback linkedLlvmIRCallback = {}, + LLVMIRCallback optimizedLlvmIRCallback = {}, + ISACallback isaCallback = {}); /// Path to the target toolkit. std::string toolkitPath; @@ -153,19 +158,19 @@ class TargetOptions { function_ref getSymbolTableCallback; /// Callback invoked with the initial LLVM IR for the device module. - function_ref initialLlvmIRCallback; + LLVMIRCallback initialLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// linking the device libraries. - function_ref linkedLlvmIRCallback; + LLVMIRCallback linkedLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// LLVM optimizations but before codegen. - function_ref optimizedLlvmIRCallback; + LLVMIRCallback optimizedLlvmIRCallback; /// Callback invoked with the target ISA for the device, /// for example PTX assembly. - function_ref isaCallback; + ISACallback isaCallback; private: TypeID typeID; diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h index 11fea6f0a4443..eb5d4f9906cb9 100644 --- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h +++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h @@ -29,13 +29,17 @@ class ModuleTranslation; /// operations being transformed must be translatable into LLVM IR. class ModuleToObject { public: - ModuleToObject( - Operation &module, StringRef triple, StringRef chip, - StringRef features = {}, int optLevel = 3, - function_ref initialLlvmIRCallback = {}, - function_ref linkedLlvmIRCallback = {}, - function_ref optimizedLlvmIRCallback = {}, - function_ref isaCallback = {}); + using DiagnosticCallback = function_ref; + using LLVMIRCallback = + function_ref; + using ISACallback = + function_ref; + ModuleToObject(Operation &module, StringRef triple, StringRef chip, + StringRef features = {}, int optLevel = 3, + LLVMIRCallback initialLlvmIRCallback = {}, + LLVMIRCallback linkedLlvmIRCallback = {}, + LLVMIRCallback optimizedLlvmIRCallback = {}, + ISACallback isaCallback = {}); virtual ~ModuleToObject(); /// Returns the operation being serialized. @@ -120,19 +124,19 @@ class ModuleToObject { int optLevel; /// Callback invoked with the initial LLVM IR for the device module. - function_ref initialLlvmIRCallback; + LLVMIRCallback initialLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// linking the device libraries. - function_ref linkedLlvmIRCallback; + LLVMIRCallback linkedLlvmIRCallback; /// Callback invoked with LLVM IR for the device module after /// LLVM optimizations but before codegen. - function_ref optimizedLlvmIRCallback; + LLVMIRCallback optimizedLlvmIRCallback; /// Callback invoked with the target ISA for the device, /// for example PTX assembly. - function_ref isaCallback; + ISACallback isaCallback; private: /// The TargetMachine created for the given Triple, if available. diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 6c6d8d2bad55d..240822d1530ed 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -2652,10 +2652,8 @@ TargetOptions::TargetOptions( StringRef cmdOptions, StringRef elfSection, CompilationTarget compilationTarget, function_ref getSymbolTableCallback, - function_ref initialLlvmIRCallback, - function_ref linkedLlvmIRCallback, - function_ref optimizedLlvmIRCallback, - function_ref isaCallback) + LLVMIRCallback initialLlvmIRCallback, LLVMIRCallback linkedLlvmIRCallback, + LLVMIRCallback optimizedLlvmIRCallback, ISACallback isaCallback) : TargetOptions(TypeID::get(), toolkitPath, librariesToLink, cmdOptions, elfSection, compilationTarget, getSymbolTableCallback, initialLlvmIRCallback, @@ -2667,10 +2665,8 @@ TargetOptions::TargetOptions( StringRef cmdOptions, StringRef elfSection, CompilationTarget compilationTarget, function_ref getSymbolTableCallback, - function_ref initialLlvmIRCallback, - function_ref linkedLlvmIRCallback, - function_ref optimizedLlvmIRCallback, - function_ref isaCallback) + LLVMIRCallback initialLlvmIRCallback, LLVMIRCallback linkedLlvmIRCallback, + LLVMIRCallback optimizedLlvmIRCallback, ISACallback isaCallback) : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink), cmdOptions(cmdOptions.str()), elfSection(elfSection.str()), compilationTarget(compilationTarget), @@ -2696,22 +2692,20 @@ SymbolTable *TargetOptions::getSymbolTable() const { return getSymbolTableCallback ? getSymbolTableCallback() : nullptr; } -function_ref -TargetOptions::getInitialLlvmIRCallback() const { +TargetOptions::LLVMIRCallback TargetOptions::getInitialLlvmIRCallback() const { return initialLlvmIRCallback; } -function_ref -TargetOptions::getLinkedLlvmIRCallback() const { +TargetOptions::LLVMIRCallback TargetOptions::getLinkedLlvmIRCallback() const { return linkedLlvmIRCallback; } -function_ref +TargetOptions::LLVMIRCallback TargetOptions::getOptimizedLlvmIRCallback() const { return optimizedLlvmIRCallback; } -function_ref TargetOptions::getISACallback() const { +TargetOptions::ISACallback TargetOptions::getISACallback() const { return isaCallback; } diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp index 4098ccc548dc1..6e50c6c735662 100644 --- a/mlir/lib/Target/LLVM/ModuleToObject.cpp +++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp @@ -34,12 +34,12 @@ using namespace mlir; using namespace mlir::LLVM; -ModuleToObject::ModuleToObject( - Operation &module, StringRef triple, StringRef chip, StringRef features, - int optLevel, function_ref initialLlvmIRCallback, - function_ref linkedLlvmIRCallback, - function_ref optimizedLlvmIRCallback, - function_ref isaCallback) +ModuleToObject::ModuleToObject(Operation &module, StringRef triple, + StringRef chip, StringRef features, int optLevel, + LLVMIRCallback initialLlvmIRCallback, + LLVMIRCallback linkedLlvmIRCallback, + LLVMIRCallback optimizedLlvmIRCallback, + ISACallback isaCallback) : module(module), triple(triple), chip(chip), features(features), optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback), linkedLlvmIRCallback(linkedLlvmIRCallback), @@ -254,8 +254,13 @@ std::optional> ModuleToObject::run() { } setDataLayoutAndTriple(*llvmModule); + auto diagnosticCallback = [&]() -> InFlightDiagnostic { + return getOperation().emitError(); + }; + if (initialLlvmIRCallback) - initialLlvmIRCallback(*llvmModule); + if (failed(initialLlvmIRCallback(*llvmModule, diagnosticCallback))) + return std::nullopt; // Link bitcode files. handleModulePreLink(*llvmModule); @@ -270,14 +275,16 @@ std::optional> ModuleToObject::run() { } if (linkedLlvmIRCallback) - linkedLlvmIRCallback(*llvmModule); + if (failed(linkedLlvmIRCallback(*llvmModule, diagnosticCallback))) + return std::nullopt; // Optimize the module. if (failed(optimizeModule(*llvmModule, optLevel))) return std::nullopt; if (optimizedLlvmIRCallback) - optimizedLlvmIRCallback(*llvmModule); + if (failed(optimizedLlvmIRCallback(*llvmModule, diagnosticCallback))) + return std::nullopt; // Return the serialized object. return moduleToObject(*llvmModule); diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp index 8760ea8588e2c..7d52957cdf6ac 100644 --- a/mlir/lib/Target/LLVM/NVVM/Target.cpp +++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp @@ -707,8 +707,13 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) { return std::nullopt; } + auto diagnosticCallback = [&]() -> InFlightDiagnostic { + return getOperation().emitError(); + }; + if (isaCallback) - isaCallback(serializedISA.value()); + if (failed(isaCallback(serializedISA.value(), diagnosticCallback))) + return std::nullopt; #define DEBUG_TYPE "serialize-to-isa" LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n" diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp index af0af89c7d07e..e9987f0bcf13c 100644 --- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp +++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp @@ -176,26 +176,40 @@ TEST_F(MLIRTargetLLVMNVVM, ASSERT_TRUE(!!serializer); std::string initialLLVMIR; - auto initialCallback = [&initialLLVMIR](llvm::Module &module) { + auto initialCallback = + [&initialLLVMIR]( + llvm::Module &module, + gpu::TargetOptions::DiagnosticCallback) -> LogicalResult { llvm::raw_string_ostream ros(initialLLVMIR); module.print(ros, nullptr); + return success(); }; std::string linkedLLVMIR; - auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { + auto linkedCallback = + [&linkedLLVMIR](llvm::Module &module, + gpu::TargetOptions::DiagnosticCallback) -> LogicalResult { llvm::raw_string_ostream ros(linkedLLVMIR); module.print(ros, nullptr); + return success(); }; std::string optimizedLLVMIR; - auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) { + auto optimizedCallback = + [&optimizedLLVMIR]( + llvm::Module &module, + gpu::TargetOptions::DiagnosticCallback) -> LogicalResult { llvm::raw_string_ostream ros(optimizedLLVMIR); module.print(ros, nullptr); + return success(); }; std::string isaResult; - auto isaCallback = [&isaResult](llvm::StringRef isa) { + auto isaCallback = + [&isaResult](llvm::StringRef isa, + gpu::TargetOptions::DiagnosticCallback) -> LogicalResult { isaResult = isa.str(); + return success(); }; gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly, @@ -220,6 +234,36 @@ TEST_F(MLIRTargetLLVMNVVM, } } +// Test callback functions failure with ISA. +TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(CallbackFailedWithISA)) { + MLIRContext context(registry); + + OwningOpRef module = + parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + + NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context); + + auto serializer = dyn_cast(target); + ASSERT_TRUE(!!serializer); + + auto isaCallback = + [](llvm::StringRef /*isa*/, + gpu::TargetOptions::DiagnosticCallback diag) -> LogicalResult { + return diag() << "test"; + }; + + gpu::TargetOptions options({}, {}, {}, {}, gpu::CompilationTarget::Assembly, + {}, {}, {}, {}, isaCallback); + + for (auto gpuModule : (*module).getBody()->getOps()) { + std::optional> object = + serializer.serializeToObject(gpuModule, options); + + ASSERT_TRUE(object == std::nullopt); + } +} + // Test linking LLVM IR from a resource attribute. TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) { MLIRContext context(registry); @@ -261,9 +305,13 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(LinkedLLVMIRResource)) { // Hook to intercept the LLVM IR after linking external libs. std::string linkedLLVMIR; - auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { + auto linkedCallback = + [&linkedLLVMIR]( + llvm::Module &module, + gpu::TargetOptions::DiagnosticCallback /*diag*/) -> LogicalResult { llvm::raw_string_ostream ros(linkedLLVMIR); module.print(ros, nullptr); + return success(); }; // Store the bitcode as a DenseI8ArrayAttr. diff --git a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp index 3c880edee4ffc..4726bf8169515 100644 --- a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp +++ b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp @@ -168,9 +168,13 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithInitialLLVMIR)) { auto targetAttr = dyn_cast(target); std::string initialLLVMIR; - auto initialCallback = [&initialLLVMIR](llvm::Module &module) { + auto initialCallback = + [&initialLLVMIR]( + llvm::Module &module, + gpu::TargetOptions::DiagnosticCallback) -> LogicalResult { llvm::raw_string_ostream ros(initialLLVMIR); module.print(ros, nullptr); + return success(); }; gpu::TargetOptions opts( @@ -196,9 +200,12 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithLinkedLLVMIR)) { auto targetAttr = dyn_cast(target); std::string linkedLLVMIR; - auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) { + auto linkedCallback = + [&linkedLLVMIR](llvm::Module &module, + gpu::TargetOptions::DiagnosticCallback) -> LogicalResult { llvm::raw_string_ostream ros(linkedLLVMIR); module.print(ros, nullptr); + return success(); }; gpu::TargetOptions opts( @@ -225,9 +232,13 @@ TEST_F(MLIRTargetLLVM, auto targetAttr = dyn_cast(target); std::string optimizedLLVMIR; - auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) { + auto optimizedCallback = + [&optimizedLLVMIR]( + llvm::Module &module, + gpu::TargetOptions::DiagnosticCallback) -> LogicalResult { llvm::raw_string_ostream ros(optimizedLLVMIR); module.print(ros, nullptr); + return success(); }; gpu::TargetOptions opts( @@ -240,3 +251,81 @@ TEST_F(MLIRTargetLLVM, ASSERT_TRUE(!serializedBinary->empty()); ASSERT_TRUE(!optimizedLLVMIR.empty()); } + +// Test callback function failure with initial LLVM IR +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithInitialLLVMIR)) { + MLIRContext context(registry); + + OwningOpRef module = + parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast(target); + + auto initialCallback = + [](llvm::Module & /*module*/, + gpu::TargetOptions::DiagnosticCallback diag) -> LogicalResult { + return diag() << "test"; + }; + + gpu::TargetOptions opts( + {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), + {}, initialCallback); + std::optional> serializedBinary = + targetAttr.serializeToObject(*module, opts); + + ASSERT_TRUE(serializedBinary == std::nullopt); +} + +// Test callback function failure with linked LLVM IR +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithLinkedLLVMIR)) { + MLIRContext context(registry); + + OwningOpRef module = + parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast(target); + + auto linkedCallback = + [](llvm::Module & /*module*/, + gpu::TargetOptions::DiagnosticCallback diag) -> LogicalResult { + return diag() << "test"; + }; + + gpu::TargetOptions opts( + {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), + {}, {}, linkedCallback); + std::optional> serializedBinary = + targetAttr.serializeToObject(*module, opts); + + ASSERT_TRUE(serializedBinary == std::nullopt); +} + +// Test callback function failure with optimized LLVM IR +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackFailedWithOptimizedLLVMIR)) { + MLIRContext context(registry); + + OwningOpRef module = + parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast(target); + + auto optimizedCallback = + [](llvm::Module & /*module*/, + gpu::TargetOptions::DiagnosticCallback diag) -> LogicalResult { + return diag() << "test"; + }; + + gpu::TargetOptions opts( + {}, {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), + {}, {}, {}, optimizedCallback); + std::optional> serializedBinary = + targetAttr.serializeToObject(*module, opts); + + ASSERT_TRUE(serializedBinary == std::nullopt); +}