From ea3d1149f769fe59dec0ae4272438bd45bd326ab Mon Sep 17 00:00:00 2001 From: Ehab Younes Date: Mon, 13 Oct 2025 10:56:06 +0300 Subject: [PATCH 1/5] Add SSE fallback to some one way WS connections --- src/api/coderApi.ts | 140 +++++++++++-- src/api/streamingFetchAdapter.ts | 62 ++++++ .../{wsLogger.ts => eventStreamLogger.ts} | 16 +- src/websocket/eventStreamConnection.ts | 51 +++++ src/websocket/oneWayWebSocket.ts | 69 ++----- src/websocket/sseConnection.ts | 191 ++++++++++++++++++ src/websocket/utils.ts | 15 ++ src/workspace/workspaceMonitor.ts | 14 +- ...gger.test.ts => eventStreamLogger.test.ts} | 62 ++++-- 9 files changed, 511 insertions(+), 109 deletions(-) create mode 100644 src/api/streamingFetchAdapter.ts rename src/logging/{wsLogger.ts => eventStreamLogger.ts} (77%) create mode 100644 src/websocket/eventStreamConnection.ts create mode 100644 src/websocket/sseConnection.ts create mode 100644 src/websocket/utils.ts rename test/unit/logging/{wsLogger.test.ts => eventStreamLogger.test.ts} (50%) diff --git a/src/api/coderApi.ts b/src/api/coderApi.ts index 99976ff7..6e8148ab 100644 --- a/src/api/coderApi.ts +++ b/src/api/coderApi.ts @@ -6,17 +6,18 @@ import { } from "axios"; import { Api } from "coder/site/src/api/api"; import { + type ServerSentEvent, type GetInboxNotificationResponse, type ProvisionerJobLog, - type ServerSentEvent, type Workspace, type WorkspaceAgent, } from "coder/site/src/api/typesGenerated"; import * as vscode from "vscode"; -import { type ClientOptions } from "ws"; +import { type ClientOptions, type CloseEvent, type ErrorEvent } from "ws"; import { CertificateError } from "../error"; import { getHeaderCommand, getHeaders } from "../headers"; +import { EventStreamLogger } from "../logging/eventStreamLogger"; import { createRequestMeta, logRequest, @@ -29,11 +30,12 @@ import { HttpClientLogLevel, } from "../logging/types"; import { sizeOf } from "../logging/utils"; -import { WsLogger } from "../logging/wsLogger"; +import { type UnidirectionalStream } from "../websocket/eventStreamConnection"; import { OneWayWebSocket, type OneWayWebSocketInit, } from "../websocket/oneWayWebSocket"; +import { SseConnection } from "../websocket/sseConnection"; import { createHttpAgent } from "./utils"; @@ -84,8 +86,9 @@ export class CoderApi extends Api { }; watchWorkspace = async (workspace: Workspace, options?: ClientOptions) => { - return this.createWebSocket({ + return this.createWebSocketWithFallback({ apiRoute: `/api/v2/workspaces/${workspace.id}/watch-ws`, + fallbackApiRoute: `/api/v2/workspaces/${workspace.id}/watch`, options, }); }; @@ -94,8 +97,9 @@ export class CoderApi extends Api { agentId: WorkspaceAgent["id"], options?: ClientOptions, ) => { - return this.createWebSocket({ + return this.createWebSocketWithFallback({ apiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata-ws`, + fallbackApiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata`, options, }); }; @@ -137,6 +141,7 @@ export class CoderApi extends Api { const httpAgent = await createHttpAgent( vscode.workspace.getConfiguration(), ); + const webSocket = new OneWayWebSocket({ location: baseUrl, ...configs, @@ -152,28 +157,123 @@ export class CoderApi extends Api { }, }); - const wsUrl = new URL(webSocket.url); - const pathWithQuery = wsUrl.pathname + wsUrl.search; - const wsLogger = new WsLogger(this.output, pathWithQuery); - wsLogger.logConnecting(); + this.attachStreamLogger(webSocket); + return webSocket; + } - webSocket.addEventListener("open", () => { - wsLogger.logOpen(); - }); + private attachStreamLogger( + connection: UnidirectionalStream, + ): void { + const url = new URL(connection.url); + const logger = new EventStreamLogger( + this.output, + url.pathname + url.search, + url.protocol.startsWith("http") ? "SSE" : "WS", + ); + logger.logConnecting(); - webSocket.addEventListener("message", (event) => { - wsLogger.logMessage(event.sourceEvent.data); - }); + connection.addEventListener("open", () => logger.logOpen()); + connection.addEventListener("close", (event: CloseEvent) => + logger.logClose(event.code, event.reason), + ); + connection.addEventListener("error", (event: ErrorEvent) => + logger.logError(event.error, event.message), + ); + connection.addEventListener("message", (event) => + logger.logMessage(event.sourceEvent.data), + ); + } + + /** + * Create a WebSocket connection with SSE fallback on 404 + */ + private async createWebSocketWithFallback(configs: { + apiRoute: string; + fallbackApiRoute: string; + searchParams?: Record | URLSearchParams; + options?: ClientOptions; + }): Promise> { + let webSocket: OneWayWebSocket; + try { + webSocket = await this.createWebSocket({ + apiRoute: configs.apiRoute, + searchParams: configs.searchParams, + options: configs.options, + }); + } catch { + // Failed to create WebSocket, use SSE fallback + return this.createSseFallback( + configs.fallbackApiRoute, + configs.searchParams, + ); + } - webSocket.addEventListener("close", (event) => { - wsLogger.logClose(event.code, event.reason); + return this.waitForConnection(webSocket, () => + this.createSseFallback( + configs.fallbackApiRoute, + configs.searchParams, + ), + ); + } + + private waitForConnection( + connection: UnidirectionalStream, + onNotFound?: () => Promise>, + ): Promise> { + return new Promise((resolve, reject) => { + const cleanup = () => { + connection.removeEventListener("open", handleOpen); + connection.removeEventListener("error", handleError); + }; + + const handleOpen = () => { + cleanup(); + resolve(connection); + }; + + const handleError = (event: ErrorEvent) => { + cleanup(); + const is404 = + event.message?.includes("404") || + event.error?.message?.includes("404"); + + if (is404 && onNotFound) { + connection.close(); + onNotFound().then(resolve).catch(reject); + } else { + reject(event.error || new Error(event.message)); + } + }; + + connection.addEventListener("open", handleOpen); + connection.addEventListener("error", handleError); }); + } - webSocket.addEventListener("error", (event) => { - wsLogger.logError(event.error, event.message); + /** + * Create SSE fallback connection + */ + private async createSseFallback( + apiRoute: string, + searchParams?: Record | URLSearchParams, + ): Promise> { + this.output.warn(`WebSocket failed, using SSE fallback: ${apiRoute}`); + + const baseUrlRaw = this.getAxiosInstance().defaults.baseURL; + if (!baseUrlRaw) { + throw new Error("No base URL set on REST client"); + } + + const baseUrl = new URL(baseUrlRaw); + const sseConnection = new SseConnection({ + location: baseUrl, + apiRoute, + searchParams, + axiosInstance: this.getAxiosInstance(), }); - return webSocket; + this.attachStreamLogger(sseConnection); + return this.waitForConnection(sseConnection); } } diff --git a/src/api/streamingFetchAdapter.ts b/src/api/streamingFetchAdapter.ts new file mode 100644 index 00000000..79e3ee39 --- /dev/null +++ b/src/api/streamingFetchAdapter.ts @@ -0,0 +1,62 @@ +import { type AxiosInstance } from "axios"; +import { type FetchLikeInit, type FetchLikeResponse } from "eventsource"; +import { type IncomingMessage } from "http"; + +/** + * Creates a fetch adapter using an Axios instance that returns streaming responses. + * This is used by EventSource to make authenticated SSE connections. + */ +export function createStreamingFetchAdapter( + axiosInstance: AxiosInstance, +): (url: string | URL, init?: FetchLikeInit) => Promise { + return async ( + url: string | URL, + init?: FetchLikeInit, + ): Promise => { + const urlStr = url.toString(); + + const response = await axiosInstance.request({ + url: urlStr, + signal: init?.signal, + headers: init?.headers, + responseType: "stream", + validateStatus: () => true, // Don't throw on any status code + }); + + const stream = new ReadableStream({ + start(controller) { + response.data.on("data", (chunk: Buffer) => { + controller.enqueue(chunk); + }); + + response.data.on("end", () => { + controller.close(); + }); + + response.data.on("error", (err: Error) => { + controller.error(err); + }); + }, + + cancel() { + response.data.destroy(); + return Promise.resolve(); + }, + }); + + return { + body: { + getReader: () => stream.getReader(), + }, + url: urlStr, + status: response.status, + redirected: response.request?.res?.responseUrl !== urlStr, + headers: { + get: (name: string) => { + const value = response.headers[name.toLowerCase()]; + return value === undefined ? null : String(value); + }, + }, + }; + }; +} diff --git a/src/logging/wsLogger.ts b/src/logging/eventStreamLogger.ts similarity index 77% rename from src/logging/wsLogger.ts rename to src/logging/eventStreamLogger.ts index fd6acd00..224f52b7 100644 --- a/src/logging/wsLogger.ts +++ b/src/logging/eventStreamLogger.ts @@ -12,31 +12,35 @@ const numFormatter = new Intl.NumberFormat("en", { compactDisplay: "short", }); -export class WsLogger { +export class EventStreamLogger { private readonly logger: Logger; private readonly url: string; private readonly id: string; + private readonly protocol: string; private readonly startedAt: number; private openedAt?: number; private msgCount = 0; private byteCount = 0; private unknownByteCount = false; - constructor(logger: Logger, url: string) { + constructor(logger: Logger, url: string, protocol: "WS" | "SSE") { this.logger = logger; this.url = url; + this.protocol = protocol; this.id = createRequestId(); this.startedAt = Date.now(); } logConnecting(): void { - this.logger.trace(`→ WS ${shortId(this.id)} ${this.url}`); + this.logger.trace(`→ ${this.protocol} ${shortId(this.id)} ${this.url}`); } logOpen(): void { this.openedAt = Date.now(); const time = formatTime(this.openedAt - this.startedAt); - this.logger.trace(`← WS ${shortId(this.id)} connected ${this.url} ${time}`); + this.logger.trace( + `← ${this.protocol} ${shortId(this.id)} connected ${this.url} ${time}`, + ); } logMessage(data: unknown): void { @@ -62,7 +66,7 @@ export class WsLogger { const statsStr = ` [${stats.join(", ")}]`; this.logger.trace( - `▣ WS ${shortId(this.id)} closed ${this.url}${codeStr}${reasonStr}${statsStr}`, + `▣ ${this.protocol} ${shortId(this.id)} closed ${this.url}${codeStr}${reasonStr}${statsStr}`, ); } @@ -70,7 +74,7 @@ export class WsLogger { const time = formatTime(Date.now() - this.startedAt); const errorMsg = message || errToStr(error, "connection error"); this.logger.error( - `✗ WS ${shortId(this.id)} error ${this.url} ${time} - ${errorMsg}`, + `✗ ${this.protocol} ${shortId(this.id)} error ${this.url} ${time} - ${errorMsg}`, error, ); } diff --git a/src/websocket/eventStreamConnection.ts b/src/websocket/eventStreamConnection.ts new file mode 100644 index 00000000..2dc6514e --- /dev/null +++ b/src/websocket/eventStreamConnection.ts @@ -0,0 +1,51 @@ +import { type WebSocketEventType } from "coder/site/src/utils/OneWayWebSocket"; +import { + type CloseEvent, + type Event as WsEvent, + type ErrorEvent, + type MessageEvent, +} from "ws"; + +// Event payload types matching OneWayWebSocket +export type ParsedMessageEvent = Readonly< + | { + sourceEvent: MessageEvent; + parsedMessage: TData; + parseError: undefined; + } + | { + sourceEvent: MessageEvent; + parsedMessage: undefined; + parseError: Error; + } +>; + +export type EventPayloadMap = { + close: CloseEvent; + error: ErrorEvent; + message: ParsedMessageEvent; + open: WsEvent; +}; + +export type EventHandler = ( + payload: EventPayloadMap[TEvent], +) => void; + +/** + * Common interface for both WebSocket and SSE connections that handle event streams. + * Matches the OneWayWebSocket interface for compatibility. + */ +export interface UnidirectionalStream { + readonly url: string; + addEventListener( + eventType: TEvent, + callback: EventHandler, + ): void; + + removeEventListener( + eventType: TEvent, + callback: EventHandler, + ): void; + + close(code?: number, reason?: string): void; +} diff --git a/src/websocket/oneWayWebSocket.ts b/src/websocket/oneWayWebSocket.ts index 37965596..c27b1fe4 100644 --- a/src/websocket/oneWayWebSocket.ts +++ b/src/websocket/oneWayWebSocket.ts @@ -8,51 +8,13 @@ */ import { type WebSocketEventType } from "coder/site/src/utils/OneWayWebSocket"; -import Ws, { - type ClientOptions, - type CloseEvent, - type ErrorEvent, - type Event, - type MessageEvent, - type RawData, -} from "ws"; +import Ws, { type ClientOptions, type MessageEvent, type RawData } from "ws"; -export type OneWayMessageEvent = Readonly< - | { - sourceEvent: MessageEvent; - parsedMessage: TData; - parseError: undefined; - } - | { - sourceEvent: MessageEvent; - parsedMessage: undefined; - parseError: Error; - } ->; - -type OneWayEventPayloadMap = { - close: CloseEvent; - error: ErrorEvent; - message: OneWayMessageEvent; - open: Event; -}; - -type OneWayEventCallback = ( - payload: OneWayEventPayloadMap[TEvent], -) => void; - -interface OneWayWebSocketApi { - get url(): string; - addEventListener( - eventType: TEvent, - callback: OneWayEventCallback, - ): void; - removeEventListener( - eventType: TEvent, - callback: OneWayEventCallback, - ): void; - close(code?: number, reason?: string): void; -} +import { + type UnidirectionalStream, + type EventHandler, +} from "./eventStreamConnection"; +import { getQueryString } from "./utils"; export type OneWayWebSocketInit = { location: { protocol: string; host: string }; @@ -63,23 +25,18 @@ export type OneWayWebSocketInit = { }; export class OneWayWebSocket - implements OneWayWebSocketApi + implements UnidirectionalStream { readonly #socket: Ws; readonly #messageCallbacks = new Map< - OneWayEventCallback, + EventHandler, (data: RawData) => void >(); constructor(init: OneWayWebSocketInit) { const { location, apiRoute, protocols, options, searchParams } = init; - const formattedParams = - searchParams instanceof URLSearchParams - ? searchParams - : new URLSearchParams(searchParams); - const paramsString = formattedParams.toString(); - const paramsSuffix = paramsString ? `?${paramsString}` : ""; + const paramsSuffix = getQueryString(searchParams); const wsProtocol = location.protocol === "https:" ? "wss:" : "ws:"; const url = `${wsProtocol}//${location.host}${apiRoute}${paramsSuffix}`; @@ -92,10 +49,10 @@ export class OneWayWebSocket addEventListener( event: TEvent, - callback: OneWayEventCallback, + callback: EventHandler, ): void { if (event === "message") { - const messageCallback = callback as OneWayEventCallback; + const messageCallback = callback as EventHandler; if (this.#messageCallbacks.has(messageCallback)) { return; @@ -128,10 +85,10 @@ export class OneWayWebSocket removeEventListener( event: TEvent, - callback: OneWayEventCallback, + callback: EventHandler, ): void { if (event === "message") { - const messageCallback = callback as OneWayEventCallback; + const messageCallback = callback as EventHandler; const wrapper = this.#messageCallbacks.get(messageCallback); if (wrapper) { diff --git a/src/websocket/sseConnection.ts b/src/websocket/sseConnection.ts new file mode 100644 index 00000000..d3c9977f --- /dev/null +++ b/src/websocket/sseConnection.ts @@ -0,0 +1,191 @@ +import { type AxiosInstance } from "axios"; +import { type ServerSentEvent } from "coder/site/src/api/typesGenerated"; +import { type WebSocketEventType } from "coder/site/src/utils/OneWayWebSocket"; +import { EventSource } from "eventsource"; + +import { createStreamingFetchAdapter } from "../api/streamingFetchAdapter"; + +import { getQueryString } from "./utils"; + +import type { + CloseEvent as WsCloseEvent, + ErrorEvent as WsErrorEvent, + Event as WsEvent, + MessageEvent as WsMessageEvent, +} from "ws"; + +import type { + UnidirectionalStream, + ParsedMessageEvent, + EventHandler, +} from "./eventStreamConnection"; + +export type SseConnectionInit = { + location: { protocol: string; host: string }; + apiRoute: string; + searchParams?: Record | URLSearchParams; + axiosInstance: AxiosInstance; +}; + +export class SseConnection implements UnidirectionalStream { + private readonly eventSource: EventSource; + private readonly callbacks = { + open: new Set>(), + close: new Set>(), + error: new Set>(), + }; + // Original callback -> wrapped callback + private readonly messageWrappers = new Map< + EventHandler, + (event: MessageEvent) => void + >(); + + public readonly url: string; + + public constructor(init: SseConnectionInit) { + this.url = this.buildUrl(init); + this.eventSource = new EventSource(this.url, { + fetch: createStreamingFetchAdapter(init.axiosInstance), + }); + this.setupEventHandlers(); + } + + private buildUrl(init: SseConnectionInit): string { + const { location, apiRoute, searchParams } = init; + const queryString = getQueryString(searchParams); + return `${location.protocol}//${location.host}${apiRoute}${queryString}`; + } + + private setupEventHandlers(): void { + this.eventSource.addEventListener("open", () => + this.callbacks.open.forEach((cb) => cb({} as WsEvent)), + ); + + this.eventSource.addEventListener("message", (event: MessageEvent) => { + [...this.messageWrappers.values()].forEach((wrapper) => wrapper(event)); + }); + + this.eventSource.addEventListener("error", (error: Event | ErrorEvent) => { + this.callbacks.error.forEach((cb) => cb(this.createErrorEvent(error))); + + if (this.eventSource.readyState === EventSource.CLOSED) { + this.callbacks.close.forEach((cb) => + cb({ + code: 1006, + reason: "Connection lost", + wasClean: false, + } as WsCloseEvent), + ); + } + }); + } + + private createErrorEvent(event: Event | ErrorEvent): WsErrorEvent { + const errorMessage = + event instanceof ErrorEvent && event.message + ? event.message + : "SSE connection error"; + const error = event instanceof ErrorEvent ? event.error : undefined; + + return { + error: error, + message: errorMessage, + } as WsErrorEvent; + } + + public addEventListener( + event: TEvent, + callback: EventHandler, + ): void { + switch (event) { + case "close": + this.callbacks.close.add( + callback as EventHandler, + ); + break; + case "error": + this.callbacks.error.add( + callback as EventHandler, + ); + break; + case "message": { + const messageCallback = callback as EventHandler< + ServerSentEvent, + "message" + >; + if (!this.messageWrappers.has(messageCallback)) { + this.messageWrappers.set(messageCallback, (event: MessageEvent) => { + messageCallback(this.parseMessage(event)); + }); + } + break; + } + case "open": + this.callbacks.open.add( + callback as EventHandler, + ); + break; + } + } + + private parseMessage( + event: MessageEvent, + ): ParsedMessageEvent { + const wsEvent = { data: event.data } as WsMessageEvent; + try { + return { + sourceEvent: wsEvent, + parsedMessage: { type: "data", data: JSON.parse(event.data) }, + parseError: undefined, + }; + } catch (err) { + return { + sourceEvent: wsEvent, + parsedMessage: undefined, + parseError: err as Error, + }; + } + } + + public removeEventListener( + event: TEvent, + callback: EventHandler, + ): void { + switch (event) { + case "close": + this.callbacks.close.delete( + callback as EventHandler, + ); + break; + case "error": + this.callbacks.error.delete( + callback as EventHandler, + ); + break; + case "message": + this.messageWrappers.delete( + callback as EventHandler, + ); + break; + case "open": + this.callbacks.open.delete( + callback as EventHandler, + ); + break; + } + } + + public close(code?: number, reason?: string): void { + this.eventSource.close(); + this.callbacks.close.forEach((cb) => + cb({ + code: code ?? 1000, + reason: reason ?? "Normal closure", + wasClean: true, + } as WsCloseEvent), + ); + + Object.values(this.callbacks).forEach((callbackSet) => callbackSet.clear()); + this.messageWrappers.clear(); + } +} diff --git a/src/websocket/utils.ts b/src/websocket/utils.ts new file mode 100644 index 00000000..592ce45e --- /dev/null +++ b/src/websocket/utils.ts @@ -0,0 +1,15 @@ +/** + * Converts params to a query string. Returns empty string if no params, + * otherwise returns params prefixed with '?'. + */ +export function getQueryString( + params: Record | URLSearchParams | undefined, +): string { + if (!params) { + return ""; + } + const searchParams = + params instanceof URLSearchParams ? params : new URLSearchParams(params); + const str = searchParams.toString(); + return str ? `?${str}` : ""; +} diff --git a/src/workspace/workspaceMonitor.ts b/src/workspace/workspaceMonitor.ts index a761249a..ceea8a91 100644 --- a/src/workspace/workspaceMonitor.ts +++ b/src/workspace/workspaceMonitor.ts @@ -9,7 +9,7 @@ import { createWorkspaceIdentifier, errToStr } from "../api/api-helper"; import { type CoderApi } from "../api/coderApi"; import { type ContextManager } from "../core/contextManager"; import { type Logger } from "../logging/logger"; -import { type OneWayWebSocket } from "../websocket/oneWayWebSocket"; +import { type UnidirectionalStream } from "../websocket/eventStreamConnection"; /** * Monitor a single workspace using a WebSocket for events like shutdown and deletion. @@ -17,7 +17,7 @@ import { type OneWayWebSocket } from "../websocket/oneWayWebSocket"; * workspace status is also shown in the status bar menu. */ export class WorkspaceMonitor implements vscode.Disposable { - private socket: OneWayWebSocket | undefined; + private socket: UnidirectionalStream | undefined; private disposed = false; // How soon in advance to notify about autostop and deletion. @@ -93,10 +93,12 @@ export class WorkspaceMonitor implements vscode.Disposable { return; } // Perhaps we need to parse this and validate it. - const newWorkspaceData = event.parsedMessage.data as Workspace; - monitor.update(newWorkspaceData); - monitor.maybeNotify(newWorkspaceData); - monitor.onChange.fire(newWorkspaceData); + const newWorkspaceData = event.parsedMessage.data as Workspace | null; + if (newWorkspaceData) { + monitor.update(newWorkspaceData); + monitor.maybeNotify(newWorkspaceData); + monitor.onChange.fire(newWorkspaceData); + } } catch (error) { monitor.notifyError(error); } diff --git a/test/unit/logging/wsLogger.test.ts b/test/unit/logging/eventStreamLogger.test.ts similarity index 50% rename from test/unit/logging/wsLogger.test.ts rename to test/unit/logging/eventStreamLogger.test.ts index 5bf9d5b1..352ccaac 100644 --- a/test/unit/logging/wsLogger.test.ts +++ b/test/unit/logging/eventStreamLogger.test.ts @@ -1,19 +1,23 @@ import { describe, expect, it } from "vitest"; -import { WsLogger } from "@/logging/wsLogger"; +import { EventStreamLogger } from "@/logging/eventStreamLogger"; import { createMockLogger } from "../../mocks/testHelpers"; -describe("WS Logger", () => { +describe("EventStreamLogger", () => { it("tracks message count and byte size", () => { const logger = createMockLogger(); - const wsLogger = new WsLogger(logger, "wss://example.com"); + const eventStreamLogger = new EventStreamLogger( + logger, + "wss://example.com", + "WS", + ); - wsLogger.logOpen(); - wsLogger.logMessage("hello"); - wsLogger.logMessage("world"); - wsLogger.logMessage(Buffer.from("test")); - wsLogger.logClose(); + eventStreamLogger.logOpen(); + eventStreamLogger.logMessage("hello"); + eventStreamLogger.logMessage("world"); + eventStreamLogger.logMessage(Buffer.from("test")); + eventStreamLogger.logClose(); expect(logger.trace).toHaveBeenCalledWith( expect.stringContaining("3 msgs"), @@ -23,12 +27,16 @@ describe("WS Logger", () => { it("handles unknown byte sizes with >= indicator", () => { const logger = createMockLogger(); - const wsLogger = new WsLogger(logger, "wss://example.com"); + const eventStreamLogger = new EventStreamLogger( + logger, + "wss://example.com", + "WS", + ); - wsLogger.logOpen(); - wsLogger.logMessage({ complex: "object" }); // Unknown size - no estimation - wsLogger.logMessage("known"); - wsLogger.logClose(); + eventStreamLogger.logOpen(); + eventStreamLogger.logMessage({ complex: "object" }); // Unknown size - no estimation + eventStreamLogger.logMessage("known"); + eventStreamLogger.logClose(); expect(logger.trace).toHaveBeenLastCalledWith( expect.stringContaining(">= 5 B"), @@ -37,22 +45,30 @@ describe("WS Logger", () => { it("handles close before open gracefully", () => { const logger = createMockLogger(); - const wsLogger = new WsLogger(logger, "wss://example.com"); + const eventStreamLogger = new EventStreamLogger( + logger, + "wss://example.com", + "WS", + ); // Closing without opening should not throw - expect(() => wsLogger.logClose()).not.toThrow(); + expect(() => eventStreamLogger.logClose()).not.toThrow(); expect(logger.trace).toHaveBeenCalled(); }); it("formats large message counts with compact notation", () => { const logger = createMockLogger(); - const wsLogger = new WsLogger(logger, "wss://example.com"); + const eventStreamLogger = new EventStreamLogger( + logger, + "wss://example.com", + "WS", + ); - wsLogger.logOpen(); + eventStreamLogger.logOpen(); for (let i = 0; i < 1100; i++) { - wsLogger.logMessage("x"); + eventStreamLogger.logMessage("x"); } - wsLogger.logClose(); + eventStreamLogger.logClose(); expect(logger.trace).toHaveBeenLastCalledWith( expect.stringMatching(/1[.,]1K\s*msgs/), @@ -61,10 +77,14 @@ describe("WS Logger", () => { it("logs errors with error object", () => { const logger = createMockLogger(); - const wsLogger = new WsLogger(logger, "wss://example.com"); + const eventStreamLogger = new EventStreamLogger( + logger, + "wss://example.com", + "WS", + ); const error = new Error("Connection failed"); - wsLogger.logError(error, "Failed to connect"); + eventStreamLogger.logError(error, "Failed to connect"); expect(logger.error).toHaveBeenCalledWith(expect.any(String), error); }); From 2292d9b0ae497d852b7426c3af04a2c6dca21743 Mon Sep 17 00:00:00 2001 From: Ehab Younes Date: Mon, 13 Oct 2025 23:03:05 +0300 Subject: [PATCH 2/5] Add CoderApi tests --- test/unit/api/coderApi.test.ts | 382 +++++++++++++++++++++++++++++++++ 1 file changed, 382 insertions(+) create mode 100644 test/unit/api/coderApi.test.ts diff --git a/test/unit/api/coderApi.test.ts b/test/unit/api/coderApi.test.ts new file mode 100644 index 00000000..a0c3b68d --- /dev/null +++ b/test/unit/api/coderApi.test.ts @@ -0,0 +1,382 @@ +import axios, { AxiosError, AxiosHeaders } from "axios"; +import { type ProvisionerJobLog } from "coder/site/src/api/typesGenerated"; +import { EventSource } from "eventsource"; +import { ProxyAgent } from "proxy-agent"; +import { describe, it, expect, vi, beforeEach } from "vitest"; +import Ws from "ws"; + +import { CoderApi } from "@/api/coderApi"; +import { createHttpAgent } from "@/api/utils"; +import { CertificateError } from "@/error"; +import { getHeaders } from "@/headers"; +import { type RequestConfigWithMeta } from "@/logging/types"; +import { OneWayWebSocket } from "@/websocket/oneWayWebSocket"; +import { SseConnection } from "@/websocket/sseConnection"; + +import { + createMockLogger, + MockConfigurationProvider, +} from "../../mocks/testHelpers"; + +vi.mock("ws"); +vi.mock("eventsource"); +vi.mock("proxy-agent"); + +const mockAdapterImpl = vi.hoisted(() => (config: Record) => { + return Promise.resolve({ + data: config.data || "{}", + status: 200, + statusText: "OK", + headers: {}, + config, + }); +}); + +vi.mock("axios", async () => { + const actual = await vi.importActual("axios"); + + const mockAdapter = vi.fn(mockAdapterImpl); + + const mockDefault = { + ...actual.default, + create: vi.fn((config) => { + const instance = actual.default.create({ + ...config, + adapter: mockAdapter, + }); + return instance; + }), + __mockAdapter: mockAdapter, + }; + + return { + ...actual, + default: mockDefault, + }; +}); + +vi.mock("@/headers", () => ({ + getHeaders: vi.fn().mockResolvedValue({}), + getHeaderCommand: vi.fn(), +})); + +vi.mock("@/api/utils", () => ({ + createHttpAgent: vi.fn(), +})); + +vi.mock("@/api/streamingFetchAdapter", () => ({ + createStreamingFetchAdapter: vi.fn(() => fetch), +})); + +describe("CoderApi", () => { + let mockLogger: ReturnType; + let mockConfig: MockConfigurationProvider; + let mockAdapter: ReturnType; + + beforeEach(() => { + vi.resetAllMocks(); + + const axiosMock = axios as typeof axios & { + __mockAdapter: ReturnType; + }; + mockAdapter = axiosMock.__mockAdapter; + mockAdapter.mockImplementation(mockAdapterImpl); + + vi.mocked(getHeaders).mockResolvedValue({}); + mockLogger = createMockLogger(); + mockConfig = new MockConfigurationProvider(); + mockConfig.set("coder.httpClientLogLevel", "BASIC"); + }); + + describe("HTTP Interceptors", () => { + it("adds custom headers and HTTP agent to requests", async () => { + const mockAgent = new ProxyAgent(); + vi.mocked(createHttpAgent).mockResolvedValue(mockAgent); + vi.mocked(getHeaders).mockResolvedValue({ + "X-Custom-Header": "custom-value", + "X-Another-Header": "another-value", + }); + + const api = CoderApi.create( + "https://coder.example.com", + "token", + mockLogger, + ); + + const response = await api.getAxiosInstance().get("/api/v2/users/me"); + + expect(response.config.headers["X-Custom-Header"]).toBe("custom-value"); + expect(response.config.headers["X-Another-Header"]).toBe("another-value"); + expect(response.config.httpsAgent).toBe(mockAgent); + expect(response.config.httpAgent).toBe(mockAgent); + expect(response.config.proxy).toBe(false); + }); + + it("wraps certificate errors in response interceptor", async () => { + const api = CoderApi.create( + "https://coder.example.com", + "token", + mockLogger, + ); + + const certError = new AxiosError( + "self signed certificate", + "DEPTH_ZERO_SELF_SIGNED_CERT", + ); + mockAdapter.mockRejectedValueOnce(certError); + + const thrownError = await api + .getAxiosInstance() + .get("/api/v2/users/me") + .catch((e) => e); + + expect(thrownError).toBeInstanceOf(CertificateError); + expect(thrownError.message).toContain("Secure connection"); + expect(thrownError.x509Err).toBeDefined(); + }); + + it("applies headers in correct precedence order", async () => { + vi.mocked(getHeaders).mockResolvedValue({ + "X-Custom-Header": "from-command", + "Coder-Session-Token": "from-header-command", + }); + + const api = CoderApi.create( + "https://coder.example.com", + "passed-token", + mockLogger, + ); + + const response = await api.getAxiosInstance().get("/api/v2/users/me", { + headers: new AxiosHeaders({ + "X-Custom-Header": "from-config", + "X-Extra": "extra-value", + "Coder-Session-Token": "ignored-token", + }), + }); + + expect(response.config.headers["X-Custom-Header"]).toBe("from-command"); + expect(response.config.headers["X-Extra"]).toBe("extra-value"); + expect(response.config.headers["Coder-Session-Token"]).toBe( + "from-header-command", + ); + }); + + it("logs requests and responses", async () => { + const api = CoderApi.create( + "https://coder.example.com", + "token", + mockLogger, + ); + + await api.getWorkspaces({}); + + expect(mockLogger.trace).toHaveBeenCalledWith( + expect.stringContaining("/api/v2/workspaces"), + ); + }); + + it("calculates request and response sizes in transforms", async () => { + const api = CoderApi.create( + "https://coder.example.com", + "token", + mockLogger, + ); + + const response = await api + .getAxiosInstance() + .post("/api/v2/workspaces", { name: "test" }); + + expect((response.config as RequestConfigWithMeta).rawRequestSize).toBe( + 15, + ); + // We return the same data we sent in the mock adapter. + expect((response.config as RequestConfigWithMeta).rawResponseSize).toBe( + 15, + ); + }); + }); + + describe("WebSocket Creation", () => { + const buildId = "build-123"; + const wsUrl = `wss://coder.example.com/api/v2/workspacebuilds/${buildId}/logs?follow=true`; + let api: CoderApi; + + beforeEach(() => { + api = CoderApi.create( + "https://coder.example.com", + "passed-token", + mockLogger, + ); + + // Mock all WS as "WatchBuildLogsByBuildId" + const mockWs = { + url: wsUrl, + on: vi.fn(), + off: vi.fn(), + close: vi.fn(), + } as Partial; + vi.mocked(Ws).mockImplementation(() => mockWs as Ws); + }); + + it("creates WebSocket with proper headers and configuration", async () => { + const mockAgent = new ProxyAgent(); + vi.mocked(getHeaders).mockResolvedValue({ + "X-Custom-Header": "custom-value", + }); + vi.mocked(createHttpAgent).mockResolvedValue(mockAgent); + + await api.watchBuildLogsByBuildId(buildId, []); + + expect(Ws).toHaveBeenCalledWith(wsUrl, undefined, { + agent: mockAgent, + followRedirects: true, + headers: { + "X-Custom-Header": "custom-value", + "Coder-Session-Token": "passed-token", + }, + }); + }); + + it("applies headers in correct precedence order", async () => { + vi.mocked(getHeaders).mockResolvedValue({ + "X-Custom-Header": "from-command", + "Coder-Session-Token": "from-header-command", + }); + + await api.watchBuildLogsByBuildId(buildId, []); + + expect(Ws).toHaveBeenCalledWith(wsUrl, undefined, { + agent: undefined, + followRedirects: true, + headers: { + "X-Custom-Header": "from-command", + "Coder-Session-Token": "passed-token", + }, + }); + }); + + it("logs WebSocket connections", async () => { + await api.watchBuildLogsByBuildId(buildId, []); + + expect(mockLogger.trace).toHaveBeenCalledWith( + expect.stringContaining(buildId), + ); + }); + + it("'watchBuildLogsByBuildId' includes after parameter for existing logs", async () => { + const jobLog: ProvisionerJobLog = { + created_at: new Date().toISOString(), + id: 1, + output: "log1", + log_source: "provisioner", + log_level: "info", + stage: "stage1", + }; + const existingLogs = [jobLog, { ...jobLog, id: 20 }]; + + await api.watchBuildLogsByBuildId(buildId, existingLogs); + + expect(Ws).toHaveBeenCalledWith( + expect.stringContaining("after=20"), + undefined, + expect.any(Object), + ); + }); + }); + + describe("SSE Fallback", () => { + let api: CoderApi; + + beforeEach(() => { + api = CoderApi.create("https://coder.example.com", "token", mockLogger); + + const mockEventSource = { + url: "https://coder.example.com/api/v2/workspaces/123/watch", + readyState: EventSource.CONNECTING, + addEventListener: vi.fn((event, handler) => { + if (event === "open") { + setImmediate(() => handler(new Event("open"))); + } + }), + removeEventListener: vi.fn(), + close: vi.fn(), + }; + + vi.mocked(EventSource).mockImplementation( + () => mockEventSource as unknown as EventSource, + ); + }); + + it("uses WebSocket when no errors occur", async () => { + const mockWs: Partial = { + url: "wss://coder.example.com/api/v2/workspaceagents/agent-123/watch-metadata", + on: vi.fn((event, handler) => { + if (event === "open") { + setImmediate(() => handler()); + } + return mockWs as Ws; + }), + off: vi.fn(), + close: vi.fn(), + }; + vi.mocked(Ws).mockImplementation(() => mockWs as Ws); + + const connection = await api.watchAgentMetadata("agent-123"); + + expect(connection).toBeInstanceOf(OneWayWebSocket); + expect(EventSource).not.toHaveBeenCalled(); + }); + + it("falls back to SSE when WebSocket creation fails", async () => { + vi.mocked(Ws).mockImplementation(() => { + throw new Error("WebSocket creation failed"); + }); + + const connection = await api.watchAgentMetadata("agent-123"); + expect(connection).toBeInstanceOf(SseConnection); + expect(EventSource).toHaveBeenCalled(); + }); + + it("falls back to SSE on 404 error from WebSocket", async () => { + const mockWs: Partial = { + url: "wss://coder.example.com/api/v2/test", + on: vi.fn((event: string, handler: (e: unknown) => void) => { + if (event === "error") { + setImmediate(() => { + handler({ + error: new Error("404 Not Found"), + message: "404 Not Found", + }); + }); + } + return mockWs as Ws; + }), + off: vi.fn(), + close: vi.fn(), + }; + + vi.mocked(Ws).mockImplementation(() => mockWs as Ws); + + const connection = await api.watchAgentMetadata("agent-123"); + expect(connection).toBeInstanceOf(SseConnection); + expect(EventSource).toHaveBeenCalled(); + }); + }); + + describe("Error Handling", () => { + it("throws error when no base URL is set", async () => { + const api = CoderApi.create( + "https://coder.example.com", + "token", + mockLogger, + ); + + api.getAxiosInstance().defaults.baseURL = undefined; + + await expect( + api.watchBuildLogsByBuildId("build-123", []), + ).rejects.toThrow("No base URL set on REST client"); + }); + }); +}); From bd3a0d802db1c323514de47cd7e82d8662fb95ef Mon Sep 17 00:00:00 2001 From: Ehab Younes Date: Mon, 13 Oct 2025 23:21:44 +0300 Subject: [PATCH 3/5] Simplify test and fix header order precedence --- src/api/coderApi.ts | 18 +- test/unit/api/coderApi.test.ts | 321 +++++++++++++++++++-------------- 2 files changed, 196 insertions(+), 143 deletions(-) diff --git a/src/api/coderApi.ts b/src/api/coderApi.ts index 6e8148ab..5d75c00e 100644 --- a/src/api/coderApi.ts +++ b/src/api/coderApi.ts @@ -107,6 +107,7 @@ export class CoderApi extends Api { watchBuildLogsByBuildId = async ( buildId: string, logs: ProvisionerJobLog[], + options?: ClientOptions, ) => { const searchParams = new URLSearchParams({ follow: "true" }); if (logs.length) { @@ -116,6 +117,7 @@ export class CoderApi extends Api { return this.createWebSocket({ apiRoute: `/api/v2/workspacebuilds/${buildId}/logs`, searchParams, + options, }); }; @@ -132,7 +134,7 @@ export class CoderApi extends Api { coderSessionTokenHeader ] as string | undefined; - const headers = await getHeaders( + const headersFromCommand = await getHeaders( baseUrlRaw, getHeaderCommand(vscode.workspace.getConfiguration()), this.output, @@ -142,18 +144,20 @@ export class CoderApi extends Api { vscode.workspace.getConfiguration(), ); + const headers = { + ...(token ? { [coderSessionTokenHeader]: token } : {}), + ...configs.options?.headers, + ...headersFromCommand, + }; + const webSocket = new OneWayWebSocket({ location: baseUrl, ...configs, options: { + ...configs.options, agent: httpAgent, followRedirects: true, - headers: { - ...(token ? { [coderSessionTokenHeader]: token } : {}), - ...configs.options?.headers, - ...headers, - }, - ...configs.options, + headers, }, }); diff --git a/test/unit/api/coderApi.test.ts b/test/unit/api/coderApi.test.ts index a0c3b68d..0336d564 100644 --- a/test/unit/api/coderApi.test.ts +++ b/test/unit/api/coderApi.test.ts @@ -18,20 +18,15 @@ import { MockConfigurationProvider, } from "../../mocks/testHelpers"; +const CODER_URL = "https://coder.example.com"; +const AXIOS_TOKEN = "passed-token"; +const BUILD_ID = "build-123"; +const AGENT_ID = "agent-123"; + vi.mock("ws"); vi.mock("eventsource"); vi.mock("proxy-agent"); -const mockAdapterImpl = vi.hoisted(() => (config: Record) => { - return Promise.resolve({ - data: config.data || "{}", - status: 200, - statusText: "OK", - headers: {}, - config, - }); -}); - vi.mock("axios", async () => { const actual = await vi.importActual("axios"); @@ -72,6 +67,11 @@ describe("CoderApi", () => { let mockLogger: ReturnType; let mockConfig: MockConfigurationProvider; let mockAdapter: ReturnType; + let api: CoderApi; + + const createApi = (url = CODER_URL, token = AXIOS_TOKEN) => { + return CoderApi.create(url, token, mockLogger); + }; beforeEach(() => { vi.resetAllMocks(); @@ -97,12 +97,7 @@ describe("CoderApi", () => { "X-Another-Header": "another-value", }); - const api = CoderApi.create( - "https://coder.example.com", - "token", - mockLogger, - ); - + const api = createApi(); const response = await api.getAxiosInstance().get("/api/v2/users/me"); expect(response.config.headers["X-Custom-Header"]).toBe("custom-value"); @@ -113,12 +108,7 @@ describe("CoderApi", () => { }); it("wraps certificate errors in response interceptor", async () => { - const api = CoderApi.create( - "https://coder.example.com", - "token", - mockLogger, - ); - + const api = createApi(); const certError = new AxiosError( "self signed certificate", "DEPTH_ZERO_SELF_SIGNED_CERT", @@ -135,39 +125,54 @@ describe("CoderApi", () => { expect(thrownError.x509Err).toBeDefined(); }); - it("applies headers in correct precedence order", async () => { - vi.mocked(getHeaders).mockResolvedValue({ - "X-Custom-Header": "from-command", - "Coder-Session-Token": "from-header-command", - }); - - const api = CoderApi.create( - "https://coder.example.com", - "passed-token", - mockLogger, - ); + it("applies headers in correct precedence order (command > config > axios default)", async () => { + const api = createApi(CODER_URL, AXIOS_TOKEN); + // Test 1: Headers from config, default token from API creation const response = await api.getAxiosInstance().get("/api/v2/users/me", { headers: new AxiosHeaders({ "X-Custom-Header": "from-config", "X-Extra": "extra-value", - "Coder-Session-Token": "ignored-token", }), }); - expect(response.config.headers["X-Custom-Header"]).toBe("from-command"); + expect(response.config.headers["X-Custom-Header"]).toBe("from-config"); expect(response.config.headers["X-Extra"]).toBe("extra-value"); - expect(response.config.headers["Coder-Session-Token"]).toBe( - "from-header-command", + expect(response.config.headers["Coder-Session-Token"]).toBe(AXIOS_TOKEN); + + // Test 2: Token from request options overrides default + const responseWithToken = await api + .getAxiosInstance() + .get("/api/v2/users/me", { + headers: new AxiosHeaders({ + "Coder-Session-Token": "from-options", + }), + }); + + expect(responseWithToken.config.headers["Coder-Session-Token"]).toBe( + "from-options", ); + + // Test 3: Header command overrides everything + vi.mocked(getHeaders).mockResolvedValue({ + "Coder-Session-Token": "from-header-command", + }); + + const responseWithHeaderCommand = await api + .getAxiosInstance() + .get("/api/v2/users/me", { + headers: new AxiosHeaders({ + "Coder-Session-Token": "from-options", + }), + }); + + expect( + responseWithHeaderCommand.config.headers["Coder-Session-Token"], + ).toBe("from-header-command"); }); it("logs requests and responses", async () => { - const api = CoderApi.create( - "https://coder.example.com", - "token", - mockLogger, - ); + const api = createApi(); await api.getWorkspaces({}); @@ -177,12 +182,7 @@ describe("CoderApi", () => { }); it("calculates request and response sizes in transforms", async () => { - const api = CoderApi.create( - "https://coder.example.com", - "token", - mockLogger, - ); - + const api = createApi(); const response = await api .getAxiosInstance() .post("/api/v2/workspaces", { name: "test" }); @@ -190,7 +190,7 @@ describe("CoderApi", () => { expect((response.config as RequestConfigWithMeta).rawRequestSize).toBe( 15, ); - // We return the same data we sent in the mock adapter. + // We return the same data we sent in the mock adapter expect((response.config as RequestConfigWithMeta).rawResponseSize).toBe( 15, ); @@ -198,25 +198,12 @@ describe("CoderApi", () => { }); describe("WebSocket Creation", () => { - const buildId = "build-123"; - const wsUrl = `wss://coder.example.com/api/v2/workspacebuilds/${buildId}/logs?follow=true`; - let api: CoderApi; + const wsUrl = `wss://${CODER_URL.replace("https://", "")}/api/v2/workspacebuilds/${BUILD_ID}/logs?follow=true`; beforeEach(() => { - api = CoderApi.create( - "https://coder.example.com", - "passed-token", - mockLogger, - ); - - // Mock all WS as "WatchBuildLogsByBuildId" - const mockWs = { - url: wsUrl, - on: vi.fn(), - off: vi.fn(), - close: vi.fn(), - } as Partial; - vi.mocked(Ws).mockImplementation(() => mockWs as Ws); + api = createApi(CODER_URL, AXIOS_TOKEN); + const mockWs = createMockWebSocket(wsUrl); + setupWebSocketMock(mockWs); }); it("creates WebSocket with proper headers and configuration", async () => { @@ -226,41 +213,72 @@ describe("CoderApi", () => { }); vi.mocked(createHttpAgent).mockResolvedValue(mockAgent); - await api.watchBuildLogsByBuildId(buildId, []); + await api.watchBuildLogsByBuildId(BUILD_ID, []); expect(Ws).toHaveBeenCalledWith(wsUrl, undefined, { agent: mockAgent, followRedirects: true, headers: { "X-Custom-Header": "custom-value", - "Coder-Session-Token": "passed-token", + "Coder-Session-Token": AXIOS_TOKEN, }, }); }); - it("applies headers in correct precedence order", async () => { + it("applies headers in correct precedence order (command > config > axios default)", async () => { + // Test 1: Default token from API creation + await api.watchBuildLogsByBuildId(BUILD_ID, []); + + expect(Ws).toHaveBeenCalledWith(wsUrl, undefined, { + agent: undefined, + followRedirects: true, + headers: { + "Coder-Session-Token": AXIOS_TOKEN, + }, + }); + + // Test 2: Token from config options overrides default + await api.watchBuildLogsByBuildId(BUILD_ID, [], { + headers: { + "X-Config-Header": "config-value", + "Coder-Session-Token": "from-config", + }, + }); + + expect(Ws).toHaveBeenCalledWith(wsUrl, undefined, { + agent: undefined, + followRedirects: true, + headers: { + "Coder-Session-Token": "from-config", + "X-Config-Header": "config-value", + }, + }); + + // Test 3: Header command overrides everything vi.mocked(getHeaders).mockResolvedValue({ - "X-Custom-Header": "from-command", "Coder-Session-Token": "from-header-command", }); - await api.watchBuildLogsByBuildId(buildId, []); + await api.watchBuildLogsByBuildId(BUILD_ID, [], { + headers: { + "Coder-Session-Token": "from-config", + }, + }); expect(Ws).toHaveBeenCalledWith(wsUrl, undefined, { agent: undefined, followRedirects: true, headers: { - "X-Custom-Header": "from-command", - "Coder-Session-Token": "passed-token", + "Coder-Session-Token": "from-header-command", }, }); }); it("logs WebSocket connections", async () => { - await api.watchBuildLogsByBuildId(buildId, []); + await api.watchBuildLogsByBuildId(BUILD_ID, []); expect(mockLogger.trace).toHaveBeenCalledWith( - expect.stringContaining(buildId), + expect.stringContaining(BUILD_ID), ); }); @@ -273,12 +291,16 @@ describe("CoderApi", () => { log_level: "info", stage: "stage1", }; - const existingLogs = [jobLog, { ...jobLog, id: 20 }]; + const existingLogs = [ + jobLog, + { ...jobLog, id: 20 }, + { ...jobLog, id: 5 }, + ]; - await api.watchBuildLogsByBuildId(buildId, existingLogs); + await api.watchBuildLogsByBuildId(BUILD_ID, existingLogs); expect(Ws).toHaveBeenCalledWith( - expect.stringContaining("after=20"), + expect.stringContaining("after=5"), undefined, expect.any(Object), ); @@ -286,43 +308,29 @@ describe("CoderApi", () => { }); describe("SSE Fallback", () => { - let api: CoderApi; - beforeEach(() => { - api = CoderApi.create("https://coder.example.com", "token", mockLogger); - - const mockEventSource = { - url: "https://coder.example.com/api/v2/workspaces/123/watch", - readyState: EventSource.CONNECTING, - addEventListener: vi.fn((event, handler) => { - if (event === "open") { - setImmediate(() => handler(new Event("open"))); - } - }), - removeEventListener: vi.fn(), - close: vi.fn(), - }; - - vi.mocked(EventSource).mockImplementation( - () => mockEventSource as unknown as EventSource, + api = createApi(); + const mockEventSource = createMockEventSource( + `${CODER_URL}/api/v2/workspaces/123/watch`, ); + setupEventSourceMock(mockEventSource); }); it("uses WebSocket when no errors occur", async () => { - const mockWs: Partial = { - url: "wss://coder.example.com/api/v2/workspaceagents/agent-123/watch-metadata", - on: vi.fn((event, handler) => { - if (event === "open") { - setImmediate(() => handler()); - } - return mockWs as Ws; - }), - off: vi.fn(), - close: vi.fn(), - }; - vi.mocked(Ws).mockImplementation(() => mockWs as Ws); + const mockWs = createMockWebSocket( + `wss://${CODER_URL.replace("https://", "")}/api/v2/workspaceagents/${AGENT_ID}/watch-metadata`, + { + on: vi.fn((event, handler) => { + if (event === "open") { + setImmediate(() => handler()); + } + return mockWs as Ws; + }), + }, + ); + setupWebSocketMock(mockWs); - const connection = await api.watchAgentMetadata("agent-123"); + const connection = await api.watchAgentMetadata(AGENT_ID); expect(connection).toBeInstanceOf(OneWayWebSocket); expect(EventSource).not.toHaveBeenCalled(); @@ -333,32 +341,33 @@ describe("CoderApi", () => { throw new Error("WebSocket creation failed"); }); - const connection = await api.watchAgentMetadata("agent-123"); + const connection = await api.watchAgentMetadata(AGENT_ID); + expect(connection).toBeInstanceOf(SseConnection); expect(EventSource).toHaveBeenCalled(); }); it("falls back to SSE on 404 error from WebSocket", async () => { - const mockWs: Partial = { - url: "wss://coder.example.com/api/v2/test", - on: vi.fn((event: string, handler: (e: unknown) => void) => { - if (event === "error") { - setImmediate(() => { - handler({ - error: new Error("404 Not Found"), - message: "404 Not Found", + const mockWs = createMockWebSocket( + `wss://${CODER_URL.replace("https://", "")}/api/v2/test`, + { + on: vi.fn((event: string, handler: (e: unknown) => void) => { + if (event === "error") { + setImmediate(() => { + handler({ + error: new Error("404 Not Found"), + message: "404 Not Found", + }); }); - }); - } - return mockWs as Ws; - }), - off: vi.fn(), - close: vi.fn(), - }; + } + return mockWs as Ws; + }), + }, + ); + setupWebSocketMock(mockWs); - vi.mocked(Ws).mockImplementation(() => mockWs as Ws); + const connection = await api.watchAgentMetadata(AGENT_ID); - const connection = await api.watchAgentMetadata("agent-123"); expect(connection).toBeInstanceOf(SseConnection); expect(EventSource).toHaveBeenCalled(); }); @@ -366,17 +375,57 @@ describe("CoderApi", () => { describe("Error Handling", () => { it("throws error when no base URL is set", async () => { - const api = CoderApi.create( - "https://coder.example.com", - "token", - mockLogger, - ); - + const api = createApi(); api.getAxiosInstance().defaults.baseURL = undefined; - await expect( - api.watchBuildLogsByBuildId("build-123", []), - ).rejects.toThrow("No base URL set on REST client"); + await expect(api.watchBuildLogsByBuildId(BUILD_ID, [])).rejects.toThrow( + "No base URL set on REST client", + ); }); }); }); + +const mockAdapterImpl = vi.hoisted(() => (config: Record) => { + return Promise.resolve({ + data: config.data || "{}", + status: 200, + statusText: "OK", + headers: {}, + config, + }); +}); + +function createMockWebSocket( + url: string, + overrides?: Partial, +): Partial { + return { + url, + on: vi.fn(), + off: vi.fn(), + close: vi.fn(), + ...overrides, + }; +} + +function createMockEventSource(url: string): Partial { + return { + url, + readyState: EventSource.CONNECTING, + addEventListener: vi.fn((event, handler) => { + if (event === "open") { + setImmediate(() => handler(new Event("open"))); + } + }), + removeEventListener: vi.fn(), + close: vi.fn(), + }; +} + +function setupWebSocketMock(ws: Partial): void { + vi.mocked(Ws).mockImplementation(() => ws as Ws); +} + +function setupEventSourceMock(es: Partial): void { + vi.mocked(EventSource).mockImplementation(() => es as EventSource); +} From 40b93bc0af24a177c8dfe4955289a7817e36cbb0 Mon Sep 17 00:00:00 2001 From: Ehab Younes Date: Tue, 14 Oct 2025 11:48:26 +0300 Subject: [PATCH 4/5] Attach the correct listener for data --- src/websocket/sseConnection.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/websocket/sseConnection.ts b/src/websocket/sseConnection.ts index d3c9977f..f68e0c73 100644 --- a/src/websocket/sseConnection.ts +++ b/src/websocket/sseConnection.ts @@ -61,7 +61,7 @@ export class SseConnection implements UnidirectionalStream { this.callbacks.open.forEach((cb) => cb({} as WsEvent)), ); - this.eventSource.addEventListener("message", (event: MessageEvent) => { + this.eventSource.addEventListener("data", (event: MessageEvent) => { [...this.messageWrappers.values()].forEach((wrapper) => wrapper(event)); }); From aef2b0a7403ea09c61b932e3c6818d5e936baf1d Mon Sep 17 00:00:00 2001 From: Ehab Younes Date: Wed, 15 Oct 2025 17:10:20 +0300 Subject: [PATCH 5/5] Review comments: better error handling and documentation --- src/api/coderApi.ts | 15 +++++++++- src/api/streamingFetchAdapter.ts | 15 ++++++++-- src/websocket/sseConnection.ts | 50 +++++++++++++++++++++++++------- 3 files changed, 66 insertions(+), 14 deletions(-) diff --git a/src/api/coderApi.ts b/src/api/coderApi.ts index 5d75c00e..6509ac67 100644 --- a/src/api/coderApi.ts +++ b/src/api/coderApi.ts @@ -144,6 +144,12 @@ export class CoderApi extends Api { vscode.workspace.getConfiguration(), ); + /** + * Similar to the REST client, we want to prioritize headers in this order (highest to lowest): + * 1. Headers from the header command + * 2. Any headers passed directly to this function + * 3. Coder session token from the Api client (if set) + */ const headers = { ...(token ? { [coderSessionTokenHeader]: token } : {}), ...configs.options?.headers, @@ -189,7 +195,9 @@ export class CoderApi extends Api { } /** - * Create a WebSocket connection with SSE fallback on 404 + * Create a WebSocket connection with SSE fallback on 404. + * + * Note: The fallback on SSE ignores all passed client options except the headers. */ private async createWebSocketWithFallback(configs: { apiRoute: string; @@ -209,6 +217,7 @@ export class CoderApi extends Api { return this.createSseFallback( configs.fallbackApiRoute, configs.searchParams, + configs.options?.headers, ); } @@ -216,6 +225,7 @@ export class CoderApi extends Api { this.createSseFallback( configs.fallbackApiRoute, configs.searchParams, + configs.options?.headers, ), ); } @@ -260,6 +270,7 @@ export class CoderApi extends Api { private async createSseFallback( apiRoute: string, searchParams?: Record | URLSearchParams, + optionsHeaders?: Record, ): Promise> { this.output.warn(`WebSocket failed, using SSE fallback: ${apiRoute}`); @@ -274,6 +285,8 @@ export class CoderApi extends Api { apiRoute, searchParams, axiosInstance: this.getAxiosInstance(), + optionsHeaders: optionsHeaders, + logger: this.output, }); this.attachStreamLogger(sseConnection); diff --git a/src/api/streamingFetchAdapter.ts b/src/api/streamingFetchAdapter.ts index 79e3ee39..f0730535 100644 --- a/src/api/streamingFetchAdapter.ts +++ b/src/api/streamingFetchAdapter.ts @@ -8,6 +8,7 @@ import { type IncomingMessage } from "http"; */ export function createStreamingFetchAdapter( axiosInstance: AxiosInstance, + configHeaders?: Record, ): (url: string | URL, init?: FetchLikeInit) => Promise { return async ( url: string | URL, @@ -18,7 +19,7 @@ export function createStreamingFetchAdapter( const response = await axiosInstance.request({ url: urlStr, signal: init?.signal, - headers: init?.headers, + headers: { ...init?.headers, ...configHeaders }, responseType: "stream", validateStatus: () => true, // Don't throw on any status code }); @@ -26,11 +27,19 @@ export function createStreamingFetchAdapter( const stream = new ReadableStream({ start(controller) { response.data.on("data", (chunk: Buffer) => { - controller.enqueue(chunk); + try { + controller.enqueue(chunk); + } catch { + // Stream already closed or errored, ignore + } }); response.data.on("end", () => { - controller.close(); + try { + controller.close(); + } catch { + // Stream already closed, ignore + } }); response.data.on("error", (err: Error) => { diff --git a/src/websocket/sseConnection.ts b/src/websocket/sseConnection.ts index f68e0c73..834100aa 100644 --- a/src/websocket/sseConnection.ts +++ b/src/websocket/sseConnection.ts @@ -4,6 +4,7 @@ import { type WebSocketEventType } from "coder/site/src/utils/OneWayWebSocket"; import { EventSource } from "eventsource"; import { createStreamingFetchAdapter } from "../api/streamingFetchAdapter"; +import { type Logger } from "../logging/logger"; import { getQueryString } from "./utils"; @@ -24,11 +25,14 @@ export type SseConnectionInit = { location: { protocol: string; host: string }; apiRoute: string; searchParams?: Record | URLSearchParams; + optionsHeaders?: Record; axiosInstance: AxiosInstance; + logger: Logger; }; export class SseConnection implements UnidirectionalStream { private readonly eventSource: EventSource; + private readonly logger: Logger; private readonly callbacks = { open: new Set>(), close: new Set>(), @@ -43,9 +47,13 @@ export class SseConnection implements UnidirectionalStream { public readonly url: string; public constructor(init: SseConnectionInit) { + this.logger = init.logger; this.url = this.buildUrl(init); this.eventSource = new EventSource(this.url, { - fetch: createStreamingFetchAdapter(init.axiosInstance), + fetch: createStreamingFetchAdapter( + init.axiosInstance, + init.optionsHeaders, + ), }); this.setupEventHandlers(); } @@ -58,28 +66,48 @@ export class SseConnection implements UnidirectionalStream { private setupEventHandlers(): void { this.eventSource.addEventListener("open", () => - this.callbacks.open.forEach((cb) => cb({} as WsEvent)), + this.invokeCallbacks(this.callbacks.open, {} as WsEvent, "open"), ); this.eventSource.addEventListener("data", (event: MessageEvent) => { - [...this.messageWrappers.values()].forEach((wrapper) => wrapper(event)); + this.invokeCallbacks(this.messageWrappers.values(), event, "message"); }); this.eventSource.addEventListener("error", (error: Event | ErrorEvent) => { - this.callbacks.error.forEach((cb) => cb(this.createErrorEvent(error))); + this.invokeCallbacks( + this.callbacks.error, + this.createErrorEvent(error), + "error", + ); if (this.eventSource.readyState === EventSource.CLOSED) { - this.callbacks.close.forEach((cb) => - cb({ + this.invokeCallbacks( + this.callbacks.close, + { code: 1006, reason: "Connection lost", wasClean: false, - } as WsCloseEvent), + } as WsCloseEvent, + "close", ); } }); } + private invokeCallbacks( + callbacks: Iterable<(event: T) => void>, + event: T, + eventType: string, + ): void { + for (const cb of callbacks) { + try { + cb(event); + } catch (err) { + this.logger.error(`Error in SSE ${eventType} callback:`, err); + } + } + } + private createErrorEvent(event: Event | ErrorEvent): WsErrorEvent { const errorMessage = event instanceof ErrorEvent && event.message @@ -177,12 +205,14 @@ export class SseConnection implements UnidirectionalStream { public close(code?: number, reason?: string): void { this.eventSource.close(); - this.callbacks.close.forEach((cb) => - cb({ + this.invokeCallbacks( + this.callbacks.close, + { code: code ?? 1000, reason: reason ?? "Normal closure", wasClean: true, - } as WsCloseEvent), + } as WsCloseEvent, + "close", ); Object.values(this.callbacks).forEach((callbackSet) => callbackSet.clear());