Skip to content

Commit

Permalink
[OpenMP][JIT] Cleanup JIT interface, caching, and races
Browse files Browse the repository at this point in the history
The JIT interface was somewhat irregular as it used multiple global
functions. It also did not cache the results of the JIT, hence multiple
GPU systems would perform the work multiple times. Finally, there might
have been races on the state if we have multi-threaded initialization of
different embedded images, or one image initialized on multiple devices.

This patch tries to rectify all of the above. The JITEngine is now a
part of the GenericPluginTy and tied to one target triple. To support
multiple "ComputeUnitKind"s (previously confusingly called Arch or
[M]CPU) and to avoid re-jitting for the same ComputeUnitKind, we keep a
map of JIT results per ComputeUnitKind. All interaction with the JIT
happens through the JITEngine directly, two functions are exposed. Both
use (shared) locks to avoid races and cache the result. All JIT-related
environment variables are now defined together.

Differential Revision: https://reviews.llvm.org/D141081
  • Loading branch information
jdoerfert committed Jan 15, 2023
1 parent 158aa99 commit f8e094b
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 174 deletions.
11 changes: 6 additions & 5 deletions openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp
Expand Up @@ -1530,7 +1530,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
char GPUName[64];
if (auto Err = getDeviceAttr(HSA_AGENT_INFO_NAME, GPUName))
return Err;
Arch = GPUName;
ComputeUnitKind = GPUName;

// Get the wavefront size.
uint32_t WavefrontSize = 0;
Expand Down Expand Up @@ -1669,7 +1669,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
INFO(OMP_INFOTYPE_PLUGIN_KERNEL, getDeviceId(),
"Using `%s` to link JITed amdgcn ouput.", LLDPath.c_str());

std::string MCPU = "-plugin-opt=mcpu=" + getArch();
std::string MCPU = "-plugin-opt=mcpu=" + getComputeUnitKind();

StringRef Args[] = {LLDPath,
"-flavor",
Expand All @@ -1692,7 +1692,8 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
MemoryBuffer::getFileOrSTDIN(LinkerOutputFilePath.data()).get());
}

std::string getArch() const override { return Arch; }
/// See GenericDeviceTy::getComputeUnitKind().
std::string getComputeUnitKind() const override { return ComputeUnitKind; }

/// Allocate and construct an AMDGPU kernel.
Expected<GenericKernelTy *>
Expand Down Expand Up @@ -2096,7 +2097,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
hsa_agent_t Agent;

/// The GPU architecture.
std::string Arch;
std::string ComputeUnitKind;

