Skip to content

Commit

Permalink
QUIC Client 0-RTT With Retry Token Fix
Browse files Browse the repository at this point in the history
Summary:
## Issue
When we enabled the JK to set QUIC server's rate limit to zero (D30703498), causing it to send retry packets to all incoming client hellos, we found that android clients were crashing (T99705615).

We were able to reproduce the bug with LigerIntegrationTest and HQClient backed by HQServer, causing the clients to hang.

## Fix
- When receiving a retry packet from the server, migrate the outstandings, stream manager, among other required fields to the new connection state.

- QuicStreamManager & QuicStreamState constructors to facilitate migrating to a new connection state.

Reviewed By: kvtsoy

Differential Revision: D31006905

fbshipit-source-id: 0490ceee1bef52b94c91019426d791e212820508
  • Loading branch information
hanidamlaj authored and facebook-github-bot committed Sep 23, 2021
1 parent acd9116 commit a38a1c7
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 5 deletions.
24 changes: 24 additions & 0 deletions quic/client/state/ClientStateMachine.cpp
Expand Up @@ -7,6 +7,8 @@
*/

#include <quic/client/state/ClientStateMachine.h>
#include <quic/codec/Types.h>
#include <quic/loss/QuicLossFunctions.h>

#include <folly/io/async/AsyncSocketException.h>
#include <quic/client/handshake/CachedServerTransportParameters.h>
Expand Down Expand Up @@ -57,6 +59,7 @@ std::unique_ptr<QuicClientConnectionState> undoAllClientStateForRetry(
std::move(conn->earlyDataAppParamsValidator);
newConn->earlyDataAppParamsGetter = std::move(conn->earlyDataAppParamsGetter);
newConn->happyEyeballsState = std::move(conn->happyEyeballsState);
newConn->flowControlState = std::move(conn->flowControlState);
newConn->pendingOneRttData.reserve(
newConn->transportSettings.maxPacketsToBuffer);
if (conn->congestionControllerFactory) {
Expand All @@ -69,6 +72,27 @@ std::unique_ptr<QuicClientConnectionState> undoAllClientStateForRetry(
*newConn, conn->congestionController->type());
}
}

// only copy over zero-rtt data
for (auto& outstandingPacket : conn->outstandings.packets) {
auto& packetHeader = outstandingPacket.packet.header;
if (packetHeader.getPacketNumberSpace() == PacketNumberSpace::AppData &&
packetHeader.getProtectionType() == ProtectionType::ZeroRtt) {
newConn->outstandings.packets.push_back(std::move(outstandingPacket));
newConn->outstandings.packetCount[PacketNumberSpace::AppData]++;
}
}

newConn->lossState = conn->lossState;
newConn->nodeType = conn->nodeType;
newConn->streamManager = std::make_unique<QuicStreamManager>(
*newConn,
newConn->nodeType,
newConn->transportSettings,
std::move(*conn->streamManager));

markZeroRttPacketsLost(*newConn, markPacketLoss);

return newConn;
}

Expand Down
102 changes: 98 additions & 4 deletions quic/fizz/client/test/QuicClientTransportTest.cpp
Expand Up @@ -5,11 +5,10 @@
* LICENSE file in the root directory of this source tree.
*
*/
#include <quic/api/test/Mocks.h>
#include <quic/client/QuicClientTransport.h>
#include <quic/server/QuicServer.h>

#include <quic/api/test/Mocks.h>

#include <folly/portability/GMock.h>
#include <folly/portability/GTest.h>

Expand All @@ -20,6 +19,7 @@
#include <folly/io/SocketOptionMap.h>
#include <folly/io/async/ScopedEventBaseThread.h>
#include <folly/io/async/test/MockAsyncUDPSocket.h>
#include <quic/QuicConstants.h>
#include <quic/codec/DefaultConnectionIdAlgo.h>
#include <quic/common/test/TestClientUtils.h>
#include <quic/common/test/TestUtils.h>
Expand All @@ -36,7 +36,6 @@
#include <quic/samples/echo/EchoHandler.h>
#include <quic/samples/echo/EchoServer.h>
#include <quic/state/test/MockQuicStats.h>
#include "quic/QuicConstants.h"

