From 418379e108bd7b0994f4f94982747c6a8282266a Mon Sep 17 00:00:00 2001 From: Ehab Younes Date: Mon, 17 Nov 2025 18:06:47 +0300 Subject: [PATCH 1/5] Add reconnecting websocket class and use it in CoderApi --- src/api/agentMetadataHelper.ts | 6 +- src/api/coderApi.ts | 139 +++++-- src/api/workspace.ts | 6 +- src/inbox.ts | 6 +- src/remote/workspaceStateMachine.ts | 7 +- src/websocket/reconnectingWebSocket.ts | 239 ++++++++++++ test/unit/api/coderApi.test.ts | 98 ++++- .../websocket/reconnectingWebSocket.test.ts | 362 ++++++++++++++++++ 8 files changed, 811 insertions(+), 52 deletions(-) create mode 100644 src/websocket/reconnectingWebSocket.ts create mode 100644 test/unit/websocket/reconnectingWebSocket.test.ts diff --git a/src/api/agentMetadataHelper.ts b/src/api/agentMetadataHelper.ts index 4de804ad..26ab1b6f 100644 --- a/src/api/agentMetadataHelper.ts +++ b/src/api/agentMetadataHelper.ts @@ -53,7 +53,11 @@ export async function createAgentMetadataWatcher( event.parsedMessage.data, ); - // Overwrite metadata if it changed. + if (watcher.error !== undefined) { + watcher.error = undefined; + onChange.fire(null); + } + if (JSON.stringify(watcher.metadata) !== JSON.stringify(metadata)) { watcher.metadata = metadata; onChange.fire(null); diff --git a/src/api/coderApi.ts b/src/api/coderApi.ts index ef120ce4..d0cb1378 100644 --- a/src/api/coderApi.ts +++ b/src/api/coderApi.ts @@ -36,6 +36,10 @@ import { OneWayWebSocket, type OneWayWebSocketInit, } from "../websocket/oneWayWebSocket"; +import { + ReconnectingWebSocket, + type SocketFactory, +} from "../websocket/reconnectingWebSocket"; import { SseConnection } from "../websocket/sseConnection"; import { createHttpAgent } from "./utils"; @@ -47,6 +51,10 @@ const coderSessionTokenHeader = "Coder-Session-Token"; * and WebSocket methods for real-time functionality. */ export class CoderApi extends Api { + private readonly reconnectingSockets = new Set< + ReconnectingWebSocket + >(); + private constructor(private readonly output: Logger) { super(); } @@ -70,6 +78,30 @@ export class CoderApi extends Api { return client; } + setSessionToken = (token: string): void => { + const currentToken = + this.getAxiosInstance().defaults.headers.common[coderSessionTokenHeader]; + this.getAxiosInstance().defaults.headers.common[coderSessionTokenHeader] = + token; + + if (currentToken !== token) { + for (const socket of this.reconnectingSockets) { + socket.reconnect(); + } + } + }; + + setHost = (host: string | undefined): void => { + const currentHost = this.getAxiosInstance().defaults.baseURL; + this.getAxiosInstance().defaults.baseURL = host; + + if (currentHost !== host) { + for (const socket of this.reconnectingSockets) { + socket.reconnect(); + } + } + }; + watchInboxNotifications = async ( watchTemplates: string[], watchTargets: string[], @@ -83,6 +115,7 @@ export class CoderApi extends Api { targets: watchTargets.join(","), }, options, + enableRetry: true, }); }; @@ -91,6 +124,7 @@ export class CoderApi extends Api { apiRoute: `/api/v2/workspaces/${workspace.id}/watch-ws`, fallbackApiRoute: `/api/v2/workspaces/${workspace.id}/watch`, options, + enableRetry: true, }); }; @@ -102,6 +136,7 @@ export class CoderApi extends Api { apiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata-ws`, fallbackApiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata`, options, + enableRetry: true, }); }; @@ -148,53 +183,73 @@ export class CoderApi extends Api { } private async createWebSocket( - configs: Omit, - ) { - const baseUrlRaw = this.getAxiosInstance().defaults.baseURL; - if (!baseUrlRaw) { - throw new Error("No base URL set on REST client"); - } + configs: Omit & { enableRetry?: boolean }, + ): Promise> { + const { enableRetry, ...socketConfigs } = configs; + + const socketFactory: SocketFactory = async () => { + const baseUrlRaw = this.getAxiosInstance().defaults.baseURL; + if (!baseUrlRaw) { + throw new Error("No base URL set on REST client"); + } + + const baseUrl = new URL(baseUrlRaw); + const token = this.getAxiosInstance().defaults.headers.common[ + coderSessionTokenHeader + ] as string | undefined; + + const headersFromCommand = await getHeaders( + baseUrlRaw, + getHeaderCommand(vscode.workspace.getConfiguration()), + this.output, + ); - const baseUrl = new URL(baseUrlRaw); - const token = this.getAxiosInstance().defaults.headers.common[ - coderSessionTokenHeader - ] as string | undefined; + const httpAgent = await createHttpAgent( + vscode.workspace.getConfiguration(), + ); - const headersFromCommand = await getHeaders( - baseUrlRaw, - getHeaderCommand(vscode.workspace.getConfiguration()), - this.output, - ); + /** + * Similar to the REST client, we want to prioritize headers in this order (highest to lowest): + * 1. Headers from the header command + * 2. Any headers passed directly to this function + * 3. Coder session token from the Api client (if set) + */ + const headers = { + ...(token ? { [coderSessionTokenHeader]: token } : {}), + ...configs.options?.headers, + ...headersFromCommand, + }; - const httpAgent = await createHttpAgent( - vscode.workspace.getConfiguration(), - ); + const webSocket = new OneWayWebSocket({ + location: baseUrl, + ...socketConfigs, + options: { + ...configs.options, + agent: httpAgent, + followRedirects: true, + headers, + }, + }); - /** - * Similar to the REST client, we want to prioritize headers in this order (highest to lowest): - * 1. Headers from the header command - * 2. Any headers passed directly to this function - * 3. Coder session token from the Api client (if set) - */ - const headers = { - ...(token ? { [coderSessionTokenHeader]: token } : {}), - ...configs.options?.headers, - ...headersFromCommand, + this.attachStreamLogger(webSocket); + return webSocket; }; - const webSocket = new OneWayWebSocket({ - location: baseUrl, - ...configs, - options: { - ...configs.options, - agent: httpAgent, - followRedirects: true, - headers, - }, - }); + if (enableRetry) { + const reconnectingSocket = await ReconnectingWebSocket.create( + socketFactory, + this.output, + configs.apiRoute, + ); + + this.reconnectingSockets.add( + reconnectingSocket as ReconnectingWebSocket, + ); - this.attachStreamLogger(webSocket); - return webSocket; + return reconnectingSocket; + } else { + return socketFactory(); + } } private attachStreamLogger( @@ -230,13 +285,15 @@ export class CoderApi extends Api { fallbackApiRoute: string; searchParams?: Record | URLSearchParams; options?: ClientOptions; + enableRetry?: boolean; }): Promise> { - let webSocket: OneWayWebSocket; + let webSocket: UnidirectionalStream; try { webSocket = await this.createWebSocket({ apiRoute: configs.apiRoute, searchParams: configs.searchParams, options: configs.options, + enableRetry: configs.enableRetry, }); } catch { // Failed to create WebSocket, use SSE fallback diff --git a/src/api/workspace.ts b/src/api/workspace.ts index a24d3a64..1d3b7a4e 100644 --- a/src/api/workspace.ts +++ b/src/api/workspace.ts @@ -11,7 +11,7 @@ import * as vscode from "vscode"; import { type FeatureSet } from "../featureSet"; import { getGlobalFlags } from "../globalFlags"; import { escapeCommandArg } from "../util"; -import { type OneWayWebSocket } from "../websocket/oneWayWebSocket"; +import { type UnidirectionalStream } from "../websocket/eventStreamConnection"; import { errToStr, createWorkspaceIdentifier } from "./api-helper"; import { type CoderApi } from "./coderApi"; @@ -93,7 +93,7 @@ export async function streamBuildLogs( client: CoderApi, writeEmitter: vscode.EventEmitter, workspace: Workspace, -): Promise> { +): Promise> { const socket = await client.watchBuildLogsByBuildId( workspace.latest_build.id, [], @@ -131,7 +131,7 @@ export async function streamAgentLogs( client: CoderApi, writeEmitter: vscode.EventEmitter, agent: WorkspaceAgent, -): Promise> { +): Promise> { const socket = await client.watchWorkspaceAgentLogs(agent.id, []); socket.addEventListener("message", (data) => { diff --git a/src/inbox.ts b/src/inbox.ts index 8dff573f..59b9ae0b 100644 --- a/src/inbox.ts +++ b/src/inbox.ts @@ -7,7 +7,7 @@ import type { import type { CoderApi } from "./api/coderApi"; import type { Logger } from "./logging/logger"; -import type { OneWayWebSocket } from "./websocket/oneWayWebSocket"; +import type { UnidirectionalStream } from "./websocket/eventStreamConnection"; // These are the template IDs of our notifications. // Maybe in the future we should avoid hardcoding @@ -16,7 +16,9 @@ const TEMPLATE_WORKSPACE_OUT_OF_MEMORY = "a9d027b4-ac49-4fb1-9f6d-45af15f64e7a"; const TEMPLATE_WORKSPACE_OUT_OF_DISK = "f047f6a3-5713-40f7-85aa-0394cce9fa3a"; export class Inbox implements vscode.Disposable { - private socket: OneWayWebSocket | undefined; + private socket: + | UnidirectionalStream + | undefined; private disposed = false; private constructor(private readonly logger: Logger) {} diff --git a/src/remote/workspaceStateMachine.ts b/src/remote/workspaceStateMachine.ts index eb7aa335..340ec960 100644 --- a/src/remote/workspaceStateMachine.ts +++ b/src/remote/workspaceStateMachine.ts @@ -21,7 +21,7 @@ import type { CoderApi } from "../api/coderApi"; import type { PathResolver } from "../core/pathResolver"; import type { FeatureSet } from "../featureSet"; import type { Logger } from "../logging/logger"; -import type { OneWayWebSocket } from "../websocket/oneWayWebSocket"; +import type { UnidirectionalStream } from "../websocket/eventStreamConnection"; /** * Manages workspace and agent state transitions until ready for SSH connection. @@ -32,9 +32,10 @@ export class WorkspaceStateMachine implements vscode.Disposable { private agent: { id: string; name: string } | undefined; - private buildLogSocket: OneWayWebSocket | null = null; + private buildLogSocket: UnidirectionalStream | null = null; - private agentLogSocket: OneWayWebSocket | null = null; + private agentLogSocket: UnidirectionalStream | null = + null; constructor( private readonly parts: AuthorityParts, diff --git a/src/websocket/reconnectingWebSocket.ts b/src/websocket/reconnectingWebSocket.ts new file mode 100644 index 00000000..6c63c0ca --- /dev/null +++ b/src/websocket/reconnectingWebSocket.ts @@ -0,0 +1,239 @@ +import type { WebSocketEventType } from "coder/site/src/utils/OneWayWebSocket"; + +import type { Logger } from "../logging/logger"; + +import type { + EventHandler, + UnidirectionalStream, +} from "./eventStreamConnection"; + +export type SocketFactory = () => Promise>; + +export type ReconnectingWebSocketOptions = { + initialBackoffMs?: number; // Default: 250ms + maxBackoffMs?: number; // Default: 30s + jitterFactor?: number; // Default: 0.1 (±10%) +}; + +// 403 Forbidden, 410 Gone, 426 Upgrade Required, 1002/1003 Protocol errors +const UNRECOVERABLE_CLOSE_CODES = new Set([403, 410, 426, 1002, 1003]); + +// Custom close code for intentional reconnection (4000-4999 range is for private use) +const CLOSE_CODE_RECONNECTING = 4000; + +export class ReconnectingWebSocket + implements UnidirectionalStream +{ + readonly #socketFactory: SocketFactory; + readonly #logger: Logger; + readonly #apiRoute: string; + readonly #options: Required; + readonly #eventHandlers = { + open: new Set>(), + close: new Set>(), + error: new Set>(), + message: new Set>(), + }; + + #currentSocket: UnidirectionalStream | null = null; + #backoffMs: number; + #reconnectTimeoutId: NodeJS.Timeout | null = null; + #isDisposed = false; + #isConnecting = false; + + private constructor( + socketFactory: SocketFactory, + logger: Logger, + apiRoute: string, + options: ReconnectingWebSocketOptions = {}, + ) { + this.#socketFactory = socketFactory; + this.#logger = logger; + this.#apiRoute = apiRoute; + this.#options = { + initialBackoffMs: options.initialBackoffMs ?? 250, + maxBackoffMs: options.maxBackoffMs ?? 30000, + jitterFactor: options.jitterFactor ?? 0.1, + }; + this.#backoffMs = this.#options.initialBackoffMs; + } + + static async create( + socketFactory: SocketFactory, + logger: Logger, + apiRoute: string, + options: ReconnectingWebSocketOptions = {}, + ): Promise> { + const instance = new ReconnectingWebSocket( + socketFactory, + logger, + apiRoute, + options, + ); + await instance.#connect(); + return instance; + } + + get url(): string { + return this.#currentSocket?.url ?? ""; + } + + addEventListener( + event: TEvent, + callback: EventHandler, + ): void { + (this.#eventHandlers[event] as Set>).add( + callback, + ); + + if (this.#currentSocket) { + this.#currentSocket.addEventListener(event, callback); + } + } + + removeEventListener( + event: TEvent, + callback: EventHandler, + ): void { + (this.#eventHandlers[event] as Set>).delete( + callback, + ); + + if (this.#currentSocket) { + this.#currentSocket.removeEventListener(event, callback); + } + } + + close(code?: number, reason?: string): void { + if (this.#isDisposed) { + return; + } + + this.#isDisposed = true; + + if (this.#reconnectTimeoutId !== null) { + clearTimeout(this.#reconnectTimeoutId); + this.#reconnectTimeoutId = null; + } + + if (this.#currentSocket) { + this.#currentSocket.close(code, reason); + this.#currentSocket = null; + } + + for (const set of Object.values(this.#eventHandlers)) { + set.clear(); + } + } + + reconnect(): void { + if (this.#isDisposed) { + return; + } + + if (this.#reconnectTimeoutId !== null) { + clearTimeout(this.#reconnectTimeoutId); + this.#reconnectTimeoutId = null; + } + + if (this.#currentSocket) { + this.#currentSocket.close(CLOSE_CODE_RECONNECTING, "Reconnecting"); + } + } + + async #connect(): Promise { + if (this.#isDisposed || this.#isConnecting) { + return; + } + + this.#isConnecting = true; + + try { + const socket = await this.#socketFactory(); + this.#currentSocket = socket; + + socket.addEventListener("open", () => { + this.#backoffMs = this.#options.initialBackoffMs; + }); + + for (const handler of this.#eventHandlers.open) { + socket.addEventListener("open", handler); + } + + for (const handler of this.#eventHandlers.message) { + socket.addEventListener("message", handler); + } + + for (const handler of this.#eventHandlers.error) { + socket.addEventListener("error", handler); + } + + socket.addEventListener("close", (event) => { + for (const handler of this.#eventHandlers.close) { + handler(event); + } + + if (this.#isDisposed) { + return; + } + + if (UNRECOVERABLE_CLOSE_CODES.has(event.code)) { + this.#logger.error( + `[ReconnectingWebSocket] Unrecoverable error (${event.code}) for ${this.#apiRoute}`, + ); + this.#isDisposed = true; + return; + } + + // Reconnect if this was an intentional close for reconnection + if (event.code === CLOSE_CODE_RECONNECTING) { + this.#scheduleReconnect(); + return; + } + + // Don't reconnect on normal closure + if (event.code === 1000 || event.code === 1001) { + return; + } + + // Reconnect on abnormal closures (e.g., 1006) or other unexpected codes + this.#scheduleReconnect(); + }); + } catch (error) { + if (!this.#isDisposed) { + this.#logger.warn( + `[ReconnectingWebSocket] Failed: ${error instanceof Error ? error.message : String(error)} for ${this.#apiRoute}`, + ); + this.#scheduleReconnect(); + } + } finally { + this.#isConnecting = false; + } + } + + #scheduleReconnect(): void { + if (this.#isDisposed || this.#reconnectTimeoutId !== null) { + return; + } + + const jitter = + this.#backoffMs * this.#options.jitterFactor * (Math.random() * 2 - 1); + const delayMs = Math.max(0, this.#backoffMs + jitter); + + this.#logger.debug( + `[ReconnectingWebSocket] Reconnecting in ${Math.round(delayMs)}ms for ${this.#apiRoute}`, + ); + + this.#reconnectTimeoutId = setTimeout(() => { + this.#reconnectTimeoutId = null; + // Errors already handled in #connect + void this.#connect(); + }, delayMs); + + this.#backoffMs = Math.min(this.#backoffMs * 2, this.#options.maxBackoffMs); + } + + isDisposed(): boolean { + return this.#isDisposed; + } +} diff --git a/test/unit/api/coderApi.test.ts b/test/unit/api/coderApi.test.ts index f133a72d..d50782f6 100644 --- a/test/unit/api/coderApi.test.ts +++ b/test/unit/api/coderApi.test.ts @@ -10,7 +10,7 @@ import { createHttpAgent } from "@/api/utils"; import { CertificateError } from "@/error"; import { getHeaders } from "@/headers"; import { type RequestConfigWithMeta } from "@/logging/types"; -import { OneWayWebSocket } from "@/websocket/oneWayWebSocket"; +import { ReconnectingWebSocket } from "@/websocket/reconnectingWebSocket"; import { SseConnection } from "@/websocket/sseConnection"; import { @@ -332,7 +332,7 @@ describe("CoderApi", () => { const connection = await api.watchAgentMetadata(AGENT_ID); - expect(connection).toBeInstanceOf(OneWayWebSocket); + expect(connection).toBeInstanceOf(ReconnectingWebSocket); expect(EventSource).not.toHaveBeenCalled(); }); @@ -373,6 +373,100 @@ describe("CoderApi", () => { }); }); + describe("Reconnection on Host/Token Changes", () => { + it("triggers reconnection when session token changes", async () => { + const mockWs = createMockWebSocket( + `wss://${CODER_URL.replace("https://", "")}/api/v2/workspaceagents/${AGENT_ID}/watch-metadata-ws`, + { + on: vi.fn((event, handler) => { + if (event === "open") { + setImmediate(() => handler()); + } + return mockWs as Ws; + }), + }, + ); + setupWebSocketMock(mockWs); + + api = createApi(CODER_URL, AXIOS_TOKEN); + const _ws = await api.watchAgentMetadata(AGENT_ID); + + // Change token - should trigger reconnection + api.setSessionToken("new-token"); + + expect(mockWs.close).toHaveBeenCalledWith(4000, "Reconnecting"); + }); + + it("triggers reconnection when host changes", async () => { + const mockWs = createMockWebSocket( + `wss://${CODER_URL.replace("https://", "")}/api/v2/workspaceagents/${AGENT_ID}/watch-metadata-ws`, + { + on: vi.fn((event, handler) => { + if (event === "open") { + setImmediate(() => handler()); + } + return mockWs as Ws; + }), + }, + ); + setupWebSocketMock(mockWs); + + api = createApi(CODER_URL, AXIOS_TOKEN); + const _ws = await api.watchAgentMetadata(AGENT_ID); + + // Change host - should trigger reconnection + api.setHost("https://new-coder.example.com"); + + expect(mockWs.close).toHaveBeenCalledWith(4000, "Reconnecting"); + }); + + it("does not reconnect when token is set to same value", async () => { + const mockWs = createMockWebSocket( + `wss://${CODER_URL.replace("https://", "")}/api/v2/workspaceagents/${AGENT_ID}/watch-metadata-ws`, + { + on: vi.fn((event, handler) => { + if (event === "open") { + setImmediate(() => handler()); + } + return mockWs as Ws; + }), + }, + ); + setupWebSocketMock(mockWs); + + api = createApi(CODER_URL, AXIOS_TOKEN); + const _ws = await api.watchAgentMetadata(AGENT_ID); + + // Set same token - should NOT trigger reconnection + api.setSessionToken(AXIOS_TOKEN); + + expect(mockWs.close).not.toHaveBeenCalled(); + }); + + it("does not reconnect when host is set to same value", async () => { + const mockWs = createMockWebSocket( + `wss://${CODER_URL.replace("https://", "")}/api/v2/workspaceagents/${AGENT_ID}/watch-metadata-ws`, + { + on: vi.fn((event, handler) => { + if (event === "open") { + setImmediate(() => handler()); + } + return mockWs as Ws; + }), + }, + ); + setupWebSocketMock(mockWs); + + api = createApi(CODER_URL, AXIOS_TOKEN); + const _ws = await api.watchAgentMetadata(AGENT_ID); + + // Set same host - should NOT trigger reconnection + api.setHost(CODER_URL); + + expect(mockWs.close).not.toHaveBeenCalled(); + }); + }); + describe("Error Handling", () => { it("throws error when no base URL is set", async () => { const api = createApi(); diff --git a/test/unit/websocket/reconnectingWebSocket.test.ts b/test/unit/websocket/reconnectingWebSocket.test.ts new file mode 100644 index 00000000..d3a39645 --- /dev/null +++ b/test/unit/websocket/reconnectingWebSocket.test.ts @@ -0,0 +1,362 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; + +import { + ReconnectingWebSocket, + type SocketFactory, +} from "@/websocket/reconnectingWebSocket"; + +import { createMockLogger } from "../../mocks/testHelpers"; + +import type { CloseEvent, Event as WsEvent } from "ws"; + +import type { UnidirectionalStream } from "@/websocket/eventStreamConnection"; + +describe("ReconnectingWebSocket", () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + vi.useRealTimers(); + }); + + describe("Reconnection Logic", () => { + it("reconnects on abnormal closure (1006)", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + + sockets[0].fireOpen(); + sockets[0].fireClose({ code: 1006, reason: "Network error" }); + + // Should schedule reconnect + await vi.advanceTimersByTimeAsync(300); + expect(sockets).toHaveLength(2); + + ws.close(); + }); + + it.each([ + { code: 1000, name: "Normal Closure" }, + { code: 1001, name: "Going Away" }, + ])("does NOT reconnect on $name ($code)", async ({ code }) => { + const { ws, sockets } = await createReconnectingWebSocket(); + + sockets[0].fireOpen(); + sockets[0].fireClose({ code, reason: "Normal" }); + + // Should NOT reconnect + await vi.advanceTimersByTimeAsync(10000); + expect(sockets).toHaveLength(1); + + ws.close(); + }); + + it.each([403, 410, 426, 1002, 1003])( + "does NOT reconnect on unrecoverable error (%i)", + async (code) => { + const { ws, sockets } = await createReconnectingWebSocket(); + + sockets[0].fireOpen(); + sockets[0].fireClose({ code, reason: "Unrecoverable" }); + + // Should NOT reconnect + await vi.advanceTimersByTimeAsync(10000); + expect(sockets).toHaveLength(1); + + ws.close(); + }, + ); + + it("reconnects when manually calling reconnect()", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + + sockets[0].fireOpen(); + // Manually trigger reconnection + ws.reconnect(); + sockets[0].fireClose({ code: 4000, reason: "Reconnecting" }); + + // Should reconnect + await vi.advanceTimersByTimeAsync(300); + expect(sockets).toHaveLength(2); + + ws.close(); + }); + }); + + describe("Listener Persistence", () => { + it("keeps listeners subscribed across reconnections", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + sockets[0].fireOpen(); + + const handler = vi.fn(); + ws.addEventListener("message", handler); + + // First message + sockets[0].fireMessage({ test: true }); + expect(handler).toHaveBeenCalledTimes(1); + + // Disconnect and reconnect + sockets[0].fireClose({ code: 1006, reason: "Network" }); + await vi.advanceTimersByTimeAsync(300); + expect(sockets).toHaveLength(2); + sockets[1].fireOpen(); + + // Handler should still work on new socket + sockets[1].fireMessage({ test: true }); + expect(handler).toHaveBeenCalledTimes(2); + + ws.close(); + }); + + it("properly removes listeners", async () => { + const socket = createMockSocket(); + const factory = vi.fn(() => Promise.resolve(socket)); + + const ws = await fromFactory(factory); + socket.fireOpen(); + + const handler1 = vi.fn(); + const handler2 = vi.fn(); + + ws.addEventListener("message", handler1); + ws.addEventListener("message", handler2); + ws.removeEventListener("message", handler1); + + socket.fireMessage({ test: true }); + + expect(handler1).not.toHaveBeenCalled(); + expect(handler2).toHaveBeenCalledTimes(1); + + ws.close(); + }); + }); + + describe("Disposal", () => { + it("stops reconnection when disposed", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + const socket = sockets[0]; + socket.fireOpen(); + + // Close and immediately dispose + socket.fireClose({ code: 1006, reason: "Network" }); + ws.close(); + + // Should NOT reconnect + await vi.advanceTimersByTimeAsync(10000); + expect(sockets).toHaveLength(1); + }); + + it("closes the underlying socket", async () => { + const socket = createMockSocket(); + const factory = vi.fn(() => Promise.resolve(socket)); + const ws = await fromFactory(factory); + + socket.fireOpen(); + + ws.close(1000, "Test close"); + expect(socket.close).toHaveBeenCalledWith(1000, "Test close"); + }); + }); + + describe("Exponential Backoff", () => { + it("increases backoff exponentially on repeated failures", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + const socket = sockets[0]; + socket.fireOpen(); + + const backoffDelays = [300, 600, 1200, 2400]; + + // Fail repeatedly + for (let i = 0; i < 4; i++) { + const currentSocket = sockets[i]; + currentSocket.fireClose({ code: 1006, reason: "Fail" }); + const delay = backoffDelays[i]; + await vi.advanceTimersByTimeAsync(delay); + const nextSocket = sockets[i + 1]; + nextSocket.fireOpen(); + } + + expect(sockets).toHaveLength(5); + ws.close(); + }); + + it("resets backoff after successful connection", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + const socket1 = sockets[0]; + socket1.fireOpen(); + + // First disconnect + socket1.fireClose({ code: 1006, reason: "Fail" }); + await vi.advanceTimersByTimeAsync(300); + const socket2 = sockets[1]; + socket2.fireOpen(); + + // Second disconnect - should use initial backoff again + socket2.fireClose({ code: 1006, reason: "Fail" }); + await vi.advanceTimersByTimeAsync(300); + + expect(sockets).toHaveLength(3); + ws.close(); + }); + }); + + describe("Edge Cases", () => { + it("handles disposal during reconnection delay", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + + sockets[0].fireOpen(); + sockets[0].fireClose({ code: 1006, reason: "Network" }); + + // Dispose while waiting for reconnect + await vi.advanceTimersByTimeAsync(100); + ws.close(); + + // Should not reconnect + await vi.advanceTimersByTimeAsync(10000); + expect(sockets).toHaveLength(1); + }); + + it("prevents concurrent reconnect attempts", async () => { + const socket = createMockSocket(); + const factory = vi.fn(() => Promise.resolve(socket)); + const ws = await fromFactory(factory); + + socket.fireOpen(); + + // Call reconnect multiple times rapidly + ws.reconnect(); + ws.reconnect(); + ws.reconnect(); + socket.fireClose({ code: 4000, reason: "Reconnecting" }); + + await vi.advanceTimersByTimeAsync(300); + + // Should only trigger one reconnection + expect(factory).toHaveBeenCalledTimes(2); + + ws.close(); + }); + + it("handles errors during socket factory", async () => { + const sockets: MockSocket[] = []; + let shouldFail = false; + const factory = vi.fn(() => { + if (shouldFail) { + return Promise.reject(new Error("Factory failed")); + } + const socket = createMockSocket(); + sockets.push(socket); + return Promise.resolve(socket); + }); + const ws = await fromFactory(factory); + + sockets[0].fireOpen(); + + // Make factory fail + shouldFail = true; + sockets[0].fireClose({ code: 1006, reason: "Network" }); + + // Should schedule retry + await vi.advanceTimersByTimeAsync(300); + expect(sockets).toHaveLength(1); + + ws.close(); + }); + }); +}); + +type MockSocket = UnidirectionalStream & { + fireOpen: () => void; + fireClose: (event: { code: number; reason: string }) => void; + fireMessage: (data: unknown) => void; + fireError: (error: Error) => void; +}; + +function createMockSocket(): MockSocket { + const listeners: { + open: Set<(event: WsEvent) => void>; + close: Set<(event: CloseEvent) => void>; + error: Set<(event: { error?: Error; message?: string }) => void>; + message: Set<(event: unknown) => void>; + } = { + open: new Set(), + close: new Set(), + error: new Set(), + message: new Set(), + }; + + return { + url: "ws://test.example.com/api/test", + addEventListener: vi.fn( + (event: keyof typeof listeners, callback: unknown) => { + (listeners[event] as Set<(data: unknown) => void>).add( + callback as (data: unknown) => void, + ); + }, + ), + removeEventListener: vi.fn( + (event: keyof typeof listeners, callback: unknown) => { + (listeners[event] as Set<(data: unknown) => void>).delete( + callback as (data: unknown) => void, + ); + }, + ), + close: vi.fn(), + fireOpen: () => { + for (const cb of listeners.open) { + cb({} as WsEvent); + } + }, + fireClose: (event: { code: number; reason: string }) => { + for (const cb of listeners.close) { + cb({ + code: event.code, + reason: event.reason, + wasClean: false, + } as CloseEvent); + } + }, + fireMessage: (data: unknown) => { + for (const cb of listeners.message) { + cb({ + sourceEvent: { data }, + parsedMessage: data, + parseError: undefined, + }); + } + }, + fireError: (error: Error) => { + for (const cb of listeners.error) { + cb({ error, message: error.message }); + } + }, + }; +} + +async function createReconnectingWebSocket(): Promise<{ + ws: ReconnectingWebSocket; + sockets: MockSocket[]; +}> { + const sockets: MockSocket[] = []; + const factory = vi.fn(() => { + const socket = createMockSocket(); + sockets.push(socket); + return Promise.resolve(socket); + }); + const ws = await fromFactory(factory); + + // We start with one socket + expect(sockets).toHaveLength(1); + + return { ws, sockets }; +} + +async function fromFactory( + factory: SocketFactory, +): Promise> { + return await ReconnectingWebSocket.create( + factory, + createMockLogger(), + "/random/api", + ); +} From da57d8aa078108bb64ff6948d1f23aa3740f5944 Mon Sep 17 00:00:00 2001 From: Ehab Younes Date: Mon, 17 Nov 2025 19:06:32 +0300 Subject: [PATCH 2/5] Fix an issue with changing host + Fix an issue with catching errors on socket creation --- src/api/coderApi.ts | 28 ++++++++++++++------------ src/websocket/reconnectingWebSocket.ts | 23 ++++++++++----------- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/api/coderApi.ts b/src/api/coderApi.ts index d0cb1378..0c9886a2 100644 --- a/src/api/coderApi.ts +++ b/src/api/coderApi.ts @@ -74,15 +74,14 @@ export class CoderApi extends Api { client.setSessionToken(token); } - setupInterceptors(client, baseUrl, output); + setupInterceptors(client, output); return client; } setSessionToken = (token: string): void => { - const currentToken = - this.getAxiosInstance().defaults.headers.common[coderSessionTokenHeader]; - this.getAxiosInstance().defaults.headers.common[coderSessionTokenHeader] = - token; + const defaultHeaders = this.getAxiosInstance().defaults.headers.common; + const currentToken = defaultHeaders[coderSessionTokenHeader]; + defaultHeaders[coderSessionTokenHeader] = token; if (currentToken !== token) { for (const socket of this.reconnectingSockets) { @@ -92,8 +91,9 @@ export class CoderApi extends Api { }; setHost = (host: string | undefined): void => { - const currentHost = this.getAxiosInstance().defaults.baseURL; - this.getAxiosInstance().defaults.baseURL = host; + const defaults = this.getAxiosInstance().defaults; + const currentHost = defaults.baseURL; + defaults.baseURL = host; if (currentHost !== host) { for (const socket of this.reconnectingSockets) { @@ -380,14 +380,11 @@ export class CoderApi extends Api { /** * Set up logging and request interceptors for the CoderApi instance. */ -function setupInterceptors( - client: CoderApi, - baseUrl: string, - output: Logger, -): void { +function setupInterceptors(client: CoderApi, output: Logger): void { addLoggingInterceptors(client.getAxiosInstance(), output); client.getAxiosInstance().interceptors.request.use(async (config) => { + const baseUrl = client.getAxiosInstance().defaults.baseURL; const headers = await getHeaders( baseUrl, getHeaderCommand(vscode.workspace.getConfiguration()), @@ -413,7 +410,12 @@ function setupInterceptors( client.getAxiosInstance().interceptors.response.use( (r) => r, async (err) => { - throw await CertificateError.maybeWrap(err, baseUrl, output); + const baseUrl = client.getAxiosInstance().defaults.baseURL; + if (baseUrl) { + throw await CertificateError.maybeWrap(err, baseUrl, output); + } else { + throw err; + } }, ); } diff --git a/src/websocket/reconnectingWebSocket.ts b/src/websocket/reconnectingWebSocket.ts index 6c63c0ca..8dbe2733 100644 --- a/src/websocket/reconnectingWebSocket.ts +++ b/src/websocket/reconnectingWebSocket.ts @@ -10,9 +10,9 @@ import type { export type SocketFactory = () => Promise>; export type ReconnectingWebSocketOptions = { - initialBackoffMs?: number; // Default: 250ms - maxBackoffMs?: number; // Default: 30s - jitterFactor?: number; // Default: 0.1 (±10%) + initialBackoffMs?: number; + maxBackoffMs?: number; + jitterFactor?: number; }; // 403 Forbidden, 410 Gone, 426 Upgrade Required, 1002/1003 Protocol errors @@ -147,7 +147,6 @@ export class ReconnectingWebSocket } this.#isConnecting = true; - try { const socket = await this.#socketFactory(); this.#currentSocket = socket; @@ -199,13 +198,6 @@ export class ReconnectingWebSocket // Reconnect on abnormal closures (e.g., 1006) or other unexpected codes this.#scheduleReconnect(); }); - } catch (error) { - if (!this.#isDisposed) { - this.#logger.warn( - `[ReconnectingWebSocket] Failed: ${error instanceof Error ? error.message : String(error)} for ${this.#apiRoute}`, - ); - this.#scheduleReconnect(); - } } finally { this.#isConnecting = false; } @@ -227,7 +219,14 @@ export class ReconnectingWebSocket this.#reconnectTimeoutId = setTimeout(() => { this.#reconnectTimeoutId = null; // Errors already handled in #connect - void this.#connect(); + this.#connect().catch((error) => { + if (!this.#isDisposed) { + this.#logger.warn( + `[ReconnectingWebSocket] Failed: ${error instanceof Error ? error.message : String(error)} for ${this.#apiRoute}`, + ); + this.#scheduleReconnect(); + } + }); }, delayMs); this.#backoffMs = Math.min(this.#backoffMs * 2, this.#options.maxBackoffMs); From 53037a37ec791f3157689f51c4808d8104b58f7c Mon Sep 17 00:00:00 2001 From: Ehab Younes Date: Mon, 17 Nov 2025 20:54:07 +0300 Subject: [PATCH 3/5] Improve logging messages --- src/websocket/reconnectingWebSocket.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/websocket/reconnectingWebSocket.ts b/src/websocket/reconnectingWebSocket.ts index 8dbe2733..9d54f2de 100644 --- a/src/websocket/reconnectingWebSocket.ts +++ b/src/websocket/reconnectingWebSocket.ts @@ -178,7 +178,7 @@ export class ReconnectingWebSocket if (UNRECOVERABLE_CLOSE_CODES.has(event.code)) { this.#logger.error( - `[ReconnectingWebSocket] Unrecoverable error (${event.code}) for ${this.#apiRoute}`, + `WebSocket connection closed with unrecoverable error code ${event.code}`, ); this.#isDisposed = true; return; @@ -213,7 +213,7 @@ export class ReconnectingWebSocket const delayMs = Math.max(0, this.#backoffMs + jitter); this.#logger.debug( - `[ReconnectingWebSocket] Reconnecting in ${Math.round(delayMs)}ms for ${this.#apiRoute}`, + `Reconnecting WebSocket in ${Math.round(delayMs)}ms for ${this.#apiRoute}`, ); this.#reconnectTimeoutId = setTimeout(() => { @@ -222,7 +222,7 @@ export class ReconnectingWebSocket this.#connect().catch((error) => { if (!this.#isDisposed) { this.#logger.warn( - `[ReconnectingWebSocket] Failed: ${error instanceof Error ? error.message : String(error)} for ${this.#apiRoute}`, + `WebSocket connection failed for ${this.#apiRoute}: ${error instanceof Error ? error.message : String(error)}`, ); this.#scheduleReconnect(); } From 656bb4cfa4e343c70cfefb797ec05de1b1a6be84 Mon Sep 17 00:00:00 2001 From: Ehab Younes Date: Wed, 19 Nov 2025 15:49:49 +0300 Subject: [PATCH 4/5] Various improvements to reconnecting websockets + More tests --- src/api/coderApi.ts | 5 + src/websocket/reconnectingWebSocket.ts | 158 ++++++++------ test/unit/api/coderApi.test.ts | 100 ++++----- .../websocket/reconnectingWebSocket.test.ts | 197 +++++++++++------- 4 files changed, 261 insertions(+), 199 deletions(-) diff --git a/src/api/coderApi.ts b/src/api/coderApi.ts index 0c9886a2..2779befd 100644 --- a/src/api/coderApi.ts +++ b/src/api/coderApi.ts @@ -240,6 +240,11 @@ export class CoderApi extends Api { socketFactory, this.output, configs.apiRoute, + undefined, + () => + this.reconnectingSockets.delete( + reconnectingSocket as ReconnectingWebSocket, + ), ); this.reconnectingSockets.add( diff --git a/src/websocket/reconnectingWebSocket.ts b/src/websocket/reconnectingWebSocket.ts index 9d54f2de..a51900a3 100644 --- a/src/websocket/reconnectingWebSocket.ts +++ b/src/websocket/reconnectingWebSocket.ts @@ -1,4 +1,5 @@ import type { WebSocketEventType } from "coder/site/src/utils/OneWayWebSocket"; +import type { CloseEvent } from "ws"; import type { Logger } from "../logging/logger"; @@ -18,9 +19,6 @@ export type ReconnectingWebSocketOptions = { // 403 Forbidden, 410 Gone, 426 Upgrade Required, 1002/1003 Protocol errors const UNRECOVERABLE_CLOSE_CODES = new Set([403, 410, 426, 1002, 1003]); -// Custom close code for intentional reconnection (4000-4999 range is for private use) -const CLOSE_CODE_RECONNECTING = 4000; - export class ReconnectingWebSocket implements UnidirectionalStream { @@ -40,12 +38,15 @@ export class ReconnectingWebSocket #reconnectTimeoutId: NodeJS.Timeout | null = null; #isDisposed = false; #isConnecting = false; + #pendingReconnect = false; + readonly #onDispose?: () => void; private constructor( socketFactory: SocketFactory, logger: Logger, apiRoute: string, options: ReconnectingWebSocketOptions = {}, + onDispose?: () => void, ) { this.#socketFactory = socketFactory; this.#logger = logger; @@ -56,6 +57,7 @@ export class ReconnectingWebSocket jitterFactor: options.jitterFactor ?? 0.1, }; this.#backoffMs = this.#options.initialBackoffMs; + this.#onDispose = onDispose; } static async create( @@ -63,14 +65,16 @@ export class ReconnectingWebSocket logger: Logger, apiRoute: string, options: ReconnectingWebSocketOptions = {}, + onDispose?: () => void, ): Promise> { const instance = new ReconnectingWebSocket( socketFactory, logger, apiRoute, options, + onDispose, ); - await instance.#connect(); + await instance.connect(); return instance; } @@ -85,10 +89,6 @@ export class ReconnectingWebSocket (this.#eventHandlers[event] as Set>).add( callback, ); - - if (this.#currentSocket) { - this.#currentSocket.addEventListener(event, callback); - } } removeEventListener( @@ -98,95 +98,95 @@ export class ReconnectingWebSocket (this.#eventHandlers[event] as Set>).delete( callback, ); - - if (this.#currentSocket) { - this.#currentSocket.removeEventListener(event, callback); - } } - close(code?: number, reason?: string): void { + reconnect(): void { if (this.#isDisposed) { return; } - this.#isDisposed = true; - if (this.#reconnectTimeoutId !== null) { clearTimeout(this.#reconnectTimeoutId); this.#reconnectTimeoutId = null; } - if (this.#currentSocket) { - this.#currentSocket.close(code, reason); - this.#currentSocket = null; + // If already connecting, schedule reconnect after current attempt + if (this.#isConnecting) { + this.#pendingReconnect = true; + return; } - for (const set of Object.values(this.#eventHandlers)) { - set.clear(); - } + // connect() will close any existing socket + this.connect().catch((error) => { + if (!this.#isDisposed) { + this.#logger.warn( + `Manual reconnection failed for ${this.#apiRoute}: ${error instanceof Error ? error.message : String(error)}`, + ); + this.scheduleReconnect(); + } + }); } - reconnect(): void { + close(code?: number, reason?: string): void { if (this.#isDisposed) { return; } - if (this.#reconnectTimeoutId !== null) { - clearTimeout(this.#reconnectTimeoutId); - this.#reconnectTimeoutId = null; - } - + // Fire close handlers synchronously before disposing if (this.#currentSocket) { - this.#currentSocket.close(CLOSE_CODE_RECONNECTING, "Reconnecting"); + this.executeHandlers("close", { + code: code ?? 1000, + reason: reason ?? "", + wasClean: true, + type: "close", + target: this.#currentSocket, + } as CloseEvent); } + + this.dispose(code, reason); } - async #connect(): Promise { + private async connect(): Promise { if (this.#isDisposed || this.#isConnecting) { return; } this.#isConnecting = true; try { + // Close any existing socket before creating a new one + if (this.#currentSocket) { + this.#currentSocket.close(1000, "Replacing connection"); + this.#currentSocket = null; + } + const socket = await this.#socketFactory(); this.#currentSocket = socket; - socket.addEventListener("open", () => { + socket.addEventListener("open", (event) => { this.#backoffMs = this.#options.initialBackoffMs; + this.executeHandlers("open", event); }); - for (const handler of this.#eventHandlers.open) { - socket.addEventListener("open", handler); - } - - for (const handler of this.#eventHandlers.message) { - socket.addEventListener("message", handler); - } + socket.addEventListener("message", (event) => { + this.executeHandlers("message", event); + }); - for (const handler of this.#eventHandlers.error) { - socket.addEventListener("error", handler); - } + socket.addEventListener("error", (event) => { + this.executeHandlers("error", event); + }); socket.addEventListener("close", (event) => { - for (const handler of this.#eventHandlers.close) { - handler(event); - } - if (this.#isDisposed) { return; } + this.executeHandlers("close", event); + if (UNRECOVERABLE_CLOSE_CODES.has(event.code)) { this.#logger.error( `WebSocket connection closed with unrecoverable error code ${event.code}`, ); - this.#isDisposed = true; - return; - } - - // Reconnect if this was an intentional close for reconnection - if (event.code === CLOSE_CODE_RECONNECTING) { - this.#scheduleReconnect(); + this.dispose(); return; } @@ -196,14 +196,19 @@ export class ReconnectingWebSocket } // Reconnect on abnormal closures (e.g., 1006) or other unexpected codes - this.#scheduleReconnect(); + this.scheduleReconnect(); }); } finally { this.#isConnecting = false; + + if (this.#pendingReconnect) { + this.#pendingReconnect = false; + this.reconnect(); + } } } - #scheduleReconnect(): void { + private scheduleReconnect(): void { if (this.#isDisposed || this.#reconnectTimeoutId !== null) { return; } @@ -218,13 +223,12 @@ export class ReconnectingWebSocket this.#reconnectTimeoutId = setTimeout(() => { this.#reconnectTimeoutId = null; - // Errors already handled in #connect - this.#connect().catch((error) => { + this.connect().catch((error) => { if (!this.#isDisposed) { this.#logger.warn( `WebSocket connection failed for ${this.#apiRoute}: ${error instanceof Error ? error.message : String(error)}`, ); - this.#scheduleReconnect(); + this.scheduleReconnect(); } }); }, delayMs); @@ -232,7 +236,45 @@ export class ReconnectingWebSocket this.#backoffMs = Math.min(this.#backoffMs * 2, this.#options.maxBackoffMs); } - isDisposed(): boolean { - return this.#isDisposed; + private executeHandlers( + event: TEvent, + eventData: Parameters>[0], + ): void { + const handlers = this.#eventHandlers[event] as Set< + EventHandler + >; + for (const handler of handlers) { + try { + handler(eventData); + } catch (error) { + this.#logger.error( + `Error in ${event} handler for ${this.#apiRoute}: ${error instanceof Error ? error.message : String(error)}`, + ); + } + } + } + + private dispose(code?: number, reason?: string): void { + if (this.#isDisposed) { + return; + } + + this.#isDisposed = true; + + if (this.#reconnectTimeoutId !== null) { + clearTimeout(this.#reconnectTimeoutId); + this.#reconnectTimeoutId = null; + } + + if (this.#currentSocket) { + this.#currentSocket.close(code, reason); + this.#currentSocket = null; + } + + for (const set of Object.values(this.#eventHandlers)) { + set.clear(); + } + + this.#onDispose?.(); } } diff --git a/test/unit/api/coderApi.test.ts b/test/unit/api/coderApi.test.ts index d50782f6..4f90f33e 100644 --- a/test/unit/api/coderApi.test.ts +++ b/test/unit/api/coderApi.test.ts @@ -374,96 +374,66 @@ describe("CoderApi", () => { }); describe("Reconnection on Host/Token Changes", () => { - it("triggers reconnection when session token changes", async () => { - const mockWs = createMockWebSocket( - `wss://${CODER_URL.replace("https://", "")}/api/v2/workspaceagents/${AGENT_ID}/watch-metadata-ws`, - { + const setupAutoOpeningWebSocket = () => { + const sockets: Array> = []; + vi.mocked(Ws).mockImplementation((url: string | URL) => { + const mockWs = createMockWebSocket(String(url), { on: vi.fn((event, handler) => { if (event === "open") { setImmediate(() => handler()); } return mockWs as Ws; }), - }, - ); - setupWebSocketMock(mockWs); + }); + sockets.push(mockWs); + return mockWs as Ws; + }); + return sockets; + }; + it("triggers reconnection when session token changes", async () => { + const sockets = setupAutoOpeningWebSocket(); api = createApi(CODER_URL, AXIOS_TOKEN); - const _ws = await api.watchAgentMetadata(AGENT_ID); + await api.watchAgentMetadata(AGENT_ID); - // Change token - should trigger reconnection api.setSessionToken("new-token"); + await new Promise((resolve) => setImmediate(resolve)); - expect(mockWs.close).toHaveBeenCalledWith(4000, "Reconnecting"); + expect(sockets[0].close).toHaveBeenCalledWith( + 1000, + "Replacing connection", + ); + expect(sockets).toHaveLength(2); }); it("triggers reconnection when host changes", async () => { - const mockWs = createMockWebSocket( - `wss://${CODER_URL.replace("https://", "")}/api/v2/workspaceagents/${AGENT_ID}/watch-metadata-ws`, - { - on: vi.fn((event, handler) => { - if (event === "open") { - setImmediate(() => handler()); - } - return mockWs as Ws; - }), - }, - ); - setupWebSocketMock(mockWs); - + const sockets = setupAutoOpeningWebSocket(); api = createApi(CODER_URL, AXIOS_TOKEN); - const _ws = await api.watchAgentMetadata(AGENT_ID); + const wsWrap = await api.watchAgentMetadata(AGENT_ID); + expect(wsWrap.url).toContain(CODER_URL.replace("http", "ws")); - // Change host - should trigger reconnection api.setHost("https://new-coder.example.com"); + await new Promise((resolve) => setImmediate(resolve)); - expect(mockWs.close).toHaveBeenCalledWith(4000, "Reconnecting"); - }); - - it("does not reconnect when token is set to same value", async () => { - const mockWs = createMockWebSocket( - `wss://${CODER_URL.replace("https://", "")}/api/v2/workspaceagents/${AGENT_ID}/watch-metadata-ws`, - { - on: vi.fn((event, handler) => { - if (event === "open") { - setImmediate(() => handler()); - } - return mockWs as Ws; - }), - }, + expect(sockets[0].close).toHaveBeenCalledWith( + 1000, + "Replacing connection", ); - setupWebSocketMock(mockWs); - - api = createApi(CODER_URL, AXIOS_TOKEN); - const _ws = await api.watchAgentMetadata(AGENT_ID); - - // Set same token - should NOT trigger reconnection - api.setSessionToken(AXIOS_TOKEN); - - expect(mockWs.close).not.toHaveBeenCalled(); + expect(sockets).toHaveLength(2); + expect(wsWrap.url).toContain("wss://new-coder.example.com"); }); - it("does not reconnect when host is set to same value", async () => { - const mockWs = createMockWebSocket( - `wss://${CODER_URL.replace("https://", "")}/api/v2/workspaceagents/${AGENT_ID}/watch-metadata-ws`, - { - on: vi.fn((event, handler) => { - if (event === "open") { - setImmediate(() => handler()); - } - return mockWs as Ws; - }), - }, - ); - setupWebSocketMock(mockWs); - + it("does not reconnect when token or host are unchanged", async () => { + const sockets = setupAutoOpeningWebSocket(); api = createApi(CODER_URL, AXIOS_TOKEN); - const _ws = await api.watchAgentMetadata(AGENT_ID); + await api.watchAgentMetadata(AGENT_ID); - // Set same host - should NOT trigger reconnection + // Same values as before + api.setSessionToken(AXIOS_TOKEN); api.setHost(CODER_URL); - expect(mockWs.close).not.toHaveBeenCalled(); + expect(sockets[0].close).not.toHaveBeenCalled(); + expect(sockets).toHaveLength(1); }); }); diff --git a/test/unit/websocket/reconnectingWebSocket.test.ts b/test/unit/websocket/reconnectingWebSocket.test.ts index d3a39645..cce231a1 100644 --- a/test/unit/websocket/reconnectingWebSocket.test.ts +++ b/test/unit/websocket/reconnectingWebSocket.test.ts @@ -22,13 +22,12 @@ describe("ReconnectingWebSocket", () => { }); describe("Reconnection Logic", () => { - it("reconnects on abnormal closure (1006)", async () => { + it("automatically reconnects on abnormal closure (1006)", async () => { const { ws, sockets } = await createReconnectingWebSocket(); sockets[0].fireOpen(); sockets[0].fireClose({ code: 1006, reason: "Network error" }); - // Should schedule reconnect await vi.advanceTimersByTimeAsync(300); expect(sockets).toHaveLength(2); @@ -38,28 +37,29 @@ describe("ReconnectingWebSocket", () => { it.each([ { code: 1000, name: "Normal Closure" }, { code: 1001, name: "Going Away" }, - ])("does NOT reconnect on $name ($code)", async ({ code }) => { - const { ws, sockets } = await createReconnectingWebSocket(); + ])( + "does not reconnect on normal closure: $name ($code)", + async ({ code }) => { + const { ws, sockets } = await createReconnectingWebSocket(); - sockets[0].fireOpen(); - sockets[0].fireClose({ code, reason: "Normal" }); + sockets[0].fireOpen(); + sockets[0].fireClose({ code, reason: "Normal" }); - // Should NOT reconnect - await vi.advanceTimersByTimeAsync(10000); - expect(sockets).toHaveLength(1); + await vi.advanceTimersByTimeAsync(10000); + expect(sockets).toHaveLength(1); - ws.close(); - }); + ws.close(); + }, + ); it.each([403, 410, 426, 1002, 1003])( - "does NOT reconnect on unrecoverable error (%i)", + "does not reconnect on unrecoverable error: %i", async (code) => { const { ws, sockets } = await createReconnectingWebSocket(); sockets[0].fireOpen(); sockets[0].fireClose({ code, reason: "Unrecoverable" }); - // Should NOT reconnect await vi.advanceTimersByTimeAsync(10000); expect(sockets).toHaveLength(1); @@ -67,24 +67,68 @@ describe("ReconnectingWebSocket", () => { }, ); - it("reconnects when manually calling reconnect()", async () => { + it("reconnect() connects immediately and cancels pending reconnections", async () => { const { ws, sockets } = await createReconnectingWebSocket(); sockets[0].fireOpen(); - // Manually trigger reconnection + sockets[0].fireClose({ code: 1006, reason: "Connection lost" }); + + // Manual reconnect() should happen immediately and cancel the scheduled reconnect ws.reconnect(); - sockets[0].fireClose({ code: 4000, reason: "Reconnecting" }); + expect(sockets).toHaveLength(2); - // Should reconnect - await vi.advanceTimersByTimeAsync(300); + // Verify pending reconnect was cancelled - no third socket should be created + await vi.advanceTimersByTimeAsync(1000); + expect(sockets).toHaveLength(2); + + ws.close(); + }); + + it("queues reconnect() calls made during connection", async () => { + const sockets: MockSocket[] = []; + let pendingResolve: ((socket: MockSocket) => void) | null = null; + + const factory = vi.fn(() => { + const socket = createMockSocket(); + sockets.push(socket); + + // First call resolves immediately, other calls wait for manual resolve + if (sockets.length === 1) { + return Promise.resolve(socket); + } else { + return new Promise((resolve) => { + pendingResolve = resolve; + }); + } + }); + + const ws = await fromFactory(factory); + sockets[0].fireOpen(); + expect(sockets).toHaveLength(1); + + // Start first reconnect (will block on factory promise) + ws.reconnect(); + expect(sockets).toHaveLength(2); + // Call reconnect again while first reconnect is in progress + ws.reconnect(); + // Still only 2 sockets (queued reconnect hasn't started) expect(sockets).toHaveLength(2); + // Complete the first reconnect + pendingResolve!(sockets[1]); + sockets[1].fireOpen(); + + // Wait a tick for the queued reconnect to execute + await Promise.resolve(); + // Now queued reconnect should have executed, creating third socket + expect(sockets).toHaveLength(3); + ws.close(); }); }); - describe("Listener Persistence", () => { - it("keeps listeners subscribed across reconnections", async () => { + describe("Event Handlers", () => { + it("persists event handlers across reconnections", async () => { const { ws, sockets } = await createReconnectingWebSocket(); sockets[0].fireOpen(); @@ -108,7 +152,7 @@ describe("ReconnectingWebSocket", () => { ws.close(); }); - it("properly removes listeners", async () => { + it("removes event handlers when removeEventListener is called", async () => { const socket = createMockSocket(); const factory = vi.fn(() => Promise.resolve(socket)); @@ -131,35 +175,71 @@ describe("ReconnectingWebSocket", () => { }); }); - describe("Disposal", () => { - it("stops reconnection when disposed", async () => { + describe("close() and Disposal", () => { + it("stops reconnection when close() is called", async () => { const { ws, sockets } = await createReconnectingWebSocket(); - const socket = sockets[0]; - socket.fireOpen(); - // Close and immediately dispose - socket.fireClose({ code: 1006, reason: "Network" }); + sockets[0].fireOpen(); + sockets[0].fireClose({ code: 1006, reason: "Network" }); ws.close(); - // Should NOT reconnect await vi.advanceTimersByTimeAsync(10000); expect(sockets).toHaveLength(1); }); - it("closes the underlying socket", async () => { + it("closes the underlying socket with provided code and reason", async () => { const socket = createMockSocket(); const factory = vi.fn(() => Promise.resolve(socket)); const ws = await fromFactory(factory); socket.fireOpen(); - ws.close(1000, "Test close"); + expect(socket.close).toHaveBeenCalledWith(1000, "Test close"); }); + + it("calls onDispose callback once, even with multiple close() calls", async () => { + let disposeCount = 0; + const { ws } = await createReconnectingWebSocket(() => ++disposeCount); + + ws.close(); + ws.close(); + ws.close(); + + expect(disposeCount).toBe(1); + }); + + it("calls onDispose callback on unrecoverable error", async () => { + let disposeCount = 0; + const { sockets } = await createReconnectingWebSocket( + () => ++disposeCount, + ); + + sockets[0].fireOpen(); + sockets[0].fireClose({ code: 403, reason: "Forbidden" }); + + expect(disposeCount).toBe(1); + }); + + it("does not call onDispose callback during reconnection", async () => { + let disposeCount = 0; + const { ws, sockets } = await createReconnectingWebSocket( + () => ++disposeCount, + ); + + sockets[0].fireOpen(); + sockets[0].fireClose({ code: 1006, reason: "Network error" }); + + await vi.advanceTimersByTimeAsync(300); + expect(disposeCount).toBe(0); + + ws.close(); + expect(disposeCount).toBe(1); + }); }); - describe("Exponential Backoff", () => { - it("increases backoff exponentially on repeated failures", async () => { + describe("Backoff Strategy", () => { + it("doubles backoff delay after each failed connection", async () => { const { ws, sockets } = await createReconnectingWebSocket(); const socket = sockets[0]; socket.fireOpen(); @@ -180,7 +260,7 @@ describe("ReconnectingWebSocket", () => { ws.close(); }); - it("resets backoff after successful connection", async () => { + it("resets backoff delay after successful connection", async () => { const { ws, sockets } = await createReconnectingWebSocket(); const socket1 = sockets[0]; socket1.fireOpen(); @@ -200,44 +280,8 @@ describe("ReconnectingWebSocket", () => { }); }); - describe("Edge Cases", () => { - it("handles disposal during reconnection delay", async () => { - const { ws, sockets } = await createReconnectingWebSocket(); - - sockets[0].fireOpen(); - sockets[0].fireClose({ code: 1006, reason: "Network" }); - - // Dispose while waiting for reconnect - await vi.advanceTimersByTimeAsync(100); - ws.close(); - - // Should not reconnect - await vi.advanceTimersByTimeAsync(10000); - expect(sockets).toHaveLength(1); - }); - - it("prevents concurrent reconnect attempts", async () => { - const socket = createMockSocket(); - const factory = vi.fn(() => Promise.resolve(socket)); - const ws = await fromFactory(factory); - - socket.fireOpen(); - - // Call reconnect multiple times rapidly - ws.reconnect(); - ws.reconnect(); - ws.reconnect(); - socket.fireClose({ code: 4000, reason: "Reconnecting" }); - - await vi.advanceTimersByTimeAsync(300); - - // Should only trigger one reconnection - expect(factory).toHaveBeenCalledTimes(2); - - ws.close(); - }); - - it("handles errors during socket factory", async () => { + describe("Error Handling", () => { + it("schedules retry when socket factory throws error", async () => { const sockets: MockSocket[] = []; let shouldFail = false; const factory = vi.fn(() => { @@ -252,11 +296,9 @@ describe("ReconnectingWebSocket", () => { sockets[0].fireOpen(); - // Make factory fail shouldFail = true; sockets[0].fireClose({ code: 1006, reason: "Network" }); - // Should schedule retry await vi.advanceTimersByTimeAsync(300); expect(sockets).toHaveLength(1); @@ -312,7 +354,7 @@ function createMockSocket(): MockSocket { cb({ code: event.code, reason: event.reason, - wasClean: false, + wasClean: event.code === 1000, } as CloseEvent); } }, @@ -333,7 +375,7 @@ function createMockSocket(): MockSocket { }; } -async function createReconnectingWebSocket(): Promise<{ +async function createReconnectingWebSocket(onDispose?: () => void): Promise<{ ws: ReconnectingWebSocket; sockets: MockSocket[]; }> { @@ -343,7 +385,7 @@ async function createReconnectingWebSocket(): Promise<{ sockets.push(socket); return Promise.resolve(socket); }); - const ws = await fromFactory(factory); + const ws = await fromFactory(factory, onDispose); // We start with one socket expect(sockets).toHaveLength(1); @@ -353,10 +395,13 @@ async function createReconnectingWebSocket(): Promise<{ async function fromFactory( factory: SocketFactory, + onDispose?: () => void, ): Promise> { return await ReconnectingWebSocket.create( factory, createMockLogger(), "/random/api", + undefined, + onDispose, ); } From eec4c99a21c7fcf0eb7ac8068de9c5176e067e5f Mon Sep 17 00:00:00 2001 From: Ehab Younes Date: Fri, 21 Nov 2025 16:29:35 +0300 Subject: [PATCH 5/5] Address review comments * Better error handling for HTTP errors * Refactor magic numbers * Use strict types and avoid casting --- src/api/coderApi.ts | 13 ++- src/websocket/codes.ts | 55 +++++++++ src/websocket/eventStreamConnection.ts | 13 ++- src/websocket/reconnectingWebSocket.ts | 104 +++++++++++------- src/websocket/sseConnection.ts | 27 ++--- .../websocket/reconnectingWebSocket.test.ts | 97 +++++++++++++--- test/unit/websocket/sseConnection.test.ts | 18 ++- 7 files changed, 241 insertions(+), 86 deletions(-) create mode 100644 src/websocket/codes.ts diff --git a/src/api/coderApi.ts b/src/api/coderApi.ts index 2779befd..04c696be 100644 --- a/src/api/coderApi.ts +++ b/src/api/coderApi.ts @@ -14,7 +14,7 @@ import { type WorkspaceAgentLog, } from "coder/site/src/api/typesGenerated"; import * as vscode from "vscode"; -import { type ClientOptions, type CloseEvent, type ErrorEvent } from "ws"; +import { type ClientOptions } from "ws"; import { CertificateError } from "../error"; import { getHeaderCommand, getHeaders } from "../headers"; @@ -31,7 +31,12 @@ import { HttpClientLogLevel, } from "../logging/types"; import { sizeOf } from "../logging/utils"; -import { type UnidirectionalStream } from "../websocket/eventStreamConnection"; +import { HttpStatusCode } from "../websocket/codes"; +import { + type UnidirectionalStream, + type CloseEvent, + type ErrorEvent, +} from "../websocket/eventStreamConnection"; import { OneWayWebSocket, type OneWayWebSocketInit, @@ -336,8 +341,8 @@ export class CoderApi extends Api { const handleError = (event: ErrorEvent) => { cleanup(); const is404 = - event.message?.includes("404") || - event.error?.message?.includes("404"); + event.message?.includes(String(HttpStatusCode.NOT_FOUND)) || + event.error?.message?.includes(String(HttpStatusCode.NOT_FOUND)); if (is404 && onNotFound) { connection.close(); diff --git a/src/websocket/codes.ts b/src/websocket/codes.ts new file mode 100644 index 00000000..ac8eccf7 --- /dev/null +++ b/src/websocket/codes.ts @@ -0,0 +1,55 @@ +/** + * WebSocket close codes (RFC 6455) and HTTP status codes for socket connections. + * @see https://www.rfc-editor.org/rfc/rfc6455#section-7.4.1 + */ + +/** WebSocket close codes defined in RFC 6455 */ +export const WebSocketCloseCode = { + /** Normal closure - connection successfully completed */ + NORMAL: 1000, + /** Endpoint going away (server shutdown) */ + GOING_AWAY: 1001, + /** Protocol error - connection cannot be recovered */ + PROTOCOL_ERROR: 1002, + /** Unsupported data type received - connection cannot be recovered */ + UNSUPPORTED_DATA: 1003, + /** Abnormal closure - connection closed without close frame (network issues) */ + ABNORMAL: 1006, +} as const; + +/** HTTP status codes used for socket creation and connection logic */ +export const HttpStatusCode = { + /** Authentication or permission denied */ + FORBIDDEN: 403, + /** Endpoint not found */ + NOT_FOUND: 404, + /** Resource permanently gone */ + GONE: 410, + /** Protocol upgrade required */ + UPGRADE_REQUIRED: 426, +} as const; + +/** + * WebSocket close codes indicating unrecoverable errors. + * These appear in close events and should stop reconnection attempts. + */ +export const UNRECOVERABLE_WS_CLOSE_CODES = new Set([ + WebSocketCloseCode.PROTOCOL_ERROR, + WebSocketCloseCode.UNSUPPORTED_DATA, +]); + +/** + * HTTP status codes indicating unrecoverable errors during handshake. + * These appear during socket creation and should stop reconnection attempts. + */ +export const UNRECOVERABLE_HTTP_CODES = new Set([ + HttpStatusCode.FORBIDDEN, + HttpStatusCode.GONE, + HttpStatusCode.UPGRADE_REQUIRED, +]); + +/** Close codes indicating intentional closure - do not reconnect */ +export const NORMAL_CLOSURE_CODES = new Set([ + WebSocketCloseCode.NORMAL, + WebSocketCloseCode.GOING_AWAY, +]); diff --git a/src/websocket/eventStreamConnection.ts b/src/websocket/eventStreamConnection.ts index 2dc6514e..e3100ee6 100644 --- a/src/websocket/eventStreamConnection.ts +++ b/src/websocket/eventStreamConnection.ts @@ -1,11 +1,16 @@ import { type WebSocketEventType } from "coder/site/src/utils/OneWayWebSocket"; import { - type CloseEvent, + type CloseEvent as WsCloseEvent, type Event as WsEvent, - type ErrorEvent, - type MessageEvent, + type ErrorEvent as WsErrorEvent, + type MessageEvent as WsMessageEvent, } from "ws"; +export type Event = Omit; +export type CloseEvent = Omit; +export type ErrorEvent = Omit; +export type MessageEvent = Omit; + // Event payload types matching OneWayWebSocket export type ParsedMessageEvent = Readonly< | { @@ -24,7 +29,7 @@ export type EventPayloadMap = { close: CloseEvent; error: ErrorEvent; message: ParsedMessageEvent; - open: WsEvent; + open: Event; }; export type EventHandler = ( diff --git a/src/websocket/reconnectingWebSocket.ts b/src/websocket/reconnectingWebSocket.ts index a51900a3..2ced9351 100644 --- a/src/websocket/reconnectingWebSocket.ts +++ b/src/websocket/reconnectingWebSocket.ts @@ -1,5 +1,11 @@ +import { + WebSocketCloseCode, + NORMAL_CLOSURE_CODES, + UNRECOVERABLE_WS_CLOSE_CODES, + UNRECOVERABLE_HTTP_CODES, +} from "./codes"; + import type { WebSocketEventType } from "coder/site/src/utils/OneWayWebSocket"; -import type { CloseEvent } from "ws"; import type { Logger } from "../logging/logger"; @@ -16,9 +22,6 @@ export type ReconnectingWebSocketOptions = { jitterFactor?: number; }; -// 403 Forbidden, 410 Gone, 426 Upgrade Required, 1002/1003 Protocol errors -const UNRECOVERABLE_CLOSE_CODES = new Set([403, 410, 426, 1002, 1003]); - export class ReconnectingWebSocket implements UnidirectionalStream { @@ -26,7 +29,9 @@ export class ReconnectingWebSocket readonly #logger: Logger; readonly #apiRoute: string; readonly #options: Required; - readonly #eventHandlers = { + readonly #eventHandlers: { + [K in WebSocketEventType]: Set>; + } = { open: new Set>(), close: new Set>(), error: new Set>(), @@ -86,18 +91,14 @@ export class ReconnectingWebSocket event: TEvent, callback: EventHandler, ): void { - (this.#eventHandlers[event] as Set>).add( - callback, - ); + this.#eventHandlers[event].add(callback); } removeEventListener( event: TEvent, callback: EventHandler, ): void { - (this.#eventHandlers[event] as Set>).delete( - callback, - ); + this.#eventHandlers[event].delete(callback); } reconnect(): void { @@ -117,14 +118,7 @@ export class ReconnectingWebSocket } // connect() will close any existing socket - this.connect().catch((error) => { - if (!this.#isDisposed) { - this.#logger.warn( - `Manual reconnection failed for ${this.#apiRoute}: ${error instanceof Error ? error.message : String(error)}`, - ); - this.scheduleReconnect(); - } - }); + this.connect().catch((error) => this.handleConnectionError(error)); } close(code?: number, reason?: string): void { @@ -135,12 +129,10 @@ export class ReconnectingWebSocket // Fire close handlers synchronously before disposing if (this.#currentSocket) { this.executeHandlers("close", { - code: code ?? 1000, - reason: reason ?? "", + code: code ?? WebSocketCloseCode.NORMAL, + reason: reason ?? "Normal closure", wasClean: true, - type: "close", - target: this.#currentSocket, - } as CloseEvent); + }); } this.dispose(code, reason); @@ -155,7 +147,10 @@ export class ReconnectingWebSocket try { // Close any existing socket before creating a new one if (this.#currentSocket) { - this.#currentSocket.close(1000, "Replacing connection"); + this.#currentSocket.close( + WebSocketCloseCode.NORMAL, + "Replacing connection", + ); this.#currentSocket = null; } @@ -182,7 +177,7 @@ export class ReconnectingWebSocket this.executeHandlers("close", event); - if (UNRECOVERABLE_CLOSE_CODES.has(event.code)) { + if (UNRECOVERABLE_WS_CLOSE_CODES.has(event.code)) { this.#logger.error( `WebSocket connection closed with unrecoverable error code ${event.code}`, ); @@ -191,7 +186,7 @@ export class ReconnectingWebSocket } // Don't reconnect on normal closure - if (event.code === 1000 || event.code === 1001) { + if (NORMAL_CLOSURE_CODES.has(event.code)) { return; } @@ -223,14 +218,7 @@ export class ReconnectingWebSocket this.#reconnectTimeoutId = setTimeout(() => { this.#reconnectTimeoutId = null; - this.connect().catch((error) => { - if (!this.#isDisposed) { - this.#logger.warn( - `WebSocket connection failed for ${this.#apiRoute}: ${error instanceof Error ? error.message : String(error)}`, - ); - this.scheduleReconnect(); - } - }); + this.connect().catch((error) => this.handleConnectionError(error)); }, delayMs); this.#backoffMs = Math.min(this.#backoffMs * 2, this.#options.maxBackoffMs); @@ -240,20 +228,56 @@ export class ReconnectingWebSocket event: TEvent, eventData: Parameters>[0], ): void { - const handlers = this.#eventHandlers[event] as Set< - EventHandler - >; - for (const handler of handlers) { + for (const handler of this.#eventHandlers[event]) { try { handler(eventData); } catch (error) { this.#logger.error( - `Error in ${event} handler for ${this.#apiRoute}: ${error instanceof Error ? error.message : String(error)}`, + `Error in ${event} handler for ${this.#apiRoute}`, + error, ); } } } + /** + * Checks if the error is unrecoverable and disposes the connection, + * otherwise schedules a reconnect. + */ + private handleConnectionError(error: unknown): void { + if (this.#isDisposed) { + return; + } + + if (this.isUnrecoverableHttpError(error)) { + this.#logger.error( + `Unrecoverable HTTP error during connection for ${this.#apiRoute}`, + error, + ); + this.dispose(); + return; + } + + this.#logger.warn( + `WebSocket connection failed for ${this.#apiRoute}`, + error, + ); + this.scheduleReconnect(); + } + + /** + * Check if an error contains an unrecoverable HTTP status code. + */ + private isUnrecoverableHttpError(error: unknown): boolean { + const errorMessage = error instanceof Error ? error.message : String(error); + for (const code of UNRECOVERABLE_HTTP_CODES) { + if (errorMessage.includes(String(code))) { + return true; + } + } + return false; + } + private dispose(code?: number, reason?: string): void { if (this.#isDisposed) { return; diff --git a/src/websocket/sseConnection.ts b/src/websocket/sseConnection.ts index 5a71d303..dc20eeda 100644 --- a/src/websocket/sseConnection.ts +++ b/src/websocket/sseConnection.ts @@ -6,19 +6,14 @@ import { EventSource } from "eventsource"; import { createStreamingFetchAdapter } from "../api/streamingFetchAdapter"; import { type Logger } from "../logging/logger"; +import { WebSocketCloseCode } from "./codes"; import { getQueryString } from "./utils"; -import type { - CloseEvent as WsCloseEvent, - ErrorEvent as WsErrorEvent, - Event as WsEvent, - MessageEvent as WsMessageEvent, -} from "ws"; - import type { UnidirectionalStream, ParsedMessageEvent, EventHandler, + ErrorEvent as WsErrorEvent, } from "./eventStreamConnection"; export type SseConnectionInit = { @@ -66,7 +61,7 @@ export class SseConnection implements UnidirectionalStream { private setupEventHandlers(): void { this.eventSource.addEventListener("open", () => - this.invokeCallbacks(this.callbacks.open, {} as WsEvent, "open"), + this.invokeCallbacks(this.callbacks.open, {}, "open"), ); this.eventSource.addEventListener("data", (event: MessageEvent) => { @@ -84,10 +79,10 @@ export class SseConnection implements UnidirectionalStream { this.invokeCallbacks( this.callbacks.close, { - code: 1006, + code: WebSocketCloseCode.ABNORMAL, reason: "Connection lost", wasClean: false, - } as WsCloseEvent, + }, "close", ); } @@ -117,7 +112,7 @@ export class SseConnection implements UnidirectionalStream { return { error: error, message: errorMessage, - } as WsErrorEvent; + }; } public addEventListener( @@ -158,7 +153,7 @@ export class SseConnection implements UnidirectionalStream { private parseMessage( event: MessageEvent, ): ParsedMessageEvent { - const wsEvent = { data: event.data } as WsMessageEvent; + const wsEvent = { data: event.data }; try { return { sourceEvent: wsEvent, @@ -207,14 +202,16 @@ export class SseConnection implements UnidirectionalStream { this.invokeCallbacks( this.callbacks.close, { - code: code ?? 1000, + code: code ?? WebSocketCloseCode.NORMAL, reason: reason ?? "Normal closure", wasClean: true, - } as WsCloseEvent, + }, "close", ); - Object.values(this.callbacks).forEach((callbackSet) => callbackSet.clear()); + for (const callbackSet of Object.values(this.callbacks)) { + callbackSet.clear(); + } this.messageWrappers.clear(); } } diff --git a/test/unit/websocket/reconnectingWebSocket.test.ts b/test/unit/websocket/reconnectingWebSocket.test.ts index cce231a1..cdf08949 100644 --- a/test/unit/websocket/reconnectingWebSocket.test.ts +++ b/test/unit/websocket/reconnectingWebSocket.test.ts @@ -1,5 +1,6 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { WebSocketCloseCode, HttpStatusCode } from "@/websocket/codes"; import { ReconnectingWebSocket, type SocketFactory, @@ -26,7 +27,10 @@ describe("ReconnectingWebSocket", () => { const { ws, sockets } = await createReconnectingWebSocket(); sockets[0].fireOpen(); - sockets[0].fireClose({ code: 1006, reason: "Network error" }); + sockets[0].fireClose({ + code: WebSocketCloseCode.ABNORMAL, + reason: "Network error", + }); await vi.advanceTimersByTimeAsync(300); expect(sockets).toHaveLength(2); @@ -35,8 +39,8 @@ describe("ReconnectingWebSocket", () => { }); it.each([ - { code: 1000, name: "Normal Closure" }, - { code: 1001, name: "Going Away" }, + { code: WebSocketCloseCode.NORMAL, name: "Normal Closure" }, + { code: WebSocketCloseCode.GOING_AWAY, name: "Going Away" }, ])( "does not reconnect on normal closure: $name ($code)", async ({ code }) => { @@ -52,8 +56,11 @@ describe("ReconnectingWebSocket", () => { }, ); - it.each([403, 410, 426, 1002, 1003])( - "does not reconnect on unrecoverable error: %i", + it.each([ + WebSocketCloseCode.PROTOCOL_ERROR, + WebSocketCloseCode.UNSUPPORTED_DATA, + ])( + "does not reconnect on unrecoverable WebSocket close code: %i", async (code) => { const { ws, sockets } = await createReconnectingWebSocket(); @@ -67,11 +74,44 @@ describe("ReconnectingWebSocket", () => { }, ); + it.each([ + HttpStatusCode.FORBIDDEN, + HttpStatusCode.GONE, + HttpStatusCode.UPGRADE_REQUIRED, + ])( + "does not reconnect on unrecoverable HTTP error during creation: %i", + async (statusCode) => { + let socketCreationAttempts = 0; + const factory = vi.fn(() => { + socketCreationAttempts++; + // Simulate HTTP error during WebSocket handshake + return Promise.reject( + new Error(`Unexpected server response: ${statusCode}`), + ); + }); + + await expect( + ReconnectingWebSocket.create( + factory, + createMockLogger(), + "/api/test", + ), + ).rejects.toThrow(`Unexpected server response: ${statusCode}`); + + // Should not retry after unrecoverable HTTP error + await vi.advanceTimersByTimeAsync(10000); + expect(socketCreationAttempts).toBe(1); + }, + ); + it("reconnect() connects immediately and cancels pending reconnections", async () => { const { ws, sockets } = await createReconnectingWebSocket(); sockets[0].fireOpen(); - sockets[0].fireClose({ code: 1006, reason: "Connection lost" }); + sockets[0].fireClose({ + code: WebSocketCloseCode.ABNORMAL, + reason: "Connection lost", + }); // Manual reconnect() should happen immediately and cancel the scheduled reconnect ws.reconnect(); @@ -140,7 +180,10 @@ describe("ReconnectingWebSocket", () => { expect(handler).toHaveBeenCalledTimes(1); // Disconnect and reconnect - sockets[0].fireClose({ code: 1006, reason: "Network" }); + sockets[0].fireClose({ + code: WebSocketCloseCode.ABNORMAL, + reason: "Network", + }); await vi.advanceTimersByTimeAsync(300); expect(sockets).toHaveLength(2); sockets[1].fireOpen(); @@ -180,7 +223,10 @@ describe("ReconnectingWebSocket", () => { const { ws, sockets } = await createReconnectingWebSocket(); sockets[0].fireOpen(); - sockets[0].fireClose({ code: 1006, reason: "Network" }); + sockets[0].fireClose({ + code: WebSocketCloseCode.ABNORMAL, + reason: "Network", + }); ws.close(); await vi.advanceTimersByTimeAsync(10000); @@ -193,9 +239,12 @@ describe("ReconnectingWebSocket", () => { const ws = await fromFactory(factory); socket.fireOpen(); - ws.close(1000, "Test close"); + ws.close(WebSocketCloseCode.NORMAL, "Test close"); - expect(socket.close).toHaveBeenCalledWith(1000, "Test close"); + expect(socket.close).toHaveBeenCalledWith( + WebSocketCloseCode.NORMAL, + "Test close", + ); }); it("calls onDispose callback once, even with multiple close() calls", async () => { @@ -209,14 +258,17 @@ describe("ReconnectingWebSocket", () => { expect(disposeCount).toBe(1); }); - it("calls onDispose callback on unrecoverable error", async () => { + it("calls onDispose callback on unrecoverable WebSocket close code", async () => { let disposeCount = 0; const { sockets } = await createReconnectingWebSocket( () => ++disposeCount, ); sockets[0].fireOpen(); - sockets[0].fireClose({ code: 403, reason: "Forbidden" }); + sockets[0].fireClose({ + code: WebSocketCloseCode.PROTOCOL_ERROR, + reason: "Protocol error", + }); expect(disposeCount).toBe(1); }); @@ -228,7 +280,10 @@ describe("ReconnectingWebSocket", () => { ); sockets[0].fireOpen(); - sockets[0].fireClose({ code: 1006, reason: "Network error" }); + sockets[0].fireClose({ + code: WebSocketCloseCode.ABNORMAL, + reason: "Network error", + }); await vi.advanceTimersByTimeAsync(300); expect(disposeCount).toBe(0); @@ -249,7 +304,10 @@ describe("ReconnectingWebSocket", () => { // Fail repeatedly for (let i = 0; i < 4; i++) { const currentSocket = sockets[i]; - currentSocket.fireClose({ code: 1006, reason: "Fail" }); + currentSocket.fireClose({ + code: WebSocketCloseCode.ABNORMAL, + reason: "Fail", + }); const delay = backoffDelays[i]; await vi.advanceTimersByTimeAsync(delay); const nextSocket = sockets[i + 1]; @@ -266,13 +324,13 @@ describe("ReconnectingWebSocket", () => { socket1.fireOpen(); // First disconnect - socket1.fireClose({ code: 1006, reason: "Fail" }); + socket1.fireClose({ code: WebSocketCloseCode.ABNORMAL, reason: "Fail" }); await vi.advanceTimersByTimeAsync(300); const socket2 = sockets[1]; socket2.fireOpen(); // Second disconnect - should use initial backoff again - socket2.fireClose({ code: 1006, reason: "Fail" }); + socket2.fireClose({ code: WebSocketCloseCode.ABNORMAL, reason: "Fail" }); await vi.advanceTimersByTimeAsync(300); expect(sockets).toHaveLength(3); @@ -297,7 +355,10 @@ describe("ReconnectingWebSocket", () => { sockets[0].fireOpen(); shouldFail = true; - sockets[0].fireClose({ code: 1006, reason: "Network" }); + sockets[0].fireClose({ + code: WebSocketCloseCode.ABNORMAL, + reason: "Network", + }); await vi.advanceTimersByTimeAsync(300); expect(sockets).toHaveLength(1); @@ -354,7 +415,7 @@ function createMockSocket(): MockSocket { cb({ code: event.code, reason: event.reason, - wasClean: event.code === 1000, + wasClean: event.code === WebSocketCloseCode.NORMAL, } as CloseEvent); } }, diff --git a/test/unit/websocket/sseConnection.test.ts b/test/unit/websocket/sseConnection.test.ts index 61cfce4d..378e6f54 100644 --- a/test/unit/websocket/sseConnection.test.ts +++ b/test/unit/websocket/sseConnection.test.ts @@ -3,10 +3,14 @@ import { type ServerSentEvent } from "coder/site/src/api/typesGenerated"; import { type WebSocketEventType } from "coder/site/src/utils/OneWayWebSocket"; import { EventSource } from "eventsource"; import { describe, it, expect, vi } from "vitest"; -import { type CloseEvent, type ErrorEvent } from "ws"; import { type Logger } from "@/logging/logger"; -import { type ParsedMessageEvent } from "@/websocket/eventStreamConnection"; +import { WebSocketCloseCode } from "@/websocket/codes"; +import { + type ParsedMessageEvent, + type CloseEvent, + type ErrorEvent, +} from "@/websocket/eventStreamConnection"; import { SseConnection } from "@/websocket/sseConnection"; import { createMockLogger } from "../../mocks/testHelpers"; @@ -168,7 +172,7 @@ describe("SseConnection", () => { await waitForNextTick(); expect(events).toEqual([ { - code: 1006, + code: WebSocketCloseCode.ABNORMAL, reason: "Connection lost", wasClean: false, }, @@ -223,13 +227,17 @@ describe("SseConnection", () => { type CloseHandlingTestCase = [ code: number | undefined, reason: string | undefined, - closeEvent: Omit, + closeEvent: CloseEvent, ]; it.each([ [ undefined, undefined, - { code: 1000, reason: "Normal closure", wasClean: true }, + { + code: WebSocketCloseCode.NORMAL, + reason: "Normal closure", + wasClean: true, + }, ], [ 4000,