Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/shared/inc/SocketChannel.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class SocketChannel

#ifdef WIN32
m_exitEvents = std::move(other.m_exitEvents);
m_pendingBytes = std::move(other.m_pendingBytes);
#endif
m_ignore_sequence = other.m_ignore_sequence;
m_sent_non_transaction_messages = other.m_sent_non_transaction_messages;
Expand Down Expand Up @@ -636,9 +637,8 @@ class SocketChannel
auto io = CreateIO();

gsl::span<gsl::byte> message;

io.AddHandle(std::make_unique<windows::common::io::ReadSocketMessageHandle>(
m_socket.get(), m_buffer, [&message](auto& received) { message = received; }));
m_socket.get(), m_buffer, m_pendingBytes, [&message](auto& received) { message = received; }));

io.Run(TimeoutToMilliseconds(timeout));

Expand Down Expand Up @@ -723,6 +723,7 @@ class SocketChannel
#ifdef WIN32

std::vector<HANDLE> m_exitEvents;
std::vector<gsl::byte> m_pendingBytes;

#endif
uint32_t m_sent_non_transaction_messages = 0;
Expand Down
107 changes: 81 additions & 26 deletions src/windows/common/HandleIO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ LARGE_INTEGER InitializeFileOffset(HANDLE File)
return Offset;
}

void CancelPendingIo(auto Handle, OVERLAPPED& Overlapped)
DWORD CancelPendingIo(auto Handle, OVERLAPPED& Overlapped)
{
DWORD bytesTransferred{};
if (CancelIoEx((HANDLE)Handle, &Overlapped))
if (CancelIoEx((HANDLE)Handle, &Overlapped) || GetLastError() == ERROR_NOT_FOUND)
{
if constexpr (std::is_same_v<decltype(Handle), SOCKET>)
{
if (!WSAGetOverlappedResult(Handle, &Overlapped, &bytesTransferred, true, nullptr))
DWORD flagsReturned{};
if (!WSAGetOverlappedResult(Handle, &Overlapped, &bytesTransferred, true, &flagsReturned))
{
auto error = WSAGetLastError();
LOG_LAST_ERROR_IF(error != WSAECONNABORTED && error != WSA_OPERATION_ABORTED && error != WSAECONNRESET);
Expand All @@ -56,9 +57,10 @@ void CancelPendingIo(auto Handle, OVERLAPPED& Overlapped)
}
else
{
// ERROR_NOT_FOUND is returned if there was no IO to cancel.
LOG_LAST_ERROR_IF(GetLastError() != ERROR_NOT_FOUND);
LOG_LAST_ERROR_MSG("Unexpected error while cancelling IO on handle: 0x%p", (void*)Handle);
}

return bytesTransferred;
}

inline void UnregisterWait(HANDLE waitHandle) noexcept
Expand Down Expand Up @@ -528,22 +530,65 @@ void HTTPChunkBasedReadHandle::OnRead(const gsl::span<char>& Input)
// ReadSocketMessageHandle

ReadSocketMessageHandle::ReadSocketMessageHandle(
HandleWrapper&& MovedSocket, std::vector<gsl::byte>& Buffer, std::function<void(const gsl::span<gsl::byte>& Message)>&& OnMessage) :
Socket(std::move(MovedSocket)), Buffer(Buffer), OnMessage(std::move(OnMessage))
HandleWrapper&& MovedSocket,
std::vector<gsl::byte>& Buffer,
std::vector<gsl::byte>& PendingBytes,
std::function<void(const gsl::span<gsl::byte>& Message)>&& OnMessage) :
Socket(std::move(MovedSocket)), Buffer(Buffer), PendingBytes(PendingBytes), OnMessage(std::move(OnMessage))
{
Overlapped.hEvent = Event.get();

if (Buffer.size() < sizeof(MESSAGE_HEADER))
{
Buffer.resize(sizeof(MESSAGE_HEADER));
}

if (PendingBytes.empty())
{
return;
}

// If bytes from a previously cancelled transaction are passed, process them now.
if (Buffer.size() < PendingBytes.size())
{
Buffer.resize(PendingBytes.size());
}

std::copy(PendingBytes.begin(), PendingBytes.end(), Buffer.begin());
CurrentOffset = PendingBytes.size();
PendingBytes.clear();

if (CurrentOffset < sizeof(MESSAGE_HEADER))
{
BytesRemaining = sizeof(MESSAGE_HEADER) - CurrentOffset;
}
else
{
BytesRemaining = 0;
}
}

ReadSocketMessageHandle::~ReadSocketMessageHandle()
{
if (State == IOHandleStatus::Pending)
if (State != IOHandleStatus::Completed)
{
CancelPendingIo((SOCKET)Socket.Get(), Overlapped);
auto pendingSize = CurrentOffset;

if (State == IOHandleStatus::Pending)
{
// Cancel the pending receive and move any bytes already buffered for the in-flight message into PendingBytes
const auto socket = reinterpret_cast<SOCKET>(Socket.Get());
pendingSize += CancelPendingIo(socket, Overlapped);
}

if (pendingSize > 0)
{
WI_ASSERT(pendingSize <= Buffer.size());
PendingBytes.assign(Buffer.begin(), Buffer.begin() + pendingSize);

WSL_LOG(
"CanceledMessageRead", TraceLoggingValue(pendingSize, "TotalBytes"), TraceLoggingValue(Socket.Get(), "Socket"));
}
}
}

Expand Down Expand Up @@ -601,40 +646,50 @@ void ReadSocketMessageHandle::ProcessRecvResult(DWORD BytesRead)
return;
}