using namespace testing;
using namespace folly;
Expand Down Expand Up @@ -250,7 +249,9 @@ class QuicClientTransportIntegrationTest : public TestWithParam<TestingParams> {
return client;
}

std::shared_ptr<QuicServer> createServer(ProcessId processId) {
std::shared_ptr<QuicServer> createServer(
ProcessId processId,
bool withRetryPacket = false) {
auto server = QuicServer::createQuicServer();
auto transportSettings = server->getTransportSettings();
auto statsFactory = std::make_unique<NiceMock<MockQuicStatsFactory>>();
Expand All @@ -261,6 +262,12 @@ class QuicClientTransportIntegrationTest : public TestWithParam<TestingParams> {
}));
transportSettings.zeroRttSourceTokenMatchingPolicy =
ZeroRttSourceTokenMatchingPolicy::LIMIT_IF_NO_EXACT_MATCH;
if (withRetryPacket) {
std::array<uint8_t, kRetryTokenSecretLength> secret;
folly::Random::secureRandom(secret.data(), secret.size());
transportSettings.retryTokenSecret = secret;
server->setRateLimit([]() { return 0u; }, 1s);
}
server->setTransportStatsCallbackFactory(std::move(statsFactory));
server->setTransportSettings(transportSettings);
server->setQuicServerTransportFactory(
Expand Down Expand Up @@ -650,6 +657,93 @@ TEST_P(QuicClientTransportIntegrationTest, TestZeroRttSuccess) {
EXPECT_TRUE(client->getConn().statelessResetToken.has_value());
}

TEST_P(QuicClientTransportIntegrationTest, ZeroRttRetryPacketTest) {
/**
* logic extrapolated from TestZeroRttSuccess and RetryPacket tests
*/
auto retryServer = createServer(ProcessId::ONE, true);
client->getNonConstConn().peerAddress = retryServer->getAddress();

SCOPE_EXIT {
retryServer->shutdown();
retryServer = nullptr;
};

auto cachedPsk = setupZeroRttOnClientCtx(*clientCtx, hostname);
pskCache_->putPsk(hostname, cachedPsk);
setupZeroRttOnServerCtx(*serverCtx, cachedPsk);
// Change the ctx
retryServer->setFizzContext(serverCtx);

std::vector<uint8_t> clientConnIdVec = {};
ConnectionId clientConnId(clientConnIdVec);

ConnectionId initialDstConnId(kInitialDstConnIdVecForRetryTest);

auto qLogger = std::make_shared<FileQLogger>(VantagePoint::Client);
client->getNonConstConn().qLogger = qLogger;
client->getNonConstConn().readCodec->setClientConnectionId(clientConnId);
client->getNonConstConn().initialDestinationConnectionId = initialDstConnId;
client->getNonConstConn().originalDestinationConnectionId = initialDstConnId;
client->setCongestionControllerFactory(
std::make_shared<DefaultCongestionControllerFactory>());
client->setCongestionControl(CongestionControlType::NewReno);

folly::Optional<std::string> alpn = std::string("h1q-fb");
bool performedValidation = false;
client->setEarlyDataAppParamsFunctions(
[&](const folly::Optional<std::string>& alpnToValidate, const Buf&) {
performedValidation = true;
EXPECT_EQ(alpnToValidate, alpn);
return true;
},
[]() -> Buf { return nullptr; });
client->start(&clientConnCallback);
EXPECT_TRUE(performedValidation);
CHECK(client->getConn().zeroRttWriteCipher);
EXPECT_TRUE(client->serverInitialParamsSet());
EXPECT_EQ(
client->peerAdvertisedInitialMaxData(), kDefaultConnectionWindowSize);
EXPECT_EQ(
client->peerAdvertisedInitialMaxStreamDataBidiLocal(),
kDefaultStreamWindowSize);
EXPECT_EQ(
client->peerAdvertisedInitialMaxStreamDataBidiRemote(),
kDefaultStreamWindowSize);
EXPECT_EQ(
client->peerAdvertisedInitialMaxStreamDataUni(),
kDefaultStreamWindowSize);
EXPECT_CALL(clientConnCallback, onTransportReady()).WillOnce(Invoke([&] {
ASSERT_EQ(client->getAppProtocol(), "h1q-fb");
CHECK(client->getConn().zeroRttWriteCipher);
eventbase_.terminateLoopSoon();
}));
eventbase_.loopForever();

EXPECT_TRUE(client->getConn().zeroRttWriteCipher);
EXPECT_TRUE(client->good());
EXPECT_FALSE(client->replaySafe());

auto streamId = client->createBidirectionalStream().value();
auto data = IOBuf::copyBuffer("hello");
auto expected = std::shared_ptr<IOBuf>(IOBuf::copyBuffer("echo "));
expected->prependChain(data->clone());

EXPECT_CALL(clientConnCallback, onReplaySafe()).WillOnce(Invoke([&] {
EXPECT_TRUE(!client->getConn().retryToken.empty());
}));
sendRequestAndResponseAndWait(*expected, data->clone(), streamId, &readCb);

// Check CC is kept after retry recreates QuicClientConnectionState
EXPECT_TRUE(client->getConn().congestionControllerFactory);
EXPECT_EQ(
client->getConn().congestionController->type(),
CongestionControlType::NewReno);

EXPECT_FALSE(client->getConn().zeroRttWriteCipher);
EXPECT_TRUE(client->getConn().statelessResetToken.has_value());
}

TEST_P(QuicClientTransportIntegrationTest, TestZeroRttRejection) {
expectTransportCallbacks();
auto qLogger = std::make_shared<FileQLogger>(VantagePoint::Client);
Expand Down
3 changes: 2 additions & 1 deletion quic/samples/echo/EchoHandler.h
Expand Up @@ -49,7 +49,8 @@ class EchoHandler : public quic::QuicSocket::ConnectionCallback,

void onConnectionError(
std::pair<quic::QuicErrorCode, std::string> error) noexcept override {
LOG(ERROR) << "Socket error=" << toString(error.first);
LOG(ERROR) << "Socket error=" << toString(error.first) << " "
<< error.second;
}

void readAvailable(quic::StreamId id) noexcept override {
Expand Down
83 changes: 83 additions & 0 deletions quic/state/QuicStreamManager.h
Expand Up @@ -57,6 +57,89 @@ class QuicStreamManager {
}
refreshTransportSettings(transportSettings);
}

/**
* Constructor to facilitate migration of a QuicStreamManager to another
* QuicConnectionStateBase
*/
explicit QuicStreamManager(
QuicConnectionStateBase& conn,
QuicNodeType nodeType,
const TransportSettings& transportSettings,
QuicStreamManager&& other)
: conn_(conn),
nodeType_(nodeType),
transportSettings_(&transportSettings) {
nextAcceptablePeerBidirectionalStreamId_ =
other.nextAcceptablePeerBidirectionalStreamId_;
nextAcceptablePeerUnidirectionalStreamId_ =
other.nextAcceptablePeerUnidirectionalStreamId_;
nextAcceptableLocalBidirectionalStreamId_ =
other.nextAcceptableLocalBidirectionalStreamId_;
nextAcceptableLocalUnidirectionalStreamId_ =
other.nextAcceptableLocalUnidirectionalStreamId_;
nextBidirectionalStreamId_ = other.nextBidirectionalStreamId_;
nextUnidirectionalStreamId_ = other.nextUnidirectionalStreamId_;
maxLocalBidirectionalStreamId_ = other.maxLocalBidirectionalStreamId_;
maxLocalUnidirectionalStreamId_ = other.maxLocalUnidirectionalStreamId_;
maxRemoteBidirectionalStreamId_ = other.maxRemoteBidirectionalStreamId_;
maxRemoteUnidirectionalStreamId_ = other.maxRemoteUnidirectionalStreamId_;
initialLocalBidirectionalStreamId_ =
other.initialLocalBidirectionalStreamId_;
initialLocalUnidirectionalStreamId_ =
other.initialLocalUnidirectionalStreamId_;
initialRemoteBidirectionalStreamId_ =
other.initialRemoteBidirectionalStreamId_;
initialRemoteUnidirectionalStreamId_ =
other.initialRemoteUnidirectionalStreamId_;

streamLimitWindowingFraction_ = other.streamLimitWindowingFraction_;
remoteBidirectionalStreamLimitUpdate_ =
other.remoteBidirectionalStreamLimitUpdate_;
remoteUnidirectionalStreamLimitUpdate_ =
other.remoteUnidirectionalStreamLimitUpdate_;
numControlStreams_ = other.numControlStreams_;
openBidirectionalPeerStreams_ =
std::move(other.openBidirectionalPeerStreams_);
openUnidirectionalPeerStreams_ =
std::move(other.openUnidirectionalPeerStreams_);
openBidirectionalLocalStreams_ =
std::move(other.openBidirectionalLocalStreams_);
openUnidirectionalLocalStreams_ =
std::move(other.openUnidirectionalLocalStreams_);
newPeerStreams_ = std::move(other.newPeerStreams_);
blockedStreams_ = std::move(other.blockedStreams_);
stopSendingStreams_ = std::move(other.stopSendingStreams_);
windowUpdates_ = std::move(other.windowUpdates_);
flowControlUpdated_ = std::move(other.flowControlUpdated_);
lossStreams_ = std::move(other.lossStreams_);
readableStreams_ = std::move(other.readableStreams_);
peekableStreams_ = std::move(other.peekableStreams_);
writableStreams_ = std::move(other.writableStreams_);
writableDSRStreams_ = std::move(other.writableDSRStreams_);
writableControlStreams_ = std::move(other.writableControlStreams_);
txStreams_ = std::move(other.txStreams_);
deliverableStreams_ = std::move(other.deliverableStreams_);
closedStreams_ = std::move(other.closedStreams_);
isAppIdle_ = other.isAppIdle_;
maxLocalBidirectionalStreamIdIncreased_ =
other.maxLocalBidirectionalStreamIdIncreased_;
maxLocalUnidirectionalStreamIdIncreased_ =
other.maxLocalUnidirectionalStreamIdIncreased_;

/**
* We can't simply std::move the streams as the underlying
* QuicStreamState(s) hold a reference to the other.conn_.
*/
for (auto& pair : other.streams_) {
streams_.emplace(
std::piecewise_construct,
std::forward_as_tuple(pair.first),
std::forward_as_tuple(
/* migrate state to new conn ref */ conn_,
std::move(pair.second)));
}
}
/*
* Create the state for a stream if it does not exist and return it. Note this
* function is only used internally or for testing.
Expand Down
24 changes: 24 additions & 0 deletions quic/state/StreamData.h
Expand Up @@ -228,6 +228,30 @@ struct QuicStreamState : public QuicStreamLike {

QuicStreamState(QuicStreamState&&) = default;

/**
* Constructor to migrate QuicStreamState to another
* QuicConnectionStateBase.
*/
QuicStreamState(QuicConnectionStateBase& connIn, QuicStreamState&& other)
: QuicStreamLike(std::move(other)), conn(connIn), id(other.id) {
// QuicStreamState fields
finalWriteOffset = other.finalWriteOffset;
flowControlState = other.flowControlState;
streamReadError = other.streamReadError;
streamWriteError = other.streamWriteError;
sendState = other.sendState;
recvState = other.recvState;
isControl = other.isControl;
lastHolbTime = other.lastHolbTime;
totalHolbTime = other.totalHolbTime;
holbCount = other.holbCount;
priority = other.priority;
dsrSender = std::move(other.dsrSender);
writeBufMeta = other.writeBufMeta;
retransmissionBufMetas = std::move(other.retransmissionBufMetas);
lossBufMetas = std::move(other.lossBufMetas);
}

// Connection that this stream is associated with.
QuicConnectionStateBase& conn;

Expand Down

0 comments on commit a38a1c7

Please sign in to comment.