160 changes: 160 additions & 0 deletions llvm/lib/Object/OffloadBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,147 @@

#include "llvm/ADT/StringSwitch.h"
#include "llvm/BinaryFormat/Magic.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/MC/StringTableBuilder.h"
#include "llvm/Object/Archive.h"
#include "llvm/Object/ArchiveWriter.h"
#include "llvm/Object/Binary.h"
#include "llvm/Object/ELFObjectFile.h"
#include "llvm/Object/Error.h"
#include "llvm/Object/IRObjectFile.h"
#include "llvm/Object/ObjectFile.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/FileOutputBuffer.h"
#include "llvm/Support/SourceMgr.h"

using namespace llvm;
using namespace llvm::object;

namespace {

/// Attempts to extract all the embedded device images contained inside the
/// buffer \p Contents. The buffer is expected to contain a valid offloading
/// binary format.
Error extractOffloadFiles(MemoryBufferRef Contents,
SmallVectorImpl<OffloadFile> &Binaries) {
uint64_t Offset = 0;
// There could be multiple offloading binaries stored at this section.
while (Offset < Contents.getBuffer().size()) {
std::unique_ptr<MemoryBuffer> Buffer =
MemoryBuffer::getMemBuffer(Contents.getBuffer().drop_front(Offset), "",
/*RequiresNullTerminator*/ false);
auto BinaryOrErr = OffloadBinary::create(*Buffer);
if (!BinaryOrErr)
return BinaryOrErr.takeError();
OffloadBinary &Binary = **BinaryOrErr;

// Create a new owned binary with a copy of the original memory.
std::unique_ptr<MemoryBuffer> BufferCopy = MemoryBuffer::getMemBufferCopy(
Binary.getData().take_front(Binary.getSize()),
Contents.getBufferIdentifier());
auto NewBinaryOrErr = OffloadBinary::create(*BufferCopy);
if (!NewBinaryOrErr)
return NewBinaryOrErr.takeError();
Binaries.emplace_back(std::move(*NewBinaryOrErr), std::move(BufferCopy));

Offset += Binary.getSize();
}

return Error::success();
}

// Extract offloading binaries from an Object file \p Obj.
Error extractFromBinary(const ObjectFile &Obj,
SmallVectorImpl<OffloadFile> &Binaries) {
for (ELFSectionRef Sec : Obj.sections()) {
if (Sec.getType() != ELF::SHT_LLVM_OFFLOADING)
continue;

Expected<StringRef> Buffer = Sec.getContents();
if (!Buffer)
return Buffer.takeError();

MemoryBufferRef Contents(*Buffer, Obj.getFileName());
if (Error Err = extractOffloadFiles(Contents, Binaries))
return Err;
}

return Error::success();
}

Error extractFromBitcode(MemoryBufferRef Buffer,
SmallVectorImpl<OffloadFile> &Binaries) {
LLVMContext Context;
SMDiagnostic Err;
std::unique_ptr<Module> M = getLazyIRModule(
MemoryBuffer::getMemBuffer(Buffer, /*RequiresNullTerminator=*/false), Err,
Context);
if (!M)
return createStringError(inconvertibleErrorCode(),
"Failed to create module");

// Extract offloading data from globals referenced by the
// `llvm.embedded.object` metadata with the `.llvm.offloading` section.
auto *MD = M->getNamedMetadata("llvm.embedded.objects");
if (!MD)
return Error::success();

for (const MDNode *Op : MD->operands()) {
if (Op->getNumOperands() < 2)
continue;

MDString *SectionID = dyn_cast<MDString>(Op->getOperand(1));
if (!SectionID || SectionID->getString() != ".llvm.offloading")
continue;

GlobalVariable *GV =
mdconst::dyn_extract_or_null<GlobalVariable>(Op->getOperand(0));
if (!GV)
continue;

auto *CDS = dyn_cast<ConstantDataSequential>(GV->getInitializer());
if (!CDS)
continue;

MemoryBufferRef Contents(CDS->getAsString(), M->getName());
if (Error Err = extractOffloadFiles(Contents, Binaries))
return Err;
}

return Error::success();
}

Error extractFromArchive(const Archive &Library,
SmallVectorImpl<OffloadFile> &Binaries) {
// Try to extract device code from each file stored in the static archive.
Error Err = Error::success();
for (auto Child : Library.children(Err)) {
auto ChildBufferOrErr = Child.getMemoryBufferRef();
if (!ChildBufferOrErr)
return ChildBufferOrErr.takeError();
std::unique_ptr<MemoryBuffer> ChildBuffer =
MemoryBuffer::getMemBuffer(*ChildBufferOrErr, false);

// Check if the buffer has the required alignment.
if (!isAddrAligned(Align(OffloadBinary::getAlignment()),
ChildBuffer->getBufferStart()))
ChildBuffer = MemoryBuffer::getMemBufferCopy(
ChildBufferOrErr->getBuffer(),
ChildBufferOrErr->getBufferIdentifier());

if (Error Err = extractOffloadBinaries(*ChildBuffer, Binaries))
return Err;
}

if (Err)
return Err;
return Error::success();
}

} // namespace

Expected<std::unique_ptr<OffloadBinary>>
OffloadBinary::create(MemoryBufferRef Buf) {
if (Buf.getBufferSize() < sizeof(Header) + sizeof(Entry))
Expand Down Expand Up @@ -115,6 +248,33 @@ OffloadBinary::write(const OffloadingImage &OffloadingData) {
return MemoryBuffer::getMemBufferCopy(OS.str());
}

Error object::extractOffloadBinaries(MemoryBufferRef Buffer,
SmallVectorImpl<OffloadFile> &Binaries) {
file_magic Type = identify_magic(Buffer.getBuffer());
switch (Type) {
case file_magic::bitcode:
return extractFromBitcode(Buffer, Binaries);
case file_magic::elf_relocatable: {
Expected<std::unique_ptr<ObjectFile>> ObjFile =
ObjectFile::createObjectFile(Buffer, Type);
if (!ObjFile)
return ObjFile.takeError();
return extractFromBinary(*ObjFile->get(), Binaries);
}
case file_magic::archive: {
Expected<std::unique_ptr<llvm::object::Archive>> LibFile =
object::Archive::create(Buffer);
if (!LibFile)
return LibFile.takeError();
return extractFromArchive(*LibFile->get(), Binaries);
}
case file_magic::offload_binary:
return extractOffloadFiles(Buffer, Binaries);
default:
return Error::success();
}
}

OffloadKind object::getOffloadKind(StringRef Name) {
return llvm::StringSwitch<OffloadKind>(Name)
.Case("openmp", OFK_OpenMP)
Expand Down