Skip to content

Commit

Permalink
Fix Cancel On Loss Test's Race Condition (#4121)
Browse files Browse the repository at this point in the history
  • Loading branch information
nibanks committed Feb 9, 2024
1 parent fa50391 commit c19e900
Showing 1 changed file with 23 additions and 65 deletions.
88 changes: 23 additions & 65 deletions src/test/lib/DataTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1398,7 +1398,6 @@ struct CancelOnLossContext
CxPlatEvent SendPhaseEndedEvent = {};
};


_Function_class_(MsQuicStreamCallback)
QUIC_STATUS
QuicCancelOnLossStreamHandler(
Expand All @@ -1407,27 +1406,18 @@ QuicCancelOnLossStreamHandler(
_Inout_ QUIC_STREAM_EVENT* Event
)
{
if (Context == nullptr) {
return QUIC_STATUS_INVALID_PARAMETER;
}

auto TestContext = reinterpret_cast<CancelOnLossContext*>(Context);

QUIC_STATUS Status = QUIC_STATUS_SUCCESS;

switch (Event->Type) {
case QUIC_STREAM_EVENT_RECEIVE:
if (TestContext->IsServer) { // only server receives
TestContext->SendPhaseEndedEvent.Set();
TestContext->ExitCode = CancelOnLossContext::SuccessExitCode;
TestContext->SendPhaseEndedEvent.Set();
}
break;
case QUIC_STREAM_EVENT_PEER_SEND_ABORTED:
if (TestContext->IsServer) { // server-side 'cancel on loss' detection
TestContext->SendPhaseEndedEvent.Set();
TestContext->ExitCode = Event->PEER_SEND_ABORTED.ErrorCode;
} else {
Status = QUIC_STATUS_INVALID_STATE;
TestContext->SendPhaseEndedEvent.Set();
}
break;
case QUIC_STREAM_EVENT_PEER_SEND_SHUTDOWN:
Expand All @@ -1440,23 +1430,18 @@ QuicCancelOnLossStreamHandler(
if (!TestContext->IsDropScenario) { // if drop scenario, we use 'cancel on loss' event
TestContext->SendPhaseEndedEvent.Set();
}
} else {
Status = QUIC_STATUS_INVALID_STATE;
}
break;
case QUIC_STREAM_EVENT_CANCEL_ON_LOSS:
if (!TestContext->IsServer && TestContext->IsDropScenario) { // only client sends & only happens if in drop scenario
Event->CANCEL_ON_LOSS.ErrorCode = CancelOnLossContext::ErrorExitCode;
TestContext->SendPhaseEndedEvent.Set();
} else {
Status = QUIC_STATUS_INVALID_STATE;
}
break;
default:
break;
}

return Status;
return QUIC_STATUS_SUCCESS;
}

_Function_class_(MsQuicConnectionCallback)
Expand All @@ -1467,14 +1452,7 @@ QuicCancelOnLossConnectionHandler(
_Inout_ QUIC_CONNECTION_EVENT* Event
)
{
if (Context == nullptr) {
return QUIC_STATUS_INVALID_PARAMETER;
}

auto TestContext = reinterpret_cast<CancelOnLossContext*>(Context);

QUIC_STATUS Status = QUIC_STATUS_SUCCESS;

switch (Event->Type) {
case QUIC_CONNECTION_EVENT_PEER_STREAM_STARTED:
TestContext->Stream = new(std::nothrow) MsQuicStream(
Expand All @@ -1489,8 +1467,7 @@ QuicCancelOnLossConnectionHandler(
default:
break;
}

return Status;
return QUIC_STATUS_SUCCESS;
}

void
Expand All @@ -1511,49 +1488,38 @@ QuicCancelOnLossSend(

uint8_t RawBuffer[] = "cancel on loss message";
QUIC_BUFFER MessageBuffer = { sizeof(RawBuffer), RawBuffer };

SelectiveLossHelper LossHelper; // used later to trigger packet drops

// Start the server.
MsQuicConfiguration ServerConfiguration(Registration, Alpn, Settings, ServerSelfSignedCredConfig);
TEST_TRUE(ServerConfiguration.IsValid());
MsQuicCredentialConfig ClientCredConfig;
MsQuicConfiguration ClientConfiguration(Registration, Alpn, Settings, ClientCredConfig);
TEST_TRUE(ClientConfiguration.IsValid());

CancelOnLossContext ServerContext{ DropPackets, true /* IsServer */, &ServerConfiguration};
QuicAddr ServerLocalAddr;

CancelOnLossContext ServerContext(DropPackets, true /* IsServer */, &ServerConfiguration);
MsQuicAutoAcceptListener Listener(Registration, ServerConfiguration, QuicCancelOnLossConnectionHandler, &ServerContext);
TEST_TRUE(Listener.IsValid());
TEST_EQUAL(Listener.Start(Alpn), QUIC_STATUS_SUCCESS);
QuicAddr ServerLocalAddr;
TEST_EQUAL(Listener.GetLocalAddr(ServerLocalAddr), QUIC_STATUS_SUCCESS);

// Start the client.
MsQuicCredentialConfig ClientCredConfig;
MsQuicConfiguration ClientConfiguration(Registration, Alpn, Settings, ClientCredConfig);
TEST_TRUE(ClientConfiguration.IsValid());

CancelOnLossContext ClientContext{ DropPackets, false /* IsServer */, &ClientConfiguration};

// Initiate connection.
CancelOnLossContext ClientContext(DropPackets, false /* IsServer */, &ClientConfiguration);
ClientContext.Connection = new(std::nothrow) MsQuicConnection(
Registration,
CleanUpManual,
QuicCancelOnLossConnectionHandler,
&ClientContext);
TEST_TRUE(ClientContext.Connection->IsValid());

QUIC_STATUS Status = ClientContext.Connection->Start(
ClientConfiguration,
QUIC_ADDRESS_FAMILY_INET,
QUIC_TEST_LOOPBACK_FOR_AF(QUIC_ADDRESS_FAMILY_INET),
ServerLocalAddr.GetPort());
if (QUIC_FAILED(Status)) {
TEST_FAILURE("Failed to start a connection from the client.");
return;
}
TEST_QUIC_SUCCEEDED(
ClientContext.Connection->Start(
ClientConfiguration,
QUIC_ADDRESS_FAMILY_INET,
QUIC_TEST_LOOPBACK_FOR_AF(QUIC_ADDRESS_FAMILY_INET),
ServerLocalAddr.GetPort()));

// Wait for connection to be established.
constexpr uint32_t EventWaitTimeoutMs{ 1'000 };

if (!ClientContext.ConnectedEvent.WaitTimeout(EventWaitTimeoutMs)) {
TEST_FAILURE("Client failed to get connected before timeout!");
return;
Expand All @@ -1574,18 +1540,8 @@ QuicCancelOnLossSend(
QuicCancelOnLossStreamHandler,
&ClientContext);
TEST_TRUE(ClientContext.Stream->IsValid());
Status = ClientContext.Stream->Start();
if (QUIC_FAILED(Status)) {
TEST_FAILURE("Client failed to start stream.");
return;
}

// Send test message.
Status = ClientContext.Stream->Send(&MessageBuffer, 1, QUIC_SEND_FLAG_CANCEL_ON_LOSS);
if (QUIC_FAILED(Status)) {
TEST_FAILURE("Client failed to send message.");
return;
}
TEST_QUIC_SUCCEEDED(ClientContext.Stream->Start());
TEST_QUIC_SUCCEEDED(ClientContext.Stream->Send(&MessageBuffer, 1, QUIC_SEND_FLAG_CANCEL_ON_LOSS));

// If requested, drop packets.
if (DropPackets) {
Expand All @@ -1603,11 +1559,13 @@ QuicCancelOnLossSend(

// Check results.
if (DropPackets) {
TEST_EQUAL(ServerContext.ExitCode, CancelOnLossContext::ErrorExitCode);
if (ServerContext.ExitCode != CancelOnLossContext::ErrorExitCode) {
TEST_FAILURE("ServerContext.ExitCode %u != ErrorExitCode", ServerContext.ExitCode);
}
} else {
if (ServerContext.ExitCode != CancelOnLossContext::SuccessExitCode) {
TEST_FAILURE("ServerContext.ExitCode %u, CancelOnLossContext::SuccessExitCode: %u", ServerContext.ExitCode, CancelOnLossContext::SuccessExitCode);
}
TEST_FAILURE("ServerContext.ExitCode %u != SuccessExitCode", ServerContext.ExitCode);
}
}

if (Listener.LastConnection) {
Expand Down

0 comments on commit c19e900

Please sign in to comment.