Skip to content

Commit

Permalink
[Libomptarget] Configure the RPC port count from the plugin
Browse files Browse the repository at this point in the history
This patch allows us to configure the port count to what the specific
card would desire for parallelism. For AMDGPU we need to use the maximum
number of hardware parallelism to avoid deadlocks. For NVPTX we don't
have this problem due to the friendlier scheduler, so we use the number
of warps active on an SM times the number of SMs as a good guess.

Note that the max ports currently is going to be smaller than these
numbers. That will be improved in the future.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D155903
  • Loading branch information
jhuber6 committed Aug 11, 2023
1 parent be237b7 commit 06adac8
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 3 deletions.
15 changes: 15 additions & 0 deletions openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1785,6 +1785,12 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
return Err;
GridValues.GV_Default_Num_Teams = ComputeUnits * OMPX_DefaultTeamsPerCU;

uint32_t WavesPerCU = 0;
if (auto Err =
getDeviceAttr(HSA_AMD_AGENT_INFO_MAX_WAVES_PER_CU, WavesPerCU))
return Err;
HardwareParallelism = ComputeUnits * WavesPerCU;

// Get maximum size of any device queues and maximum number of queues.
uint32_t MaxQueueSize;
if (auto Err = getDeviceAttr(HSA_AGENT_INFO_QUEUE_MAX_SIZE, MaxQueueSize))
Expand Down Expand Up @@ -1932,6 +1938,12 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
return libomptargetSupportsRPC();
}

/// AMDGPU returns the product of the number of compute units and the waves
/// per compute unit.
uint64_t requestedRPCPortCount() const override {
return HardwareParallelism;
}

/// Get the stream of the asynchronous info sructure or get a new one.
Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper,
AMDGPUStreamTy *&Stream) {
Expand Down Expand Up @@ -2577,6 +2589,9 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
/// The frequency of the steady clock inside the device.
uint64_t ClockFrequency;

/// The total number of concurrent work items that can be running on the GPU.
uint64_t HardwareParallelism;

/// Reference to the host device.
AMDHostDeviceTy &HostDevice;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,19 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
/// Get the RPC server running on this device.
RPCServerTy *getRPCServer() const { return RPCServer; }

/// The number of parallel RPC ports to use on the device. In general, this
/// should be roughly equivalent to the amount of hardware parallelism the
/// device can support. This is because GPUs in general do not have forward
/// progress guarantees, so we minimize thread level dependencies by
/// allocating enough space such that each device thread can have a port. This
/// is likely overly pessimistic in the average case, but guarantees no
/// deadlocks at the cost of memory. This must be overloaded by targets
/// expecting to use the RPC server.
virtual uint64_t requestedRPCPortCount() const {
assert(!shouldSetupRPCServer() && "Default implementation cannot be used");
return 0;
}

private:
/// Register offload entry for global variable.
Error registerGlobalOffloadEntry(DeviceImageTy &DeviceImage,
Expand Down Expand Up @@ -888,7 +901,6 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
#endif

private:

/// Return the kernel environment object for kernel \p Name.
Expected<KernelEnvironmentTy>
getKernelEnvironmentForKernel(StringRef Name, DeviceImageTy &Image);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
*reinterpret_cast<plugin::GenericDeviceTy *>(Data);
return Device.allocate(Size, nullptr, TARGET_ALLOC_HOST);
};
// TODO: Allow the device to declare its requested port count.
if (rpc_status_t Err = rpc_server_init(DeviceId, RPC_MAXIMUM_PORT_COUNT,
uint64_t NumPorts =
std::min(Device.requestedRPCPortCount(), RPC_MAXIMUM_PORT_COUNT);
if (rpc_status_t Err = rpc_server_init(DeviceId, NumPorts,
Device.getWarpSize(), Alloc, &Device))
return plugin::Plugin::error(
"Failed to initialize RPC server for device %d: %d", DeviceId, Err);
Expand Down
23 changes: 23 additions & 0 deletions openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,19 @@ struct CUDADeviceTy : public GenericDeviceTy {
ComputeCapability.Minor))
return Err;

uint32_t NumMuliprocessors = 0;
uint32_t MaxThreadsPerSM = 0;
uint32_t WarpSize = 0;
if (auto Err = getDeviceAttr(CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
NumMuliprocessors))
return Err;
if (auto Err = getDeviceAttr(CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR,
MaxThreadsPerSM))
return Err;
if (auto Err = getDeviceAttr(CU_DEVICE_ATTRIBUTE_WARP_SIZE, WarpSize))
return Err;
HardwareParallelism = NumMuliprocessors * (MaxThreadsPerSM / WarpSize);

return Plugin::success();
}

Expand Down Expand Up @@ -366,6 +379,12 @@ struct CUDADeviceTy : public GenericDeviceTy {
return libomptargetSupportsRPC();
}

/// NVIDIA returns the product of the SM count and the number of warps that
/// fit if the maximum number of threads were scheduled on each SM.
uint64_t requestedRPCPortCount() const override {
return HardwareParallelism;
}

/// Get the stream of the asynchronous info sructure or get a new one.
Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper, CUstream &Stream) {
// Get the stream (if any) from the async info.
Expand Down Expand Up @@ -876,6 +895,10 @@ struct CUDADeviceTy : public GenericDeviceTy {
return "sm_" + std::to_string(Major * 10 + Minor);
}
} ComputeCapability;

/// The maximum number of warps that can be resident on all the SMs
/// simultaneously.
uint32_t HardwareParallelism = 0;
};

Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
Expand Down

0 comments on commit 06adac8

Please sign in to comment.