Skip to content

Commit

Permalink
fix(server): Handle upgrade requests with multiple subprotocols and o…
Browse files Browse the repository at this point in the history
…mit `Sec-WebSocket-Protocol` header if none supported
  • Loading branch information
enisdenjo committed Feb 21, 2022
1 parent d47e1bb commit 9bae064
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 39 deletions.
28 changes: 28 additions & 0 deletions docs/modules/server.md
Expand Up @@ -18,8 +18,36 @@

### Functions

- [handleProtocols](server.md#handleprotocols)
- [makeServer](server.md#makeserver)

## Other

### handleProtocols

**handleProtocols**(`protocols`): typeof [`GRAPHQL_TRANSPORT_WS_PROTOCOL`](common.md#graphql_transport_ws_protocol) \| ``false``

Helper utility for choosing the "graphql-transport-ws" subprotocol from
a set of WebSocket subprotocols.

Accepts a set of already extracted WebSocket subprotocols or the raw
Sec-WebSocket-Protocol header value. In either case, if the right
protocol appears, it will be returned.

By specification, the server should not provide a value with Sec-WebSocket-Protocol
if it does not agree with client's subprotocols. The client has a responsibility
to handle the connection afterwards.

#### Parameters

| Name | Type |
| :------ | :------ |
| `protocols` | `string` \| `Set`<`string`\> |

#### Returns

typeof [`GRAPHQL_TRANSPORT_WS_PROTOCOL`](common.md#graphql_transport_ws_protocol) \| ``false``

## Server

### GraphQLExecutionContextValue
Expand Down
70 changes: 48 additions & 22 deletions src/__tests__/use.ts
Expand Up @@ -15,6 +15,7 @@ import {
WSExtra,
UWSExtra,
FastifyExtra,
TClient,
} from './utils';

// silence console.error calls for nicer tests overview
Expand All @@ -30,44 +31,69 @@ afterAll(() => {

for (const { tServer, skipUWS, startTServer } of tServers) {
describe(tServer, () => {
it('should allow connections with valid protocols only', async () => {
it("should omit the subprotocol from the response if there's no valid one offered by the client", async () => {
const { url } = await startTServer();

const warn = console.warn;
console.warn = () => {
/* hide warnings for test */
};

let client = await createTClient(url, 'notme');
await client.waitForClose((event) => {
expect(event.code).toBe(CloseCode.SubprotocolNotAcceptable);
expect(event.reason).toBe('Subprotocol not acceptable');
expect(event.wasClean).toBeTruthy();
});
let client: TClient;
try {
client = await createTClient(url, ['notme', 'notmeither']);
} catch (err) {
expect(err).toMatchInlineSnapshot(
'[Error: Server sent no subprotocol]',
);
}

client = await createTClient(url, ['graphql', 'json']);
await client.waitForClose((event) => {
expect(event.code).toBe(CloseCode.SubprotocolNotAcceptable);
expect(event.reason).toBe('Subprotocol not acceptable');
expect(event.wasClean).toBeTruthy();
});
try {
client = await createTClient(url, 'notme');
} catch (err) {
expect(err).toMatchInlineSnapshot(
'[Error: Server sent no subprotocol]',
);
}

client = await createTClient(
url,
GRAPHQL_TRANSPORT_WS_PROTOCOL + 'gibberish',
);
await client.waitForClose((event) => {
expect(event.code).toBe(CloseCode.SubprotocolNotAcceptable);
expect(event.reason).toBe('Subprotocol not acceptable');
expect(event.wasClean).toBeTruthy();
});
try {
client = await createTClient(url, ['graphql', 'json']);
} catch (err) {
expect(err).toMatchInlineSnapshot(
'[Error: Server sent no subprotocol]',
);
}

try {
client = await createTClient(
url,
GRAPHQL_TRANSPORT_WS_PROTOCOL + 'gibberish',
);
} catch (err) {
expect(err).toMatchInlineSnapshot(
'[Error: Server sent no subprotocol]',
);
}

client = await createTClient(url, GRAPHQL_TRANSPORT_WS_PROTOCOL);
await client.waitForClose(
() => fail('shouldnt close for valid protocol'),
30, // should be kicked off within this time
);

client = await createTClient(url, [
'this',
GRAPHQL_TRANSPORT_WS_PROTOCOL,
'one',
]);
await client.waitForClose(
(e) => {
console.log(e);
fail('shouldnt close for valid protocol');
},
30, // should be kicked off within this time
);

console.warn = warn;
});

Expand Down
27 changes: 15 additions & 12 deletions src/__tests__/utils/tclient.ts
@@ -1,27 +1,30 @@
import WebSocket from 'ws';
import { GRAPHQL_TRANSPORT_WS_PROTOCOL } from '../../common';

export interface TClient {
ws: WebSocket;
waitForMessage: (
test?: (data: WebSocket.MessageEvent) => void,
expire?: number,
) => Promise<void>;
waitForClose: (
test?: (event: WebSocket.CloseEvent) => void,
expire?: number,
) => Promise<void>;
}

// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
export function createTClient(
url: string,
protocols: string | string[] = GRAPHQL_TRANSPORT_WS_PROTOCOL,
) {
): Promise<TClient> {
let closeEvent: WebSocket.CloseEvent;
const queue: WebSocket.MessageEvent[] = [];
return new Promise<{
ws: WebSocket;
waitForMessage: (
test?: (data: WebSocket.MessageEvent) => void,
expire?: number,
) => Promise<void>;
waitForClose: (
test?: (event: WebSocket.CloseEvent) => void,
expire?: number,
) => Promise<void>;
}>((resolve) => {
return new Promise((resolve, reject) => {
const ws = new WebSocket(url, protocols);
ws.onclose = (event) => (closeEvent = event); // just so that none are missed
ws.onmessage = (message) => queue.push(message); // guarantee message delivery with a queue
ws.once('error', reject);
ws.once('open', () =>
resolve({
ws,
Expand Down
27 changes: 27 additions & 0 deletions src/server.ts
Expand Up @@ -848,3 +848,30 @@ export function makeServer<
},
};
}

/**
* Helper utility for choosing the "graphql-transport-ws" subprotocol from
* a set of WebSocket subprotocols.
*
* Accepts a set of already extracted WebSocket subprotocols or the raw
* Sec-WebSocket-Protocol header value. In either case, if the right
* protocol appears, it will be returned.
*
* By specification, the server should not provide a value with Sec-WebSocket-Protocol
* if it does not agree with client's subprotocols. The client has a responsibility
* to handle the connection afterwards.
*/
export function handleProtocols(
protocols: Set<string> | string,
): typeof GRAPHQL_TRANSPORT_WS_PROTOCOL | false {
return (
typeof protocols === 'string'
? protocols
.split(',')
.map((p) => p.trim())
.includes(GRAPHQL_TRANSPORT_WS_PROTOCOL)
: protocols.has(GRAPHQL_TRANSPORT_WS_PROTOCOL)
)
? GRAPHQL_TRANSPORT_WS_PROTOCOL
: false;
}
5 changes: 4 additions & 1 deletion src/use/fastify-websocket.ts
@@ -1,6 +1,6 @@
import type { FastifyRequest } from 'fastify';
import type * as fastifyWebsocket from 'fastify-websocket';
import { makeServer, ServerOptions } from '../server';
import { handleProtocols, makeServer, ServerOptions } from '../server';
import {
GRAPHQL_TRANSPORT_WS_PROTOCOL,
ConnectionInitMessage,
Expand Down Expand Up @@ -56,6 +56,9 @@ export function makeHandler<
return function handler(connection, request) {
const { socket } = connection;

// might be too late, but meh
this.websocketServer.options.handleProtocols = handleProtocols;

// handle server emitted errors only if not already handling
if (!handlingServerEmittedErrors) {
handlingServerEmittedErrors = true;
Expand Down
11 changes: 8 additions & 3 deletions src/use/uWebSockets.ts
@@ -1,6 +1,6 @@
import type * as uWS from 'uWebSockets.js';
import type http from 'http';
import { makeServer, ServerOptions } from '../server';
import { handleProtocols, makeServer, ServerOptions } from '../server';
import { ConnectionInitMessage, CloseCode } from '../common';
import { limitCloseReason } from '../utils';

Expand Down Expand Up @@ -92,6 +92,7 @@ export function makeBehavior<

return {
...behavior,

pong(...args) {
behavior.pong?.(...args);
const [socket] = args;
Expand Down Expand Up @@ -123,7 +124,8 @@ export function makeBehavior<
},
},
req.getHeader('sec-websocket-key'),
req.getHeader('sec-websocket-protocol'),
handleProtocols(req.getHeader('sec-websocket-protocol')) ||
new Uint8Array(),
req.getHeader('sec-websocket-extensions'),
context,
);
Expand All @@ -147,7 +149,10 @@ export function makeBehavior<

client.closed = server.opened(
{
protocol: persistedRequest.headers['sec-websocket-protocol'] ?? '',
protocol:
handleProtocols(
persistedRequest.headers['sec-websocket-protocol'] || '',
) || '',
send: async (message) => {
// the socket might have been destroyed in the meantime
if (!clients.has(socket)) return;
Expand Down
4 changes: 3 additions & 1 deletion src/use/ws.ts
@@ -1,6 +1,6 @@
import type * as http from 'http';
import type * as ws from 'ws';
import { makeServer, ServerOptions } from '../server';
import { handleProtocols, makeServer, ServerOptions } from '../server';
import {
GRAPHQL_TRANSPORT_WS_PROTOCOL,
ConnectionInitMessage,
Expand Down Expand Up @@ -54,6 +54,8 @@ export function useServer<
const isProd = process.env.NODE_ENV === 'production';
const server = makeServer(options);

ws.options.handleProtocols = handleProtocols;

ws.once('error', (err) => {
console.error(
'Internal error emitted on the WebSocket server. ' +
Expand Down

0 comments on commit 9bae064

Please sign in to comment.