Skip to content

Commit

Permalink
Don't Allow Version Negotiation Packets for Server Connections
Browse files Browse the repository at this point in the history
Adds a new test case that was able to repro the crash before the fix (in connection.c) was added.
  • Loading branch information
nibanks committed Oct 10, 2023
1 parent 5732f89 commit d1cb96a
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 31 deletions.
3 changes: 2 additions & 1 deletion src/core/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -3687,7 +3687,8 @@ QuicConnRecvHeader(
//
// Do not return FALSE here, continue with the connection.
//
} else if (Packet->Invariant->LONG_HDR.Version == QUIC_VERSION_VER_NEG &&
} else if (QuicConnIsClient(Connection) &&
Packet->Invariant->LONG_HDR.Version == QUIC_VERSION_VER_NEG &&
!Connection->Stats.VersionNegotiation) {
//
// Version negotiation packet received.
Expand Down
19 changes: 15 additions & 4 deletions src/inc/msquic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1140,21 +1140,32 @@ struct MsQuicConnection {
};

struct MsQuicAutoAcceptListener : public MsQuicListener {
const MsQuicConfiguration& Configuration;
const MsQuicConfiguration* Configuration;
MsQuicConnectionCallback* ConnectionHandler;
void* ConnectionContext;
#ifdef CX_PLATFORM_TYPE
uint32_t AcceptedConnectionCount {0};
#endif

MsQuicAutoAcceptListener(
_In_ const MsQuicRegistration& Registration,
_In_ MsQuicConnectionCallback* _ConnectionHandler,
_In_ void* _ConnectionContext = nullptr
) noexcept :
MsQuicListener(Registration, ListenerCallback, this),
Configuration(nullptr),
ConnectionHandler(_ConnectionHandler),
ConnectionContext(_ConnectionContext)
{ }

MsQuicAutoAcceptListener(
_In_ const MsQuicRegistration& Registration,
_In_ const MsQuicConfiguration& Config,
_In_ MsQuicConnectionCallback* _ConnectionHandler,
_In_ void* _ConnectionContext = nullptr
) noexcept :
MsQuicListener(Registration, ListenerCallback, this),
Configuration(Config),
Configuration(&Config),
ConnectionHandler(_ConnectionHandler),
ConnectionContext(_ConnectionContext)
{ }
Expand All @@ -1176,8 +1187,8 @@ struct MsQuicAutoAcceptListener : public MsQuicListener {
if (Event->Type == QUIC_LISTENER_EVENT_NEW_CONNECTION) {
auto Connection = new(std::nothrow) MsQuicConnection(Event->NEW_CONNECTION.Connection, CleanUpAutoDelete, pThis->ConnectionHandler, pThis->ConnectionContext);
if (Connection) {
Status = Connection->SetConfiguration(pThis->Configuration);
if (QUIC_FAILED(Status)) {
if (!pThis->Configuration ||
QUIC_FAILED(Status = Connection->SetConfiguration(*pThis->Configuration))) {
//
// The connection is being rejected. Let MsQuic free the handle.
//
Expand Down
11 changes: 10 additions & 1 deletion src/test/MsQuicTests.h
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,11 @@ QuicDrillTestInitialToken(
_In_ int Family
);

void
QuicDrillTestServerVNPacket(
_In_ int Family
);

//
// Datagram tests
//
Expand Down Expand Up @@ -1088,4 +1093,8 @@ typedef struct {
#define IOCTL_QUIC_RUN_CONNECT_AND_IDLE_FOR_DEST_CID_CHANGE \
QUIC_CTL_CODE(101, METHOD_BUFFERED, FILE_WRITE_DATA)

#define QUIC_MAX_IOCTL_FUNC_CODE 101
#define IOCTL_QUIC_RUN_DRILL_VN_PACKET_TOKEN \
QUIC_CTL_CODE(102, METHOD_BUFFERED, FILE_WRITE_DATA)
// int - Family

#define QUIC_MAX_IOCTL_FUNC_CODE 102
9 changes: 9 additions & 0 deletions src/test/bin/quic_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1787,6 +1787,15 @@ TEST_P(WithDrillInitialPacketTokenArgs, DrillInitialPacketToken) {
}
}

TEST_P(WithDrillInitialPacketTokenArgs, QuicDrillTestServerVNPacket) {
TestLoggerT<ParamType> Logger("QuicDrillTestServerVNPacket", GetParam());
if (TestingKernelMode) {
ASSERT_TRUE(DriverClient.Run(IOCTL_QUIC_RUN_DRILL_VN_PACKET_TOKEN, GetParam().Family));
} else {
QuicDrillTestServerVNPacket(GetParam().Family);
}
}

TEST_P(WithDatagramNegotiationArgs, DatagramNegotiation) {
TestLoggerT<ParamType> Logger("QuicTestDatagramNegotiation", GetParam());
if (TestingKernelMode) {
Expand Down
6 changes: 6 additions & 0 deletions src/test/bin/winkernel/control.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ size_t QUIC_IOCTL_BUFFER_SIZES[] =
0,
0,
0,
sizeof(INT32),
};

CXPLAT_STATIC_ASSERT(
Expand Down Expand Up @@ -1272,6 +1273,11 @@ QuicTestCtlEvtIoDeviceControl(
QuicTestCtlRun(QuicTestConnectAndIdleForDestCidChange());
break;

case IOCTL_QUIC_RUN_DRILL_VN_PACKET_TOKEN:
CXPLAT_FRE_ASSERT(Params != nullptr);
QuicTestCtlRun(QuicDrillTestServerVNPacket(Params->Family));
break;

default:
Status = STATUS_NOT_IMPLEMENTED;
break;
Expand Down
35 changes: 24 additions & 11 deletions src/test/lib/DrillDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,31 @@ DrillPacketDescriptor::write(
}
PacketBuffer.insert(PacketBuffer.end(), SourceCid.begin(), SourceCid.end());

//
// TODO: Do type-specific stuff here.
//
return PacketBuffer;
}

DrillBuffer
DrillVNPacketDescriptor::write(
) const
{
DrillBuffer PacketBuffer = DrillPacketDescriptor::write();

// uint32_t SupportedVersions[]
uint32_t SupportedVer = QUIC_VERSION_2_H;
for (size_t i = 0; i < sizeof(SupportedVer); ++i) {
PacketBuffer.push_back((uint8_t) (SupportedVer >> (((sizeof(SupportedVer) - 1) - i) * 8)));
}
SupportedVer = QUIC_VERSION_1_MS_H;
for (size_t i = 0; i < sizeof(SupportedVer); ++i) {
PacketBuffer.push_back((uint8_t) (SupportedVer >> (((sizeof(SupportedVer) - 1) - i) * 8)));
}

return PacketBuffer;
}

DrillInitialPacketDescriptor::DrillInitialPacketDescriptor(
) : DrillPacketDescriptor(), TokenLen(nullptr), PacketLength(nullptr), PacketNumber(0)
DrillInitialPacketDescriptor::DrillInitialPacketDescriptor()
{
Type = Initial;
Header.LongHeader = 1;
Header.FixedBit = 1;
Version = QUIC_VERSION_LATEST_H;

Expand Down Expand Up @@ -197,10 +210,7 @@ DrillInitialPacketDescriptor::write(
}

CalculatedPacketLength += PacketNumberBuffer.size();

//
// TODO: Calculate the payload length.
//
CalculatedPacketLength += Payload.size();

//
// Write packet length.
Expand All @@ -219,8 +229,11 @@ DrillInitialPacketDescriptor::write(
PacketBuffer.insert(PacketBuffer.end(), PacketNumberBuffer.begin(), PacketNumberBuffer.end());

//
// TODO: Write payload here.
// Write payload.
//
if (Payload.size() > 0) {
PacketBuffer.insert(PacketBuffer.end(), Payload.begin(), Payload.end());
}

return PacketBuffer;
}
26 changes: 17 additions & 9 deletions src/test/lib/DrillDescriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,39 +64,45 @@ struct DrillPacketDescriptor {
//
// The type of datagram this describes.
//
DrillPacketDescriptorType Type;
DrillPacketDescriptorType Type {VersionNegotiation};

QuicHeader Header;
QuicHeader Header {0};

uint32_t Version;
uint32_t Version {QUIC_VERSION_VER_NEG};

//
// Optional destination CID length. If not set, will use length of DestCid.
//
uint8_t* DestCidLen;
uint8_t* DestCidLen {nullptr};
DrillBuffer DestCid;

//
// Optional source CID length. If not set, will use length of SourceCid.
//
uint8_t* SourceCidLen;
uint8_t* SourceCidLen {nullptr};
DrillBuffer SourceCid;

DrillPacketDescriptor() : DestCidLen(nullptr), SourceCidLen(nullptr) {};
DrillPacketDescriptor() { Header.LongHeader = TRUE; }

//
// Write this descriptor to a byte array to send on the wire.
//
virtual DrillBuffer write() const;
};

struct DrillVNPacketDescriptor : DrillPacketDescriptor {
//
// Write this descriptor to a byte array to send on the wire.
//
virtual DrillBuffer write() const;
};

struct DrillInitialPacketDescriptor : DrillPacketDescriptor {
//
// Optional Token length for the token. If unspecified, uses the length
// of Token below.
//
uint64_t* TokenLen;
uint64_t* TokenLen {nullptr};

//
// Token is optional. If unspecified, then it is elidded.
Expand All @@ -107,13 +113,15 @@ struct DrillInitialPacketDescriptor : DrillPacketDescriptor {
// If unspecified, this value is auto-calculated from the fields.
// Otherwise, this value is used regardless of actual packet length.
//
uint64_t* PacketLength;
uint64_t* PacketLength {nullptr};

//
// The caller must ensure the packet number length bits in the header
// match the magnitude of this PacketNumber.
//
uint32_t PacketNumber;
uint32_t PacketNumber {0};

DrillBuffer Payload;


DrillInitialPacketDescriptor();
Expand Down
50 changes: 45 additions & 5 deletions src/test/lib/QuicDrill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,12 @@ struct DrillSender {

QUIC_STATUS
Send(
_In_ const DrillBuffer* PacketBuffer
_In_ const DrillBuffer& PacketBuffer
)
{
QUIC_STATUS Status = QUIC_STATUS_SUCCESS;
CXPLAT_FRE_ASSERT(PacketBuffer->size() <= UINT16_MAX);
const uint16_t DatagramLength = (uint16_t) PacketBuffer->size();
CXPLAT_FRE_ASSERT(PacketBuffer.size() <= UINT16_MAX);
const uint16_t DatagramLength = (uint16_t) PacketBuffer.size();

CXPLAT_ROUTE Route;
CxPlatSocketGetLocalAddress(Binding, &Route.LocalAddress);
Expand All @@ -215,7 +215,7 @@ struct DrillSender {
//
// Copy test packet into SendBuffer.
//
memcpy(SendBuffer->Buffer, PacketBuffer->data(), DatagramLength);
memcpy(SendBuffer->Buffer, PacketBuffer.data(), DatagramLength);

Status =
CxPlatSocketSend(
Expand Down Expand Up @@ -307,7 +307,7 @@ QuicDrillInitialPacketFailureTest(
//
// Send test packet to the server.
//
Status = Sender.Send(&PacketBuffer);
Status = Sender.Send(PacketBuffer);
if (QUIC_FAILED(Status)) {
return false;
}
Expand Down Expand Up @@ -492,3 +492,43 @@ QuicDrillTestInitialToken(
}
}
}

void
QuicDrillTestServerVNPacket(
_In_ int Family
)
{
MsQuicRegistration Registration(true);
TEST_QUIC_SUCCEEDED(Registration.GetInitStatus());

QUIC_ADDRESS_FAMILY QuicAddrFamily = (Family == 4) ? QUIC_ADDRESS_FAMILY_INET : QUIC_ADDRESS_FAMILY_INET6;
QuicAddr ServerLocalAddr(QuicAddrFamily);

MsQuicAutoAcceptListener Listener(Registration, MsQuicConnection::NoOpCallback);
TEST_QUIC_SUCCEEDED(Listener.Start("MsQuicTest", &ServerLocalAddr.SockAddr));
TEST_QUIC_SUCCEEDED(Listener.GetInitStatus());
TEST_QUIC_SUCCEEDED(Listener.GetLocalAddr(ServerLocalAddr));

DrillSender Sender;
TEST_QUIC_SUCCEEDED(
Sender.Initialize(
QUIC_TEST_LOOPBACK_FOR_AF(QuicAddrFamily),
QuicAddrFamily,
(QuicAddrFamily == QUIC_ADDRESS_FAMILY_INET) ?
ServerLocalAddr.SockAddr.Ipv4.sin_port :
ServerLocalAddr.SockAddr.Ipv6.sin6_port));

uint8_t SourceCidLen = 0;
DrillInitialPacketDescriptor InitialPacketBuffer;
InitialPacketBuffer.SourceCidLen = &SourceCidLen;
for (uint8_t i = 0; i < 8; ++i) { InitialPacketBuffer.DestCid.push_back(i); }
for (uint16_t i = 0; i < 1200; ++i) { InitialPacketBuffer.Payload.push_back(0); }

DrillVNPacketDescriptor VNPacketBuffer;
for (uint8_t i = 0; i < 8; ++i) { VNPacketBuffer.DestCid.push_back(i); }

TEST_QUIC_SUCCEEDED(Sender.Send(InitialPacketBuffer.write()));
TEST_QUIC_SUCCEEDED(Sender.Send(VNPacketBuffer.write()));

CxPlatSleep(500);
}

0 comments on commit d1cb96a

Please sign in to comment.