diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp index 023c65b3dd9df..586748df8d154 100644 --- a/mlir/lib/Target/LLVM/NVVM/Target.cpp +++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp @@ -722,12 +722,8 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) { #undef DEBUG_TYPE // Return PTX if the compilation target is `assembly`. - if (targetOptions.getCompilationTarget() == - gpu::CompilationTarget::Assembly) { - // Make sure to include the null terminator. - StringRef bin(serializedISA->c_str(), serializedISA->size() + 1); - return SmallVector(bin.begin(), bin.end()); - } + if (targetOptions.getCompilationTarget() == gpu::CompilationTarget::Assembly) + return SmallVector(serializedISA->begin(), serializedISA->end()); std::optional> result; moduleToObjectTimer.startTimer(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp index ade239c526af8..8d4a0bcf8adbf 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp @@ -116,8 +116,11 @@ LogicalResult SelectObjectAttrImpl::embedBinary( llvm::Module *module = moduleTranslation.getLLVMModule(); // Embed the object as a global string. + // Add null for assembly output for JIT paths that expect null-terminated + // strings. + bool addNull = (object.getFormat() == gpu::CompilationTarget::Assembly); llvm::Constant *binary = llvm::ConstantDataArray::getString( - builder.getContext(), object.getObject().getValue(), false); + builder.getContext(), object.getObject().getValue(), addNull); llvm::GlobalVariable *serializedObj = new llvm::GlobalVariable(*module, binary->getType(), true, llvm::GlobalValue::LinkageTypes::InternalLinkage, diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp index eabfd1c4d32eb..cae713a1ce1d2 100644 --- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp +++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp @@ -130,6 +130,7 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMToPTX)) { ASSERT_TRUE( StringRef(object->data(), object->size()).contains("nvvm_kernel")); + ASSERT_TRUE(StringRef(object->data(), object->size()).count('\0') == 0); } }