208 changes: 100 additions & 108 deletions openmp/libomptarget/plugins-nextgen/common/PluginInterface/JIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
#include "JIT.h"
#include "Debug.h"

#include "PluginInterface.h"
#include "Utilities.h"
#include "omptarget.h"

#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/CodeGen/CommandFlags.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/IR/LLVMContext.h"
Expand All @@ -28,7 +28,6 @@
#include "llvm/Object/IRObjectFile.h"
#include "llvm/Passes/OptimizationLevel.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
Expand All @@ -39,15 +38,23 @@
#include "llvm/Target/TargetOptions.h"

#include <mutex>
#include <shared_mutex>
#include <system_error>

using namespace llvm;
using namespace llvm::object;
using namespace omp;
using namespace omp::target;

static codegen::RegisterCodeGenFlags RCGF;

namespace {

/// A map from a bitcode image start address to its corresponding triple. If the
/// image is not in the map, it is not a bitcode image.
DenseMap<void *, Triple::ArchType> BitcodeImageMap;
std::shared_mutex BitcodeImageMapMutex;

std::once_flag InitFlag;

void init(Triple TT) {
Expand All @@ -70,10 +77,8 @@ void init(Triple TT) {
JITTargetInitialized = true;
}
#endif
if (!JITTargetInitialized) {
FAILURE_MESSAGE("unsupported JIT target: %s\n", TT.str().c_str());
abort();
}
if (!JITTargetInitialized)
return;

// Initialize passes
PassRegistry &Registry = *PassRegistry::getPassRegistry();
Expand Down Expand Up @@ -125,9 +130,9 @@ createModuleFromMemoryBuffer(std::unique_ptr<MemoryBuffer> &MB,
return std::move(Mod);
}
Expected<std::unique_ptr<Module>>
createModuleFromImage(__tgt_device_image *Image, LLVMContext &Context) {
StringRef Data((const char *)Image->ImageStart,
(char *)Image->ImageEnd - (char *)Image->ImageStart);
createModuleFromImage(const __tgt_device_image &Image, LLVMContext &Context) {
StringRef Data((const char *)Image.ImageStart,
target::getPtrDiff(Image.ImageEnd, Image.ImageStart));
std::unique_ptr<MemoryBuffer> MB = MemoryBuffer::getMemBuffer(
Data, /* BufferName */ "", /* RequiresNullTerminator */ false);
return createModuleFromMemoryBuffer(MB, Context);
Expand Down Expand Up @@ -192,44 +197,11 @@ createTargetMachine(Module &M, std::string CPU, unsigned OptLevel) {
return std::move(TM);
}

///
class JITEngine {
public:
JITEngine(Triple::ArchType TA, std::string MCpu)
: TT(Triple::getArchTypeName(TA)), CPU(MCpu),
ReplacementModuleFileName("LIBOMPTARGET_JIT_REPLACEMENT_MODULE"),
PreOptIRModuleFileName("LIBOMPTARGET_JIT_PRE_OPT_IR_MODULE"),
PostOptIRModuleFileName("LIBOMPTARGET_JIT_POST_OPT_IR_MODULE") {
std::call_once(InitFlag, init, TT);
}

/// Run jit compilation. It is expected to get a memory buffer containing the
/// generated device image that could be loaded to the device directly.
Expected<std::unique_ptr<MemoryBuffer>>
run(__tgt_device_image *Image, unsigned OptLevel,
jit::PostProcessingFn PostProcessing);

private:
/// Run backend, which contains optimization and code generation.
Expected<std::unique_ptr<MemoryBuffer>> backend(Module &M, unsigned OptLevel);

/// Run optimization pipeline.
void opt(TargetMachine *TM, TargetLibraryInfoImpl *TLII, Module &M,
unsigned OptLevel);

/// Run code generation.
void codegen(TargetMachine *TM, TargetLibraryInfoImpl *TLII, Module &M,
raw_pwrite_stream &OS);

LLVMContext Context;
const Triple TT;
const std::string CPU;
} // namespace

/// Control environment variables.
target::StringEnvar ReplacementModuleFileName;
target::StringEnvar PreOptIRModuleFileName;
target::StringEnvar PostOptIRModuleFileName;
};
JITEngine::JITEngine(Triple::ArchType TA) : TT(Triple::getArchTypeName(TA)) {
std::call_once(InitFlag, init, TT);
}

void JITEngine::opt(TargetMachine *TM, TargetLibraryInfoImpl *TLII, Module &M,
unsigned OptLevel) {
Expand Down Expand Up @@ -274,18 +246,19 @@ void JITEngine::codegen(TargetMachine *TM, TargetLibraryInfoImpl *TLII,
PM.run(M);
}

Expected<std::unique_ptr<MemoryBuffer>> JITEngine::backend(Module &M,
unsigned OptLevel) {
Expected<std::unique_ptr<MemoryBuffer>>
JITEngine::backend(Module &M, const std::string &ComputeUnitKind,
unsigned OptLevel) {

auto RemarksFileOrErr = setupLLVMOptimizationRemarks(
Context, /* RemarksFilename */ "", /* RemarksPasses */ "",
M.getContext(), /* RemarksFilename */ "", /* RemarksPasses */ "",
/* RemarksFormat */ "", /* RemarksWithHotness */ false);
if (Error E = RemarksFileOrErr.takeError())
return std::move(E);
if (*RemarksFileOrErr)
(*RemarksFileOrErr)->keep();

auto TMOrErr = createTargetMachine(M, CPU, OptLevel);
auto TMOrErr = createTargetMachine(M, ComputeUnitKind, OptLevel);
if (!TMOrErr)
return TMOrErr.takeError();

Expand All @@ -302,7 +275,8 @@ Expected<std::unique_ptr<MemoryBuffer>> JITEngine::backend(Module &M,
M.print(FD, nullptr);
}

opt(TM.get(), &TLII, M, OptLevel);
if (!JITSkipOpt)
opt(TM.get(), &TLII, M, OptLevel);

if (PostOptIRModuleFileName.isPresent()) {
std::error_code EC;
Expand All @@ -324,13 +298,26 @@ Expected<std::unique_ptr<MemoryBuffer>> JITEngine::backend(Module &M,
}

Expected<std::unique_ptr<MemoryBuffer>>
JITEngine::run(__tgt_device_image *Image, unsigned OptLevel,
jit::PostProcessingFn PostProcessing) {
JITEngine::getOrCreateObjFile(const __tgt_device_image &Image, LLVMContext &Ctx,
const std::string &ComputeUnitKind) {

// Check if the user replaces the module at runtime with a finished object.
if (ReplacementObjectFileName.isPresent()) {
auto MBOrErr =
MemoryBuffer::getFileOrSTDIN(ReplacementObjectFileName.get());
if (!MBOrErr)
return createStringError(MBOrErr.getError(),
"Could not read replacement obj from %s\n",
ReplacementModuleFileName.get().c_str());
return std::move(*MBOrErr);
}

Module *Mod = nullptr;
// Check if the user replaces the module at runtime or we read it from the
// image.
// TODO: Allow the user to specify images per device (Arch + ComputeUnitKind).
if (!ReplacementModuleFileName.isPresent()) {
auto ModOrErr = createModuleFromImage(Image, Context);
auto ModOrErr = createModuleFromImage(Image, Ctx);
if (!ModOrErr)
return ModOrErr.takeError();
Mod = ModOrErr->release();
Expand All @@ -341,45 +328,79 @@ JITEngine::run(__tgt_device_image *Image, unsigned OptLevel,
return createStringError(MBOrErr.getError(),
"Could not read replacement module from %s\n",
ReplacementModuleFileName.get().c_str());
auto ModOrErr = createModuleFromMemoryBuffer(MBOrErr.get(), Context);
auto ModOrErr = createModuleFromMemoryBuffer(MBOrErr.get(), Ctx);
if (!ModOrErr)
return ModOrErr.takeError();
Mod = ModOrErr->release();
}

auto MBOrError = backend(*Mod, OptLevel);
if (!MBOrError)
return MBOrError.takeError();
return backend(*Mod, ComputeUnitKind, JITOptLevel);
}

Expected<const __tgt_device_image *>
JITEngine::compile(const __tgt_device_image &Image,
const std::string &ComputeUnitKind,
PostProcessingFn PostProcessing) {
std::lock_guard<std::mutex> Lock(ComputeUnitMapMutex);

// Check if we JITed this image for the given compute unit kind before.
ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind];
if (__tgt_device_image *JITedImage = CUI.TgtImageMap.lookup(&Image))
return JITedImage;

auto ObjMBOrErr = getOrCreateObjFile(Image, CUI.Context, ComputeUnitKind);
if (!ObjMBOrErr)
return ObjMBOrErr.takeError();

return PostProcessing(std::move(*MBOrError));
auto ImageMBOrErr = PostProcessing(std::move(*ObjMBOrErr));
if (!ImageMBOrErr)
return ImageMBOrErr.takeError();

CUI.JITImages.push_back(std::move(*ImageMBOrErr));
__tgt_device_image *&JITedImage = CUI.TgtImageMap[&Image];
JITedImage = new __tgt_device_image();
*JITedImage = Image;

auto &ImageMB = CUI.JITImages.back();

JITedImage->ImageStart = (void *)ImageMB->getBufferStart();
JITedImage->ImageEnd = (void *)ImageMB->getBufferEnd();

return JITedImage;
}

/// A map from a bitcode image start address to its corresponding triple. If the
/// image is not in the map, it is not a bitcode image.
DenseMap<void *, Triple::ArchType> BitcodeImageMap;
Expected<const __tgt_device_image *>
JITEngine::process(const __tgt_device_image &Image,
target::plugin::GenericDeviceTy &Device) {
const std::string &ComputeUnitKind = Device.getComputeUnitKind();

/// Output images generated from LLVM backend.
SmallVector<std::unique_ptr<MemoryBuffer>, 4> JITImages;
PostProcessingFn PostProcessing = [&Device](std::unique_ptr<MemoryBuffer> MB)
-> Expected<std::unique_ptr<MemoryBuffer>> {
return Device.doJITPostProcessing(std::move(MB));
};

/// A list of __tgt_device_image images.
std::list<__tgt_device_image> TgtImages;
} // namespace
{
std::shared_lock<std::shared_mutex> SharedLock(BitcodeImageMapMutex);
auto Itr = BitcodeImageMap.find(Image.ImageStart);
if (Itr != BitcodeImageMap.end() && Itr->second == TT.getArch())
return compile(Image, ComputeUnitKind, PostProcessing);
}

return &Image;
}

namespace llvm {
namespace omp {
namespace jit {
bool checkBitcodeImage(__tgt_device_image *Image, Triple::ArchType TA) {
bool JITEngine::checkBitcodeImage(const __tgt_device_image &Image) {
TimeTraceScope TimeScope("Check bitcode image");
std::lock_guard<std::shared_mutex> Lock(BitcodeImageMapMutex);

{
auto Itr = BitcodeImageMap.find(Image->ImageStart);
if (Itr != BitcodeImageMap.end() && Itr->second == TA)
auto Itr = BitcodeImageMap.find(Image.ImageStart);
if (Itr != BitcodeImageMap.end() && Itr->second == TT.getArch())
return true;
}

StringRef Data(reinterpret_cast<const char *>(Image->ImageStart),
reinterpret_cast<char *>(Image->ImageEnd) -
reinterpret_cast<char *>(Image->ImageStart));
StringRef Data(reinterpret_cast<const char *>(Image.ImageStart),
target::getPtrDiff(Image.ImageEnd, Image.ImageStart));
std::unique_ptr<MemoryBuffer> MB = MemoryBuffer::getMemBuffer(
Data, /* BufferName */ "", /* RequiresNullTerminator */ false);
if (!MB)
Expand All @@ -392,37 +413,8 @@ bool checkBitcodeImage(__tgt_device_image *Image, Triple::ArchType TA) {
}

auto ActualTriple = FOrErr->TheReader.getTargetTriple();
auto BitcodeTA = Triple(ActualTriple).getArch();
BitcodeImageMap[Image.ImageStart] = BitcodeTA;

if (Triple(ActualTriple).getArch() == TA) {
BitcodeImageMap[Image->ImageStart] = TA;
return true;
}

return false;
return BitcodeTA == TT.getArch();
}

Expected<__tgt_device_image *> compile(__tgt_device_image *Image,
Triple::ArchType TA, std::string MCPU,
unsigned OptLevel,
PostProcessingFn PostProcessing) {
JITEngine J(TA, MCPU);

auto ImageMBOrErr = J.run(Image, OptLevel, PostProcessing);
if (!ImageMBOrErr)
return ImageMBOrErr.takeError();

JITImages.push_back(std::move(*ImageMBOrErr));
TgtImages.push_back(*Image);

auto &ImageMB = JITImages.back();
auto *NewImage = &TgtImages.back();

NewImage->ImageStart = (void *)ImageMB->getBufferStart();
NewImage->ImageEnd = (void *)ImageMB->getBufferEnd();

return NewImage;
}

} // namespace jit
} // namespace omp
} // namespace llvm
115 changes: 96 additions & 19 deletions openmp/libomptarget/plugins-nextgen/common/PluginInterface/JIT.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,20 @@
#ifndef OPENMP_LIBOMPTARGET_PLUGINS_NEXTGEN_COMMON_JIT_H
#define OPENMP_LIBOMPTARGET_PLUGINS_NEXTGEN_COMMON_JIT_H

#include "Utilities.h"

#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Triple.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Error.h"
#include "llvm/Target/TargetMachine.h"

#include <functional>
#include <memory>
#include <shared_mutex>
#include <string>

struct __tgt_device_image;
Expand All @@ -25,25 +33,94 @@ namespace llvm {
class MemoryBuffer;

namespace omp {
namespace jit {

/// Function type for a callback that will be called after the backend is
/// called.
using PostProcessingFn = std::function<Expected<std::unique_ptr<MemoryBuffer>>(
std::unique_ptr<MemoryBuffer>)>;

/// Check if \p Image contains bitcode with triple \p Triple.
bool checkBitcodeImage(__tgt_device_image *Image, Triple::ArchType TA);

/// Compile the bitcode image \p Image and generate the binary image that can be
/// loaded to the target device of the triple \p Triple architecture \p MCpu. \p
/// PostProcessing will be called after codegen to handle cases such as assember
/// as an external tool.
Expected<__tgt_device_image *> compile(__tgt_device_image *Image,
Triple::ArchType TA, std::string MCpu,
unsigned OptLevel,
PostProcessingFn PostProcessing);
} // namespace jit
namespace target {
namespace plugin {
struct GenericDeviceTy;
} // namespace plugin

/// The JIT infrastructure and caching mechanism.
struct JITEngine {
/// Function type for a callback that will be called after the backend is
/// called.
using PostProcessingFn =
std::function<Expected<std::unique_ptr<MemoryBuffer>>(
std::unique_ptr<MemoryBuffer>)>;

JITEngine(Triple::ArchType TA);

/// Run jit compilation if \p Image is a bitcode image, otherwise simply
/// return \p Image. It is expected to return a memory buffer containing the
/// generated device image that could be loaded to the device directly.
Expected<const __tgt_device_image *>
process(const __tgt_device_image &Image,
target::plugin::GenericDeviceTy &Device);

/// Return true if \p Image is a bitcode image that can be JITed for the given
/// architecture.
bool checkBitcodeImage(const __tgt_device_image &Image);

private:
/// Compile the bitcode image \p Image and generate the binary image that can
/// be loaded to the target device of the triple \p Triple architecture \p
/// MCpu. \p PostProcessing will be called after codegen to handle cases such
/// as assember as an external tool.
Expected<const __tgt_device_image *>
compile(const __tgt_device_image &Image, const std::string &ComputeUnitKind,
PostProcessingFn PostProcessing);

/// Create or retrieve the object image file from the file system or via
/// compilation of the \p Image.
Expected<std::unique_ptr<MemoryBuffer>>
getOrCreateObjFile(const __tgt_device_image &Image, LLVMContext &Ctx,
const std::string &ComputeUnitKind);

/// Run backend, which contains optimization and code generation.
Expected<std::unique_ptr<MemoryBuffer>>
backend(Module &M, const std::string &ComputeUnitKind, unsigned OptLevel);

/// Run optimization pipeline.
void opt(TargetMachine *TM, TargetLibraryInfoImpl *TLII, Module &M,
unsigned OptLevel);

/// Run code generation.
void codegen(TargetMachine *TM, TargetLibraryInfoImpl *TLII, Module &M,
raw_pwrite_stream &OS);

/// The target triple used by the JIT.
const Triple TT;

struct ComputeUnitInfo {
/// LLVM Context in which the modules will be constructed.
LLVMContext Context;

/// Output images generated from LLVM backend.
SmallVector<std::unique_ptr<MemoryBuffer>, 4> JITImages;

/// A map of embedded IR images to JITed images.
DenseMap<const __tgt_device_image *, __tgt_device_image *> TgtImageMap;
};

/// Map from (march) "CPUs" (e.g., sm_80, or gfx90a), which we call compute
/// units as they are not CPUs, to the image information we cached for them.
StringMap<ComputeUnitInfo> ComputeUnitMap;
std::mutex ComputeUnitMapMutex;

/// Control environment variables.
target::StringEnvar ReplacementObjectFileName =
target::StringEnvar("LIBOMPTARGET_JIT_REPLACEMENT_OBJECT");
target::StringEnvar ReplacementModuleFileName =
target::StringEnvar("LIBOMPTARGET_JIT_REPLACEMENT_MODULE");
target::StringEnvar PreOptIRModuleFileName =
target::StringEnvar("LIBOMPTARGET_JIT_PRE_OPT_IR_MODULE");
target::StringEnvar PostOptIRModuleFileName =
target::StringEnvar("LIBOMPTARGET_JIT_POST_OPT_IR_MODULE");
target::UInt32Envar JITOptLevel =
target::UInt32Envar("LIBOMPTARGET_JIT_OPT_LEVEL", 3);
target::BoolEnvar JITSkipOpt =
target::BoolEnvar("LIBOMPTARGET_JIT_SKIP_OPT", false);
};

} // namespace target
} // namespace omp
} // namespace llvm

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,22 @@ Error GenericDeviceTy::deinit() {

Expected<__tgt_target_table *>
GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
const __tgt_device_image *TgtImage) {
DP("Load data from image " DPxMOD "\n", DPxPTR(TgtImage->ImageStart));
const __tgt_device_image *InputTgtImage) {
assert(InputTgtImage && "Expected non-null target image");
DP("Load data from image " DPxMOD "\n", DPxPTR(InputTgtImage->ImageStart));

auto PostJITImageOrErr = Plugin.getJIT().process(*InputTgtImage, *this);
if (!PostJITImageOrErr) {
auto Err = PostJITImageOrErr.takeError();
REPORT("Failure to jit IR image %p on device %d: %s\n", InputTgtImage,
DeviceId, toString(std::move(Err)).data());
return nullptr;
}

// Load the binary and allocate the image object. Use the next available id
// for the image id, which is the number of previously loaded images.
auto ImageOrErr = loadBinaryImpl(TgtImage, LoadedImages.size());
auto ImageOrErr =
loadBinaryImpl(PostJITImageOrErr.get(), LoadedImages.size());
if (!ImageOrErr)
return ImageOrErr.takeError();

Expand Down Expand Up @@ -668,7 +678,7 @@ int32_t __tgt_rtl_is_valid_binary(__tgt_device_image *TgtImage) {
if (elf_check_machine(TgtImage, Plugin::get().getMagicElfBits()))
return true;

return jit::checkBitcodeImage(TgtImage, Plugin::get().getTripleArch());
return Plugin::get().getJIT().checkBitcodeImage(*TgtImage);
}

int32_t __tgt_rtl_is_valid_binary_info(__tgt_device_image *TgtImage,
Expand Down Expand Up @@ -745,34 +755,6 @@ __tgt_target_table *__tgt_rtl_load_binary(int32_t DeviceId,
GenericPluginTy &Plugin = Plugin::get();
GenericDeviceTy &Device = Plugin.getDevice(DeviceId);

// If it is a bitcode image, we have to jit the binary image before loading to
// the device.
{
// TODO: Move this (at least the environment variable) into the JIT.h.
UInt32Envar JITOptLevel("LIBOMPTARGET_JIT_OPT_LEVEL", 3);
Triple::ArchType TA = Plugin.getTripleArch();
std::string Arch = Device.getArch();

jit::PostProcessingFn PostProcessing =
[&Device](std::unique_ptr<MemoryBuffer> MB)
-> Expected<std::unique_ptr<MemoryBuffer>> {
return Device.doJITPostProcessing(std::move(MB));
};

if (jit::checkBitcodeImage(TgtImage, TA)) {
auto TgtImageOrErr =
jit::compile(TgtImage, TA, Arch, JITOptLevel, PostProcessing);
if (!TgtImageOrErr) {
auto Err = TgtImageOrErr.takeError();
REPORT("Failure to jit binary image from bitcode image %p on device "
"%d: %s\n",
TgtImage, DeviceId, toString(std::move(Err)).data());
return nullptr;
}

TgtImage = *TgtImageOrErr;
}
}

auto TableOrErr = Device.loadBinary(Plugin, TgtImage);
if (!TableOrErr) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "Debug.h"
#include "DeviceEnvironment.h"
#include "GlobalHandler.h"
#include "JIT.h"
#include "MemoryManager.h"
#include "Utilities.h"
#include "omptarget.h"
Expand All @@ -37,6 +38,7 @@
namespace llvm {
namespace omp {
namespace target {

namespace plugin {

struct GenericPluginTy;
Expand Down Expand Up @@ -132,7 +134,7 @@ class DeviceImageTy {

/// Get the image size.
size_t getSize() const {
return ((char *)TgtImage->ImageEnd) - ((char *)TgtImage->ImageStart);
return getPtrDiff(TgtImage->ImageEnd, TgtImage->ImageStart);
}

/// Get a memory buffer reference to the whole image.
Expand Down Expand Up @@ -378,10 +380,8 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
}
uint32_t getDynamicMemorySize() const { return OMPX_SharedMemorySize; }

/// Get target architecture.
virtual std::string getArch() const {
return "unknown";
}
/// Get target compute unit kind (e.g., sm_80, or gfx908).
virtual std::string getComputeUnitKind() const { return "unknown"; }

/// Post processing after jit backend. The ownership of \p MB will be taken.
virtual Expected<std::unique_ptr<MemoryBuffer>>
Expand Down Expand Up @@ -469,7 +469,7 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
--It;

// Evaluate whether the buffer is contained in the pinned allocation.
return ((const char *)It->first + It->second > (const char *)Buffer);
return (advanceVoidPtr(It->first, It->second) > (const char *)Buffer);
}

/// Return the execution mode used for kernel \p Name.
Expand Down Expand Up @@ -513,8 +513,8 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
struct GenericPluginTy {

/// Construct a plugin instance.
GenericPluginTy()
: RequiresFlags(OMP_REQ_UNDEFINED), GlobalHandler(nullptr) {}
GenericPluginTy(Triple::ArchType TA)
: RequiresFlags(OMP_REQ_UNDEFINED), GlobalHandler(nullptr), JIT(TA) {}

virtual ~GenericPluginTy() {}

Expand Down Expand Up @@ -543,9 +543,7 @@ struct GenericPluginTy {
virtual uint16_t getMagicElfBits() const = 0;

/// Get the target triple of this plugin.
virtual Triple::ArchType getTripleArch() const {
return Triple::ArchType::UnknownArch;
}
virtual Triple::ArchType getTripleArch() const = 0;

/// Allocate a structure using the internal allocator.
template <typename Ty> Ty *allocate() {
Expand All @@ -558,6 +556,10 @@ struct GenericPluginTy {
return *GlobalHandler;
}

/// Get the reference to the JIT used for all devices connected to this
/// plugin.
JITEngine &getJIT() { return JIT; }

/// Get the OpenMP requires flags set for this plugin.
int64_t getRequiresFlags() const { return RequiresFlags; }

Expand Down Expand Up @@ -609,6 +611,9 @@ struct GenericPluginTy {

/// Internal allocator for different structures.
BumpPtrAllocator Allocator;

/// The JIT engine shared by all devices connected to this plugin.
JITEngine JIT;
};

/// Class for simplifying the getter operation of the plugin. Anywhere on the
Expand Down
8 changes: 5 additions & 3 deletions openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -784,8 +784,10 @@ struct CUDADeviceTy : public GenericDeviceTy {
return Plugin::check(Res, "Error in cuDeviceGetAttribute: %s");
}

/// See GenericDeviceTy::getArch().
std::string getArch() const override { return ComputeCapability.str(); }
/// See GenericDeviceTy::getComputeUnitKind().
std::string getComputeUnitKind() const override {
return ComputeCapability.str();
}

private:
using CUDAStreamManagerTy = GenericDeviceResourceManagerTy<CUDAStreamRef>;
Expand Down Expand Up @@ -867,7 +869,7 @@ class CUDAGlobalHandlerTy final : public GenericGlobalHandlerTy {
/// Class implementing the CUDA-specific functionalities of the plugin.
struct CUDAPluginTy final : public GenericPluginTy {
/// Create a CUDA plugin.
CUDAPluginTy() : GenericPluginTy() {}
CUDAPluginTy() : GenericPluginTy(getTripleArch()) {}

/// This class should not be copied.
CUDAPluginTy(const CUDAPluginTy &) = delete;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ class GenELF64GlobalHandlerTy final : public GenericGlobalHandlerTy {
/// Class implementing the plugin functionalities for GenELF64.
struct GenELF64PluginTy final : public GenericPluginTy {
/// Create the GenELF64 plugin.
GenELF64PluginTy() : GenericPluginTy() {}
GenELF64PluginTy() : GenericPluginTy(getTripleArch()) {}

/// This class should not be copied.
GenELF64PluginTy(const GenELF64PluginTy &) = delete;
Expand Down