Skip to content
Merged
6 changes: 5 additions & 1 deletion sycl/source/detail/scheduler/commands.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,18 @@ class Command {
/// \param Blocking if this argument is true, function will wait for the
/// command to be unblocked before calling enqueueImp.
/// \return true if the command is enqueued.
bool enqueue(EnqueueResultT &EnqueueResult, BlockingT Blocking);
virtual bool enqueue(EnqueueResultT &EnqueueResult, BlockingT Blocking);

bool isFinished();

bool isSuccessfullyEnqueued() const {
return MEnqueueStatus == EnqueueResultT::SyclEnqueueSuccess;
}

bool isEnqueueBlocked() const {
return MEnqueueStatus == EnqueueResultT::SyclEnqueueBlocked;
}

std::shared_ptr<queue_impl> getQueue() const { return MQueue; }

std::shared_ptr<event_impl> getEvent() const { return MEvent; }
Expand Down
32 changes: 9 additions & 23 deletions sycl/source/detail/scheduler/graph_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,33 +58,19 @@ bool Scheduler::GraphProcessor::enqueueCommand(Command *Cmd,
if (!Cmd || Cmd->isSuccessfullyEnqueued())
return true;

// Indicates whether dependency cannot be enqueued
bool BlockedByDep = false;
// Exit early if the command is blocked and the enqueue type is non-blocking
if (Cmd->isEnqueueBlocked() && !Blocking) {
EnqueueResult = EnqueueResultT(EnqueueResultT::SyclEnqueueBlocked, Cmd);
return false;
}

// Recursively enqueue all the dependencies first and
// exit immediately if any of the commands cannot be enqueued.
for (DepDesc &Dep : Cmd->MDeps) {
const bool Enqueued =
enqueueCommand(Dep.MDepCommand, EnqueueResult, Blocking);
if (!Enqueued)
switch (EnqueueResult.MResult) {
case EnqueueResultT::SyclEnqueueFailed:
default:
// Exit immediately if a command fails to avoid enqueueing commands
// result of which will be discarded.
return false;
case EnqueueResultT::SyclEnqueueBlocked:
// If some dependency is blocked from enqueueing remember that, but
// try to enqueue other dependencies(that can be ready for
// enqueueing).
BlockedByDep = true;
break;
}
if (!enqueueCommand(Dep.MDepCommand, EnqueueResult, Blocking))
return false;
}

// Exit if some command is blocked from enqueueing, the EnqueueResult is set
// by the latest dependency which was blocked.
if (BlockedByDep)
return false;

return Cmd->enqueue(EnqueueResult, Blocking);
}

Expand Down
85 changes: 85 additions & 0 deletions sycl/unittests/scheduler/BlockedCommands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "SchedulerTestUtils.hpp"

using namespace cl::sycl;
using namespace testing;

TEST_F(SchedulerTest, BlockedCommands) {
MockCommand MockCmd(detail::getSyclObjImpl(MQueue));
Expand Down Expand Up @@ -45,3 +46,87 @@ TEST_F(SchedulerTest, BlockedCommands) {
Res.MResult == detail::EnqueueResultT::SyclEnqueueSuccess)
<< "The command is expected to be successfully enqueued.\n";
}

TEST_F(SchedulerTest, DontEnqueueDepsIfOneOfThemIsBlocked) {
MockCommand A(detail::getSyclObjImpl(MQueue));
A.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueReady;
A.MIsBlockable = true;
A.MRetVal = CL_SUCCESS;

MockCommand B(detail::getSyclObjImpl(MQueue));
B.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueReady;
B.MIsBlockable = true;
B.MRetVal = CL_SUCCESS;

MockCommand C(detail::getSyclObjImpl(MQueue));
C.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueBlocked;
C.MIsBlockable = true;

MockCommand D(detail::getSyclObjImpl(MQueue));
D.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueReady;
D.MIsBlockable = true;
D.MRetVal = CL_SUCCESS;

addEdge(&A, &B, nullptr);
addEdge(&A, &C, nullptr);
addEdge(&A, &D, nullptr);

// We have such a graph:
//
// A
// / | \
// B C D
//
// If C is blocked, we should not try to enqueue D.
Comment on lines +74 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand it correctly that A depends on B, C and D? Why can't we enqueue D then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed it with @romanovvlad and AFAIU, in the current design there are no benefits of including D, as it is likely to be blocked as well.


EXPECT_CALL(A, enqueue(_, _)).Times(0);
EXPECT_CALL(B, enqueue(_, _)).Times(1);
EXPECT_CALL(C, enqueue(_, _)).Times(0);
EXPECT_CALL(D, enqueue(_, _)).Times(0);

detail::EnqueueResultT Res;
bool Enqueued = MockScheduler::enqueueCommand(&A, Res, detail::NON_BLOCKING);
ASSERT_FALSE(Enqueued) << "Blocked command should not be enqueued\n";
ASSERT_EQ(detail::EnqueueResultT::SyclEnqueueBlocked, Res.MResult)
<< "Result of enqueueing blocked command should be BLOCKED.\n";
ASSERT_EQ(&C, Res.MCmd) << "Expected different failed command.\n";
}

TEST_F(SchedulerTest, EnqueueBlockedCommandEarlyExit) {
MockCommand A(detail::getSyclObjImpl(MQueue));
A.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueBlocked;
A.MIsBlockable = true;

MockCommand B(detail::getSyclObjImpl(MQueue));
B.MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueReady;
B.MRetVal = CL_OUT_OF_RESOURCES;

addEdge(&A, &B, nullptr);

// We have such a graph:
//
// A -> B
//
// If A is blocked, we should not try to enqueue B.

EXPECT_CALL(A, enqueue(_, _)).Times(0);
EXPECT_CALL(B, enqueue(_, _)).Times(0);

detail::EnqueueResultT Res;
bool Enqueued = MockScheduler::enqueueCommand(&A, Res, detail::NON_BLOCKING);
ASSERT_FALSE(Enqueued) << "Blocked command should not be enqueued\n";
ASSERT_EQ(detail::EnqueueResultT::SyclEnqueueBlocked, Res.MResult)
<< "Result of enqueueing blocked command should be BLOCKED.\n";
ASSERT_EQ(&A, Res.MCmd) << "Expected different failed command.\n";

// But if the enqueue type is blocking we should not exit early.

EXPECT_CALL(A, enqueue(_, _)).Times(0);
EXPECT_CALL(B, enqueue(_, _)).Times(1);

Enqueued = MockScheduler::enqueueCommand(&A, Res, detail::BLOCKING);
ASSERT_FALSE(Enqueued) << "Blocked command should not be enqueued\n";
ASSERT_EQ(detail::EnqueueResultT::SyclEnqueueFailed, Res.MResult)
<< "Result of enqueueing blocked command should be BLOCKED.\n";
ASSERT_EQ(&B, Res.MCmd) << "Expected different failed command.\n";
}
31 changes: 17 additions & 14 deletions sycl/unittests/scheduler/LeafLimit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,42 +21,45 @@ using namespace cl::sycl;
// overflowed.
TEST_F(SchedulerTest, LeafLimit) {
MockScheduler MS;
std::vector<std::unique_ptr<MockCommand>> LeavesToAdd;
std::unique_ptr<MockCommand> MockDepCmd;

buffer<int, 1> Buf(range<1>(1));
detail::Requirement MockReq = getMockRequirement(Buf);
MockCommand *MockDepCmd =
new MockCommand(detail::getSyclObjImpl(MQueue), MockReq);

MockDepCmd =
std::make_unique<MockCommand>(detail::getSyclObjImpl(MQueue), MockReq);
detail::MemObjRecord *Rec =
MS.getOrInsertMemObjRecord(detail::getSyclObjImpl(MQueue), &MockReq);

// Create commands that will be added as leaves exceeding the limit by 1
std::vector<MockCommand *> LeavesToAdd;
for (std::size_t i = 0; i < Rec->MWriteLeaves.genericCommandsCapacity() + 1;
++i) {
LeavesToAdd.push_back(
new MockCommand(detail::getSyclObjImpl(MQueue), MockReq));
std::make_unique<MockCommand>(detail::getSyclObjImpl(MQueue), MockReq));
}
// Create edges: all soon-to-be leaves are direct users of MockDep
for (auto Leaf : LeavesToAdd) {
MockDepCmd->addUser(Leaf);
Leaf->addDep(detail::DepDesc{MockDepCmd, Leaf->getRequirement(), nullptr});
for (auto &Leaf : LeavesToAdd) {
MockDepCmd->addUser(Leaf.get());
Leaf->addDep(
detail::DepDesc{MockDepCmd.get(), Leaf->getRequirement(), nullptr});
}
// Add edges as leaves and exceed the leaf limit
for (auto LeafPtr : LeavesToAdd) {
MS.addNodeToLeaves(Rec, LeafPtr);
for (auto &LeafPtr : LeavesToAdd) {
MS.addNodeToLeaves(Rec, LeafPtr.get());
}
// Check that the oldest leaf has been removed from the leaf list
// and added as a dependency of the newest one instead
const detail::CircularBuffer<detail::Command *> &Leaves =
Rec->MWriteLeaves.getGenericCommands();
ASSERT_TRUE(std::find(Leaves.begin(), Leaves.end(), LeavesToAdd.front()) ==
Leaves.end());
ASSERT_TRUE(std::find(Leaves.begin(), Leaves.end(),
LeavesToAdd.front().get()) == Leaves.end());
for (std::size_t i = 1; i < LeavesToAdd.size(); ++i) {
assert(std::find(Leaves.begin(), Leaves.end(), LeavesToAdd[i]) !=
assert(std::find(Leaves.begin(), Leaves.end(), LeavesToAdd[i].get()) !=
Leaves.end());
}
MockCommand *OldestLeaf = LeavesToAdd.front();
MockCommand *NewestLeaf = LeavesToAdd.back();
MockCommand *OldestLeaf = LeavesToAdd.front().get();
MockCommand *NewestLeaf = LeavesToAdd.back().get();
ASSERT_EQ(OldestLeaf->MUsers.size(), 1U);
EXPECT_GT(OldestLeaf->MUsers.count(NewestLeaf), 0U);
ASSERT_EQ(NewestLeaf->MDeps.size(), 2U);
Expand Down
23 changes: 21 additions & 2 deletions sycl/unittests/scheduler/SchedulerTestUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include <detail/scheduler/scheduler.hpp>

#include <functional>
#include <gmock/gmock.h>

// This header contains a few common classes/methods used in
// execution graph testing.

Expand All @@ -24,12 +26,22 @@ class MockCommand : public cl::sycl::detail::Command {
cl::sycl::detail::Requirement Req,
cl::sycl::detail::Command::CommandType Type =
cl::sycl::detail::Command::RUN_CG)
: Command{Type, Queue}, MRequirement{std::move(Req)} {}
: Command{Type, Queue}, MRequirement{std::move(Req)} {
using namespace testing;
ON_CALL(*this, enqueue(_, _))
.WillByDefault(Invoke(this, &MockCommand::enqueueOrigin));
EXPECT_CALL(*this, enqueue(_, _)).Times(AnyNumber());
}

MockCommand(cl::sycl::detail::QueueImplPtr Queue,
cl::sycl::detail::Command::CommandType Type =
cl::sycl::detail::Command::RUN_CG)
: Command{Type, Queue}, MRequirement{std::move(getMockRequirement())} {}
: Command{Type, Queue}, MRequirement{std::move(getMockRequirement())} {
using namespace testing;
ON_CALL(*this, enqueue(_, _))
.WillByDefault(Invoke(this, &MockCommand::enqueueOrigin));
EXPECT_CALL(*this, enqueue(_, _)).Times(AnyNumber());
}

void printDot(std::ostream &) const override {}
void emitInstrumentationData() override {}
Expand All @@ -40,6 +52,13 @@ class MockCommand : public cl::sycl::detail::Command {

cl_int enqueueImp() override { return MRetVal; }

MOCK_METHOD2(enqueue, bool(cl::sycl::detail::EnqueueResultT &,
cl::sycl::detail::BlockingT));
bool enqueueOrigin(cl::sycl::detail::EnqueueResultT &EnqueueResult,
cl::sycl::detail::BlockingT Blocking) {
return cl::sycl::detail::Command::enqueue(EnqueueResult, Blocking);
}

cl_int MRetVal = CL_SUCCESS;

void waitForEventsCall(
Expand Down