ProcessChunk();
}

bool ReadSocketMessageHandle::ProcessChunk()
{
const auto messageSize = gslhelpers::get_struct<MESSAGE_HEADER>(gsl::make_span(Buffer.data(), sizeof(MESSAGE_HEADER)))->MessageSize;

if (ReadingHeader)
{
auto messageSize = gslhelpers::get_struct<MESSAGE_HEADER>(gsl::make_span(Buffer.data(), sizeof(MESSAGE_HEADER)))->MessageSize;

THROW_HR_IF_MSG(E_UNEXPECTED, messageSize < sizeof(MESSAGE_HEADER), "Unexpected message size: %u", messageSize);
THROW_HR_IF_MSG(E_UNEXPECTED, messageSize > 4 * 1024 * 1024, "Message size too large: %u", messageSize);

if (messageSize == sizeof(MESSAGE_HEADER))
{
OnMessage(gsl::make_span(Buffer.data(), messageSize));
State = IOHandleStatus::Completed;
return;
}

if (Buffer.size() < messageSize)
{
Buffer.resize(messageSize);
}

ReadingHeader = false;
CurrentOffset = sizeof(MESSAGE_HEADER);
BytesRemaining = messageSize - sizeof(MESSAGE_HEADER);
}
else
{
auto messageSize = gslhelpers::get_struct<MESSAGE_HEADER>(gsl::make_span(Buffer.data(), sizeof(MESSAGE_HEADER)))->MessageSize;
OnMessage(gsl::make_span(Buffer.data(), messageSize));
State = IOHandleStatus::Completed;
if (CurrentOffset < messageSize)
{
BytesRemaining = messageSize - CurrentOffset;
}

if (BytesRemaining > 0)
{
return true;
}
}

OnMessage(gsl::make_span(Buffer.data(), messageSize));
State = IOHandleStatus::Completed;
return false;
}

void ReadSocketMessageHandle::Schedule()
{
WI_ASSERT(State == IOHandleStatus::Standby);

// Process previously received bytes, if any.
if (BytesRemaining == 0 && !ProcessChunk())
{
return; // Message has been fully received, no need to schedule a receive.
}

ScheduleRecv();
}

Expand Down
8 changes: 7 additions & 1 deletion src/windows/common/HandleIO.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,11 @@ class ReadSocketMessageHandle : public OverlappedIOHandle
NON_COPYABLE(ReadSocketMessageHandle);
NON_MOVABLE(ReadSocketMessageHandle);

ReadSocketMessageHandle(HandleWrapper&& Socket, std::vector<gsl::byte>& Buffer, std::function<void(const gsl::span<gsl::byte>& Message)>&& OnMessage);
ReadSocketMessageHandle(
HandleWrapper&& Socket,
std::vector<gsl::byte>& Buffer,
std::vector<gsl::byte>& PendingBytes,
std::function<void(const gsl::span<gsl::byte>& Message)>&& OnMessage);
~ReadSocketMessageHandle();

