diff --git a/unified-runtime/source/adapters/level_zero/v2/context.cpp b/unified-runtime/source/adapters/level_zero/v2/context.cpp index b96e05e9d315a..3d2a7758d6be4 100644 --- a/unified-runtime/source/adapters/level_zero/v2/context.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/context.cpp @@ -35,11 +35,28 @@ filterP2PDevices(ur_device_handle_t hSourceDevice, } static std::vector> -populateP2PDevices(size_t maxDevices, - const std::vector &devices) { - std::vector> p2pDevices(maxDevices); +populateP2PDevices(const std::vector &devices) { + std::vector allDevices; + std::function collectDeviceAndSubdevices = + [&allDevices, &collectDeviceAndSubdevices](ur_device_handle_t device) { + allDevices.push_back(device); + for (auto &subDevice : device->SubDevices) { + collectDeviceAndSubdevices(subDevice); + } + }; + for (auto &device : devices) { - p2pDevices[device->Id.value()] = filterP2PDevices(device, devices); + collectDeviceAndSubdevices(device); + } + + uint64_t maxDeviceId = 0; + for (auto &device : allDevices) { + maxDeviceId = std::max(maxDeviceId, device->Id.value()); + } + + std::vector> p2pDevices(maxDeviceId + 1); + for (auto &device : allDevices) { + p2pDevices[device->Id.value()] = filterP2PDevices(device, allDevices); } return p2pDevices; } @@ -83,8 +100,7 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext, nativeEventsPool(this, std::make_unique( this, v2::QUEUE_IMMEDIATE, v2::EVENT_FLAGS_PROFILING_ENABLED)), - p2pAccessDevices(populateP2PDevices( - phDevices[0]->Platform->getNumDevices(), this->hDevices)), + p2pAccessDevices(populateP2PDevices(this->hDevices)), defaultUSMPool(this, nullptr), asyncPool(this, nullptr) {} ur_result_t ur_context_handle_t_::retain() {