From 02dfaf1aa2c798315d0dd7f809cc469771b36ffc Mon Sep 17 00:00:00 2001 From: DD Date: Fri, 14 Apr 2023 23:26:37 +0300 Subject: [PATCH] refactor: abstract identify throttling and correct max_concurrency handling (#9375) * refactor: properly support max_concurrency ratelimit keys * fix: properly block for same key * chore: export session state * chore: throttler no longer requires manager * refactor: abstract throttlers * chore: proper member order * chore: remove leftover debug log * chore: use @link tag in doc comment Co-authored-by: Jiralite <33201955+Jiralite@users.noreply.github.com> * chore: suggested changes * fix(WebSocketShard): cancel identify if the shard closed in the meantime * refactor(throttlers): support abort signals * fix: memory leak * chore: remove leftover --------- Co-authored-by: Jiralite <33201955+Jiralite@users.noreply.github.com> Co-authored-by: kodiakhq[bot] <49736102+kodiakhq[bot]@users.noreply.github.com> --- .../strategy/WorkerShardingStrategy.test.ts | 30 ++++----- .../__tests__/util/IdentifyThrottler.test.ts | 46 ------------- .../util/SimpleIdentifyThrottler.test.ts | 32 +++++++++ packages/ws/src/index.ts | 4 +- .../context/IContextFetchingStrategy.ts | 28 ++++++-- .../context/SimpleContextFetchingStrategy.ts | 30 ++++----- .../context/WorkerContextFetchingStrategy.ts | 60 +++++++++++++---- .../sharding/SimpleShardingStrategy.ts | 1 + .../sharding/WorkerShardingStrategy.ts | 65 ++++++++++++++----- .../ws/src/throttling/IIdentifyThrottler.ts | 11 ++++ .../src/throttling/SimpleIdentifyThrottler.ts | 50 ++++++++++++++ packages/ws/src/utils/IdentifyThrottler.ts | 39 ----------- packages/ws/src/utils/WorkerBootstrapper.ts | 14 ++-- packages/ws/src/utils/constants.ts | 7 +- packages/ws/src/ws/WebSocketManager.ts | 7 +- packages/ws/src/ws/WebSocketShard.ts | 16 ++++- 16 files changed, 279 insertions(+), 161 deletions(-) delete mode 100644 packages/ws/__tests__/util/IdentifyThrottler.test.ts create mode 100644 packages/ws/__tests__/util/SimpleIdentifyThrottler.test.ts create mode 100644 packages/ws/src/throttling/IIdentifyThrottler.ts create mode 100644 packages/ws/src/throttling/SimpleIdentifyThrottler.ts delete mode 100644 packages/ws/src/utils/IdentifyThrottler.ts diff --git a/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts b/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts index 472f518caf5d..b62fceb59b86 100644 --- a/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts +++ b/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts @@ -57,9 +57,9 @@ vi.mock('node:worker_threads', async () => { this.emit('online'); // same deal here setImmediate(() => { - const message = { + const message: WorkerReceivePayload = { op: WorkerReceivePayloadOp.WorkerReady, - } satisfies WorkerReceivePayload; + }; this.emit('message', message); }); }); @@ -68,39 +68,39 @@ vi.mock('node:worker_threads', async () => { public postMessage(message: WorkerSendPayload) { switch (message.op) { case WorkerSendPayloadOp.Connect: { - const response = { + const response: WorkerReceivePayload = { op: WorkerReceivePayloadOp.Connected, shardId: message.shardId, - } satisfies WorkerReceivePayload; + }; this.emit('message', response); break; } case WorkerSendPayloadOp.Destroy: { - const response = { + const response: WorkerReceivePayload = { op: WorkerReceivePayloadOp.Destroyed, shardId: message.shardId, - } satisfies WorkerReceivePayload; + }; this.emit('message', response); break; } case WorkerSendPayloadOp.Send: { if (message.payload.op === GatewayOpcodes.RequestGuildMembers) { - const response = { + const response: WorkerReceivePayload = { op: WorkerReceivePayloadOp.Event, shardId: message.shardId, event: WebSocketShardEvents.Dispatch, data: memberChunkData, - } satisfies WorkerReceivePayload; + }; this.emit('message', response); // Fetch session info - const sessionFetch = { + const sessionFetch: WorkerReceivePayload = { op: WorkerReceivePayloadOp.RetrieveSessionInfo, shardId: message.shardId, nonce: Math.random(), - } satisfies WorkerReceivePayload; + }; this.emit('message', sessionFetch); } @@ -111,16 +111,16 @@ vi.mock('node:worker_threads', async () => { case WorkerSendPayloadOp.SessionInfoResponse: { message.session ??= sessionInfo; - const session = { + const session: WorkerReceivePayload = { op: WorkerReceivePayloadOp.UpdateSessionInfo, shardId: message.session.shardId, session: { ...message.session, sequence: message.session.sequence + 1 }, - } satisfies WorkerReceivePayload; + }; this.emit('message', session); break; } - case WorkerSendPayloadOp.ShardCanIdentify: { + case WorkerSendPayloadOp.ShardIdentifyResponse: { break; } @@ -198,10 +198,10 @@ test('spawn, connect, send a message, session info, and destroy', async () => { expect.objectContaining({ workerData: expect.objectContaining({ shardIds: [0, 1] }) }), ); - const payload = { + const payload: GatewaySendPayload = { op: GatewayOpcodes.RequestGuildMembers, d: { guild_id: '123', limit: 0, query: '' }, - } satisfies GatewaySendPayload; + }; await manager.send(0, payload); expect(mockSend).toHaveBeenCalledWith(0, payload); expect(managerEmitSpy).toHaveBeenCalledWith(WebSocketShardEvents.Dispatch, { diff --git a/packages/ws/__tests__/util/IdentifyThrottler.test.ts b/packages/ws/__tests__/util/IdentifyThrottler.test.ts deleted file mode 100644 index 74417d5881eb..000000000000 --- a/packages/ws/__tests__/util/IdentifyThrottler.test.ts +++ /dev/null @@ -1,46 +0,0 @@ -import { setTimeout as sleep } from 'node:timers/promises'; -import { expect, test, vi, type Mock } from 'vitest'; -import { IdentifyThrottler, type WebSocketManager } from '../../src/index.js'; - -vi.mock('node:timers/promises', () => ({ - setTimeout: vi.fn(), -})); - -const fetchGatewayInformation = vi.fn(); - -const manager = { - fetchGatewayInformation, -} as unknown as WebSocketManager; - -const throttler = new IdentifyThrottler(manager); - -vi.useFakeTimers(); - -const NOW = vi.fn().mockReturnValue(Date.now()); -global.Date.now = NOW; - -test('wait for identify', async () => { - fetchGatewayInformation.mockReturnValue({ - session_start_limit: { - max_concurrency: 2, - }, - }); - - // First call should never wait - await throttler.waitForIdentify(); - expect(sleep).not.toHaveBeenCalled(); - - // Second call still won't wait because max_concurrency is 2 - await throttler.waitForIdentify(); - expect(sleep).not.toHaveBeenCalled(); - - // Third call should wait - await throttler.waitForIdentify(); - expect(sleep).toHaveBeenCalled(); - - (sleep as Mock).mockRestore(); - - // Fourth call shouldn't wait, because our max_concurrency is 2 and we waited for a reset - await throttler.waitForIdentify(); - expect(sleep).not.toHaveBeenCalled(); -}); diff --git a/packages/ws/__tests__/util/SimpleIdentifyThrottler.test.ts b/packages/ws/__tests__/util/SimpleIdentifyThrottler.test.ts new file mode 100644 index 000000000000..518b6d8f2218 --- /dev/null +++ b/packages/ws/__tests__/util/SimpleIdentifyThrottler.test.ts @@ -0,0 +1,32 @@ +import { setTimeout as sleep } from 'node:timers/promises'; +import { expect, test, vi, type Mock } from 'vitest'; +import { SimpleIdentifyThrottler } from '../../src/index.js'; + +vi.mock('node:timers/promises', () => ({ + setTimeout: vi.fn(), +})); + +const throttler = new SimpleIdentifyThrottler(2); + +vi.useFakeTimers(); + +const NOW = vi.fn().mockReturnValue(Date.now()); +global.Date.now = NOW; + +test('basic case', async () => { + // Those shouldn't wait since they're in different keys + + await throttler.waitForIdentify(0); + expect(sleep).not.toHaveBeenCalled(); + + await throttler.waitForIdentify(1); + expect(sleep).not.toHaveBeenCalled(); + + // Those should wait + + await throttler.waitForIdentify(2); + expect(sleep).toHaveBeenCalledTimes(1); + + await throttler.waitForIdentify(3); + expect(sleep).toHaveBeenCalledTimes(2); +}); diff --git a/packages/ws/src/index.ts b/packages/ws/src/index.ts index efeba21bd760..9283dd549a5f 100644 --- a/packages/ws/src/index.ts +++ b/packages/ws/src/index.ts @@ -6,8 +6,10 @@ export * from './strategies/sharding/IShardingStrategy.js'; export * from './strategies/sharding/SimpleShardingStrategy.js'; export * from './strategies/sharding/WorkerShardingStrategy.js'; +export * from './throttling/IIdentifyThrottler.js'; +export * from './throttling/SimpleIdentifyThrottler.js'; + export * from './utils/constants.js'; -export * from './utils/IdentifyThrottler.js'; export * from './utils/WorkerBootstrapper.js'; export * from './ws/WebSocketManager.js'; diff --git a/packages/ws/src/strategies/context/IContextFetchingStrategy.ts b/packages/ws/src/strategies/context/IContextFetchingStrategy.ts index 6f6c3155d687..d2d056fbae3c 100644 --- a/packages/ws/src/strategies/context/IContextFetchingStrategy.ts +++ b/packages/ws/src/strategies/context/IContextFetchingStrategy.ts @@ -5,7 +5,13 @@ import type { SessionInfo, WebSocketManager, WebSocketManagerOptions } from '../ export interface FetchingStrategyOptions extends Omit< WebSocketManagerOptions, - 'buildStrategy' | 'rest' | 'retrieveSessionInfo' | 'shardCount' | 'shardIds' | 'updateSessionInfo' + | 'buildIdentifyThrottler' + | 'buildStrategy' + | 'rest' + | 'retrieveSessionInfo' + | 'shardCount' + | 'shardIds' + | 'updateSessionInfo' > { readonly gatewayInformation: APIGatewayBotInfo; readonly shardCount: number; @@ -18,13 +24,25 @@ export interface IContextFetchingStrategy { readonly options: FetchingStrategyOptions; retrieveSessionInfo(shardId: number): Awaitable; updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null): Awaitable; - waitForIdentify(): Promise; + /** + * Resolves once the given shard should be allowed to identify, or rejects if the operation was aborted + */ + waitForIdentify(shardId: number, signal: AbortSignal): Promise; } export async function managerToFetchingStrategyOptions(manager: WebSocketManager): Promise { - // eslint-disable-next-line @typescript-eslint/unbound-method - const { buildStrategy, retrieveSessionInfo, updateSessionInfo, shardCount, shardIds, rest, ...managerOptions } = - manager.options; + /* eslint-disable @typescript-eslint/unbound-method */ + const { + buildIdentifyThrottler, + buildStrategy, + retrieveSessionInfo, + updateSessionInfo, + shardCount, + shardIds, + rest, + ...managerOptions + } = manager.options; + /* eslint-enable @typescript-eslint/unbound-method */ return { ...managerOptions, diff --git a/packages/ws/src/strategies/context/SimpleContextFetchingStrategy.ts b/packages/ws/src/strategies/context/SimpleContextFetchingStrategy.ts index 4865f68e131a..511179aa4991 100644 --- a/packages/ws/src/strategies/context/SimpleContextFetchingStrategy.ts +++ b/packages/ws/src/strategies/context/SimpleContextFetchingStrategy.ts @@ -1,29 +1,26 @@ -import { IdentifyThrottler } from '../../utils/IdentifyThrottler.js'; +import type { IIdentifyThrottler } from '../../throttling/IIdentifyThrottler.js'; import type { SessionInfo, WebSocketManager } from '../../ws/WebSocketManager.js'; import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy.js'; export class SimpleContextFetchingStrategy implements IContextFetchingStrategy { // This strategy assumes every shard is running under the same process - therefore we need a single // IdentifyThrottler per manager. - private static throttlerCache = new WeakMap(); + private static throttlerCache = new WeakMap(); - private static ensureThrottler(manager: WebSocketManager): IdentifyThrottler { - const existing = SimpleContextFetchingStrategy.throttlerCache.get(manager); - if (existing) { - return existing; + private static async ensureThrottler(manager: WebSocketManager): Promise { + const throttler = SimpleContextFetchingStrategy.throttlerCache.get(manager); + if (throttler) { + return throttler; } - const throttler = new IdentifyThrottler(manager); - SimpleContextFetchingStrategy.throttlerCache.set(manager, throttler); - return throttler; - } - - private readonly throttler: IdentifyThrottler; + const newThrottler = await manager.options.buildIdentifyThrottler(manager); + SimpleContextFetchingStrategy.throttlerCache.set(manager, newThrottler); - public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) { - this.throttler = SimpleContextFetchingStrategy.ensureThrottler(manager); + return newThrottler; } + public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) {} + public async retrieveSessionInfo(shardId: number): Promise { return this.manager.options.retrieveSessionInfo(shardId); } @@ -32,7 +29,8 @@ export class SimpleContextFetchingStrategy implements IContextFetchingStrategy { return this.manager.options.updateSessionInfo(shardId, sessionInfo); } - public async waitForIdentify(): Promise { - await this.throttler.waitForIdentify(); + public async waitForIdentify(shardId: number, signal: AbortSignal): Promise { + const throttler = await SimpleContextFetchingStrategy.ensureThrottler(this.manager); + await throttler.waitForIdentify(shardId, signal); } } diff --git a/packages/ws/src/strategies/context/WorkerContextFetchingStrategy.ts b/packages/ws/src/strategies/context/WorkerContextFetchingStrategy.ts index 5039123ce489..5a79eb88f99f 100644 --- a/packages/ws/src/strategies/context/WorkerContextFetchingStrategy.ts +++ b/packages/ws/src/strategies/context/WorkerContextFetchingStrategy.ts @@ -9,10 +9,17 @@ import { } from '../sharding/WorkerShardingStrategy.js'; import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy.js'; +// Because the global types are incomplete for whatever reason +interface PolyFillAbortSignal { + readonly aborted: boolean; + addEventListener(type: 'abort', listener: () => void): void; + removeEventListener(type: 'abort', listener: () => void): void; +} + export class WorkerContextFetchingStrategy implements IContextFetchingStrategy { private readonly sessionPromises = new Collection void>(); - private readonly waitForIdentifyPromises = new Collection void>(); + private readonly waitForIdentifyPromises = new Collection(); public constructor(public readonly options: FetchingStrategyOptions) { if (isMainThread) { @@ -25,8 +32,14 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy { this.sessionPromises.delete(payload.nonce); } - if (payload.op === WorkerSendPayloadOp.ShardCanIdentify) { - this.waitForIdentifyPromises.get(payload.nonce)?.(); + if (payload.op === WorkerSendPayloadOp.ShardIdentifyResponse) { + const promise = this.waitForIdentifyPromises.get(payload.nonce); + if (payload.ok) { + promise?.resolve(); + } else { + promise?.reject(); + } + this.waitForIdentifyPromises.delete(payload.nonce); } }); @@ -34,11 +47,11 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy { public async retrieveSessionInfo(shardId: number): Promise { const nonce = Math.random(); - const payload = { + const payload: WorkerReceivePayload = { op: WorkerReceivePayloadOp.RetrieveSessionInfo, shardId, nonce, - } satisfies WorkerReceivePayload; + }; // eslint-disable-next-line no-promise-executor-return const promise = new Promise((resolve) => this.sessionPromises.set(nonce, resolve)); parentPort!.postMessage(payload); @@ -46,23 +59,44 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy { } public updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null) { - const payload = { + const payload: WorkerReceivePayload = { op: WorkerReceivePayloadOp.UpdateSessionInfo, shardId, session: sessionInfo, - } satisfies WorkerReceivePayload; + }; parentPort!.postMessage(payload); } - public async waitForIdentify(): Promise { + public async waitForIdentify(shardId: number, signal: AbortSignal): Promise { const nonce = Math.random(); - const payload = { + + const payload: WorkerReceivePayload = { op: WorkerReceivePayloadOp.WaitForIdentify, nonce, - } satisfies WorkerReceivePayload; - // eslint-disable-next-line no-promise-executor-return - const promise = new Promise((resolve) => this.waitForIdentifyPromises.set(nonce, resolve)); + shardId, + }; + const promise = new Promise((resolve, reject) => + // eslint-disable-next-line no-promise-executor-return + this.waitForIdentifyPromises.set(nonce, { resolve, reject }), + ); + parentPort!.postMessage(payload); - return promise; + + const listener = () => { + const payload: WorkerReceivePayload = { + op: WorkerReceivePayloadOp.CancelIdentify, + nonce, + }; + + parentPort!.postMessage(payload); + }; + + (signal as unknown as PolyFillAbortSignal).addEventListener('abort', listener); + + try { + await promise; + } finally { + (signal as unknown as PolyFillAbortSignal).removeEventListener('abort', listener); + } } } diff --git a/packages/ws/src/strategies/sharding/SimpleShardingStrategy.ts b/packages/ws/src/strategies/sharding/SimpleShardingStrategy.ts index 17b5ffe07c49..9e3a54a9953d 100644 --- a/packages/ws/src/strategies/sharding/SimpleShardingStrategy.ts +++ b/packages/ws/src/strategies/sharding/SimpleShardingStrategy.ts @@ -23,6 +23,7 @@ export class SimpleShardingStrategy implements IShardingStrategy { */ public async spawn(shardIds: number[]) { const strategyOptions = await managerToFetchingStrategyOptions(this.manager); + for (const shardId of shardIds) { const strategy = new SimpleContextFetchingStrategy(this.manager, strategyOptions); const shard = new WebSocketShard(strategy, shardId); diff --git a/packages/ws/src/strategies/sharding/WorkerShardingStrategy.ts b/packages/ws/src/strategies/sharding/WorkerShardingStrategy.ts index f61ad68fb187..ca92d6c01f26 100644 --- a/packages/ws/src/strategies/sharding/WorkerShardingStrategy.ts +++ b/packages/ws/src/strategies/sharding/WorkerShardingStrategy.ts @@ -3,7 +3,7 @@ import { join, isAbsolute, resolve } from 'node:path'; import { Worker } from 'node:worker_threads'; import { Collection } from '@discordjs/collection'; import type { GatewaySendPayload } from 'discord-api-types/v10'; -import { IdentifyThrottler } from '../../utils/IdentifyThrottler.js'; +import type { IIdentifyThrottler } from '../../throttling/IIdentifyThrottler'; import type { SessionInfo, WebSocketManager } from '../../ws/WebSocketManager'; import type { WebSocketShardDestroyOptions, WebSocketShardEvents, WebSocketShardStatus } from '../../ws/WebSocketShard'; import { managerToFetchingStrategyOptions, type FetchingStrategyOptions } from '../context/IContextFetchingStrategy.js'; @@ -18,14 +18,14 @@ export enum WorkerSendPayloadOp { Destroy, Send, SessionInfoResponse, - ShardCanIdentify, + ShardIdentifyResponse, FetchStatus, } export type WorkerSendPayload = + | { nonce: number; ok: boolean; op: WorkerSendPayloadOp.ShardIdentifyResponse } | { nonce: number; op: WorkerSendPayloadOp.FetchStatus; shardId: number } | { nonce: number; op: WorkerSendPayloadOp.SessionInfoResponse; session: SessionInfo | null } - | { nonce: number; op: WorkerSendPayloadOp.ShardCanIdentify } | { op: WorkerSendPayloadOp.Connect; shardId: number } | { op: WorkerSendPayloadOp.Destroy; options?: WebSocketShardDestroyOptions; shardId: number } | { op: WorkerSendPayloadOp.Send; payload: GatewaySendPayload; shardId: number }; @@ -39,14 +39,16 @@ export enum WorkerReceivePayloadOp { WaitForIdentify, FetchStatusResponse, WorkerReady, + CancelIdentify, } export type WorkerReceivePayload = // Can't seem to get a type-safe union based off of the event, so I'm sadly leaving data as any for now | { data: any; event: WebSocketShardEvents; op: WorkerReceivePayloadOp.Event; shardId: number } + | { nonce: number; op: WorkerReceivePayloadOp.CancelIdentify } | { nonce: number; op: WorkerReceivePayloadOp.FetchStatusResponse; status: WebSocketShardStatus } | { nonce: number; op: WorkerReceivePayloadOp.RetrieveSessionInfo; shardId: number } - | { nonce: number; op: WorkerReceivePayloadOp.WaitForIdentify } + | { nonce: number; op: WorkerReceivePayloadOp.WaitForIdentify; shardId: number } | { op: WorkerReceivePayloadOp.Connected; shardId: number } | { op: WorkerReceivePayloadOp.Destroyed; shardId: number } | { op: WorkerReceivePayloadOp.UpdateSessionInfo; session: SessionInfo | null; shardId: number } @@ -84,11 +86,12 @@ export class WorkerShardingStrategy implements IShardingStrategy { private readonly fetchStatusPromises = new Collection void>(); - private readonly throttler: IdentifyThrottler; + private readonly waitForIdentifyControllers = new Collection(); + + private throttler?: IIdentifyThrottler; public constructor(manager: WebSocketManager, options: WorkerShardingStrategyOptions) { this.manager = manager; - this.throttler = new IdentifyThrottler(manager); this.options = options; } @@ -122,10 +125,10 @@ export class WorkerShardingStrategy implements IShardingStrategy { const promises = []; for (const [shardId, worker] of this.#workerByShardId.entries()) { - const payload = { + const payload: WorkerSendPayload = { op: WorkerSendPayloadOp.Connect, shardId, - } satisfies WorkerSendPayload; + }; // eslint-disable-next-line no-promise-executor-return const promise = new Promise((resolve) => this.connectPromises.set(shardId, resolve)); @@ -143,11 +146,11 @@ export class WorkerShardingStrategy implements IShardingStrategy { const promises = []; for (const [shardId, worker] of this.#workerByShardId.entries()) { - const payload = { + const payload: WorkerSendPayload = { op: WorkerSendPayloadOp.Destroy, shardId, options, - } satisfies WorkerSendPayload; + }; promises.push( // eslint-disable-next-line no-promise-executor-return, promise/prefer-await-to-then @@ -171,11 +174,11 @@ export class WorkerShardingStrategy implements IShardingStrategy { throw new Error(`No worker found for shard ${shardId}`); } - const payload = { + const payload: WorkerSendPayload = { op: WorkerSendPayloadOp.Send, shardId, payload: data, - } satisfies WorkerSendPayload; + }; worker.postMessage(payload); } @@ -187,11 +190,11 @@ export class WorkerShardingStrategy implements IShardingStrategy { for (const [shardId, worker] of this.#workerByShardId.entries()) { const nonce = Math.random(); - const payload = { + const payload: WorkerSendPayload = { op: WorkerSendPayloadOp.FetchStatus, shardId, nonce, - } satisfies WorkerSendPayload; + }; // eslint-disable-next-line no-promise-executor-return const promise = new Promise((resolve) => this.fetchStatusPromises.set(nonce, resolve)); @@ -297,10 +300,21 @@ export class WorkerShardingStrategy implements IShardingStrategy { } case WorkerReceivePayloadOp.WaitForIdentify: { - await this.throttler.waitForIdentify(); + const throttler = await this.ensureThrottler(); + + // If this rejects it means we aborted, in which case we reply elsewhere. + try { + const controller = new AbortController(); + this.waitForIdentifyControllers.set(payload.nonce, controller); + await throttler.waitForIdentify(payload.shardId, controller.signal); + } catch { + return; + } + const response: WorkerSendPayload = { - op: WorkerSendPayloadOp.ShardCanIdentify, + op: WorkerSendPayloadOp.ShardIdentifyResponse, nonce: payload.nonce, + ok: true, }; worker.postMessage(response); break; @@ -315,6 +329,25 @@ export class WorkerShardingStrategy implements IShardingStrategy { case WorkerReceivePayloadOp.WorkerReady: { break; } + + case WorkerReceivePayloadOp.CancelIdentify: { + this.waitForIdentifyControllers.get(payload.nonce)?.abort(); + this.waitForIdentifyControllers.delete(payload.nonce); + + const response: WorkerSendPayload = { + op: WorkerSendPayloadOp.ShardIdentifyResponse, + nonce: payload.nonce, + ok: false, + }; + worker.postMessage(response); + + break; + } } } + + private async ensureThrottler(): Promise { + this.throttler ??= await this.manager.options.buildIdentifyThrottler(this.manager); + return this.throttler; + } } diff --git a/packages/ws/src/throttling/IIdentifyThrottler.ts b/packages/ws/src/throttling/IIdentifyThrottler.ts new file mode 100644 index 000000000000..f2faad5495b7 --- /dev/null +++ b/packages/ws/src/throttling/IIdentifyThrottler.ts @@ -0,0 +1,11 @@ +/** + * IdentifyThrottlers are responsible for dictating when a shard is allowed to identify. + * + * @see {@link https://discord.com/developers/docs/topics/gateway#sharding-max-concurrency} + */ +export interface IIdentifyThrottler { + /** + * Resolves once the given shard should be allowed to identify, or rejects if the operation was aborted. + */ + waitForIdentify(shardId: number, signal: AbortSignal): Promise; +} diff --git a/packages/ws/src/throttling/SimpleIdentifyThrottler.ts b/packages/ws/src/throttling/SimpleIdentifyThrottler.ts new file mode 100644 index 000000000000..a612012b453b --- /dev/null +++ b/packages/ws/src/throttling/SimpleIdentifyThrottler.ts @@ -0,0 +1,50 @@ +import { setTimeout as sleep } from 'node:timers/promises'; +import { Collection } from '@discordjs/collection'; +import { AsyncQueue } from '@sapphire/async-queue'; +import type { IIdentifyThrottler } from './IIdentifyThrottler'; + +/** + * The state of a rate limit key's identify queue. + */ +export interface IdentifyState { + queue: AsyncQueue; + resetsAt: number; +} + +/** + * Local, in-memory identify throttler. + */ +export class SimpleIdentifyThrottler implements IIdentifyThrottler { + private readonly states = new Collection(); + + public constructor(private readonly maxConcurrency: number) {} + + /** + * {@inheritDoc IIdentifyThrottler.waitForIdentify} + */ + public async waitForIdentify(shardId: number, signal: AbortSignal): Promise { + const key = shardId % this.maxConcurrency; + + const state = this.states.ensure(key, () => { + return { + queue: new AsyncQueue(), + resetsAt: Number.POSITIVE_INFINITY, + }; + }); + + await state.queue.wait({ signal }); + + try { + const diff = state.resetsAt - Date.now(); + if (diff <= 5_000) { + // To account for the latency the IDENTIFY payload goes through, we add a bit more wait time + const time = diff + Math.random() * 1_500; + await sleep(time); + } + + state.resetsAt = Date.now() + 5_000; + } finally { + state.queue.shift(); + } + } +} diff --git a/packages/ws/src/utils/IdentifyThrottler.ts b/packages/ws/src/utils/IdentifyThrottler.ts deleted file mode 100644 index 45c35c5e7183..000000000000 --- a/packages/ws/src/utils/IdentifyThrottler.ts +++ /dev/null @@ -1,39 +0,0 @@ -import { setTimeout as sleep } from 'node:timers/promises'; -import { AsyncQueue } from '@sapphire/async-queue'; -import type { WebSocketManager } from '../ws/WebSocketManager.js'; - -export class IdentifyThrottler { - private readonly queue = new AsyncQueue(); - - private identifyState = { - remaining: 0, - resetsAt: Number.POSITIVE_INFINITY, - }; - - public constructor(private readonly manager: WebSocketManager) {} - - public async waitForIdentify(): Promise { - await this.queue.wait(); - - try { - if (this.identifyState.remaining <= 0) { - const diff = this.identifyState.resetsAt - Date.now(); - if (diff <= 5_000) { - // To account for the latency the IDENTIFY payload goes through, we add a bit more wait time - const time = diff + Math.random() * 1_500; - await sleep(time); - } - - const info = await this.manager.fetchGatewayInformation(); - this.identifyState = { - remaining: info.session_start_limit.max_concurrency, - resetsAt: Date.now() + 5_000, - }; - } - - this.identifyState.remaining--; - } finally { - this.queue.shift(); - } - } -} diff --git a/packages/ws/src/utils/WorkerBootstrapper.ts b/packages/ws/src/utils/WorkerBootstrapper.ts index 6becafe41e1a..dd101293e969 100644 --- a/packages/ws/src/utils/WorkerBootstrapper.ts +++ b/packages/ws/src/utils/WorkerBootstrapper.ts @@ -117,7 +117,7 @@ export class WorkerBootstrapper { break; } - case WorkerSendPayloadOp.ShardCanIdentify: { + case WorkerSendPayloadOp.ShardIdentifyResponse: { break; } @@ -127,11 +127,11 @@ export class WorkerBootstrapper { throw new Error(`Shard ${payload.shardId} does not exist`); } - const response = { + const response: WorkerReceivePayload = { op: WorkerReceivePayloadOp.FetchStatusResponse, status: shard.status, nonce: payload.nonce, - } satisfies WorkerReceivePayload; + }; parentPort!.postMessage(response); break; @@ -150,12 +150,12 @@ export class WorkerBootstrapper { for (const event of options.forwardEvents ?? Object.values(WebSocketShardEvents)) { // @ts-expect-error: Event types incompatible shard.on(event, (data) => { - const payload = { + const payload: WorkerReceivePayload = { op: WorkerReceivePayloadOp.Event, event, data, shardId, - } satisfies WorkerReceivePayload; + }; parentPort!.postMessage(payload); }); } @@ -168,9 +168,9 @@ export class WorkerBootstrapper { // Lastly, start listening to messages from the parent thread this.setupThreadEvents(); - const message = { + const message: WorkerReceivePayload = { op: WorkerReceivePayloadOp.WorkerReady, - } satisfies WorkerReceivePayload; + }; parentPort!.postMessage(message); } } diff --git a/packages/ws/src/utils/constants.ts b/packages/ws/src/utils/constants.ts index b0208d00f2dd..a99f817f9289 100644 --- a/packages/ws/src/utils/constants.ts +++ b/packages/ws/src/utils/constants.ts @@ -3,7 +3,8 @@ import { Collection } from '@discordjs/collection'; import { lazy } from '@discordjs/util'; import { APIVersion, GatewayOpcodes } from 'discord-api-types/v10'; import { SimpleShardingStrategy } from '../strategies/sharding/SimpleShardingStrategy.js'; -import type { SessionInfo, OptionalWebSocketManagerOptions } from '../ws/WebSocketManager.js'; +import { SimpleIdentifyThrottler } from '../throttling/SimpleIdentifyThrottler.js'; +import type { SessionInfo, OptionalWebSocketManagerOptions, WebSocketManager } from '../ws/WebSocketManager.js'; import type { SendRateLimitState } from '../ws/WebSocketShard.js'; /** @@ -28,6 +29,10 @@ const getDefaultSessionStore = lazy(() => new Collection new SimpleShardingStrategy(manager), shardCount: null, shardIds: null, diff --git a/packages/ws/src/ws/WebSocketManager.ts b/packages/ws/src/ws/WebSocketManager.ts index a18986b4aee5..2c646c04b223 100644 --- a/packages/ws/src/ws/WebSocketManager.ts +++ b/packages/ws/src/ws/WebSocketManager.ts @@ -11,6 +11,7 @@ import { type GatewaySendPayload, } from 'discord-api-types/v10'; import type { IShardingStrategy } from '../strategies/sharding/IShardingStrategy'; +import type { IIdentifyThrottler } from '../throttling/IIdentifyThrottler'; import { DefaultWebSocketManagerOptions, type CompressionMethod, type Encoding } from '../utils/constants.js'; import type { WebSocketShardDestroyOptions, WebSocketShardEventsMap } from './WebSocketShard.js'; @@ -55,7 +56,7 @@ export interface RequiredWebSocketManagerOptions { /** * The intents to request */ - intents: GatewayIntentBits; + intents: GatewayIntentBits | 0; /** * The REST instance to use for fetching gateway information */ @@ -70,6 +71,10 @@ export interface RequiredWebSocketManagerOptions { * Optional additional configuration for the WebSocketManager */ export interface OptionalWebSocketManagerOptions { + /** + * Builds an identify throttler to use for this manager's shards + */ + buildIdentifyThrottler(manager: WebSocketManager): Awaitable; /** * Builds the strategy to use for sharding * diff --git a/packages/ws/src/ws/WebSocketShard.ts b/packages/ws/src/ws/WebSocketShard.ts index abf1f10adde7..85b3f3c22cb5 100644 --- a/packages/ws/src/ws/WebSocketShard.ts +++ b/packages/ws/src/ws/WebSocketShard.ts @@ -358,7 +358,21 @@ export class WebSocketShard extends AsyncEventEmitter { private async identify() { this.debug(['Waiting for identify throttle']); - await this.strategy.waitForIdentify(); + const controller = new AbortController(); + const closeHandler = () => { + controller.abort(); + }; + + this.on(WebSocketShardEvents.Closed, closeHandler); + + try { + await this.strategy.waitForIdentify(this.id, controller.signal); + } catch { + this.debug(['Was waiting for an identify, but the shard closed in the meantime']); + return; + } finally { + this.off(WebSocketShardEvents.Closed, closeHandler); + } this.debug([ 'Identifying',