void Schedule() override;
Expand All @@ -188,9 +192,11 @@ class ReadSocketMessageHandle : public OverlappedIOHandle
private:
void ScheduleRecv();
void ProcessRecvResult(DWORD BytesRead);
bool ProcessChunk();

HandleWrapper Socket;
std::vector<gsl::byte>& Buffer;
std::vector<gsl::byte>& PendingBytes;
std::function<void(const gsl::span<gsl::byte>& Message)> OnMessage;
wil::unique_event Event{wil::EventOptions::ManualReset};
OVERLAPPED Overlapped{};
Expand Down
137 changes: 135 additions & 2 deletions test/windows/UnitTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6842,14 +6842,17 @@ Error code: Wsl/InstallDistro/WSL_E_INVALID_JSON\r\n",
// Drive a ReadSocketMessageHandle until completion and return the bytes delivered to its
// OnMessage callback. If a non-success HRESULT is supplied, the call is expected to throw
// that HRESULT instead, and the OnMessage callback must not be invoked.
auto readMessage = [](wil::unique_socket&& server, HRESULT expectedHr = S_OK) {
auto readMessage = [](wil::unique_socket&& server, HRESULT expectedHr = S_OK, std::vector<gsl::byte> pendingBytes = {}) {
std::vector<gsl::byte> buffer;
bool callbackInvoked = false;
std::vector<gsl::byte> message;

wsl::windows::common::io::MultiHandleWait io;
io.AddHandle(std::make_unique<wsl::windows::common::io::ReadSocketMessageHandle>(
wsl::windows::common::io::HandleWrapper{std::move(server)}, buffer, [&callbackInvoked, &message](const gsl::span<gsl::byte>& received) {
wsl::windows::common::io::HandleWrapper{std::move(server)},
buffer,
pendingBytes,
[&callbackInvoked, &message](const gsl::span<gsl::byte>& received) {
callbackInvoked = true;
message.assign(received.begin(), received.end());
}));
Expand Down Expand Up @@ -6944,6 +6947,136 @@ Error code: Wsl/InstallDistro/WSL_E_INVALID_JSON\r\n",

readMessage(std::move(server), E_UNEXPECTED);
}

// Scenario 6: PendingBytes carries a complete header-only message left over from a
// previous aborted receive. The reader should deliver it without touching the socket.
{
auto [client, server] = MakeSocketPair();
client.reset(); // close the peer; we should still complete from PendingBytes alone.

MESSAGE_HEADER header{};
header.MessageType = LxMiniInitMessageAny;
header.MessageSize = sizeof(header);
header.TransactionId = 77;
header.TransactionStep = 1;

const auto* headerBytes = reinterpret_cast<const gsl::byte*>(&header);
std::vector<gsl::byte> pendingBytes(headerBytes, headerBytes + sizeof(header));

const auto message = readMessage(std::move(server), S_OK, std::move(pendingBytes));
VERIFY_ARE_EQUAL(message.size(), sizeof(header));
VERIFY_IS_TRUE(std::memcmp(message.data(), &header, sizeof(header)) == 0);
}

// Scenario 7: PendingBytes carries a complete message with a body left over from a
// previous aborted receive. The reader should deliver it without touching the socket.
{
auto [client, server] = MakeSocketPair();
client.reset();

constexpr size_t bodySize = 32;
std::vector<gsl::byte> payload(sizeof(MESSAGE_HEADER) + bodySize);
auto* header = reinterpret_cast<MESSAGE_HEADER*>(payload.data());
header->MessageType = LxMiniInitMessageAny;
header->MessageSize = gsl::narrow_cast<unsigned int>(payload.size());
header->TransactionId = 81;
header->TransactionStep = 3;
for (size_t i = 0; i < bodySize; ++i)
{
payload[sizeof(MESSAGE_HEADER) + i] = static_cast<gsl::byte>(i ^ 0xA5);
}

std::vector<gsl::byte> pendingBytes(payload.begin(), payload.end());

const auto message = readMessage(std::move(server), S_OK, std::move(pendingBytes));
VERIFY_ARE_EQUAL(message.size(), payload.size());
VERIFY_IS_TRUE(std::memcmp(message.data(), payload.data(), payload.size()) == 0);
}

// Scenario 8: PendingBytes carries only part of a header. The reader must fill in the
// rest of the header (and the body) from the socket and deliver the assembled message.
{
auto [client, server] = MakeSocketPair();

constexpr size_t bodySize = 48;
constexpr size_t prebufferedBytes = 6; // less than sizeof(MESSAGE_HEADER) = 16
std::vector<gsl::byte> payload(sizeof(MESSAGE_HEADER) + bodySize);
auto* header = reinterpret_cast<MESSAGE_HEADER*>(payload.data());
header->MessageType = LxMiniInitMessageAny;
header->MessageSize = gsl::narrow_cast<unsigned int>(payload.size());
header->TransactionId = 91;
header->TransactionStep = 4;
for (size_t i = 0; i < bodySize; ++i)
{
payload[sizeof(MESSAGE_HEADER) + i] = static_cast<gsl::byte>(0xC3);
}

std::vector<gsl::byte> pendingBytes(payload.begin(), payload.begin() + prebufferedBytes);
WriteSocket(client.get(), payload.data() + prebufferedBytes, payload.size() - prebufferedBytes);
client.reset();

const auto message = readMessage(std::move(server), S_OK, std::move(pendingBytes));
VERIFY_ARE_EQUAL(message.size(), payload.size());
VERIFY_IS_TRUE(std::memcmp(message.data(), payload.data(), payload.size()) == 0);
}

// Scenario 9: PendingBytes carries the full header plus part of the body. The reader
// must read the remaining body bytes from the socket and deliver the assembled message.
{
auto [client, server] = MakeSocketPair();

constexpr size_t bodySize = 64;
constexpr size_t prebufferedBodyBytes = 12;
std::vector<gsl::byte> payload(sizeof(MESSAGE_HEADER) + bodySize);
auto* header = reinterpret_cast<MESSAGE_HEADER*>(payload.data());
header->MessageType = LxMiniInitMessageAny;
header->MessageSize = gsl::narrow_cast<unsigned int>(payload.size());
header->TransactionId = 92;
header->TransactionStep = 5;
for (size_t i = 0; i < bodySize; ++i)
{
payload[sizeof(MESSAGE_HEADER) + i] = static_cast<gsl::byte>(i & 0xFF);
}

const size_t prebufferedBytes = sizeof(MESSAGE_HEADER) + prebufferedBodyBytes;
std::vector<gsl::byte> pendingBytes(payload.begin(), payload.begin() + prebufferedBytes);
WriteSocket(client.get(), payload.data() + prebufferedBytes, payload.size() - prebufferedBytes);
client.reset();

const auto message = readMessage(std::move(server), S_OK, std::move(pendingBytes));
VERIFY_ARE_EQUAL(message.size(), payload.size());
VERIFY_IS_TRUE(std::memcmp(message.data(), payload.data(), payload.size()) == 0);
}

// Scenario 10: PendingBytes contains an invalid (too-small) message size. The
// IO should detect this and throw E_UNEXPECTED without invoking OnMessage.
{
auto [client, server] = MakeSocketPair();
client.reset();

MESSAGE_HEADER header{};
header.MessageType = LxMiniInitMessageAny;
header.MessageSize = sizeof(header) - 1; // invalid: smaller than the header itself
header.TransactionId = 99;
header.TransactionStep = 1;

const auto* headerBytes = reinterpret_cast<const gsl::byte*>(&header);
std::vector<gsl::byte> pendingBytes{headerBytes, headerBytes + sizeof(header)};

std::vector<gsl::byte> buffer;
bool callbackInvoked = false;
const auto hr = wil::ResultFromException([&]() {
wsl::windows::common::io::MultiHandleWait io;
io.AddHandle(std::make_unique<wsl::windows::common::io::ReadSocketMessageHandle>(
wsl::windows::common::io::HandleWrapper{std::move(server)},
buffer,
pendingBytes,
[&callbackInvoked](const gsl::span<gsl::byte>&) { callbackInvoked = true; }));
io.Run(std::chrono::seconds(60));
});
VERIFY_ARE_EQUAL(hr, E_UNEXPECTED);
VERIFY_IS_FALSE(callbackInvoked);
}
}

TEST_METHOD(MultiHandleWaitAboveMaximumWaitObjects)
Expand Down