Skip to content

Commit

Permalink
adding more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
csegarragonz committed Jun 14, 2021
1 parent aa9f765 commit 3101386
Showing 1 changed file with 137 additions and 8 deletions.
145 changes: 137 additions & 8 deletions tests/test/scheduler/test_remote_mpi_worlds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test send across hosts", "[mpi]")
// Send a message that should get sent to this host
remoteWorld.send(
rankB, rankA, BYTES(messageData.data()), MPI_INT, messageData.size());
usleep(1000 * 500);
remoteWorld.destroy();
});

Expand Down Expand Up @@ -93,21 +92,18 @@ TEST_CASE_METHOD(RemoteMpiTestFixture,
faabric::util::setMockMode(false);

std::thread senderThread([this, rankA, rankB, numMessages] {
std::vector<int> messageData = { 0, 1, 2 };

remoteWorld.initialiseFromMsg(msg);

for (int i = 0; i < numMessages; i++) {
remoteWorld.send(rankB, rankA, BYTES(&i), MPI_INT, sizeof(int));
remoteWorld.send(rankB, rankA, BYTES(&i), MPI_INT, 1);
}
usleep(1000 * 500);
remoteWorld.destroy();
});

int recv;
for (int i = 0; i < numMessages; i++) {
localWorld.recv(
rankB, rankA, BYTES(&recv), MPI_INT, sizeof(int), MPI_STATUS_IGNORE);
rankB, rankA, BYTES(&recv), MPI_INT, 1, MPI_STATUS_IGNORE);

// Check in-order delivery
if (i % (numMessages / 10) == 0) {
Expand Down Expand Up @@ -331,7 +327,7 @@ TEST_CASE_METHOD(RemoteMpiTestFixture,
int recvRank = 0;
std::vector<int> messageData = { 0, 1, 2 };

// Initi world
// Init world
MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId);
faabric::util::setMockMode(false);

Expand All @@ -352,7 +348,6 @@ TEST_CASE_METHOD(RemoteMpiTestFixture,
MPI_INT,
messageData.size());

usleep(1000 * 500);
remoteWorld.destroy();
});

Expand Down Expand Up @@ -384,4 +379,138 @@ TEST_CASE_METHOD(RemoteMpiTestFixture,
senderThread.join();
localWorld.destroy();
}

TEST_CASE_METHOD(RemoteMpiTestFixture,
"Test receiving remote async requests out of order",
"[mpi]")
{
// Allocate two ranks in total, one rank per host
this->setWorldsSizes(2, 1, 1);
int sendRank = 1;
int recvRank = 0;

// Init world
MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId);
faabric::util::setMockMode(false);

std::thread senderThread([this, sendRank, recvRank] {
remoteWorld.initialiseFromMsg(msg);

// Send different messages
for (int i = 0; i < 3; i++) {
remoteWorld.send(sendRank,
recvRank,
BYTES(&i),
MPI_INT,
1);
}

remoteWorld.destroy();
});

// Receive two messages asynchronously
int recv1, recv2, recv3;
int recvId1 = localWorld.irecv(sendRank,
recvRank,
BYTES(&recv1),
MPI_INT,
1);

int recvId2 = localWorld.irecv(sendRank,
recvRank,
BYTES(&recv2),
MPI_INT,
1);

// Receive one message synchronously
localWorld.recv(sendRank,
recvRank,
BYTES(&recv3),
MPI_INT,
1,
MPI_STATUS_IGNORE);

SECTION("Wait out of order")
{
localWorld.awaitAsyncRequest(recvId2);
localWorld.awaitAsyncRequest(recvId1);
}

SECTION("Wait in order")
{
localWorld.awaitAsyncRequest(recvId1);
localWorld.awaitAsyncRequest(recvId2);
}

// Checks
REQUIRE(recv1 == 0);
REQUIRE(recv2 == 1);
REQUIRE(recv3 == 2);

// Destroy world
senderThread.join();
localWorld.destroy();
}

TEST_CASE_METHOD(RemoteMpiTestFixture,
"Test ring sendrecv across hosts",
"[mpi]")
{
// Allocate two ranks in total, one rank per host
this->setWorldsSizes(3, 1, 2);
int worldSize = 3;
std::vector<int> localRanks = {0};

// Init world
MpiWorld& localWorld = getMpiWorldRegistry().createWorld(msg, worldId);
faabric::util::setMockMode(false);

std::thread senderThread([this, worldSize] {
std::vector<int> remoteRanks = {1, 2};
remoteWorld.initialiseFromMsg(msg);

// Send different messages
for (auto& rank : remoteRanks) {
int left = rank > 0 ? rank - 1 : worldSize - 1;
int right = (rank + 1) % worldSize;
int recvData = -1;

remoteWorld.sendRecv(BYTES(&rank),
1,
MPI_INT,
right,
BYTES(&recvData),
1,
MPI_INT,
left,
rank,
MPI_STATUS_IGNORE);
}

remoteWorld.destroy();
});

for (auto& rank : localRanks) {
int left = rank > 0 ? rank - 1 : worldSize - 1;
int right = (rank + 1) % worldSize;
int recvData = -1;

localWorld.sendRecv(BYTES(&rank),
1,
MPI_INT,
right,
BYTES(&recvData),
1,
MPI_INT,
left,
rank,
MPI_STATUS_IGNORE);

REQUIRE(recvData == left);
}

// Destroy world
senderThread.join();
localWorld.destroy();
}
}

0 comments on commit 3101386

Please sign in to comment.