Skip to content

Commit

Permalink
[OpenMP] [OMPT] [7/8] Invoke tool-supplied callbacks before and after…
Browse files Browse the repository at this point in the history
… target launch and data transfer operations

Implemented RAII objects, initialized at target entry points, that
invoke tool-supplied callbacks. Updated status of target callbacks as
implemented.

Depends on D127365

Patch from John Mellor-Crummey <johnmc@rice.edu>
With contributions from:
Dhruva Chakrabarti <Dhruva.Chakrabarti@amd.com>
Jan-Patrick Lehr <janpatrick.lehr@amd.com>

Reviewed By: jdoerfert, dhruvachak, jplehr

Differential Revision: https://reviews.llvm.org/D127367
  • Loading branch information
mhalk committed Jul 25, 2023
1 parent 3898107 commit 1dec417
Show file tree
Hide file tree
Showing 14 changed files with 473 additions and 88 deletions.
9 changes: 9 additions & 0 deletions openmp/libomptarget/include/OmptCallback.h
Expand Up @@ -27,6 +27,13 @@
FOREACH_OMPT_NOEMI_EVENT(macro) \
FOREACH_OMPT_EMI_EVENT(macro)

#define performIfOmptInitialized(stmt) \
do { \
if (llvm::omp::target::ompt::Initialized) { \
stmt; \
} \
} while (0)

#define performOmptCallback(CallbackName, ...) \
do { \
if (ompt_callback_##CallbackName##_fn) \
Expand Down Expand Up @@ -89,6 +96,8 @@ extern bool Initialized;
} // namespace omp
} // namespace llvm

#else
#define performIfOmptInitialized(stmt)
#endif // OMPT_SUPPORT

#pragma pop_macro("DEBUG_PREFIX")
Expand Down
99 changes: 54 additions & 45 deletions openmp/libomptarget/src/OmptCallback.cpp
Expand Up @@ -35,12 +35,20 @@ FOREACH_OMPT_NOEMI_EVENT(defineOmptCallback)
FOREACH_OMPT_EMI_EVENT(defineOmptCallback)
#undef defineOmptCallback

/// Thread local state for target region and associated metadata
thread_local llvm::omp::target::ompt::Interface OmptInterface;
/// Forward declaration
class LibomptargetRtlFinalizer;

/// Define function pointers
ompt_get_task_data_t ompt_get_task_data_fn = nullptr;
/// Object that will maintain the RTL finalizer from the plugin
LibomptargetRtlFinalizer *LibraryFinalizer = nullptr;

thread_local Interface llvm::omp::target::ompt::RegionInterface;

bool llvm::omp::target::ompt::Initialized = false;

ompt_get_callback_t llvm::omp::target::ompt::lookupCallbackByCode = nullptr;
ompt_function_lookup_t llvm::omp::target::ompt::lookupCallbackByName = nullptr;
ompt_get_target_task_data_t ompt_get_target_task_data_fn = nullptr;
ompt_get_task_data_t ompt_get_task_data_fn = nullptr;

/// Unique correlation id
static std::atomic<uint64_t> IdCounter(1);
Expand All @@ -51,14 +59,14 @@ static uint64_t createId() { return IdCounter.fetch_add(1); }
/// Create a new correlation id and update the operations id
static uint64_t createOpId() {
uint64_t NewId = createId();
OmptInterface.setHostOpId(NewId);
RegionInterface.setHostOpId(NewId);
return NewId;
}

/// Create a new correlation id and update the target region id
static uint64_t createRegionId() {
uint64_t NewId = createId();
OmptInterface.setTargetDataValue(NewId);
RegionInterface.setTargetDataValue(NewId);
return NewId;
}

