diff --git a/src/shared/inc/SocketChannel.h b/src/shared/inc/SocketChannel.h index b20221736..83d4710f4 100644 --- a/src/shared/inc/SocketChannel.h +++ b/src/shared/inc/SocketChannel.h @@ -110,6 +110,7 @@ class SocketChannel #ifdef WIN32 m_exitEvents = std::move(other.m_exitEvents); + m_pendingBytes = std::move(other.m_pendingBytes); #endif m_ignore_sequence = other.m_ignore_sequence; m_sent_non_transaction_messages = other.m_sent_non_transaction_messages; @@ -636,9 +637,8 @@ class SocketChannel auto io = CreateIO(); gsl::span message; - io.AddHandle(std::make_unique( - m_socket.get(), m_buffer, [&message](auto& received) { message = received; })); + m_socket.get(), m_buffer, m_pendingBytes, [&message](auto& received) { message = received; })); io.Run(TimeoutToMilliseconds(timeout)); @@ -723,6 +723,7 @@ class SocketChannel #ifdef WIN32 std::vector m_exitEvents; + std::vector m_pendingBytes; #endif uint32_t m_sent_non_transaction_messages = 0; diff --git a/src/windows/common/HandleIO.cpp b/src/windows/common/HandleIO.cpp index b78e73a98..980714f72 100644 --- a/src/windows/common/HandleIO.cpp +++ b/src/windows/common/HandleIO.cpp @@ -31,14 +31,15 @@ LARGE_INTEGER InitializeFileOffset(HANDLE File) return Offset; } -void CancelPendingIo(auto Handle, OVERLAPPED& Overlapped) +DWORD CancelPendingIo(auto Handle, OVERLAPPED& Overlapped) { DWORD bytesTransferred{}; - if (CancelIoEx((HANDLE)Handle, &Overlapped)) + if (CancelIoEx((HANDLE)Handle, &Overlapped) || GetLastError() == ERROR_NOT_FOUND) { if constexpr (std::is_same_v) { - if (!WSAGetOverlappedResult(Handle, &Overlapped, &bytesTransferred, true, nullptr)) + DWORD flagsReturned{}; + if (!WSAGetOverlappedResult(Handle, &Overlapped, &bytesTransferred, true, &flagsReturned)) { auto error = WSAGetLastError(); LOG_LAST_ERROR_IF(error != WSAECONNABORTED && error != WSA_OPERATION_ABORTED && error != WSAECONNRESET); @@ -56,9 +57,10 @@ void CancelPendingIo(auto Handle, OVERLAPPED& Overlapped) } else { - // ERROR_NOT_FOUND is returned if there was no IO to cancel. - LOG_LAST_ERROR_IF(GetLastError() != ERROR_NOT_FOUND); + LOG_LAST_ERROR_MSG("Unexpected error while cancelling IO on handle: 0x%p", (void*)Handle); } + + return bytesTransferred; } inline void UnregisterWait(HANDLE waitHandle) noexcept @@ -528,8 +530,11 @@ void HTTPChunkBasedReadHandle::OnRead(const gsl::span& Input) // ReadSocketMessageHandle ReadSocketMessageHandle::ReadSocketMessageHandle( - HandleWrapper&& MovedSocket, std::vector& Buffer, std::function& Message)>&& OnMessage) : - Socket(std::move(MovedSocket)), Buffer(Buffer), OnMessage(std::move(OnMessage)) + HandleWrapper&& MovedSocket, + std::vector& Buffer, + std::vector& PendingBytes, + std::function& Message)>&& OnMessage) : + Socket(std::move(MovedSocket)), Buffer(Buffer), PendingBytes(PendingBytes), OnMessage(std::move(OnMessage)) { Overlapped.hEvent = Event.get(); @@ -537,13 +542,53 @@ ReadSocketMessageHandle::ReadSocketMessageHandle( { Buffer.resize(sizeof(MESSAGE_HEADER)); } + + if (PendingBytes.empty()) + { + return; + } + + // If bytes from a previously cancelled transaction are passed, process them now. + if (Buffer.size() < PendingBytes.size()) + { + Buffer.resize(PendingBytes.size()); + } + + std::copy(PendingBytes.begin(), PendingBytes.end(), Buffer.begin()); + CurrentOffset = PendingBytes.size(); + PendingBytes.clear(); + + if (CurrentOffset < sizeof(MESSAGE_HEADER)) + { + BytesRemaining = sizeof(MESSAGE_HEADER) - CurrentOffset; + } + else + { + BytesRemaining = 0; + } } ReadSocketMessageHandle::~ReadSocketMessageHandle() { - if (State == IOHandleStatus::Pending) + if (State != IOHandleStatus::Completed) { - CancelPendingIo((SOCKET)Socket.Get(), Overlapped); + auto pendingSize = CurrentOffset; + + if (State == IOHandleStatus::Pending) + { + // Cancel the pending receive and move any bytes already buffered for the in-flight message into PendingBytes + const auto socket = reinterpret_cast(Socket.Get()); + pendingSize += CancelPendingIo(socket, Overlapped); + } + + if (pendingSize > 0) + { + WI_ASSERT(pendingSize <= Buffer.size()); + PendingBytes.assign(Buffer.begin(), Buffer.begin() + pendingSize); + + WSL_LOG( + "CanceledMessageRead", TraceLoggingValue(pendingSize, "TotalBytes"), TraceLoggingValue(Socket.Get(), "Socket")); + } } } @@ -601,40 +646,50 @@ void ReadSocketMessageHandle::ProcessRecvResult(DWORD BytesRead) return; } + ProcessChunk(); +} + +bool ReadSocketMessageHandle::ProcessChunk() +{ + const auto messageSize = gslhelpers::get_struct(gsl::make_span(Buffer.data(), sizeof(MESSAGE_HEADER)))->MessageSize; + if (ReadingHeader) { - auto messageSize = gslhelpers::get_struct(gsl::make_span(Buffer.data(), sizeof(MESSAGE_HEADER)))->MessageSize; - THROW_HR_IF_MSG(E_UNEXPECTED, messageSize < sizeof(MESSAGE_HEADER), "Unexpected message size: %u", messageSize); THROW_HR_IF_MSG(E_UNEXPECTED, messageSize > 4 * 1024 * 1024, "Message size too large: %u", messageSize); - if (messageSize == sizeof(MESSAGE_HEADER)) - { - OnMessage(gsl::make_span(Buffer.data(), messageSize)); - State = IOHandleStatus::Completed; - return; - } - if (Buffer.size() < messageSize) { Buffer.resize(messageSize); } ReadingHeader = false; - CurrentOffset = sizeof(MESSAGE_HEADER); - BytesRemaining = messageSize - sizeof(MESSAGE_HEADER); - } - else - { - auto messageSize = gslhelpers::get_struct(gsl::make_span(Buffer.data(), sizeof(MESSAGE_HEADER)))->MessageSize; - OnMessage(gsl::make_span(Buffer.data(), messageSize)); - State = IOHandleStatus::Completed; + if (CurrentOffset < messageSize) + { + BytesRemaining = messageSize - CurrentOffset; + } + + if (BytesRemaining > 0) + { + return true; + } } + + OnMessage(gsl::make_span(Buffer.data(), messageSize)); + State = IOHandleStatus::Completed; + return false; } void ReadSocketMessageHandle::Schedule() { WI_ASSERT(State == IOHandleStatus::Standby); + + // Process previously received bytes, if any. + if (BytesRemaining == 0 && !ProcessChunk()) + { + return; // Message has been fully received, no need to schedule a receive. + } + ScheduleRecv(); } diff --git a/src/windows/common/HandleIO.h b/src/windows/common/HandleIO.h index 387b4259d..c197d21ee 100644 --- a/src/windows/common/HandleIO.h +++ b/src/windows/common/HandleIO.h @@ -178,7 +178,11 @@ class ReadSocketMessageHandle : public OverlappedIOHandle NON_COPYABLE(ReadSocketMessageHandle); NON_MOVABLE(ReadSocketMessageHandle); - ReadSocketMessageHandle(HandleWrapper&& Socket, std::vector& Buffer, std::function& Message)>&& OnMessage); + ReadSocketMessageHandle( + HandleWrapper&& Socket, + std::vector& Buffer, + std::vector& PendingBytes, + std::function& Message)>&& OnMessage); ~ReadSocketMessageHandle(); void Schedule() override; @@ -188,9 +192,11 @@ class ReadSocketMessageHandle : public OverlappedIOHandle private: void ScheduleRecv(); void ProcessRecvResult(DWORD BytesRead); + bool ProcessChunk(); HandleWrapper Socket; std::vector& Buffer; + std::vector& PendingBytes; std::function& Message)> OnMessage; wil::unique_event Event{wil::EventOptions::ManualReset}; OVERLAPPED Overlapped{}; diff --git a/test/windows/UnitTests.cpp b/test/windows/UnitTests.cpp index 7b888a770..b9d54b369 100644 --- a/test/windows/UnitTests.cpp +++ b/test/windows/UnitTests.cpp @@ -6842,14 +6842,17 @@ Error code: Wsl/InstallDistro/WSL_E_INVALID_JSON\r\n", // Drive a ReadSocketMessageHandle until completion and return the bytes delivered to its // OnMessage callback. If a non-success HRESULT is supplied, the call is expected to throw // that HRESULT instead, and the OnMessage callback must not be invoked. - auto readMessage = [](wil::unique_socket&& server, HRESULT expectedHr = S_OK) { + auto readMessage = [](wil::unique_socket&& server, HRESULT expectedHr = S_OK, std::vector pendingBytes = {}) { std::vector buffer; bool callbackInvoked = false; std::vector message; wsl::windows::common::io::MultiHandleWait io; io.AddHandle(std::make_unique( - wsl::windows::common::io::HandleWrapper{std::move(server)}, buffer, [&callbackInvoked, &message](const gsl::span& received) { + wsl::windows::common::io::HandleWrapper{std::move(server)}, + buffer, + pendingBytes, + [&callbackInvoked, &message](const gsl::span& received) { callbackInvoked = true; message.assign(received.begin(), received.end()); })); @@ -6944,6 +6947,136 @@ Error code: Wsl/InstallDistro/WSL_E_INVALID_JSON\r\n", readMessage(std::move(server), E_UNEXPECTED); } + + // Scenario 6: PendingBytes carries a complete header-only message left over from a + // previous aborted receive. The reader should deliver it without touching the socket. + { + auto [client, server] = MakeSocketPair(); + client.reset(); // close the peer; we should still complete from PendingBytes alone. + + MESSAGE_HEADER header{}; + header.MessageType = LxMiniInitMessageAny; + header.MessageSize = sizeof(header); + header.TransactionId = 77; + header.TransactionStep = 1; + + const auto* headerBytes = reinterpret_cast(&header); + std::vector pendingBytes(headerBytes, headerBytes + sizeof(header)); + + const auto message = readMessage(std::move(server), S_OK, std::move(pendingBytes)); + VERIFY_ARE_EQUAL(message.size(), sizeof(header)); + VERIFY_IS_TRUE(std::memcmp(message.data(), &header, sizeof(header)) == 0); + } + + // Scenario 7: PendingBytes carries a complete message with a body left over from a + // previous aborted receive. The reader should deliver it without touching the socket. + { + auto [client, server] = MakeSocketPair(); + client.reset(); + + constexpr size_t bodySize = 32; + std::vector payload(sizeof(MESSAGE_HEADER) + bodySize); + auto* header = reinterpret_cast(payload.data()); + header->MessageType = LxMiniInitMessageAny; + header->MessageSize = gsl::narrow_cast(payload.size()); + header->TransactionId = 81; + header->TransactionStep = 3; + for (size_t i = 0; i < bodySize; ++i) + { + payload[sizeof(MESSAGE_HEADER) + i] = static_cast(i ^ 0xA5); + } + + std::vector pendingBytes(payload.begin(), payload.end()); + + const auto message = readMessage(std::move(server), S_OK, std::move(pendingBytes)); + VERIFY_ARE_EQUAL(message.size(), payload.size()); + VERIFY_IS_TRUE(std::memcmp(message.data(), payload.data(), payload.size()) == 0); + } + + // Scenario 8: PendingBytes carries only part of a header. The reader must fill in the + // rest of the header (and the body) from the socket and deliver the assembled message. + { + auto [client, server] = MakeSocketPair(); + + constexpr size_t bodySize = 48; + constexpr size_t prebufferedBytes = 6; // less than sizeof(MESSAGE_HEADER) = 16 + std::vector payload(sizeof(MESSAGE_HEADER) + bodySize); + auto* header = reinterpret_cast(payload.data()); + header->MessageType = LxMiniInitMessageAny; + header->MessageSize = gsl::narrow_cast(payload.size()); + header->TransactionId = 91; + header->TransactionStep = 4; + for (size_t i = 0; i < bodySize; ++i) + { + payload[sizeof(MESSAGE_HEADER) + i] = static_cast(0xC3); + } + + std::vector pendingBytes(payload.begin(), payload.begin() + prebufferedBytes); + WriteSocket(client.get(), payload.data() + prebufferedBytes, payload.size() - prebufferedBytes); + client.reset(); + + const auto message = readMessage(std::move(server), S_OK, std::move(pendingBytes)); + VERIFY_ARE_EQUAL(message.size(), payload.size()); + VERIFY_IS_TRUE(std::memcmp(message.data(), payload.data(), payload.size()) == 0); + } + + // Scenario 9: PendingBytes carries the full header plus part of the body. The reader + // must read the remaining body bytes from the socket and deliver the assembled message. + { + auto [client, server] = MakeSocketPair(); + + constexpr size_t bodySize = 64; + constexpr size_t prebufferedBodyBytes = 12; + std::vector payload(sizeof(MESSAGE_HEADER) + bodySize); + auto* header = reinterpret_cast(payload.data()); + header->MessageType = LxMiniInitMessageAny; + header->MessageSize = gsl::narrow_cast(payload.size()); + header->TransactionId = 92; + header->TransactionStep = 5; + for (size_t i = 0; i < bodySize; ++i) + { + payload[sizeof(MESSAGE_HEADER) + i] = static_cast(i & 0xFF); + } + + const size_t prebufferedBytes = sizeof(MESSAGE_HEADER) + prebufferedBodyBytes; + std::vector pendingBytes(payload.begin(), payload.begin() + prebufferedBytes); + WriteSocket(client.get(), payload.data() + prebufferedBytes, payload.size() - prebufferedBytes); + client.reset(); + + const auto message = readMessage(std::move(server), S_OK, std::move(pendingBytes)); + VERIFY_ARE_EQUAL(message.size(), payload.size()); + VERIFY_IS_TRUE(std::memcmp(message.data(), payload.data(), payload.size()) == 0); + } + + // Scenario 10: PendingBytes contains an invalid (too-small) message size. The + // IO should detect this and throw E_UNEXPECTED without invoking OnMessage. + { + auto [client, server] = MakeSocketPair(); + client.reset(); + + MESSAGE_HEADER header{}; + header.MessageType = LxMiniInitMessageAny; + header.MessageSize = sizeof(header) - 1; // invalid: smaller than the header itself + header.TransactionId = 99; + header.TransactionStep = 1; + + const auto* headerBytes = reinterpret_cast(&header); + std::vector pendingBytes{headerBytes, headerBytes + sizeof(header)}; + + std::vector buffer; + bool callbackInvoked = false; + const auto hr = wil::ResultFromException([&]() { + wsl::windows::common::io::MultiHandleWait io; + io.AddHandle(std::make_unique( + wsl::windows::common::io::HandleWrapper{std::move(server)}, + buffer, + pendingBytes, + [&callbackInvoked](const gsl::span&) { callbackInvoked = true; })); + io.Run(std::chrono::seconds(60)); + }); + VERIFY_ARE_EQUAL(hr, E_UNEXPECTED); + VERIFY_IS_FALSE(callbackInvoked); + } } TEST_METHOD(MultiHandleWaitAboveMaximumWaitObjects)