/// Reference to the host device.
AMDHostDeviceTy &HostDevice;
Expand Down Expand Up @@ -2244,7 +2245,7 @@ struct AMDGPUGlobalHandlerTy final : public GenericGlobalHandlerTy {
/// Class implementing the AMDGPU-specific functionalities of the plugin.
struct AMDGPUPluginTy final : public GenericPluginTy {
/// Create an AMDGPU plugin and initialize the AMDGPU driver.
AMDGPUPluginTy() : GenericPluginTy(), HostDevice(nullptr) {}
AMDGPUPluginTy() : GenericPluginTy(getTripleArch()), HostDevice(nullptr) {}

/// This class should not be copied.
AMDGPUPluginTy(const AMDGPUPluginTy &) = delete;
Expand Down
184 changes: 79 additions & 105 deletions openmp/libomptarget/plugins-nextgen/common/PluginInterface/JIT.cpp
Expand Up @@ -11,11 +11,11 @@
#include "JIT.h"
#include "Debug.h"

#include "PluginInterface.h"
#include "Utilities.h"
#include "omptarget.h"

#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/CodeGen/CommandFlags.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/IR/LLVMContext.h"
Expand All @@ -28,7 +28,6 @@
#include "llvm/Object/IRObjectFile.h"
#include "llvm/Passes/OptimizationLevel.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
Expand All @@ -39,15 +38,23 @@
#include "llvm/Target/TargetOptions.h"

#include <mutex>
#include <shared_mutex>
#include <system_error>

using namespace llvm;
using namespace llvm::object;
using namespace omp;
using namespace omp::target;

static codegen::RegisterCodeGenFlags RCGF;

namespace {

/// A map from a bitcode image start address to its corresponding triple. If the
/// image is not in the map, it is not a bitcode image.
DenseMap<void *, Triple::ArchType> BitcodeImageMap;
std::shared_mutex BitcodeImageMapMutex;

std::once_flag InitFlag;

void init(Triple TT) {
Expand All @@ -70,10 +77,8 @@ void init(Triple TT) {
JITTargetInitialized = true;
}
#endif
if (!JITTargetInitialized) {
FAILURE_MESSAGE("unsupported JIT target: %s\n", TT.str().c_str());
abort();
}
if (!JITTargetInitialized)
return;

// Initialize passes
PassRegistry &Registry = *PassRegistry::getPassRegistry();
Expand Down Expand Up @@ -125,9 +130,9 @@ createModuleFromMemoryBuffer(std::unique_ptr<MemoryBuffer> &MB,
return std::move(Mod);
}
Expected<std::unique_ptr<Module>>
createModuleFromImage(__tgt_device_image *Image, LLVMContext &Context) {
StringRef Data((const char *)Image->ImageStart,
target::getPtrDiff(Image->ImageEnd, Image->ImageStart));
createModuleFromImage(const __tgt_device_image &Image, LLVMContext &Context) {
StringRef Data((const char *)Image.ImageStart,
target::getPtrDiff(Image.ImageEnd, Image.ImageStart));
std::unique_ptr<MemoryBuffer> MB = MemoryBuffer::getMemBuffer(
Data, /* BufferName */ "", /* RequiresNullTerminator */ false);
return createModuleFromMemoryBuffer(MB, Context);
Expand Down Expand Up @@ -192,44 +197,11 @@ createTargetMachine(Module &M, std::string CPU, unsigned OptLevel) {
return std::move(TM);
}

///
class JITEngine {
public:
JITEngine(Triple::ArchType TA, std::string MCpu)
: TT(Triple::getArchTypeName(TA)), CPU(MCpu),
ReplacementModuleFileName("LIBOMPTARGET_JIT_REPLACEMENT_MODULE"),
PreOptIRModuleFileName("LIBOMPTARGET_JIT_PRE_OPT_IR_MODULE"),
PostOptIRModuleFileName("LIBOMPTARGET_JIT_POST_OPT_IR_MODULE") {
std::call_once(InitFlag, init, TT);
}

/// Run jit compilation. It is expected to get a memory buffer containing the
/// generated device image that could be loaded to the device directly.
Expected<std::unique_ptr<MemoryBuffer>>
run(__tgt_device_image *Image, unsigned OptLevel,
jit::PostProcessingFn PostProcessing);

private:
/// Run backend, which contains optimization and code generation.
Expected<std::unique_ptr<MemoryBuffer>> backend(Module &M, unsigned OptLevel);

/// Run optimization pipeline.
void opt(TargetMachine *TM, TargetLibraryInfoImpl *TLII, Module &M,
unsigned OptLevel);

/// Run code generation.
void codegen(TargetMachine *TM, TargetLibraryInfoImpl *TLII, Module &M,
raw_pwrite_stream &OS);

LLVMContext Context;
const Triple TT;
const std::string CPU;
} // namespace

/// Control environment variables.
target::StringEnvar ReplacementModuleFileName;
target::StringEnvar PreOptIRModuleFileName;
target::StringEnvar PostOptIRModuleFileName;
};
JITEngine::JITEngine(Triple::ArchType TA) : TT(Triple::getArchTypeName(TA)) {
std::call_once(InitFlag, init, TT);
}

void JITEngine::opt(TargetMachine *TM, TargetLibraryInfoImpl *TLII, Module &M,
unsigned OptLevel) {
Expand Down Expand Up @@ -274,18 +246,19 @@ void JITEngine::codegen(TargetMachine *TM, TargetLibraryInfoImpl *TLII,
PM.run(M);
}

Expected<std::unique_ptr<MemoryBuffer>> JITEngine::backend(Module &M,
unsigned OptLevel) {
Expected<std::unique_ptr<MemoryBuffer>>
JITEngine::backend(Module &M, const std::string &ComputeUnitKind,
unsigned OptLevel) {

auto RemarksFileOrErr = setupLLVMOptimizationRemarks(
Context, /* RemarksFilename */ "", /* RemarksPasses */ "",
M.getContext(), /* RemarksFilename */ "", /* RemarksPasses */ "",
/* RemarksFormat */ "", /* RemarksWithHotness */ false);
if (Error E = RemarksFileOrErr.takeError())
return std::move(E);
if (*RemarksFileOrErr)
(*RemarksFileOrErr)->keep();

auto TMOrErr = createTargetMachine(M, CPU, OptLevel);
auto TMOrErr = createTargetMachine(M, ComputeUnitKind, OptLevel);
if (!TMOrErr)
return TMOrErr.takeError();

Expand Down Expand Up @@ -323,14 +296,23 @@ Expected<std::unique_ptr<MemoryBuffer>> JITEngine::backend(Module &M,
return MemoryBuffer::getMemBufferCopy(OS.str());
}

Expected<std::unique_ptr<MemoryBuffer>>
JITEngine::run(__tgt_device_image *Image, unsigned OptLevel,
jit::PostProcessingFn PostProcessing) {
Expected<const __tgt_device_image *>
JITEngine::compile(const __tgt_device_image &Image,
const std::string &ComputeUnitKind,
PostProcessingFn PostProcessing) {
std::lock_guard<std::mutex> Lock(ComputeUnitMapMutex);

// Check if we JITed this image for the given compute unit kind before.
ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind];
if (__tgt_device_image *JITedImage = CUI.TgtImageMap.lookup(&Image))
return JITedImage;

Module *Mod = nullptr;
// Check if the user replaces the module at runtime or we read it from the
// image.
// TODO: Allow the user to specify images per device (Arch + ComputeUnitKind).
if (!ReplacementModuleFileName.isPresent()) {
auto ModOrErr = createModuleFromImage(Image, Context);
auto ModOrErr = createModuleFromImage(Image, CUI.Context);
if (!ModOrErr)
return ModOrErr.takeError();
Mod = ModOrErr->release();
Expand All @@ -341,44 +323,65 @@ JITEngine::run(__tgt_device_image *Image, unsigned OptLevel,
return createStringError(MBOrErr.getError(),
"Could not read replacement module from %s\n",
ReplacementModuleFileName.get().c_str());
auto ModOrErr = createModuleFromMemoryBuffer(MBOrErr.get(), Context);
auto ModOrErr = createModuleFromMemoryBuffer(MBOrErr.get(), CUI.Context);
if (!ModOrErr)
return ModOrErr.takeError();
Mod = ModOrErr->release();
}

auto MBOrError = backend(*Mod, OptLevel);
auto MBOrError = backend(*Mod, ComputeUnitKind, JITOptLevel);
if (!MBOrError)
return MBOrError.takeError();

return PostProcessing(std::move(*MBOrError));
auto ImageMBOrErr = PostProcessing(std::move(*MBOrError));
if (!ImageMBOrErr)
return ImageMBOrErr.takeError();

CUI.JITImages.push_back(std::move(*ImageMBOrErr));
__tgt_device_image *&JITedImage = CUI.TgtImageMap[&Image];
JITedImage = new __tgt_device_image();
*JITedImage = Image;

auto &ImageMB = CUI.JITImages.back();

JITedImage->ImageStart = (void *)ImageMB->getBufferStart();
JITedImage->ImageEnd = (void *)ImageMB->getBufferEnd();

return JITedImage;
}

/// A map from a bitcode image start address to its corresponding triple. If the
/// image is not in the map, it is not a bitcode image.
DenseMap<void *, Triple::ArchType> BitcodeImageMap;
Expected<const __tgt_device_image *>
JITEngine::process(const __tgt_device_image &Image,
target::plugin::GenericDeviceTy &Device) {
const std::string &ComputeUnitKind = Device.getComputeUnitKind();

/// Output images generated from LLVM backend.
SmallVector<std::unique_ptr<MemoryBuffer>, 4> JITImages;
PostProcessingFn PostProcessing = [&Device](std::unique_ptr<MemoryBuffer> MB)
-> Expected<std::unique_ptr<MemoryBuffer>> {
return Device.doJITPostProcessing(std::move(MB));
};

/// A list of __tgt_device_image images.
std::list<__tgt_device_image> TgtImages;
} // namespace
{
std::shared_lock<std::shared_mutex> SharedLock(BitcodeImageMapMutex);
auto Itr = BitcodeImageMap.find(Image.ImageStart);
if (Itr != BitcodeImageMap.end() && Itr->second == TT.getArch())
return compile(Image, ComputeUnitKind, PostProcessing);
}

namespace llvm {
namespace omp {
namespace jit {
bool checkBitcodeImage(__tgt_device_image *Image, Triple::ArchType TA) {
return &Image;
}

bool JITEngine::checkBitcodeImage(const __tgt_device_image &Image) {
TimeTraceScope TimeScope("Check bitcode image");
std::lock_guard<std::shared_mutex> Lock(BitcodeImageMapMutex);

{
auto Itr = BitcodeImageMap.find(Image->ImageStart);
if (Itr != BitcodeImageMap.end() && Itr->second == TA)
auto Itr = BitcodeImageMap.find(Image.ImageStart);
if (Itr != BitcodeImageMap.end() && Itr->second == TT.getArch())
return true;
}

StringRef Data(reinterpret_cast<const char *>(Image->ImageStart),
target::getPtrDiff(Image->ImageEnd, Image->ImageStart));
StringRef Data(reinterpret_cast<const char *>(Image.ImageStart),
target::getPtrDiff(Image.ImageEnd, Image.ImageStart));
std::unique_ptr<MemoryBuffer> MB = MemoryBuffer::getMemBuffer(
Data, /* BufferName */ "", /* RequiresNullTerminator */ false);
if (!MB)
Expand All @@ -391,37 +394,8 @@ bool checkBitcodeImage(__tgt_device_image *Image, Triple::ArchType TA) {
}

auto ActualTriple = FOrErr->TheReader.getTargetTriple();
auto BitcodeTA = Triple(ActualTriple).getArch();
BitcodeImageMap[Image.ImageStart] = BitcodeTA;

if (Triple(ActualTriple).getArch() == TA) {
BitcodeImageMap[Image->ImageStart] = TA;
return true;
}

return false;
return BitcodeTA == TT.getArch();
}

Expected<__tgt_device_image *> compile(__tgt_device_image *Image,
Triple::ArchType TA, std::string MCPU,
unsigned OptLevel,
PostProcessingFn PostProcessing) {
JITEngine J(TA, MCPU);

auto ImageMBOrErr = J.run(Image, OptLevel, PostProcessing);
if (!ImageMBOrErr)
return ImageMBOrErr.takeError();

JITImages.push_back(std::move(*ImageMBOrErr));
TgtImages.push_back(*Image);

auto &ImageMB = JITImages.back();
auto *NewImage = &TgtImages.back();

NewImage->ImageStart = (void *)ImageMB->getBufferStart();
NewImage->ImageEnd = (void *)ImageMB->getBufferEnd();

return NewImage;
}

} // namespace jit
} // namespace omp
} // namespace llvm

0 comments on commit f8e094b

Please sign in to comment.