Expand All @@ -68,18 +76,19 @@ void Interface::beginTargetDataAlloc(int64_t DeviceId, void *HstPtrBegin,
if (ompt_callback_target_data_op_emi_fn) {
// HostOpId will be set by the tool. Invoke the tool supplied data op EMI
// callback
ompt_callback_target_data_op_emi_fn(ompt_scope_begin, TargetTaskData,
&TargetData, &TargetRegionOpId,
ompt_target_data_alloc, HstPtrBegin,
DeviceId, /* TgtPtrBegin */ nullptr,
/* TgtDeviceNum */ 0, Size, Code);
ompt_callback_target_data_op_emi_fn(
ompt_scope_begin, TargetTaskData, &TargetData, &TargetRegionOpId,
ompt_target_data_alloc, HstPtrBegin,
/* SrcDeviceNum */ omp_get_initial_device(), /* TgtPtrBegin */ nullptr,
/* TgtDeviceNum */ DeviceId, Size, Code);
} else if (ompt_callback_target_data_op_fn) {
// HostOpId is set by the runtime
HostOpId = createOpId();
// Invoke the tool supplied data op callback
ompt_callback_target_data_op_fn(
TargetData.value, HostOpId, ompt_target_data_alloc, HstPtrBegin,
DeviceId, /* TgtPtrBegin */ nullptr, /* TgtDeviceNum */ 0, Size, Code);
/* SrcDeviceNum */ omp_get_initial_device(), /* TgtPtrBegin */ nullptr,
/* TgtDeviceNum */ DeviceId, Size, Code);
}
}

Expand All @@ -89,11 +98,11 @@ void Interface::endTargetDataAlloc(int64_t DeviceId, void *HstPtrBegin,
if (ompt_callback_target_data_op_emi_fn) {
// HostOpId will be set by the tool. Invoke the tool supplied data op EMI
// callback
ompt_callback_target_data_op_emi_fn(ompt_scope_end, TargetTaskData,
&TargetData, &TargetRegionOpId,
ompt_target_data_alloc, HstPtrBegin,
DeviceId, /* TgtPtrBegin */ nullptr,
/* TgtDeviceNum */ 0, Size, Code);
ompt_callback_target_data_op_emi_fn(
ompt_scope_end, TargetTaskData, &TargetData, &TargetRegionOpId,
ompt_target_data_alloc, HstPtrBegin,
/* SrcDeviceNum */ omp_get_initial_device(), /* TgtPtrBegin */ nullptr,
/* TgtDeviceNum */ DeviceId, Size, Code);
}
endTargetDataOperation();
}
Expand All @@ -108,14 +117,16 @@ void Interface::beginTargetDataSubmit(int64_t DeviceId, void *TgtPtrBegin,
ompt_callback_target_data_op_emi_fn(
ompt_scope_begin, TargetTaskData, &TargetData, &TargetRegionOpId,
ompt_target_data_transfer_to_device, HstPtrBegin,
/* SrcDeviceNum */ 0, TgtPtrBegin, DeviceId, Size, Code);
/* SrcDeviceNum */ omp_get_initial_device(), TgtPtrBegin, DeviceId,
Size, Code);
} else if (ompt_callback_target_data_op_fn) {
// HostOpId is set by the runtime
HostOpId = createOpId();
// Invoke the tool supplied data op callback
ompt_callback_target_data_op_fn(
TargetData.value, HostOpId, ompt_target_data_transfer_to_device,
HstPtrBegin, /* SrcDeviceNum */ 0, TgtPtrBegin, DeviceId, Size, Code);
HstPtrBegin, /* SrcDeviceNum */ omp_get_initial_device(), TgtPtrBegin,
DeviceId, Size, Code);
}
}

