Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions unified-runtime/source/adapters/level_zero/v2/queue_batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ ur_queue_batched_t::renewBatchUnlocked(locked<batch_manager> &batchLocked) {
if (batchLocked->isLimitOfUsedCommandListsReached()) {
return queueFinishUnlocked(batchLocked);
} else {
UR_CALL(batchLocked->enqueueCurrentBatchUnlocked());
return batchLocked->renewRegularUnlocked(getNewRegularCmdList());
}
}
Expand Down Expand Up @@ -157,7 +158,7 @@ ur_queue_batched_t::onEventWaitListUse(ur_event_generation_t batch_generation) {

auto batchLocked = currentCmdLists.lock();
if (batchLocked->isCurrentGeneration(batch_generation)) {
return queueFlushUnlocked(batchLocked);
return renewBatchUnlocked(batchLocked);
Comment thread
ldorau marked this conversation as resolved.
} else {
return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -228,7 +229,9 @@ ur_result_t batch_manager::batchFinish() {
UR_CALL(activeBatch.releaseSubmittedKernels());

if (!isActiveBatchEmpty()) {
// Should have been enqueued as part of queueFinishUnlocked
// The active batch was already submitted to the immediate command list
// by queueFinishUnlocked. Reset it here so it is ready to record new
// commands.
TRACK_SCOPE_LATENCY("ur_queue_batched_t::resetRegCmdlist");
ZE2UR_CALL(zeCommandListReset, (activeBatch.getZeCommandList()));

Expand Down Expand Up @@ -432,7 +435,7 @@ ur_result_t ur_queue_batched_t::enqueueUSMFreeExp(
createEventIfRequestedRegular(phEvent,
lockedBatch->getCurrentGeneration())));

return queueFlushUnlocked(lockedBatch);
return renewBatchUnlocked(lockedBatch);
}

ur_result_t ur_queue_batched_t::enqueueMemBufferMap(
Expand Down Expand Up @@ -634,7 +637,7 @@ ur_result_t ur_queue_batched_t::enqueueEventsWaitWithBarrier(
phEvent, lockedBatch->getCurrentGeneration())));
}

return queueFlushUnlocked(lockedBatch);
return renewBatchUnlocked(lockedBatch);
}

ur_result_t
Expand All @@ -652,7 +655,7 @@ ur_queue_batched_t::enqueueEventsWait(uint32_t numEventsInWaitList,
waitListView, createEventIfRequestedRegular(
phEvent, lockedBatch->getCurrentGeneration())));

return queueFlushUnlocked(lockedBatch);
return renewBatchUnlocked(lockedBatch);
}

ur_result_t ur_queue_batched_t::enqueueMemBufferCopy(
Expand Down Expand Up @@ -818,7 +821,7 @@ ur_result_t ur_queue_batched_t::enqueueUSMDeviceAllocExp(
lockedBatch->getCurrentGeneration()),
UR_USM_TYPE_DEVICE));

return queueFlushUnlocked(lockedBatch);
return renewBatchUnlocked(lockedBatch);
}

ur_result_t ur_queue_batched_t::enqueueUSMSharedAllocExp(
Expand All @@ -840,7 +843,7 @@ ur_result_t ur_queue_batched_t::enqueueUSMSharedAllocExp(
lockedBatch->getCurrentGeneration()),
UR_USM_TYPE_SHARED));

return queueFlushUnlocked(lockedBatch);
return renewBatchUnlocked(lockedBatch);
}

ur_result_t ur_queue_batched_t::enqueueUSMHostAllocExp(
Expand All @@ -861,7 +864,7 @@ ur_result_t ur_queue_batched_t::enqueueUSMHostAllocExp(
lockedBatch->getCurrentGeneration()),
UR_USM_TYPE_HOST));

return queueFlushUnlocked(lockedBatch);
return renewBatchUnlocked(lockedBatch);
}

