Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"type": "prerelease",
"comment": "Fix WebSocket binaryType handling — stop unconditional Blob interception of binary messages",
"packageName": "react-native-windows",
"email": "gordomacmaster@gmail.com",
"dependentChangeType": "patch"
}
30 changes: 30 additions & 0 deletions vnext/Desktop.IntegrationTests/RNTesterHeadlessTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,36 @@ TEST_CLASS (RNTesterHeadlessTests) {
auto status = TestModule::AwaitCompletion();
Assert::IsTrue(status == TestStatus::Passed, L"Test did not pass (JS did not call markTestPassed within timeout)");
}

BEGIN_TEST_METHOD_ATTRIBUTE(WebSocketArrayBuffer)
TEST_IGNORE()
END_TEST_METHOD_ATTRIBUTE()
TEST_METHOD(WebSocketArrayBuffer) {
TestModule::Reset();

winrt::handle instanceLoadedEvent{CreateEvent(nullptr, TRUE, FALSE, nullptr)};
bool instanceFailed{false};

auto holder = TestReactNativeHostHolder(
L"IntegrationTests/WebSocketArrayBufferTest",
[&instanceLoadedEvent, &instanceFailed](msrn::ReactNativeHost const &host) noexcept {
host.InstanceSettings().InstanceLoaded(
[&instanceLoadedEvent, &instanceFailed](auto const &, msrn::InstanceLoadedEventArgs args) noexcept {
instanceFailed = args.Failed();
SetEvent(instanceLoadedEvent.get());
});
});

WaitForSingleObject(instanceLoadedEvent.get(), INFINITE);
if (instanceFailed) {
auto err = holder.GetLastError();
auto msg = L"InstanceLoaded reported failure: " + (err.empty() ? L"(no error captured)" : err);
Assert::Fail(msg.c_str());
}

auto status = TestModule::AwaitCompletion();
Assert::IsTrue(status == TestStatus::Passed, L"Test did not pass (JS did not call markTestPassed within timeout)");
}
};

} // namespace Microsoft::React::Test
15 changes: 15 additions & 0 deletions vnext/Shared/Modules/IWebSocketModuleContentHandler.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,26 @@ namespace Microsoft::React {
struct IWebSocketModuleContentHandler {
virtual ~IWebSocketModuleContentHandler() noexcept {}

/// Returns true if this handler should process messages for the given socket.
virtual bool CanHandleSocket(int64_t socketId) noexcept = 0;

virtual void ProcessMessage(std::string &&message, winrt::Microsoft::ReactNative::JSValueObject &params) noexcept = 0;

virtual void ProcessMessage(
std::vector<uint8_t> &&message,
winrt::Microsoft::ReactNative::JSValueObject &params) noexcept = 0;

/// Check CanHandleSocket() then ProcessMessage() in one call.
/// Returns true if the message was handled.
virtual bool TryProcessMessage(
int64_t socketId,
std::string &&message,
winrt::Microsoft::ReactNative::JSValueObject &params) noexcept = 0;

virtual bool TryProcessMessage(
int64_t socketId,
std::vector<uint8_t> &&message,
winrt::Microsoft::ReactNative::JSValueObject &params) noexcept = 0;
};

} // namespace Microsoft::React
11 changes: 8 additions & 3 deletions vnext/Shared/Modules/WebSocketModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,23 @@ shared_ptr<IWebSocketResource> WebSocketTurboModule::CreateResource(int64_t id,
if (auto prop = propBag.Get(BlobModuleContentHandlerPropertyId()))
contentHandler = prop.Value().lock();

bool handled = false;
if (contentHandler) {
if (isBinary) {
auto buffer = CryptographicBuffer::DecodeFromBase64String(winrt::to_hstring(message));
winrt::com_array<uint8_t> arr;
CryptographicBuffer::CopyToByteArray(buffer, arr);
auto data = vector<uint8_t>(arr.begin(), arr.end());

contentHandler->ProcessMessage(std::move(data), args);
handled = contentHandler->TryProcessMessage(id, std::move(data), args);
} else {
contentHandler->ProcessMessage(string{message}, args);
handled = contentHandler->TryProcessMessage(id, string{message}, args);
}
} else {
}
// When the content handler processes the message, it takes ownership of the
// payload and populates args itself (e.g. as a blob reference), so we only
// fall back to setting args["data"] when no handler claimed the message.
if (!handled) {
args["data"] = message;
Comment thread
gmacmaster marked this conversation as resolved.
}

Expand Down
37 changes: 37 additions & 0 deletions vnext/Shared/Networking/DefaultBlobResource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ BlobWebSocketModuleContentHandler::BlobWebSocketModuleContentHandler(shared_ptr<

#pragma region IWebSocketModuleContentHandler

bool BlobWebSocketModuleContentHandler::CanHandleSocket(int64_t socketId) noexcept /*override*/ {
scoped_lock lock{m_mutex};
return m_socketIds.find(socketId) != m_socketIds.end();
}

void BlobWebSocketModuleContentHandler::ProcessMessage(
string &&message,
msrn::JSValueObject &params) noexcept /*override*/
Expand All @@ -241,6 +246,38 @@ void BlobWebSocketModuleContentHandler::ProcessMessage(
params[blobKeys.Type] = blobKeys.Blob;
}

bool BlobWebSocketModuleContentHandler::TryProcessMessage(
int64_t socketId,
string &&message,
msrn::JSValueObject &params) noexcept /*override*/
{
scoped_lock lock{m_mutex};
if (m_socketIds.find(socketId) == m_socketIds.end())
return false;

params[blobKeys.Data] = std::move(message);
return true;
}

bool BlobWebSocketModuleContentHandler::TryProcessMessage(
int64_t socketId,
vector<uint8_t> &&message,
msrn::JSValueObject &params) noexcept /*override*/
{
scoped_lock lock{m_mutex};
if (m_socketIds.find(socketId) == m_socketIds.end())
return false;

auto blob = msrn::JSValueObject{
{blobKeys.Offset, 0},
{blobKeys.Size, message.size()},
{blobKeys.BlobId, m_blobPersistor->StoreMessage(std::move(message))}};

params[blobKeys.Data] = std::move(blob);
params[blobKeys.Type] = blobKeys.Blob;
return true;
}

#pragma endregion IWebSocketModuleContentHandler

void BlobWebSocketModuleContentHandler::Register(int64_t socketID) noexcept {
Expand Down
12 changes: 12 additions & 0 deletions vnext/Shared/Networking/DefaultBlobResource.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,23 @@ class BlobWebSocketModuleContentHandler final : public IWebSocketModuleContentHa

#pragma region IWebSocketModuleContentHandler

bool CanHandleSocket(int64_t socketId) noexcept override;

void ProcessMessage(std::string &&message, winrt::Microsoft::ReactNative::JSValueObject &params) noexcept override;

void ProcessMessage(std::vector<uint8_t> &&message, winrt::Microsoft::ReactNative::JSValueObject &params) noexcept
override;

bool TryProcessMessage(
int64_t socketId,
std::string &&message,
winrt::Microsoft::ReactNative::JSValueObject &params) noexcept override;

bool TryProcessMessage(
int64_t socketId,
std::vector<uint8_t> &&message,
winrt::Microsoft::ReactNative::JSValueObject &params) noexcept override;

#pragma endregion IWebSocketModuleContentHandler

void Register(int64_t socketID) noexcept;
Expand Down
4 changes: 4 additions & 0 deletions vnext/overrides.json
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@
"type": "platform",
"file": "src-win/IntegrationTests/websocket_integration_test_server_blob.js"
},
{
"type": "platform",
"file": "src-win/IntegrationTests/WebSocketArrayBufferTest.js"
},
{
"type": "platform",
"file": "src-win/IntegrationTests/WebSocketBinaryTest.js"
Expand Down
76 changes: 76 additions & 0 deletions vnext/src-win/IntegrationTests/WebSocketArrayBufferTest.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/**
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT License.
* @format
*/

'use strict';

const {TurboModuleRegistry} = require('react-native');
const TestModule = TurboModuleRegistry.get('TestModule');

if (!TestModule) {
throw new Error('TestModule is not available');
}

// eslint-disable-next-line @microsoft/sdl/no-insecure-url
const WS_URL = 'ws://localhost:5555/rnw/rntester/websocketbinarytest';

const socket = new WebSocket(WS_URL);
socket.binaryType = 'arraybuffer';

socket.addEventListener('open', () => {
socket.send('hello');
});

socket.addEventListener('message', event => {
const data = event.data;

if (!(data instanceof ArrayBuffer)) {
console.log(
'WebSocketArrayBufferTest FAIL: expected ArrayBuffer, got ' + typeof data,
);
TestModule.markTestPassed(false);
socket.close();
return;
}

const bytes = new Uint8Array(data);
const expected = new Uint8Array([4, 5, 6, 7]);

if (bytes.length !== expected.length) {
console.log(
'WebSocketArrayBufferTest FAIL: expected ' +
expected.length +
' bytes, got ' +
bytes.length,
);
TestModule.markTestPassed(false);
socket.close();
return;
}

for (let i = 0; i < expected.length; i++) {
if (bytes[i] !== expected[i]) {
console.log(
'WebSocketArrayBufferTest FAIL: byte[' +
i +
'] expected ' +
expected[i] +
' got ' +
bytes[i],
);
TestModule.markTestPassed(false);
socket.close();
return;
}
}

TestModule.markTestPassed(true);
socket.close();
});

socket.addEventListener('error', () => {
console.log('WebSocketArrayBufferTest FAIL: WebSocket error');
TestModule.markTestPassed(false);
});
Loading