Expand All @@ -129,7 +140,8 @@ void Interface::endTargetDataSubmit(int64_t DeviceId, void *TgtPtrBegin,
ompt_callback_target_data_op_emi_fn(
ompt_scope_end, TargetTaskData, &TargetData, &TargetRegionOpId,
ompt_target_data_transfer_to_device, HstPtrBegin,
/* SrcDeviceNum */ 0, TgtPtrBegin, DeviceId, Size, Code);
/* SrcDeviceNum */ omp_get_initial_device(), TgtPtrBegin, DeviceId,
Size, Code);
}
endTargetDataOperation();
}
Expand All @@ -143,15 +155,15 @@ void Interface::beginTargetDataDelete(int64_t DeviceId, void *TgtPtrBegin,
ompt_callback_target_data_op_emi_fn(
ompt_scope_begin, TargetTaskData, &TargetData, &TargetRegionOpId,
ompt_target_data_delete, TgtPtrBegin, DeviceId,
/* TgtPtrBegin */ nullptr, /* TgtDeviceNum */ 0, /* Bytes */ 0, Code);
/* TgtPtrBegin */ nullptr, /* TgtDeviceNum */ -1, /* Bytes */ 0, Code);
} else if (ompt_callback_target_data_op_fn) {
// HostOpId is set by the runtime
HostOpId = createOpId();
// Invoke the tool supplied data op callback
ompt_callback_target_data_op_fn(TargetData.value, HostOpId,
ompt_target_data_delete, TgtPtrBegin,
DeviceId, /* TgtPtrBegin */ nullptr,
/* TgtDeviceNum */ 0, /* Bytes */ 0, Code);
/* TgtDeviceNum */ -1, /* Bytes */ 0, Code);
}
}

Expand All @@ -164,7 +176,7 @@ void Interface::endTargetDataDelete(int64_t DeviceId, void *TgtPtrBegin,
ompt_callback_target_data_op_emi_fn(
ompt_scope_end, TargetTaskData, &TargetData, &TargetRegionOpId,
ompt_target_data_delete, TgtPtrBegin, DeviceId,
/* TgtPtrBegin */ nullptr, /* TgtDeviceNum */ 0, /* Bytes */ 0, Code);
/* TgtPtrBegin */ nullptr, /* TgtDeviceNum */ -1, /* Bytes */ 0, Code);
}
endTargetDataOperation();
}
Expand All @@ -176,19 +188,19 @@ void Interface::beginTargetDataRetrieve(int64_t DeviceId, void *HstPtrBegin,
if (ompt_callback_target_data_op_emi_fn) {
// HostOpId will be set by the tool. Invoke the tool supplied data op EMI
// callback
ompt_callback_target_data_op_emi_fn(ompt_scope_begin, TargetTaskData,
&TargetData, &TargetRegionOpId,
ompt_target_data_transfer_from_device,
TgtPtrBegin, DeviceId, HstPtrBegin,
/* TgtDeviceNum */ 0, Size, Code);
ompt_callback_target_data_op_emi_fn(
ompt_scope_begin, TargetTaskData, &TargetData, &TargetRegionOpId,
ompt_target_data_transfer_from_device, TgtPtrBegin, DeviceId,
HstPtrBegin,
/* TgtDeviceNum */ omp_get_initial_device(), Size, Code);
} else if (ompt_callback_target_data_op_fn) {
// HostOpId is set by the runtime
HostOpId = createOpId();
// Invoke the tool supplied data op callback
ompt_callback_target_data_op_fn(TargetData.value, HostOpId,
ompt_target_data_transfer_from_device,
TgtPtrBegin, DeviceId, HstPtrBegin,
/* TgtDeviceNum */ 0, Size, Code);
ompt_callback_target_data_op_fn(
TargetData.value, HostOpId, ompt_target_data_transfer_from_device,
TgtPtrBegin, DeviceId, HstPtrBegin,
/* TgtDeviceNum */ omp_get_initial_device(), Size, Code);
}
}

Expand All @@ -199,11 +211,11 @@ void Interface::endTargetDataRetrieve(int64_t DeviceId, void *HstPtrBegin,
if (ompt_callback_target_data_op_emi_fn) {
// HostOpId will be set by the tool. Invoke the tool supplied data op EMI
// callback
ompt_callback_target_data_op_emi_fn(ompt_scope_end, TargetTaskData,
&TargetData, &TargetRegionOpId,
ompt_target_data_transfer_from_device,
TgtPtrBegin, DeviceId, HstPtrBegin,
/* TgtDeviceNum */ 0, Size, Code);
ompt_callback_target_data_op_emi_fn(
ompt_scope_end, TargetTaskData, &TargetData, &TargetRegionOpId,
ompt_target_data_transfer_from_device, TgtPtrBegin, DeviceId,
HstPtrBegin,
/* TgtDeviceNum */ omp_get_initial_device(), Size, Code);
}
endTargetDataOperation();
}
Expand All @@ -230,6 +242,7 @@ void Interface::endTargetSubmit(unsigned int numTeams) {
numTeams);
}
}

