diff --git a/src/api/agentMetadataHelper.ts b/src/api/agentMetadataHelper.ts index 4de804ad..26ab1b6f 100644 --- a/src/api/agentMetadataHelper.ts +++ b/src/api/agentMetadataHelper.ts @@ -53,7 +53,11 @@ export async function createAgentMetadataWatcher( event.parsedMessage.data, ); - // Overwrite metadata if it changed. + if (watcher.error !== undefined) { + watcher.error = undefined; + onChange.fire(null); + } + if (JSON.stringify(watcher.metadata) !== JSON.stringify(metadata)) { watcher.metadata = metadata; onChange.fire(null); diff --git a/src/api/coderApi.ts b/src/api/coderApi.ts index ef120ce4..04c696be 100644 --- a/src/api/coderApi.ts +++ b/src/api/coderApi.ts @@ -14,7 +14,7 @@ import { type WorkspaceAgentLog, } from "coder/site/src/api/typesGenerated"; import * as vscode from "vscode"; -import { type ClientOptions, type CloseEvent, type ErrorEvent } from "ws"; +import { type ClientOptions } from "ws"; import { CertificateError } from "../error"; import { getHeaderCommand, getHeaders } from "../headers"; @@ -31,11 +31,20 @@ import { HttpClientLogLevel, } from "../logging/types"; import { sizeOf } from "../logging/utils"; -import { type UnidirectionalStream } from "../websocket/eventStreamConnection"; +import { HttpStatusCode } from "../websocket/codes"; +import { + type UnidirectionalStream, + type CloseEvent, + type ErrorEvent, +} from "../websocket/eventStreamConnection"; import { OneWayWebSocket, type OneWayWebSocketInit, } from "../websocket/oneWayWebSocket"; +import { + ReconnectingWebSocket, + type SocketFactory, +} from "../websocket/reconnectingWebSocket"; import { SseConnection } from "../websocket/sseConnection"; import { createHttpAgent } from "./utils"; @@ -47,6 +56,10 @@ const coderSessionTokenHeader = "Coder-Session-Token"; * and WebSocket methods for real-time functionality. */ export class CoderApi extends Api { + private readonly reconnectingSockets = new Set< + ReconnectingWebSocket + >(); + private constructor(private readonly output: Logger) { super(); } @@ -66,10 +79,34 @@ export class CoderApi extends Api { client.setSessionToken(token); } - setupInterceptors(client, baseUrl, output); + setupInterceptors(client, output); return client; } + setSessionToken = (token: string): void => { + const defaultHeaders = this.getAxiosInstance().defaults.headers.common; + const currentToken = defaultHeaders[coderSessionTokenHeader]; + defaultHeaders[coderSessionTokenHeader] = token; + + if (currentToken !== token) { + for (const socket of this.reconnectingSockets) { + socket.reconnect(); + } + } + }; + + setHost = (host: string | undefined): void => { + const defaults = this.getAxiosInstance().defaults; + const currentHost = defaults.baseURL; + defaults.baseURL = host; + + if (currentHost !== host) { + for (const socket of this.reconnectingSockets) { + socket.reconnect(); + } + } + }; + watchInboxNotifications = async ( watchTemplates: string[], watchTargets: string[], @@ -83,6 +120,7 @@ export class CoderApi extends Api { targets: watchTargets.join(","), }, options, + enableRetry: true, }); }; @@ -91,6 +129,7 @@ export class CoderApi extends Api { apiRoute: `/api/v2/workspaces/${workspace.id}/watch-ws`, fallbackApiRoute: `/api/v2/workspaces/${workspace.id}/watch`, options, + enableRetry: true, }); }; @@ -102,6 +141,7 @@ export class CoderApi extends Api { apiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata-ws`, fallbackApiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata`, options, + enableRetry: true, }); }; @@ -148,53 +188,78 @@ export class CoderApi extends Api { } private async createWebSocket( - configs: Omit, - ) { - const baseUrlRaw = this.getAxiosInstance().defaults.baseURL; - if (!baseUrlRaw) { - throw new Error("No base URL set on REST client"); - } + configs: Omit & { enableRetry?: boolean }, + ): Promise> { + const { enableRetry, ...socketConfigs } = configs; + + const socketFactory: SocketFactory = async () => { + const baseUrlRaw = this.getAxiosInstance().defaults.baseURL; + if (!baseUrlRaw) { + throw new Error("No base URL set on REST client"); + } + + const baseUrl = new URL(baseUrlRaw); + const token = this.getAxiosInstance().defaults.headers.common[ + coderSessionTokenHeader + ] as string | undefined; + + const headersFromCommand = await getHeaders( + baseUrlRaw, + getHeaderCommand(vscode.workspace.getConfiguration()), + this.output, + ); - const baseUrl = new URL(baseUrlRaw); - const token = this.getAxiosInstance().defaults.headers.common[ - coderSessionTokenHeader - ] as string | undefined; + const httpAgent = await createHttpAgent( + vscode.workspace.getConfiguration(), + ); - const headersFromCommand = await getHeaders( - baseUrlRaw, - getHeaderCommand(vscode.workspace.getConfiguration()), - this.output, - ); + /** + * Similar to the REST client, we want to prioritize headers in this order (highest to lowest): + * 1. Headers from the header command + * 2. Any headers passed directly to this function + * 3. Coder session token from the Api client (if set) + */ + const headers = { + ...(token ? { [coderSessionTokenHeader]: token } : {}), + ...configs.options?.headers, + ...headersFromCommand, + }; - const httpAgent = await createHttpAgent( - vscode.workspace.getConfiguration(), - ); + const webSocket = new OneWayWebSocket({ + location: baseUrl, + ...socketConfigs, + options: { + ...configs.options, + agent: httpAgent, + followRedirects: true, + headers, + }, + }); - /** - * Similar to the REST client, we want to prioritize headers in this order (highest to lowest): - * 1. Headers from the header command - * 2. Any headers passed directly to this function - * 3. Coder session token from the Api client (if set) - */ - const headers = { - ...(token ? { [coderSessionTokenHeader]: token } : {}), - ...configs.options?.headers, - ...headersFromCommand, + this.attachStreamLogger(webSocket); + return webSocket; }; - const webSocket = new OneWayWebSocket({ - location: baseUrl, - ...configs, - options: { - ...configs.options, - agent: httpAgent, - followRedirects: true, - headers, - }, - }); + if (enableRetry) { + const reconnectingSocket = await ReconnectingWebSocket.create( + socketFactory, + this.output, + configs.apiRoute, + undefined, + () => + this.reconnectingSockets.delete( + reconnectingSocket as ReconnectingWebSocket, + ), + ); + + this.reconnectingSockets.add( + reconnectingSocket as ReconnectingWebSocket, + ); - this.attachStreamLogger(webSocket); - return webSocket; + return reconnectingSocket; + } else { + return socketFactory(); + } } private attachStreamLogger( @@ -230,13 +295,15 @@ export class CoderApi extends Api { fallbackApiRoute: string; searchParams?: Record | URLSearchParams; options?: ClientOptions; + enableRetry?: boolean; }): Promise> { - let webSocket: OneWayWebSocket; + let webSocket: UnidirectionalStream; try { webSocket = await this.createWebSocket({ apiRoute: configs.apiRoute, searchParams: configs.searchParams, options: configs.options, + enableRetry: configs.enableRetry, }); } catch { // Failed to create WebSocket, use SSE fallback @@ -274,8 +341,8 @@ export class CoderApi extends Api { const handleError = (event: ErrorEvent) => { cleanup(); const is404 = - event.message?.includes("404") || - event.error?.message?.includes("404"); + event.message?.includes(String(HttpStatusCode.NOT_FOUND)) || + event.error?.message?.includes(String(HttpStatusCode.NOT_FOUND)); if (is404 && onNotFound) { connection.close(); @@ -323,14 +390,11 @@ export class CoderApi extends Api { /** * Set up logging and request interceptors for the CoderApi instance. */ -function setupInterceptors( - client: CoderApi, - baseUrl: string, - output: Logger, -): void { +function setupInterceptors(client: CoderApi, output: Logger): void { addLoggingInterceptors(client.getAxiosInstance(), output); client.getAxiosInstance().interceptors.request.use(async (config) => { + const baseUrl = client.getAxiosInstance().defaults.baseURL; const headers = await getHeaders( baseUrl, getHeaderCommand(vscode.workspace.getConfiguration()), @@ -356,7 +420,12 @@ function setupInterceptors( client.getAxiosInstance().interceptors.response.use( (r) => r, async (err) => { - throw await CertificateError.maybeWrap(err, baseUrl, output); + const baseUrl = client.getAxiosInstance().defaults.baseURL; + if (baseUrl) { + throw await CertificateError.maybeWrap(err, baseUrl, output); + } else { + throw err; + } }, ); } diff --git a/src/api/workspace.ts b/src/api/workspace.ts index a24d3a64..1d3b7a4e 100644 --- a/src/api/workspace.ts +++ b/src/api/workspace.ts @@ -11,7 +11,7 @@ import * as vscode from "vscode"; import { type FeatureSet } from "../featureSet"; import { getGlobalFlags } from "../globalFlags"; import { escapeCommandArg } from "../util"; -import { type OneWayWebSocket } from "../websocket/oneWayWebSocket"; +import { type UnidirectionalStream } from "../websocket/eventStreamConnection"; import { errToStr, createWorkspaceIdentifier } from "./api-helper"; import { type CoderApi } from "./coderApi"; @@ -93,7 +93,7 @@ export async function streamBuildLogs( client: CoderApi, writeEmitter: vscode.EventEmitter, workspace: Workspace, -): Promise> { +): Promise> { const socket = await client.watchBuildLogsByBuildId( workspace.latest_build.id, [], @@ -131,7 +131,7 @@ export async function streamAgentLogs( client: CoderApi, writeEmitter: vscode.EventEmitter, agent: WorkspaceAgent, -): Promise> { +): Promise> { const socket = await client.watchWorkspaceAgentLogs(agent.id, []); socket.addEventListener("message", (data) => { diff --git a/src/inbox.ts b/src/inbox.ts index 8dff573f..59b9ae0b 100644 --- a/src/inbox.ts +++ b/src/inbox.ts @@ -7,7 +7,7 @@ import type { import type { CoderApi } from "./api/coderApi"; import type { Logger } from "./logging/logger"; -import type { OneWayWebSocket } from "./websocket/oneWayWebSocket"; +import type { UnidirectionalStream } from "./websocket/eventStreamConnection"; // These are the template IDs of our notifications. // Maybe in the future we should avoid hardcoding @@ -16,7 +16,9 @@ const TEMPLATE_WORKSPACE_OUT_OF_MEMORY = "a9d027b4-ac49-4fb1-9f6d-45af15f64e7a"; const TEMPLATE_WORKSPACE_OUT_OF_DISK = "f047f6a3-5713-40f7-85aa-0394cce9fa3a"; export class Inbox implements vscode.Disposable { - private socket: OneWayWebSocket | undefined; + private socket: + | UnidirectionalStream + | undefined; private disposed = false; private constructor(private readonly logger: Logger) {} diff --git a/src/remote/workspaceStateMachine.ts b/src/remote/workspaceStateMachine.ts index eb7aa335..340ec960 100644 --- a/src/remote/workspaceStateMachine.ts +++ b/src/remote/workspaceStateMachine.ts @@ -21,7 +21,7 @@ import type { CoderApi } from "../api/coderApi"; import type { PathResolver } from "../core/pathResolver"; import type { FeatureSet } from "../featureSet"; import type { Logger } from "../logging/logger"; -import type { OneWayWebSocket } from "../websocket/oneWayWebSocket"; +import type { UnidirectionalStream } from "../websocket/eventStreamConnection"; /** * Manages workspace and agent state transitions until ready for SSH connection. @@ -32,9 +32,10 @@ export class WorkspaceStateMachine implements vscode.Disposable { private agent: { id: string; name: string } | undefined; - private buildLogSocket: OneWayWebSocket | null = null; + private buildLogSocket: UnidirectionalStream | null = null; - private agentLogSocket: OneWayWebSocket | null = null; + private agentLogSocket: UnidirectionalStream | null = + null; constructor( private readonly parts: AuthorityParts, diff --git a/src/websocket/codes.ts b/src/websocket/codes.ts new file mode 100644 index 00000000..ac8eccf7 --- /dev/null +++ b/src/websocket/codes.ts @@ -0,0 +1,55 @@ +/** + * WebSocket close codes (RFC 6455) and HTTP status codes for socket connections. + * @see https://www.rfc-editor.org/rfc/rfc6455#section-7.4.1 + */ + +/** WebSocket close codes defined in RFC 6455 */ +export const WebSocketCloseCode = { + /** Normal closure - connection successfully completed */ + NORMAL: 1000, + /** Endpoint going away (server shutdown) */ + GOING_AWAY: 1001, + /** Protocol error - connection cannot be recovered */ + PROTOCOL_ERROR: 1002, + /** Unsupported data type received - connection cannot be recovered */ + UNSUPPORTED_DATA: 1003, + /** Abnormal closure - connection closed without close frame (network issues) */ + ABNORMAL: 1006, +} as const; + +/** HTTP status codes used for socket creation and connection logic */ +export const HttpStatusCode = { + /** Authentication or permission denied */ + FORBIDDEN: 403, + /** Endpoint not found */ + NOT_FOUND: 404, + /** Resource permanently gone */ + GONE: 410, + /** Protocol upgrade required */ + UPGRADE_REQUIRED: 426, +} as const; + +/** + * WebSocket close codes indicating unrecoverable errors. + * These appear in close events and should stop reconnection attempts. + */ +export const UNRECOVERABLE_WS_CLOSE_CODES = new Set([ + WebSocketCloseCode.PROTOCOL_ERROR, + WebSocketCloseCode.UNSUPPORTED_DATA, +]); + +/** + * HTTP status codes indicating unrecoverable errors during handshake. + * These appear during socket creation and should stop reconnection attempts. + */ +export const UNRECOVERABLE_HTTP_CODES = new Set([ + HttpStatusCode.FORBIDDEN, + HttpStatusCode.GONE, + HttpStatusCode.UPGRADE_REQUIRED, +]); + +/** Close codes indicating intentional closure - do not reconnect */ +export const NORMAL_CLOSURE_CODES = new Set([ + WebSocketCloseCode.NORMAL, + WebSocketCloseCode.GOING_AWAY, +]); diff --git a/src/websocket/eventStreamConnection.ts b/src/websocket/eventStreamConnection.ts index 2dc6514e..e3100ee6 100644 --- a/src/websocket/eventStreamConnection.ts +++ b/src/websocket/eventStreamConnection.ts @@ -1,11 +1,16 @@ import { type WebSocketEventType } from "coder/site/src/utils/OneWayWebSocket"; import { - type CloseEvent, + type CloseEvent as WsCloseEvent, type Event as WsEvent, - type ErrorEvent, - type MessageEvent, + type ErrorEvent as WsErrorEvent, + type MessageEvent as WsMessageEvent, } from "ws"; +export type Event = Omit; +export type CloseEvent = Omit; +export type ErrorEvent = Omit; +export type MessageEvent = Omit; + // Event payload types matching OneWayWebSocket export type ParsedMessageEvent = Readonly< | { @@ -24,7 +29,7 @@ export type EventPayloadMap = { close: CloseEvent; error: ErrorEvent; message: ParsedMessageEvent; - open: WsEvent; + open: Event; }; export type EventHandler = ( diff --git a/src/websocket/reconnectingWebSocket.ts b/src/websocket/reconnectingWebSocket.ts new file mode 100644 index 00000000..2ced9351 --- /dev/null +++ b/src/websocket/reconnectingWebSocket.ts @@ -0,0 +1,304 @@ +import { + WebSocketCloseCode, + NORMAL_CLOSURE_CODES, + UNRECOVERABLE_WS_CLOSE_CODES, + UNRECOVERABLE_HTTP_CODES, +} from "./codes"; + +import type { WebSocketEventType } from "coder/site/src/utils/OneWayWebSocket"; + +import type { Logger } from "../logging/logger"; + +import type { + EventHandler, + UnidirectionalStream, +} from "./eventStreamConnection"; + +export type SocketFactory = () => Promise>; + +export type ReconnectingWebSocketOptions = { + initialBackoffMs?: number; + maxBackoffMs?: number; + jitterFactor?: number; +}; + +export class ReconnectingWebSocket + implements UnidirectionalStream +{ + readonly #socketFactory: SocketFactory; + readonly #logger: Logger; + readonly #apiRoute: string; + readonly #options: Required; + readonly #eventHandlers: { + [K in WebSocketEventType]: Set>; + } = { + open: new Set>(), + close: new Set>(), + error: new Set>(), + message: new Set>(), + }; + + #currentSocket: UnidirectionalStream | null = null; + #backoffMs: number; + #reconnectTimeoutId: NodeJS.Timeout | null = null; + #isDisposed = false; + #isConnecting = false; + #pendingReconnect = false; + readonly #onDispose?: () => void; + + private constructor( + socketFactory: SocketFactory, + logger: Logger, + apiRoute: string, + options: ReconnectingWebSocketOptions = {}, + onDispose?: () => void, + ) { + this.#socketFactory = socketFactory; + this.#logger = logger; + this.#apiRoute = apiRoute; + this.#options = { + initialBackoffMs: options.initialBackoffMs ?? 250, + maxBackoffMs: options.maxBackoffMs ?? 30000, + jitterFactor: options.jitterFactor ?? 0.1, + }; + this.#backoffMs = this.#options.initialBackoffMs; + this.#onDispose = onDispose; + } + + static async create( + socketFactory: SocketFactory, + logger: Logger, + apiRoute: string, + options: ReconnectingWebSocketOptions = {}, + onDispose?: () => void, + ): Promise> { + const instance = new ReconnectingWebSocket( + socketFactory, + logger, + apiRoute, + options, + onDispose, + ); + await instance.connect(); + return instance; + } + + get url(): string { + return this.#currentSocket?.url ?? ""; + } + + addEventListener( + event: TEvent, + callback: EventHandler, + ): void { + this.#eventHandlers[event].add(callback); + } + + removeEventListener( + event: TEvent, + callback: EventHandler, + ): void { + this.#eventHandlers[event].delete(callback); + } + + reconnect(): void { + if (this.#isDisposed) { + return; + } + + if (this.#reconnectTimeoutId !== null) { + clearTimeout(this.#reconnectTimeoutId); + this.#reconnectTimeoutId = null; + } + + // If already connecting, schedule reconnect after current attempt + if (this.#isConnecting) { + this.#pendingReconnect = true; + return; + } + + // connect() will close any existing socket + this.connect().catch((error) => this.handleConnectionError(error)); + } + + close(code?: number, reason?: string): void { + if (this.#isDisposed) { + return; + } + + // Fire close handlers synchronously before disposing + if (this.#currentSocket) { + this.executeHandlers("close", { + code: code ?? WebSocketCloseCode.NORMAL, + reason: reason ?? "Normal closure", + wasClean: true, + }); + } + + this.dispose(code, reason); + } + + private async connect(): Promise { + if (this.#isDisposed || this.#isConnecting) { + return; + } + + this.#isConnecting = true; + try { + // Close any existing socket before creating a new one + if (this.#currentSocket) { + this.#currentSocket.close( + WebSocketCloseCode.NORMAL, + "Replacing connection", + ); + this.#currentSocket = null; + } + + const socket = await this.#socketFactory(); + this.#currentSocket = socket; + + socket.addEventListener("open", (event) => { + this.#backoffMs = this.#options.initialBackoffMs; + this.executeHandlers("open", event); + }); + + socket.addEventListener("message", (event) => { + this.executeHandlers("message", event); + }); + + socket.addEventListener("error", (event) => { + this.executeHandlers("error", event); + }); + + socket.addEventListener("close", (event) => { + if (this.#isDisposed) { + return; + } + + this.executeHandlers("close", event); + + if (UNRECOVERABLE_WS_CLOSE_CODES.has(event.code)) { + this.#logger.error( + `WebSocket connection closed with unrecoverable error code ${event.code}`, + ); + this.dispose(); + return; + } + + // Don't reconnect on normal closure + if (NORMAL_CLOSURE_CODES.has(event.code)) { + return; + } + + // Reconnect on abnormal closures (e.g., 1006) or other unexpected codes + this.scheduleReconnect(); + }); + } finally { + this.#isConnecting = false; + + if (this.#pendingReconnect) { + this.#pendingReconnect = false; + this.reconnect(); + } + } + } + + private scheduleReconnect(): void { + if (this.#isDisposed || this.#reconnectTimeoutId !== null) { + return; + } + + const jitter = + this.#backoffMs * this.#options.jitterFactor * (Math.random() * 2 - 1); + const delayMs = Math.max(0, this.#backoffMs + jitter); + + this.#logger.debug( + `Reconnecting WebSocket in ${Math.round(delayMs)}ms for ${this.#apiRoute}`, + ); + + this.#reconnectTimeoutId = setTimeout(() => { + this.#reconnectTimeoutId = null; + this.connect().catch((error) => this.handleConnectionError(error)); + }, delayMs); + + this.#backoffMs = Math.min(this.#backoffMs * 2, this.#options.maxBackoffMs); + } + + private executeHandlers( + event: TEvent, + eventData: Parameters>[0], + ): void { + for (const handler of this.#eventHandlers[event]) { + try { + handler(eventData); + } catch (error) { + this.#logger.error( + `Error in ${event} handler for ${this.#apiRoute}`, + error, + ); + } + } + } + + /** + * Checks if the error is unrecoverable and disposes the connection, + * otherwise schedules a reconnect. + */ + private handleConnectionError(error: unknown): void { + if (this.#isDisposed) { + return; + } + + if (this.isUnrecoverableHttpError(error)) { + this.#logger.error( + `Unrecoverable HTTP error during connection for ${this.#apiRoute}`, + error, + ); + this.dispose(); + return; + } + + this.#logger.warn( + `WebSocket connection failed for ${this.#apiRoute}`, + error, + ); + this.scheduleReconnect(); + } + + /** + * Check if an error contains an unrecoverable HTTP status code. + */ + private isUnrecoverableHttpError(error: unknown): boolean { + const errorMessage = error instanceof Error ? error.message : String(error); + for (const code of UNRECOVERABLE_HTTP_CODES) { + if (errorMessage.includes(String(code))) { + return true; + } + } + return false; + } + + private dispose(code?: number, reason?: string): void { + if (this.#isDisposed) { + return; + } + + this.#isDisposed = true; + + if (this.#reconnectTimeoutId !== null) { + clearTimeout(this.#reconnectTimeoutId); + this.#reconnectTimeoutId = null; + } + + if (this.#currentSocket) { + this.#currentSocket.close(code, reason); + this.#currentSocket = null; + } + + for (const set of Object.values(this.#eventHandlers)) { + set.clear(); + } + + this.#onDispose?.(); + } +} diff --git a/src/websocket/sseConnection.ts b/src/websocket/sseConnection.ts index 5a71d303..dc20eeda 100644 --- a/src/websocket/sseConnection.ts +++ b/src/websocket/sseConnection.ts @@ -6,19 +6,14 @@ import { EventSource } from "eventsource"; import { createStreamingFetchAdapter } from "../api/streamingFetchAdapter"; import { type Logger } from "../logging/logger"; +import { WebSocketCloseCode } from "./codes"; import { getQueryString } from "./utils"; -import type { - CloseEvent as WsCloseEvent, - ErrorEvent as WsErrorEvent, - Event as WsEvent, - MessageEvent as WsMessageEvent, -} from "ws"; - import type { UnidirectionalStream, ParsedMessageEvent, EventHandler, + ErrorEvent as WsErrorEvent, } from "./eventStreamConnection"; export type SseConnectionInit = { @@ -66,7 +61,7 @@ export class SseConnection implements UnidirectionalStream { private setupEventHandlers(): void { this.eventSource.addEventListener("open", () => - this.invokeCallbacks(this.callbacks.open, {} as WsEvent, "open"), + this.invokeCallbacks(this.callbacks.open, {}, "open"), ); this.eventSource.addEventListener("data", (event: MessageEvent) => { @@ -84,10 +79,10 @@ export class SseConnection implements UnidirectionalStream { this.invokeCallbacks( this.callbacks.close, { - code: 1006, + code: WebSocketCloseCode.ABNORMAL, reason: "Connection lost", wasClean: false, - } as WsCloseEvent, + }, "close", ); } @@ -117,7 +112,7 @@ export class SseConnection implements UnidirectionalStream { return { error: error, message: errorMessage, - } as WsErrorEvent; + }; } public addEventListener( @@ -158,7 +153,7 @@ export class SseConnection implements UnidirectionalStream { private parseMessage( event: MessageEvent, ): ParsedMessageEvent { - const wsEvent = { data: event.data } as WsMessageEvent; + const wsEvent = { data: event.data }; try { return { sourceEvent: wsEvent, @@ -207,14 +202,16 @@ export class SseConnection implements UnidirectionalStream { this.invokeCallbacks( this.callbacks.close, { - code: code ?? 1000, + code: code ?? WebSocketCloseCode.NORMAL, reason: reason ?? "Normal closure", wasClean: true, - } as WsCloseEvent, + }, "close", ); - Object.values(this.callbacks).forEach((callbackSet) => callbackSet.clear()); + for (const callbackSet of Object.values(this.callbacks)) { + callbackSet.clear(); + } this.messageWrappers.clear(); } } diff --git a/test/unit/api/coderApi.test.ts b/test/unit/api/coderApi.test.ts index f133a72d..4f90f33e 100644 --- a/test/unit/api/coderApi.test.ts +++ b/test/unit/api/coderApi.test.ts @@ -10,7 +10,7 @@ import { createHttpAgent } from "@/api/utils"; import { CertificateError } from "@/error"; import { getHeaders } from "@/headers"; import { type RequestConfigWithMeta } from "@/logging/types"; -import { OneWayWebSocket } from "@/websocket/oneWayWebSocket"; +import { ReconnectingWebSocket } from "@/websocket/reconnectingWebSocket"; import { SseConnection } from "@/websocket/sseConnection"; import { @@ -332,7 +332,7 @@ describe("CoderApi", () => { const connection = await api.watchAgentMetadata(AGENT_ID); - expect(connection).toBeInstanceOf(OneWayWebSocket); + expect(connection).toBeInstanceOf(ReconnectingWebSocket); expect(EventSource).not.toHaveBeenCalled(); }); @@ -373,6 +373,70 @@ describe("CoderApi", () => { }); }); + describe("Reconnection on Host/Token Changes", () => { + const setupAutoOpeningWebSocket = () => { + const sockets: Array> = []; + vi.mocked(Ws).mockImplementation((url: string | URL) => { + const mockWs = createMockWebSocket(String(url), { + on: vi.fn((event, handler) => { + if (event === "open") { + setImmediate(() => handler()); + } + return mockWs as Ws; + }), + }); + sockets.push(mockWs); + return mockWs as Ws; + }); + return sockets; + }; + + it("triggers reconnection when session token changes", async () => { + const sockets = setupAutoOpeningWebSocket(); + api = createApi(CODER_URL, AXIOS_TOKEN); + await api.watchAgentMetadata(AGENT_ID); + + api.setSessionToken("new-token"); + await new Promise((resolve) => setImmediate(resolve)); + + expect(sockets[0].close).toHaveBeenCalledWith( + 1000, + "Replacing connection", + ); + expect(sockets).toHaveLength(2); + }); + + it("triggers reconnection when host changes", async () => { + const sockets = setupAutoOpeningWebSocket(); + api = createApi(CODER_URL, AXIOS_TOKEN); + const wsWrap = await api.watchAgentMetadata(AGENT_ID); + expect(wsWrap.url).toContain(CODER_URL.replace("http", "ws")); + + api.setHost("https://new-coder.example.com"); + await new Promise((resolve) => setImmediate(resolve)); + + expect(sockets[0].close).toHaveBeenCalledWith( + 1000, + "Replacing connection", + ); + expect(sockets).toHaveLength(2); + expect(wsWrap.url).toContain("wss://new-coder.example.com"); + }); + + it("does not reconnect when token or host are unchanged", async () => { + const sockets = setupAutoOpeningWebSocket(); + api = createApi(CODER_URL, AXIOS_TOKEN); + await api.watchAgentMetadata(AGENT_ID); + + // Same values as before + api.setSessionToken(AXIOS_TOKEN); + api.setHost(CODER_URL); + + expect(sockets[0].close).not.toHaveBeenCalled(); + expect(sockets).toHaveLength(1); + }); + }); + describe("Error Handling", () => { it("throws error when no base URL is set", async () => { const api = createApi(); diff --git a/test/unit/websocket/reconnectingWebSocket.test.ts b/test/unit/websocket/reconnectingWebSocket.test.ts new file mode 100644 index 00000000..cdf08949 --- /dev/null +++ b/test/unit/websocket/reconnectingWebSocket.test.ts @@ -0,0 +1,468 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; + +import { WebSocketCloseCode, HttpStatusCode } from "@/websocket/codes"; +import { + ReconnectingWebSocket, + type SocketFactory, +} from "@/websocket/reconnectingWebSocket"; + +import { createMockLogger } from "../../mocks/testHelpers"; + +import type { CloseEvent, Event as WsEvent } from "ws"; + +import type { UnidirectionalStream } from "@/websocket/eventStreamConnection"; + +describe("ReconnectingWebSocket", () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + vi.useRealTimers(); + }); + + describe("Reconnection Logic", () => { + it("automatically reconnects on abnormal closure (1006)", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + + sockets[0].fireOpen(); + sockets[0].fireClose({ + code: WebSocketCloseCode.ABNORMAL, + reason: "Network error", + }); + + await vi.advanceTimersByTimeAsync(300); + expect(sockets).toHaveLength(2); + + ws.close(); + }); + + it.each([ + { code: WebSocketCloseCode.NORMAL, name: "Normal Closure" }, + { code: WebSocketCloseCode.GOING_AWAY, name: "Going Away" }, + ])( + "does not reconnect on normal closure: $name ($code)", + async ({ code }) => { + const { ws, sockets } = await createReconnectingWebSocket(); + + sockets[0].fireOpen(); + sockets[0].fireClose({ code, reason: "Normal" }); + + await vi.advanceTimersByTimeAsync(10000); + expect(sockets).toHaveLength(1); + + ws.close(); + }, + ); + + it.each([ + WebSocketCloseCode.PROTOCOL_ERROR, + WebSocketCloseCode.UNSUPPORTED_DATA, + ])( + "does not reconnect on unrecoverable WebSocket close code: %i", + async (code) => { + const { ws, sockets } = await createReconnectingWebSocket(); + + sockets[0].fireOpen(); + sockets[0].fireClose({ code, reason: "Unrecoverable" }); + + await vi.advanceTimersByTimeAsync(10000); + expect(sockets).toHaveLength(1); + + ws.close(); + }, + ); + + it.each([ + HttpStatusCode.FORBIDDEN, + HttpStatusCode.GONE, + HttpStatusCode.UPGRADE_REQUIRED, + ])( + "does not reconnect on unrecoverable HTTP error during creation: %i", + async (statusCode) => { + let socketCreationAttempts = 0; + const factory = vi.fn(() => { + socketCreationAttempts++; + // Simulate HTTP error during WebSocket handshake + return Promise.reject( + new Error(`Unexpected server response: ${statusCode}`), + ); + }); + + await expect( + ReconnectingWebSocket.create( + factory, + createMockLogger(), + "/api/test", + ), + ).rejects.toThrow(`Unexpected server response: ${statusCode}`); + + // Should not retry after unrecoverable HTTP error + await vi.advanceTimersByTimeAsync(10000); + expect(socketCreationAttempts).toBe(1); + }, + ); + + it("reconnect() connects immediately and cancels pending reconnections", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + + sockets[0].fireOpen(); + sockets[0].fireClose({ + code: WebSocketCloseCode.ABNORMAL, + reason: "Connection lost", + }); + + // Manual reconnect() should happen immediately and cancel the scheduled reconnect + ws.reconnect(); + expect(sockets).toHaveLength(2); + + // Verify pending reconnect was cancelled - no third socket should be created + await vi.advanceTimersByTimeAsync(1000); + expect(sockets).toHaveLength(2); + + ws.close(); + }); + + it("queues reconnect() calls made during connection", async () => { + const sockets: MockSocket[] = []; + let pendingResolve: ((socket: MockSocket) => void) | null = null; + + const factory = vi.fn(() => { + const socket = createMockSocket(); + sockets.push(socket); + + // First call resolves immediately, other calls wait for manual resolve + if (sockets.length === 1) { + return Promise.resolve(socket); + } else { + return new Promise((resolve) => { + pendingResolve = resolve; + }); + } + }); + + const ws = await fromFactory(factory); + sockets[0].fireOpen(); + expect(sockets).toHaveLength(1); + + // Start first reconnect (will block on factory promise) + ws.reconnect(); + expect(sockets).toHaveLength(2); + // Call reconnect again while first reconnect is in progress + ws.reconnect(); + // Still only 2 sockets (queued reconnect hasn't started) + expect(sockets).toHaveLength(2); + + // Complete the first reconnect + pendingResolve!(sockets[1]); + sockets[1].fireOpen(); + + // Wait a tick for the queued reconnect to execute + await Promise.resolve(); + // Now queued reconnect should have executed, creating third socket + expect(sockets).toHaveLength(3); + + ws.close(); + }); + }); + + describe("Event Handlers", () => { + it("persists event handlers across reconnections", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + sockets[0].fireOpen(); + + const handler = vi.fn(); + ws.addEventListener("message", handler); + + // First message + sockets[0].fireMessage({ test: true }); + expect(handler).toHaveBeenCalledTimes(1); + + // Disconnect and reconnect + sockets[0].fireClose({ + code: WebSocketCloseCode.ABNORMAL, + reason: "Network", + }); + await vi.advanceTimersByTimeAsync(300); + expect(sockets).toHaveLength(2); + sockets[1].fireOpen(); + + // Handler should still work on new socket + sockets[1].fireMessage({ test: true }); + expect(handler).toHaveBeenCalledTimes(2); + + ws.close(); + }); + + it("removes event handlers when removeEventListener is called", async () => { + const socket = createMockSocket(); + const factory = vi.fn(() => Promise.resolve(socket)); + + const ws = await fromFactory(factory); + socket.fireOpen(); + + const handler1 = vi.fn(); + const handler2 = vi.fn(); + + ws.addEventListener("message", handler1); + ws.addEventListener("message", handler2); + ws.removeEventListener("message", handler1); + + socket.fireMessage({ test: true }); + + expect(handler1).not.toHaveBeenCalled(); + expect(handler2).toHaveBeenCalledTimes(1); + + ws.close(); + }); + }); + + describe("close() and Disposal", () => { + it("stops reconnection when close() is called", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + + sockets[0].fireOpen(); + sockets[0].fireClose({ + code: WebSocketCloseCode.ABNORMAL, + reason: "Network", + }); + ws.close(); + + await vi.advanceTimersByTimeAsync(10000); + expect(sockets).toHaveLength(1); + }); + + it("closes the underlying socket with provided code and reason", async () => { + const socket = createMockSocket(); + const factory = vi.fn(() => Promise.resolve(socket)); + const ws = await fromFactory(factory); + + socket.fireOpen(); + ws.close(WebSocketCloseCode.NORMAL, "Test close"); + + expect(socket.close).toHaveBeenCalledWith( + WebSocketCloseCode.NORMAL, + "Test close", + ); + }); + + it("calls onDispose callback once, even with multiple close() calls", async () => { + let disposeCount = 0; + const { ws } = await createReconnectingWebSocket(() => ++disposeCount); + + ws.close(); + ws.close(); + ws.close(); + + expect(disposeCount).toBe(1); + }); + + it("calls onDispose callback on unrecoverable WebSocket close code", async () => { + let disposeCount = 0; + const { sockets } = await createReconnectingWebSocket( + () => ++disposeCount, + ); + + sockets[0].fireOpen(); + sockets[0].fireClose({ + code: WebSocketCloseCode.PROTOCOL_ERROR, + reason: "Protocol error", + }); + + expect(disposeCount).toBe(1); + }); + + it("does not call onDispose callback during reconnection", async () => { + let disposeCount = 0; + const { ws, sockets } = await createReconnectingWebSocket( + () => ++disposeCount, + ); + + sockets[0].fireOpen(); + sockets[0].fireClose({ + code: WebSocketCloseCode.ABNORMAL, + reason: "Network error", + }); + + await vi.advanceTimersByTimeAsync(300); + expect(disposeCount).toBe(0); + + ws.close(); + expect(disposeCount).toBe(1); + }); + }); + + describe("Backoff Strategy", () => { + it("doubles backoff delay after each failed connection", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + const socket = sockets[0]; + socket.fireOpen(); + + const backoffDelays = [300, 600, 1200, 2400]; + + // Fail repeatedly + for (let i = 0; i < 4; i++) { + const currentSocket = sockets[i]; + currentSocket.fireClose({ + code: WebSocketCloseCode.ABNORMAL, + reason: "Fail", + }); + const delay = backoffDelays[i]; + await vi.advanceTimersByTimeAsync(delay); + const nextSocket = sockets[i + 1]; + nextSocket.fireOpen(); + } + + expect(sockets).toHaveLength(5); + ws.close(); + }); + + it("resets backoff delay after successful connection", async () => { + const { ws, sockets } = await createReconnectingWebSocket(); + const socket1 = sockets[0]; + socket1.fireOpen(); + + // First disconnect + socket1.fireClose({ code: WebSocketCloseCode.ABNORMAL, reason: "Fail" }); + await vi.advanceTimersByTimeAsync(300); + const socket2 = sockets[1]; + socket2.fireOpen(); + + // Second disconnect - should use initial backoff again + socket2.fireClose({ code: WebSocketCloseCode.ABNORMAL, reason: "Fail" }); + await vi.advanceTimersByTimeAsync(300); + + expect(sockets).toHaveLength(3); + ws.close(); + }); + }); + + describe("Error Handling", () => { + it("schedules retry when socket factory throws error", async () => { + const sockets: MockSocket[] = []; + let shouldFail = false; + const factory = vi.fn(() => { + if (shouldFail) { + return Promise.reject(new Error("Factory failed")); + } + const socket = createMockSocket(); + sockets.push(socket); + return Promise.resolve(socket); + }); + const ws = await fromFactory(factory); + + sockets[0].fireOpen(); + + shouldFail = true; + sockets[0].fireClose({ + code: WebSocketCloseCode.ABNORMAL, + reason: "Network", + }); + + await vi.advanceTimersByTimeAsync(300); + expect(sockets).toHaveLength(1); + + ws.close(); + }); + }); +}); + +type MockSocket = UnidirectionalStream & { + fireOpen: () => void; + fireClose: (event: { code: number; reason: string }) => void; + fireMessage: (data: unknown) => void; + fireError: (error: Error) => void; +}; + +function createMockSocket(): MockSocket { + const listeners: { + open: Set<(event: WsEvent) => void>; + close: Set<(event: CloseEvent) => void>; + error: Set<(event: { error?: Error; message?: string }) => void>; + message: Set<(event: unknown) => void>; + } = { + open: new Set(), + close: new Set(), + error: new Set(), + message: new Set(), + }; + + return { + url: "ws://test.example.com/api/test", + addEventListener: vi.fn( + (event: keyof typeof listeners, callback: unknown) => { + (listeners[event] as Set<(data: unknown) => void>).add( + callback as (data: unknown) => void, + ); + }, + ), + removeEventListener: vi.fn( + (event: keyof typeof listeners, callback: unknown) => { + (listeners[event] as Set<(data: unknown) => void>).delete( + callback as (data: unknown) => void, + ); + }, + ), + close: vi.fn(), + fireOpen: () => { + for (const cb of listeners.open) { + cb({} as WsEvent); + } + }, + fireClose: (event: { code: number; reason: string }) => { + for (const cb of listeners.close) { + cb({ + code: event.code, + reason: event.reason, + wasClean: event.code === WebSocketCloseCode.NORMAL, + } as CloseEvent); + } + }, + fireMessage: (data: unknown) => { + for (const cb of listeners.message) { + cb({ + sourceEvent: { data }, + parsedMessage: data, + parseError: undefined, + }); + } + }, + fireError: (error: Error) => { + for (const cb of listeners.error) { + cb({ error, message: error.message }); + } + }, + }; +} + +async function createReconnectingWebSocket(onDispose?: () => void): Promise<{ + ws: ReconnectingWebSocket; + sockets: MockSocket[]; +}> { + const sockets: MockSocket[] = []; + const factory = vi.fn(() => { + const socket = createMockSocket(); + sockets.push(socket); + return Promise.resolve(socket); + }); + const ws = await fromFactory(factory, onDispose); + + // We start with one socket + expect(sockets).toHaveLength(1); + + return { ws, sockets }; +} + +async function fromFactory( + factory: SocketFactory, + onDispose?: () => void, +): Promise> { + return await ReconnectingWebSocket.create( + factory, + createMockLogger(), + "/random/api", + undefined, + onDispose, + ); +} diff --git a/test/unit/websocket/sseConnection.test.ts b/test/unit/websocket/sseConnection.test.ts index 61cfce4d..378e6f54 100644 --- a/test/unit/websocket/sseConnection.test.ts +++ b/test/unit/websocket/sseConnection.test.ts @@ -3,10 +3,14 @@ import { type ServerSentEvent } from "coder/site/src/api/typesGenerated"; import { type WebSocketEventType } from "coder/site/src/utils/OneWayWebSocket"; import { EventSource } from "eventsource"; import { describe, it, expect, vi } from "vitest"; -import { type CloseEvent, type ErrorEvent } from "ws"; import { type Logger } from "@/logging/logger"; -import { type ParsedMessageEvent } from "@/websocket/eventStreamConnection"; +import { WebSocketCloseCode } from "@/websocket/codes"; +import { + type ParsedMessageEvent, + type CloseEvent, + type ErrorEvent, +} from "@/websocket/eventStreamConnection"; import { SseConnection } from "@/websocket/sseConnection"; import { createMockLogger } from "../../mocks/testHelpers"; @@ -168,7 +172,7 @@ describe("SseConnection", () => { await waitForNextTick(); expect(events).toEqual([ { - code: 1006, + code: WebSocketCloseCode.ABNORMAL, reason: "Connection lost", wasClean: false, }, @@ -223,13 +227,17 @@ describe("SseConnection", () => { type CloseHandlingTestCase = [ code: number | undefined, reason: string | undefined, - closeEvent: Omit, + closeEvent: CloseEvent, ]; it.each([ [ undefined, undefined, - { code: 1000, reason: "Normal closure", wasClean: true }, + { + code: WebSocketCloseCode.NORMAL, + reason: "Normal closure", + wasClean: true, + }, ], [ 4000,