ur_result_t ur_queue_batched_t::bindlessImagesImageCopyExp(
Expand Down Expand Up @@ -969,7 +972,7 @@ ur_result_t ur_queue_batched_t::enqueueCommandBufferExp(
// command buffer batch (also a regular list) to preserve the order of
// operations
if (!lockedBatch->isActiveBatchEmpty()) {
UR_CALL(queueFlushUnlocked(lockedBatch));
UR_CALL(renewBatchUnlocked(lockedBatch));
}

// Regular lists cannot be appended to other regular lists for execution, only
Expand Down Expand Up @@ -1075,20 +1078,13 @@ ur_queue_batched_t::queueGetNativeHandle(ur_queue_native_desc_t * /*pDesc*/,
return UR_RESULT_SUCCESS;
}

ur_result_t
ur_queue_batched_t::queueFlushUnlocked(locked<batch_manager> &batchLocked) {
UR_CALL(batchLocked->enqueueCurrentBatchUnlocked());

return renewBatchUnlocked(batchLocked);
}

ur_result_t ur_queue_batched_t::queueFlush() {
auto batchLocked = currentCmdLists.lock();

if (batchLocked->isActiveBatchEmpty()) {
return UR_RESULT_SUCCESS;
} else {
return queueFlushUnlocked(batchLocked);
return renewBatchUnlocked(batchLocked);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,6 @@ struct ur_queue_batched_t : ur_object, ur_queue_t_ {

ur_result_t queueFinishUnlocked(locked<batch_manager> &batchLocked);

ur_result_t queueFlushUnlocked(locked<batch_manager> &batchLocked);

ur_result_t markIssuedCommandInBatch(locked<batch_manager> &batchLocked);

public:
Expand Down
51 changes: 29 additions & 22 deletions unified-runtime/test/adapters/level_zero/enqueue_alloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,19 @@ std::ostream &operator<<(std::ostream &os,
}

struct urL0EnqueueAllocMultiQueueSameDeviceTest
: uur::urContextTestWithParam<EnqueueAllocMultiQueueTestParam> {
: uur::urContextTestWithParam<
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a fixture urMultiQueueTypeTestWithParam, which is defined as urContextTestWithParam<MultiQueueParam<T>> (which is exactly what is defined here).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@EuphoricThinking @pbalcer I think the proposed replacement is not viable, because urMultiQueueTypeTestWithParam<T> creates exactly one ur_queue_handle_t queue in its SetUp.
But urL0EnqueueAllocMultiQueueSameDeviceTest creates a variable-length vector of queues
(std::vector<ur_queue_handle_t> queues) sized by param.numQueues. If the test inherited from urMultiQueueTypeTestWithParam, the base SetUp would create a single unused queue handle in addition to the test's own queues vector, what IMO would be wasteful and misleading.

uur::MultiQueueParam<EnqueueAllocMultiQueueTestParam>> {
void SetUp() override {
UUR_RETURN_ON_FATAL_FAILURE(urContextTestWithParam::SetUp());
auto param = std::get<1>(this->GetParam());

const auto &param = getAllocParam();

ur_queue_properties_t props = {UR_STRUCTURE_TYPE_QUEUE_PROPERTIES, nullptr,
getQueueFlags()};
queues.reserve(param.numQueues);
for (size_t i = 0; i < param.numQueues; i++) {
ur_queue_handle_t queue = nullptr;
ASSERT_SUCCESS(urQueueCreate(context, device, 0, &queue));
SKIP_IF_BATCHED_QUEUE(queue);
ASSERT_SUCCESS(urQueueCreate(context, device, &props, &queue));
queues.push_back(queue);
}
}
Expand All @@ -95,6 +98,14 @@ struct urL0EnqueueAllocMultiQueueSameDeviceTest
UUR_RETURN_ON_FATAL_FAILURE(urContextTestWithParam::TearDown());
}

const EnqueueAllocMultiQueueTestParam &getAllocParam() const {
return std::get<0>(this->getParam());
}

ur_queue_flag_t getQueueFlags() const {
return std::get<1>(this->getParam());
}

std::vector<ur_queue_handle_t> queues;
};

Expand Down Expand Up @@ -322,7 +333,7 @@ TEST_P(urL0EnqueueAllocTest, SuccessWithKernelRepeat) {
ValidateEnqueueFree(ptr2);
}

UUR_DEVICE_TEST_SUITE_WITH_PARAM(
UUR_MULTI_QUEUE_TYPE_TEST_SUITE_WITH_PARAM(
urL0EnqueueAllocMultiQueueSameDeviceTest,
::testing::ValuesIn({
EnqueueAllocMultiQueueTestParam{1024, 256, 8, urEnqueueUSMHostAllocExp,
Expand All @@ -334,20 +345,16 @@ UUR_DEVICE_TEST_SUITE_WITH_PARAM(
urEnqueueUSMDeviceAllocExp,
uur::GetDeviceUSMDeviceSupport},
}),
uur::deviceTestWithParamPrinter<EnqueueAllocMultiQueueTestParam>);
uur::deviceTestWithParamPrinterMulti<EnqueueAllocMultiQueueTestParam>);

TEST_P(urL0EnqueueAllocMultiQueueSameDeviceTest, SuccessMt) {
const size_t allocSize = std::get<1>(this->GetParam()).allocSize;
const size_t numQueues = std::get<1>(this->GetParam()).numQueues;
const size_t iterations = std::get<1>(this->GetParam()).iterations;
const size_t allocSize = getAllocParam().allocSize;
const size_t numQueues = getAllocParam().numQueues;
const size_t iterations = getAllocParam().iterations;
const auto enqueueUSMAllocFunc =
std::get<1>(this->GetParam()).funcParams.enqueueUSMAllocFunc;
getAllocParam().funcParams.enqueueUSMAllocFunc;
const auto checkUSMSupportFunc =
std::get<1>(this->GetParam()).funcParams.checkUSMSupportFunc;

if (numQueues > 0) {
SKIP_IF_BATCHED_QUEUE(queues[0]);
}
getAllocParam().funcParams.checkUSMSupportFunc;

ur_device_usm_access_capability_flags_t USMSupport = 0;
ASSERT_SUCCESS(checkUSMSupportFunc(device, USMSupport));
Expand Down Expand Up @@ -394,11 +401,11 @@ TEST_P(urL0EnqueueAllocMultiQueueSameDeviceTest, SuccessMt) {
TEST_P(urL0EnqueueAllocMultiQueueSameDeviceTest, SuccessReuse) {
GTEST_SKIP() << "Multi queue reuse is not supported.";

const size_t allocSize = std::get<1>(this->GetParam()).allocSize;
const size_t allocSize = getAllocParam().allocSize;
const auto enqueueUSMAllocFunc =
std::get<1>(this->GetParam()).funcParams.enqueueUSMAllocFunc;
getAllocParam().funcParams.enqueueUSMAllocFunc;
const auto checkUSMSupportFunc =
std::get<1>(this->GetParam()).funcParams.checkUSMSupportFunc;
getAllocParam().funcParams.checkUSMSupportFunc;

ur_device_usm_access_capability_flags_t USMSupport = 0;
ASSERT_SUCCESS(checkUSMSupportFunc(device, USMSupport));
Expand Down Expand Up @@ -457,12 +464,12 @@ TEST_P(urL0EnqueueAllocMultiQueueSameDeviceTest, SuccessReuse) {
}

TEST_P(urL0EnqueueAllocMultiQueueSameDeviceTest, SuccessDependantMt) {
const size_t allocSize = std::get<1>(this->GetParam()).allocSize;
const size_t iterations = std::get<1>(this->GetParam()).iterations;
const size_t allocSize = getAllocParam().allocSize;
const size_t iterations = getAllocParam().iterations;
const auto enqueueUSMAllocFunc =
std::get<1>(this->GetParam()).funcParams.enqueueUSMAllocFunc;
getAllocParam().funcParams.enqueueUSMAllocFunc;
const auto checkUSMSupportFunc =
std::get<1>(this->GetParam()).funcParams.checkUSMSupportFunc;
getAllocParam().funcParams.checkUSMSupportFunc;

ur_device_usm_access_capability_flags_t USMSupport = 0;
ASSERT_SUCCESS(checkUSMSupportFunc(device, USMSupport));
Expand Down
Loading