void Interface::beginTargetDataEnter(int64_t DeviceId, void *Code) {
beginTargetRegion();
if (ompt_callback_target_emi_fn) {
Expand Down Expand Up @@ -391,14 +404,6 @@ class LibomptargetRtlFinalizer {
llvm::SmallVector<ompt_finalize_t> RtlFinalizationFunctions;
};

/// Object that will maintain the RTL finalizer from the plugin
LibomptargetRtlFinalizer *LibraryFinalizer = nullptr;

bool llvm::omp::target::ompt::Initialized = false;

ompt_get_callback_t llvm::omp::target::ompt::lookupCallbackByCode = nullptr;
ompt_function_lookup_t llvm::omp::target::ompt::lookupCallbackByName = nullptr;

int llvm::omp::target::ompt::initializeLibrary(ompt_function_lookup_t lookup,
int initial_device_num,
ompt_data_t *tool_data) {
Expand All @@ -418,6 +423,9 @@ int llvm::omp::target::ompt::initializeLibrary(ompt_function_lookup_t lookup,

assert(lookupCallbackByCode && "lookupCallbackByCode should be non-null");
assert(lookupCallbackByName && "lookupCallbackByName should be non-null");
assert(ompt_get_task_data_fn && "ompt_get_task_data_fn should be non-null");
assert(ompt_get_target_task_data_fn &&
"ompt_get_target_task_data_fn should be non-null");
assert(LibraryFinalizer == nullptr &&
"LibraryFinalizer should not be initialized yet");

Expand All @@ -434,6 +442,7 @@ void llvm::omp::target::ompt::finalizeLibrary(ompt_data_t *data) {
// with this library
LibraryFinalizer->finalize();
delete LibraryFinalizer;
Initialized = false;
}

void llvm::omp::target::ompt::connectLibrary() {
Expand Down
120 changes: 113 additions & 7 deletions openmp/libomptarget/src/OmptInterface.h
Expand Up @@ -13,14 +13,18 @@
#ifndef _OMPTARGET_OMPTINTERFACE_H
#define _OMPTARGET_OMPTINTERFACE_H

// Only provide functionality if target OMPT support is enabled
#ifdef OMPT_SUPPORT
#include <functional>
#include <tuple>

#include "OmptCallback.h"
#include "omp-tools.h"

// If target OMPT support is compiled in
#ifdef OMPT_SUPPORT
#include "llvm/Support/ErrorHandling.h"

#define OMPT_IF_BUILT(stmt) stmt
#else
#define OMPT_IF_BUILT(stmt)
#endif
#define OMPT_GET_RETURN_ADDRESS(level) __builtin_return_address(level)

/// Callbacks for target regions require task_data representing the
/// encountering task.
Expand Down Expand Up @@ -108,6 +112,66 @@ class Interface {
/// Top-level function for invoking callback after target construct
void endTarget(int64_t DeviceId, void *Code);

// Callback getter: Target data operations
template <ompt_target_data_op_t OpType> auto getCallbacks() {
if constexpr (OpType == ompt_target_data_alloc ||
OpType == ompt_target_data_alloc_async)
return std::make_pair(std::mem_fn(&Interface::beginTargetDataAlloc),
std::mem_fn(&Interface::endTargetDataAlloc));

if constexpr (OpType == ompt_target_data_delete ||
OpType == ompt_target_data_delete_async)
return std::make_pair(std::mem_fn(&Interface::beginTargetDataDelete),
std::mem_fn(&Interface::endTargetDataDelete));

if constexpr (OpType == ompt_target_data_transfer_to_device ||
OpType == ompt_target_data_transfer_to_device_async)
return std::make_pair(std::mem_fn(&Interface::beginTargetDataSubmit),
std::mem_fn(&Interface::endTargetDataSubmit));

if constexpr (OpType == ompt_target_data_transfer_from_device ||
OpType == ompt_target_data_transfer_from_device_async)
return std::make_pair(std::mem_fn(&Interface::beginTargetDataRetrieve),
std::mem_fn(&Interface::endTargetDataRetrieve));

llvm_unreachable("Unhandled target data operation type!");
}

// Callback getter: Target region operations
template <ompt_target_t OpType> auto getCallbacks() {
if constexpr (OpType == ompt_target_enter_data ||
OpType == ompt_target_enter_data_nowait)
return std::make_pair(std::mem_fn(&Interface::beginTargetDataEnter),
std::mem_fn(&Interface::endTargetDataEnter));

if constexpr (OpType == ompt_target_exit_data ||
OpType == ompt_target_exit_data_nowait)
return std::make_pair(std::mem_fn(&Interface::beginTargetDataExit),
std::mem_fn(&Interface::endTargetDataExit));

if constexpr (OpType == ompt_target_update ||
OpType == ompt_target_update_nowait)
return std::make_pair(std::mem_fn(&Interface::beginTargetUpdate),
std::mem_fn(&Interface::endTargetUpdate));

if constexpr (OpType == ompt_target || OpType == ompt_target_nowait)
return std::make_pair(std::mem_fn(&Interface::beginTarget),
std::mem_fn(&Interface::endTarget));

llvm_unreachable("Unknown target region operation type!");
}

// Callback getter: Kernel launch operation
template <ompt_callbacks_t OpType> auto getCallbacks() {
// We use 'ompt_callbacks_t', because no other enum is currently available
// to model a kernel launch / target submit operation.
if constexpr (OpType == ompt_callback_target_submit)
return std::make_pair(std::mem_fn(&Interface::beginTargetSubmit),
std::mem_fn(&Interface::endTargetSubmit));

llvm_unreachable("Unhandled target operation!");
}

/// Setters for target region and target operation correlation ids
void setTargetDataValue(uint64_t DataValue) { TargetData.value = DataValue; }
void setTargetDataPtr(void *DataPtr) { TargetData.ptr = DataPtr; }
Expand Down Expand Up @@ -147,11 +211,53 @@ class Interface {
void endTargetRegion();
};

/// Thread local state for target region and associated metadata
extern thread_local Interface RegionInterface;

template <typename FuncTy, typename ArgsTy, size_t... IndexSeq>
void InvokeInterfaceFunction(FuncTy Func, ArgsTy Args,
std::index_sequence<IndexSeq...>) {
std::invoke(Func, RegionInterface, std::get<IndexSeq>(Args)...);
}

template <typename CallbackPairTy, typename... ArgsTy> class InterfaceRAII {
public:
InterfaceRAII(CallbackPairTy Callbacks, ArgsTy... Args)
: Arguments(Args...), beginFunction(std::get<0>(Callbacks)),
endFunction(std::get<1>(Callbacks)) {
performIfOmptInitialized(begin());
}
~InterfaceRAII() { performIfOmptInitialized(end()); }

private:
void begin() {
auto IndexSequence =
std::make_index_sequence<std::tuple_size_v<decltype(Arguments)>>{};
InvokeInterfaceFunction(beginFunction, Arguments, IndexSequence);
}

void end() {
auto IndexSequence =
std::make_index_sequence<std::tuple_size_v<decltype(Arguments)>>{};
InvokeInterfaceFunction(endFunction, Arguments, IndexSequence);
}

std::tuple<ArgsTy...> Arguments;
typename CallbackPairTy::first_type beginFunction;
typename CallbackPairTy::second_type endFunction;
};

// InterfaceRAII's class template argument deduction guide
template <typename CallbackPairTy, typename... ArgsTy>
InterfaceRAII(CallbackPairTy Callbacks, ArgsTy... Args)
-> InterfaceRAII<CallbackPairTy, ArgsTy...>;

} // namespace ompt
} // namespace target
} // namespace omp
} // namespace llvm

extern thread_local llvm::omp::target::ompt::Interface OmptInterface;
#else
#define OMPT_IF_BUILT(stmt)
#endif

#endif // _OMPTARGET_OMPTINTERFACE_H

0 comments on commit 1dec417

Please sign in to comment.