diff --git a/docs/modules/server.md b/docs/modules/server.md index f91826cc..7bf019c8 100644 --- a/docs/modules/server.md +++ b/docs/modules/server.md @@ -18,8 +18,36 @@ ### Functions +- [handleProtocols](server.md#handleprotocols) - [makeServer](server.md#makeserver) +## Other + +### handleProtocols + +▸ **handleProtocols**(`protocols`): typeof [`GRAPHQL_TRANSPORT_WS_PROTOCOL`](common.md#graphql_transport_ws_protocol) \| ``false`` + +Helper utility for choosing the "graphql-transport-ws" subprotocol from +a set of WebSocket subprotocols. + +Accepts a set of already extracted WebSocket subprotocols or the raw +Sec-WebSocket-Protocol header value. In either case, if the right +protocol appears, it will be returned. + +By specification, the server should not provide a value with Sec-WebSocket-Protocol +if it does not agree with client's subprotocols. The client has a responsibility +to handle the connection afterwards. + +#### Parameters + +| Name | Type | +| :------ | :------ | +| `protocols` | `string` \| `Set`<`string`\> | + +#### Returns + +typeof [`GRAPHQL_TRANSPORT_WS_PROTOCOL`](common.md#graphql_transport_ws_protocol) \| ``false`` + ## Server ### GraphQLExecutionContextValue diff --git a/src/__tests__/use.ts b/src/__tests__/use.ts index bfa1775b..bfdc5ed7 100644 --- a/src/__tests__/use.ts +++ b/src/__tests__/use.ts @@ -15,6 +15,7 @@ import { WSExtra, UWSExtra, FastifyExtra, + TClient, } from './utils'; // silence console.error calls for nicer tests overview @@ -30,7 +31,7 @@ afterAll(() => { for (const { tServer, skipUWS, startTServer } of tServers) { describe(tServer, () => { - it('should allow connections with valid protocols only', async () => { + it("should omit the subprotocol from the response if there's no valid one offered by the client", async () => { const { url } = await startTServer(); const warn = console.warn; @@ -38,29 +39,41 @@ for (const { tServer, skipUWS, startTServer } of tServers) { /* hide warnings for test */ }; - let client = await createTClient(url, 'notme'); - await client.waitForClose((event) => { - expect(event.code).toBe(CloseCode.SubprotocolNotAcceptable); - expect(event.reason).toBe('Subprotocol not acceptable'); - expect(event.wasClean).toBeTruthy(); - }); + let client: TClient; + try { + client = await createTClient(url, ['notme', 'notmeither']); + } catch (err) { + expect(err).toMatchInlineSnapshot( + '[Error: Server sent no subprotocol]', + ); + } - client = await createTClient(url, ['graphql', 'json']); - await client.waitForClose((event) => { - expect(event.code).toBe(CloseCode.SubprotocolNotAcceptable); - expect(event.reason).toBe('Subprotocol not acceptable'); - expect(event.wasClean).toBeTruthy(); - }); + try { + client = await createTClient(url, 'notme'); + } catch (err) { + expect(err).toMatchInlineSnapshot( + '[Error: Server sent no subprotocol]', + ); + } - client = await createTClient( - url, - GRAPHQL_TRANSPORT_WS_PROTOCOL + 'gibberish', - ); - await client.waitForClose((event) => { - expect(event.code).toBe(CloseCode.SubprotocolNotAcceptable); - expect(event.reason).toBe('Subprotocol not acceptable'); - expect(event.wasClean).toBeTruthy(); - }); + try { + client = await createTClient(url, ['graphql', 'json']); + } catch (err) { + expect(err).toMatchInlineSnapshot( + '[Error: Server sent no subprotocol]', + ); + } + + try { + client = await createTClient( + url, + GRAPHQL_TRANSPORT_WS_PROTOCOL + 'gibberish', + ); + } catch (err) { + expect(err).toMatchInlineSnapshot( + '[Error: Server sent no subprotocol]', + ); + } client = await createTClient(url, GRAPHQL_TRANSPORT_WS_PROTOCOL); await client.waitForClose( @@ -68,6 +81,19 @@ for (const { tServer, skipUWS, startTServer } of tServers) { 30, // should be kicked off within this time ); + client = await createTClient(url, [ + 'this', + GRAPHQL_TRANSPORT_WS_PROTOCOL, + 'one', + ]); + await client.waitForClose( + (e) => { + console.log(e); + fail('shouldnt close for valid protocol'); + }, + 30, // should be kicked off within this time + ); + console.warn = warn; }); diff --git a/src/__tests__/utils/tclient.ts b/src/__tests__/utils/tclient.ts index ce02be93..d7d971bb 100644 --- a/src/__tests__/utils/tclient.ts +++ b/src/__tests__/utils/tclient.ts @@ -1,27 +1,30 @@ import WebSocket from 'ws'; import { GRAPHQL_TRANSPORT_WS_PROTOCOL } from '../../common'; +export interface TClient { + ws: WebSocket; + waitForMessage: ( + test?: (data: WebSocket.MessageEvent) => void, + expire?: number, + ) => Promise; + waitForClose: ( + test?: (event: WebSocket.CloseEvent) => void, + expire?: number, + ) => Promise; +} + // eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types export function createTClient( url: string, protocols: string | string[] = GRAPHQL_TRANSPORT_WS_PROTOCOL, -) { +): Promise { let closeEvent: WebSocket.CloseEvent; const queue: WebSocket.MessageEvent[] = []; - return new Promise<{ - ws: WebSocket; - waitForMessage: ( - test?: (data: WebSocket.MessageEvent) => void, - expire?: number, - ) => Promise; - waitForClose: ( - test?: (event: WebSocket.CloseEvent) => void, - expire?: number, - ) => Promise; - }>((resolve) => { + return new Promise((resolve, reject) => { const ws = new WebSocket(url, protocols); ws.onclose = (event) => (closeEvent = event); // just so that none are missed ws.onmessage = (message) => queue.push(message); // guarantee message delivery with a queue + ws.once('error', reject); ws.once('open', () => resolve({ ws, diff --git a/src/server.ts b/src/server.ts index 012c176e..2a2cbb4b 100644 --- a/src/server.ts +++ b/src/server.ts @@ -848,3 +848,30 @@ export function makeServer< }, }; } + +/** + * Helper utility for choosing the "graphql-transport-ws" subprotocol from + * a set of WebSocket subprotocols. + * + * Accepts a set of already extracted WebSocket subprotocols or the raw + * Sec-WebSocket-Protocol header value. In either case, if the right + * protocol appears, it will be returned. + * + * By specification, the server should not provide a value with Sec-WebSocket-Protocol + * if it does not agree with client's subprotocols. The client has a responsibility + * to handle the connection afterwards. + */ +export function handleProtocols( + protocols: Set | string, +): typeof GRAPHQL_TRANSPORT_WS_PROTOCOL | false { + return ( + typeof protocols === 'string' + ? protocols + .split(',') + .map((p) => p.trim()) + .includes(GRAPHQL_TRANSPORT_WS_PROTOCOL) + : protocols.has(GRAPHQL_TRANSPORT_WS_PROTOCOL) + ) + ? GRAPHQL_TRANSPORT_WS_PROTOCOL + : false; +} diff --git a/src/use/fastify-websocket.ts b/src/use/fastify-websocket.ts index e1304988..d4f2172c 100644 --- a/src/use/fastify-websocket.ts +++ b/src/use/fastify-websocket.ts @@ -1,6 +1,6 @@ import type { FastifyRequest } from 'fastify'; import type * as fastifyWebsocket from 'fastify-websocket'; -import { makeServer, ServerOptions } from '../server'; +import { handleProtocols, makeServer, ServerOptions } from '../server'; import { GRAPHQL_TRANSPORT_WS_PROTOCOL, ConnectionInitMessage, @@ -56,6 +56,9 @@ export function makeHandler< return function handler(connection, request) { const { socket } = connection; + // might be too late, but meh + this.websocketServer.options.handleProtocols = handleProtocols; + // handle server emitted errors only if not already handling if (!handlingServerEmittedErrors) { handlingServerEmittedErrors = true; diff --git a/src/use/uWebSockets.ts b/src/use/uWebSockets.ts index a789a933..d175b3fc 100644 --- a/src/use/uWebSockets.ts +++ b/src/use/uWebSockets.ts @@ -1,6 +1,6 @@ import type * as uWS from 'uWebSockets.js'; import type http from 'http'; -import { makeServer, ServerOptions } from '../server'; +import { handleProtocols, makeServer, ServerOptions } from '../server'; import { ConnectionInitMessage, CloseCode } from '../common'; import { limitCloseReason } from '../utils'; @@ -92,6 +92,7 @@ export function makeBehavior< return { ...behavior, + pong(...args) { behavior.pong?.(...args); const [socket] = args; @@ -123,7 +124,8 @@ export function makeBehavior< }, }, req.getHeader('sec-websocket-key'), - req.getHeader('sec-websocket-protocol'), + handleProtocols(req.getHeader('sec-websocket-protocol')) || + new Uint8Array(), req.getHeader('sec-websocket-extensions'), context, ); @@ -147,7 +149,10 @@ export function makeBehavior< client.closed = server.opened( { - protocol: persistedRequest.headers['sec-websocket-protocol'] ?? '', + protocol: + handleProtocols( + persistedRequest.headers['sec-websocket-protocol'] || '', + ) || '', send: async (message) => { // the socket might have been destroyed in the meantime if (!clients.has(socket)) return; diff --git a/src/use/ws.ts b/src/use/ws.ts index af61ac6b..1ed7a17a 100644 --- a/src/use/ws.ts +++ b/src/use/ws.ts @@ -1,6 +1,6 @@ import type * as http from 'http'; import type * as ws from 'ws'; -import { makeServer, ServerOptions } from '../server'; +import { handleProtocols, makeServer, ServerOptions } from '../server'; import { GRAPHQL_TRANSPORT_WS_PROTOCOL, ConnectionInitMessage, @@ -54,6 +54,8 @@ export function useServer< const isProd = process.env.NODE_ENV === 'production'; const server = makeServer(options); + ws.options.handleProtocols = handleProtocols; + ws.once('error', (err) => { console.error( 'Internal error emitted on the WebSocket server. ' +