diff --git a/src/api/coderApi.ts b/src/api/coderApi.ts index 99976ff7..6509ac67 100644 --- a/src/api/coderApi.ts +++ b/src/api/coderApi.ts @@ -6,17 +6,18 @@ import { } from "axios"; import { Api } from "coder/site/src/api/api"; import { + type ServerSentEvent, type GetInboxNotificationResponse, type ProvisionerJobLog, - type ServerSentEvent, type Workspace, type WorkspaceAgent, } from "coder/site/src/api/typesGenerated"; import * as vscode from "vscode"; -import { type ClientOptions } from "ws"; +import { type ClientOptions, type CloseEvent, type ErrorEvent } from "ws"; import { CertificateError } from "../error"; import { getHeaderCommand, getHeaders } from "../headers"; +import { EventStreamLogger } from "../logging/eventStreamLogger"; import { createRequestMeta, logRequest, @@ -29,11 +30,12 @@ import { HttpClientLogLevel, } from "../logging/types"; import { sizeOf } from "../logging/utils"; -import { WsLogger } from "../logging/wsLogger"; +import { type UnidirectionalStream } from "../websocket/eventStreamConnection"; import { OneWayWebSocket, type OneWayWebSocketInit, } from "../websocket/oneWayWebSocket"; +import { SseConnection } from "../websocket/sseConnection"; import { createHttpAgent } from "./utils"; @@ -84,8 +86,9 @@ export class CoderApi extends Api { }; watchWorkspace = async (workspace: Workspace, options?: ClientOptions) => { - return this.createWebSocket({ + return this.createWebSocketWithFallback({ apiRoute: `/api/v2/workspaces/${workspace.id}/watch-ws`, + fallbackApiRoute: `/api/v2/workspaces/${workspace.id}/watch`, options, }); }; @@ -94,8 +97,9 @@ export class CoderApi extends Api { agentId: WorkspaceAgent["id"], options?: ClientOptions, ) => { - return this.createWebSocket({ + return this.createWebSocketWithFallback({ apiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata-ws`, + fallbackApiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata`, options, }); }; @@ -103,6 +107,7 @@ export class CoderApi extends Api { watchBuildLogsByBuildId = async ( buildId: string, logs: ProvisionerJobLog[], + options?: ClientOptions, ) => { const searchParams = new URLSearchParams({ follow: "true" }); if (logs.length) { @@ -112,6 +117,7 @@ export class CoderApi extends Api { return this.createWebSocket({ apiRoute: `/api/v2/workspacebuilds/${buildId}/logs`, searchParams, + options, }); }; @@ -128,7 +134,7 @@ export class CoderApi extends Api { coderSessionTokenHeader ] as string | undefined; - const headers = await getHeaders( + const headersFromCommand = await getHeaders( baseUrlRaw, getHeaderCommand(vscode.workspace.getConfiguration()), this.output, @@ -137,43 +143,154 @@ export class CoderApi extends Api { const httpAgent = await createHttpAgent( vscode.workspace.getConfiguration(), ); + + /** + * Similar to the REST client, we want to prioritize headers in this order (highest to lowest): + * 1. Headers from the header command + * 2. Any headers passed directly to this function + * 3. Coder session token from the Api client (if set) + */ + const headers = { + ...(token ? { [coderSessionTokenHeader]: token } : {}), + ...configs.options?.headers, + ...headersFromCommand, + }; + const webSocket = new OneWayWebSocket({ location: baseUrl, ...configs, options: { + ...configs.options, agent: httpAgent, followRedirects: true, - headers: { - ...(token ? { [coderSessionTokenHeader]: token } : {}), - ...configs.options?.headers, - ...headers, - }, - ...configs.options, + headers, }, }); - const wsUrl = new URL(webSocket.url); - const pathWithQuery = wsUrl.pathname + wsUrl.search; - const wsLogger = new WsLogger(this.output, pathWithQuery); - wsLogger.logConnecting(); + this.attachStreamLogger(webSocket); + return webSocket; + } - webSocket.addEventListener("open", () => { - wsLogger.logOpen(); - }); + private attachStreamLogger( + connection: UnidirectionalStream, + ): void { + const url = new URL(connection.url); + const logger = new EventStreamLogger( + this.output, + url.pathname + url.search, + url.protocol.startsWith("http") ? "SSE" : "WS", + ); + logger.logConnecting(); - webSocket.addEventListener("message", (event) => { - wsLogger.logMessage(event.sourceEvent.data); - }); + connection.addEventListener("open", () => logger.logOpen()); + connection.addEventListener("close", (event: CloseEvent) => + logger.logClose(event.code, event.reason), + ); + connection.addEventListener("error", (event: ErrorEvent) => + logger.logError(event.error, event.message), + ); + connection.addEventListener("message", (event) => + logger.logMessage(event.sourceEvent.data), + ); + } - webSocket.addEventListener("close", (event) => { - wsLogger.logClose(event.code, event.reason); + /** + * Create a WebSocket connection with SSE fallback on 404. + * + * Note: The fallback on SSE ignores all passed client options except the headers. + */ + private async createWebSocketWithFallback(configs: { + apiRoute: string; + fallbackApiRoute: string; + searchParams?: Record | URLSearchParams; + options?: ClientOptions; + }): Promise> { + let webSocket: OneWayWebSocket; + try { + webSocket = await this.createWebSocket({ + apiRoute: configs.apiRoute, + searchParams: configs.searchParams, + options: configs.options, + }); + } catch { + // Failed to create WebSocket, use SSE fallback + return this.createSseFallback( + configs.fallbackApiRoute, + configs.searchParams, + configs.options?.headers, + ); + } + + return this.waitForConnection(webSocket, () => + this.createSseFallback( + configs.fallbackApiRoute, + configs.searchParams, + configs.options?.headers, + ), + ); + } + + private waitForConnection( + connection: UnidirectionalStream, + onNotFound?: () => Promise>, + ): Promise> { + return new Promise((resolve, reject) => { + const cleanup = () => { + connection.removeEventListener("open", handleOpen); + connection.removeEventListener("error", handleError); + }; + + const handleOpen = () => { + cleanup(); + resolve(connection); + }; + + const handleError = (event: ErrorEvent) => { + cleanup(); + const is404 = + event.message?.includes("404") || + event.error?.message?.includes("404"); + + if (is404 && onNotFound) { + connection.close(); + onNotFound().then(resolve).catch(reject); + } else { + reject(event.error || new Error(event.message)); + } + }; + + connection.addEventListener("open", handleOpen); + connection.addEventListener("error", handleError); }); + } + + /** + * Create SSE fallback connection + */ + private async createSseFallback( + apiRoute: string, + searchParams?: Record | URLSearchParams, + optionsHeaders?: Record, + ): Promise> { + this.output.warn(`WebSocket failed, using SSE fallback: ${apiRoute}`); + + const baseUrlRaw = this.getAxiosInstance().defaults.baseURL; + if (!baseUrlRaw) { + throw new Error("No base URL set on REST client"); + } - webSocket.addEventListener("error", (event) => { - wsLogger.logError(event.error, event.message); + const baseUrl = new URL(baseUrlRaw); + const sseConnection = new SseConnection({ + location: baseUrl, + apiRoute, + searchParams, + axiosInstance: this.getAxiosInstance(), + optionsHeaders: optionsHeaders, + logger: this.output, }); - return webSocket; + this.attachStreamLogger(sseConnection); + return this.waitForConnection(sseConnection); } } diff --git a/src/api/streamingFetchAdapter.ts b/src/api/streamingFetchAdapter.ts new file mode 100644 index 00000000..f0730535 --- /dev/null +++ b/src/api/streamingFetchAdapter.ts @@ -0,0 +1,71 @@ +import { type AxiosInstance } from "axios"; +import { type FetchLikeInit, type FetchLikeResponse } from "eventsource"; +import { type IncomingMessage } from "http"; + +/** + * Creates a fetch adapter using an Axios instance that returns streaming responses. + * This is used by EventSource to make authenticated SSE connections. + */ +export function createStreamingFetchAdapter( + axiosInstance: AxiosInstance, + configHeaders?: Record, +): (url: string | URL, init?: FetchLikeInit) => Promise { + return async ( + url: string | URL, + init?: FetchLikeInit, + ): Promise => { + const urlStr = url.toString(); + + const response = await axiosInstance.request({ + url: urlStr, + signal: init?.signal, + headers: { ...init?.headers, ...configHeaders }, + responseType: "stream", + validateStatus: () => true, // Don't throw on any status code + }); + + const stream = new ReadableStream({ + start(controller) { + response.data.on("data", (chunk: Buffer) => { + try { + controller.enqueue(chunk); + } catch { + // Stream already closed or errored, ignore + } + }); + + response.data.on("end", () => { + try { + controller.close(); + } catch { + // Stream already closed, ignore + } + }); + + response.data.on("error", (err: Error) => { + controller.error(err); + }); + }, + + cancel() { + response.data.destroy(); + return Promise.resolve(); + }, + }); + + return { + body: { + getReader: () => stream.getReader(), + }, + url: urlStr, + status: response.status, + redirected: response.request?.res?.responseUrl !== urlStr, + headers: { + get: (name: string) => { + const value = response.headers[name.toLowerCase()]; + return value === undefined ? null : String(value); + }, + }, + }; + }; +} diff --git a/src/logging/wsLogger.ts b/src/logging/eventStreamLogger.ts similarity index 77% rename from src/logging/wsLogger.ts rename to src/logging/eventStreamLogger.ts index fd6acd00..224f52b7 100644 --- a/src/logging/wsLogger.ts +++ b/src/logging/eventStreamLogger.ts @@ -12,31 +12,35 @@ const numFormatter = new Intl.NumberFormat("en", { compactDisplay: "short", }); -export class WsLogger { +export class EventStreamLogger { private readonly logger: Logger; private readonly url: string; private readonly id: string; + private readonly protocol: string; private readonly startedAt: number; private openedAt?: number; private msgCount = 0; private byteCount = 0; private unknownByteCount = false; - constructor(logger: Logger, url: string) { + constructor(logger: Logger, url: string, protocol: "WS" | "SSE") { this.logger = logger; this.url = url; + this.protocol = protocol; this.id = createRequestId(); this.startedAt = Date.now(); } logConnecting(): void { - this.logger.trace(`→ WS ${shortId(this.id)} ${this.url}`); + this.logger.trace(`→ ${this.protocol} ${shortId(this.id)} ${this.url}`); } logOpen(): void { this.openedAt = Date.now(); const time = formatTime(this.openedAt - this.startedAt); - this.logger.trace(`← WS ${shortId(this.id)} connected ${this.url} ${time}`); + this.logger.trace( + `← ${this.protocol} ${shortId(this.id)} connected ${this.url} ${time}`, + ); } logMessage(data: unknown): void { @@ -62,7 +66,7 @@ export class WsLogger { const statsStr = ` [${stats.join(", ")}]`; this.logger.trace( - `▣ WS ${shortId(this.id)} closed ${this.url}${codeStr}${reasonStr}${statsStr}`, + `▣ ${this.protocol} ${shortId(this.id)} closed ${this.url}${codeStr}${reasonStr}${statsStr}`, ); } @@ -70,7 +74,7 @@ export class WsLogger { const time = formatTime(Date.now() - this.startedAt); const errorMsg = message || errToStr(error, "connection error"); this.logger.error( - `✗ WS ${shortId(this.id)} error ${this.url} ${time} - ${errorMsg}`, + `✗ ${this.protocol} ${shortId(this.id)} error ${this.url} ${time} - ${errorMsg}`, error, ); } diff --git a/src/websocket/eventStreamConnection.ts b/src/websocket/eventStreamConnection.ts new file mode 100644 index 00000000..2dc6514e --- /dev/null +++ b/src/websocket/eventStreamConnection.ts @@ -0,0 +1,51 @@ +import { type WebSocketEventType } from "coder/site/src/utils/OneWayWebSocket"; +import { + type CloseEvent, + type Event as WsEvent, + type ErrorEvent, + type MessageEvent, +} from "ws"; + +// Event payload types matching OneWayWebSocket +export type ParsedMessageEvent = Readonly< + | { + sourceEvent: MessageEvent; + parsedMessage: TData; + parseError: undefined; + } + | { + sourceEvent: MessageEvent; + parsedMessage: undefined; + parseError: Error; + } +>; + +export type EventPayloadMap = { + close: CloseEvent; + error: ErrorEvent; + message: ParsedMessageEvent; + open: WsEvent; +}; + +export type EventHandler = ( + payload: EventPayloadMap[TEvent], +) => void; + +/** + * Common interface for both WebSocket and SSE connections that handle event streams. + * Matches the OneWayWebSocket interface for compatibility. + */ +export interface UnidirectionalStream { + readonly url: string; + addEventListener( + eventType: TEvent, + callback: EventHandler, + ): void; + + removeEventListener( + eventType: TEvent, + callback: EventHandler, + ): void; + + close(code?: number, reason?: string): void; +} diff --git a/src/websocket/oneWayWebSocket.ts b/src/websocket/oneWayWebSocket.ts index 37965596..c27b1fe4 100644 --- a/src/websocket/oneWayWebSocket.ts +++ b/src/websocket/oneWayWebSocket.ts @@ -8,51 +8,13 @@ */ import { type WebSocketEventType } from "coder/site/src/utils/OneWayWebSocket"; -import Ws, { - type ClientOptions, - type CloseEvent, - type ErrorEvent, - type Event, - type MessageEvent, - type RawData, -} from "ws"; +import Ws, { type ClientOptions, type MessageEvent, type RawData } from "ws"; -export type OneWayMessageEvent = Readonly< - | { - sourceEvent: MessageEvent; - parsedMessage: TData; - parseError: undefined; - } - | { - sourceEvent: MessageEvent; - parsedMessage: undefined; - parseError: Error; - } ->; - -type OneWayEventPayloadMap = { - close: CloseEvent; - error: ErrorEvent; - message: OneWayMessageEvent; - open: Event; -}; - -type OneWayEventCallback = ( - payload: OneWayEventPayloadMap[TEvent], -) => void; - -interface OneWayWebSocketApi { - get url(): string; - addEventListener( - eventType: TEvent, - callback: OneWayEventCallback, - ): void; - removeEventListener( - eventType: TEvent, - callback: OneWayEventCallback, - ): void; - close(code?: number, reason?: string): void; -} +import { + type UnidirectionalStream, + type EventHandler, +} from "./eventStreamConnection"; +import { getQueryString } from "./utils"; export type OneWayWebSocketInit = { location: { protocol: string; host: string }; @@ -63,23 +25,18 @@ export type OneWayWebSocketInit = { }; export class OneWayWebSocket - implements OneWayWebSocketApi + implements UnidirectionalStream { readonly #socket: Ws; readonly #messageCallbacks = new Map< - OneWayEventCallback, + EventHandler, (data: RawData) => void >(); constructor(init: OneWayWebSocketInit) { const { location, apiRoute, protocols, options, searchParams } = init; - const formattedParams = - searchParams instanceof URLSearchParams - ? searchParams - : new URLSearchParams(searchParams); - const paramsString = formattedParams.toString(); - const paramsSuffix = paramsString ? `?${paramsString}` : ""; + const paramsSuffix = getQueryString(searchParams); const wsProtocol = location.protocol === "https:" ? "wss:" : "ws:"; const url = `${wsProtocol}//${location.host}${apiRoute}${paramsSuffix}`; @@ -92,10 +49,10 @@ export class OneWayWebSocket addEventListener( event: TEvent, - callback: OneWayEventCallback, + callback: EventHandler, ): void { if (event === "message") { - const messageCallback = callback as OneWayEventCallback; + const messageCallback = callback as EventHandler; if (this.#messageCallbacks.has(messageCallback)) { return; @@ -128,10 +85,10 @@ export class OneWayWebSocket removeEventListener( event: TEvent, - callback: OneWayEventCallback, + callback: EventHandler, ): void { if (event === "message") { - const messageCallback = callback as OneWayEventCallback; + const messageCallback = callback as EventHandler; const wrapper = this.#messageCallbacks.get(messageCallback); if (wrapper) { diff --git a/src/websocket/sseConnection.ts b/src/websocket/sseConnection.ts new file mode 100644 index 00000000..834100aa --- /dev/null +++ b/src/websocket/sseConnection.ts @@ -0,0 +1,221 @@ +import { type AxiosInstance } from "axios"; +import { type ServerSentEvent } from "coder/site/src/api/typesGenerated"; +import { type WebSocketEventType } from "coder/site/src/utils/OneWayWebSocket"; +import { EventSource } from "eventsource"; + +import { createStreamingFetchAdapter } from "../api/streamingFetchAdapter"; +import { type Logger } from "../logging/logger"; + +import { getQueryString } from "./utils"; + +import type { + CloseEvent as WsCloseEvent, + ErrorEvent as WsErrorEvent, + Event as WsEvent, + MessageEvent as WsMessageEvent, +} from "ws"; + +import type { + UnidirectionalStream, + ParsedMessageEvent, + EventHandler, +} from "./eventStreamConnection"; + +export type SseConnectionInit = { + location: { protocol: string; host: string }; + apiRoute: string; + searchParams?: Record | URLSearchParams; + optionsHeaders?: Record; + axiosInstance: AxiosInstance; + logger: Logger; +}; + +export class SseConnection implements UnidirectionalStream { + private readonly eventSource: EventSource; + private readonly logger: Logger; + private readonly callbacks = { + open: new Set>(), + close: new Set>(), + error: new Set>(), + }; + // Original callback -> wrapped callback + private readonly messageWrappers = new Map< + EventHandler, + (event: MessageEvent) => void + >(); + + public readonly url: string; + + public constructor(init: SseConnectionInit) { + this.logger = init.logger; + this.url = this.buildUrl(init); + this.eventSource = new EventSource(this.url, { + fetch: createStreamingFetchAdapter( + init.axiosInstance, + init.optionsHeaders, + ), + }); + this.setupEventHandlers(); + } + + private buildUrl(init: SseConnectionInit): string { + const { location, apiRoute, searchParams } = init; + const queryString = getQueryString(searchParams); + return `${location.protocol}//${location.host}${apiRoute}${queryString}`; + } + + private setupEventHandlers(): void { + this.eventSource.addEventListener("open", () => + this.invokeCallbacks(this.callbacks.open, {} as WsEvent, "open"), + ); + + this.eventSource.addEventListener("data", (event: MessageEvent) => { + this.invokeCallbacks(this.messageWrappers.values(), event, "message"); + }); + + this.eventSource.addEventListener("error", (error: Event | ErrorEvent) => { + this.invokeCallbacks( + this.callbacks.error, + this.createErrorEvent(error), + "error", + ); + + if (this.eventSource.readyState === EventSource.CLOSED) { + this.invokeCallbacks( + this.callbacks.close, + { + code: 1006, + reason: "Connection lost", + wasClean: false, + } as WsCloseEvent, + "close", + ); + } + }); + } + + private invokeCallbacks( + callbacks: Iterable<(event: T) => void>, + event: T, + eventType: string, + ): void { + for (const cb of callbacks) { + try { + cb(event); + } catch (err) { + this.logger.error(`Error in SSE ${eventType} callback:`, err); + } + } + } + + private createErrorEvent(event: Event | ErrorEvent): WsErrorEvent { + const errorMessage = + event instanceof ErrorEvent && event.message + ? event.message + : "SSE connection error"; + const error = event instanceof ErrorEvent ? event.error : undefined; + + return { + error: error, + message: errorMessage, + } as WsErrorEvent; + } + + public addEventListener( + event: TEvent, + callback: EventHandler, + ): void { + switch (event) { + case "close": + this.callbacks.close.add( + callback as EventHandler, + ); + break; + case "error": + this.callbacks.error.add( + callback as EventHandler, + ); + break; + case "message": { + const messageCallback = callback as EventHandler< + ServerSentEvent, + "message" + >; + if (!this.messageWrappers.has(messageCallback)) { + this.messageWrappers.set(messageCallback, (event: MessageEvent) => { + messageCallback(this.parseMessage(event)); + }); + } + break; + } + case "open": + this.callbacks.open.add( + callback as EventHandler, + ); + break; + } + } + + private parseMessage( + event: MessageEvent, + ): ParsedMessageEvent { + const wsEvent = { data: event.data } as WsMessageEvent; + try { + return { + sourceEvent: wsEvent, + parsedMessage: { type: "data", data: JSON.parse(event.data) }, + parseError: undefined, + }; + } catch (err) { + return { + sourceEvent: wsEvent, + parsedMessage: undefined, + parseError: err as Error, + }; + } + } + + public removeEventListener( + event: TEvent, + callback: EventHandler, + ): void { + switch (event) { + case "close": + this.callbacks.close.delete( + callback as EventHandler, + ); + break; + case "error": + this.callbacks.error.delete( + callback as EventHandler, + ); + break; + case "message": + this.messageWrappers.delete( + callback as EventHandler, + ); + break; + case "open": + this.callbacks.open.delete( + callback as EventHandler, + ); + break; + } + } + + public close(code?: number, reason?: string): void { + this.eventSource.close(); + this.invokeCallbacks( + this.callbacks.close, + { + code: code ?? 1000, + reason: reason ?? "Normal closure", + wasClean: true, + } as WsCloseEvent, + "close", + ); + + Object.values(this.callbacks).forEach((callbackSet) => callbackSet.clear()); + this.messageWrappers.clear(); + } +} diff --git a/src/websocket/utils.ts b/src/websocket/utils.ts new file mode 100644 index 00000000..592ce45e --- /dev/null +++ b/src/websocket/utils.ts @@ -0,0 +1,15 @@ +/** + * Converts params to a query string. Returns empty string if no params, + * otherwise returns params prefixed with '?'. + */ +export function getQueryString( + params: Record | URLSearchParams | undefined, +): string { + if (!params) { + return ""; + } + const searchParams = + params instanceof URLSearchParams ? params : new URLSearchParams(params); + const str = searchParams.toString(); + return str ? `?${str}` : ""; +} diff --git a/src/workspace/workspaceMonitor.ts b/src/workspace/workspaceMonitor.ts index a761249a..ceea8a91 100644 --- a/src/workspace/workspaceMonitor.ts +++ b/src/workspace/workspaceMonitor.ts @@ -9,7 +9,7 @@ import { createWorkspaceIdentifier, errToStr } from "../api/api-helper"; import { type CoderApi } from "../api/coderApi"; import { type ContextManager } from "../core/contextManager"; import { type Logger } from "../logging/logger"; -import { type OneWayWebSocket } from "../websocket/oneWayWebSocket"; +import { type UnidirectionalStream } from "../websocket/eventStreamConnection"; /** * Monitor a single workspace using a WebSocket for events like shutdown and deletion. @@ -17,7 +17,7 @@ import { type OneWayWebSocket } from "../websocket/oneWayWebSocket"; * workspace status is also shown in the status bar menu. */ export class WorkspaceMonitor implements vscode.Disposable { - private socket: OneWayWebSocket | undefined; + private socket: UnidirectionalStream | undefined; private disposed = false; // How soon in advance to notify about autostop and deletion. @@ -93,10 +93,12 @@ export class WorkspaceMonitor implements vscode.Disposable { return; } // Perhaps we need to parse this and validate it. - const newWorkspaceData = event.parsedMessage.data as Workspace; - monitor.update(newWorkspaceData); - monitor.maybeNotify(newWorkspaceData); - monitor.onChange.fire(newWorkspaceData); + const newWorkspaceData = event.parsedMessage.data as Workspace | null; + if (newWorkspaceData) { + monitor.update(newWorkspaceData); + monitor.maybeNotify(newWorkspaceData); + monitor.onChange.fire(newWorkspaceData); + } } catch (error) { monitor.notifyError(error); } diff --git a/test/unit/api/coderApi.test.ts b/test/unit/api/coderApi.test.ts new file mode 100644 index 00000000..0336d564 --- /dev/null +++ b/test/unit/api/coderApi.test.ts @@ -0,0 +1,431 @@ +import axios, { AxiosError, AxiosHeaders } from "axios"; +import { type ProvisionerJobLog } from "coder/site/src/api/typesGenerated"; +import { EventSource } from "eventsource"; +import { ProxyAgent } from "proxy-agent"; +import { describe, it, expect, vi, beforeEach } from "vitest"; +import Ws from "ws"; + +import { CoderApi } from "@/api/coderApi"; +import { createHttpAgent } from "@/api/utils"; +import { CertificateError } from "@/error"; +import { getHeaders } from "@/headers"; +import { type RequestConfigWithMeta } from "@/logging/types"; +import { OneWayWebSocket } from "@/websocket/oneWayWebSocket"; +import { SseConnection } from "@/websocket/sseConnection"; + +import { + createMockLogger, + MockConfigurationProvider, +} from "../../mocks/testHelpers"; + +const CODER_URL = "https://coder.example.com"; +const AXIOS_TOKEN = "passed-token"; +const BUILD_ID = "build-123"; +const AGENT_ID = "agent-123"; + +vi.mock("ws"); +vi.mock("eventsource"); +vi.mock("proxy-agent"); + +vi.mock("axios", async () => { + const actual = await vi.importActual("axios"); + + const mockAdapter = vi.fn(mockAdapterImpl); + + const mockDefault = { + ...actual.default, + create: vi.fn((config) => { + const instance = actual.default.create({ + ...config, + adapter: mockAdapter, + }); + return instance; + }), + __mockAdapter: mockAdapter, + }; + + return { + ...actual, + default: mockDefault, + }; +}); + +vi.mock("@/headers", () => ({ + getHeaders: vi.fn().mockResolvedValue({}), + getHeaderCommand: vi.fn(), +})); + +vi.mock("@/api/utils", () => ({ + createHttpAgent: vi.fn(), +})); + +vi.mock("@/api/streamingFetchAdapter", () => ({ + createStreamingFetchAdapter: vi.fn(() => fetch), +})); + +describe("CoderApi", () => { + let mockLogger: ReturnType; + let mockConfig: MockConfigurationProvider; + let mockAdapter: ReturnType; + let api: CoderApi; + + const createApi = (url = CODER_URL, token = AXIOS_TOKEN) => { + return CoderApi.create(url, token, mockLogger); + }; + + beforeEach(() => { + vi.resetAllMocks(); + + const axiosMock = axios as typeof axios & { + __mockAdapter: ReturnType; + }; + mockAdapter = axiosMock.__mockAdapter; + mockAdapter.mockImplementation(mockAdapterImpl); + + vi.mocked(getHeaders).mockResolvedValue({}); + mockLogger = createMockLogger(); + mockConfig = new MockConfigurationProvider(); + mockConfig.set("coder.httpClientLogLevel", "BASIC"); + }); + + describe("HTTP Interceptors", () => { + it("adds custom headers and HTTP agent to requests", async () => { + const mockAgent = new ProxyAgent(); + vi.mocked(createHttpAgent).mockResolvedValue(mockAgent); + vi.mocked(getHeaders).mockResolvedValue({ + "X-Custom-Header": "custom-value", + "X-Another-Header": "another-value", + }); + + const api = createApi(); + const response = await api.getAxiosInstance().get("/api/v2/users/me"); + + expect(response.config.headers["X-Custom-Header"]).toBe("custom-value"); + expect(response.config.headers["X-Another-Header"]).toBe("another-value"); + expect(response.config.httpsAgent).toBe(mockAgent); + expect(response.config.httpAgent).toBe(mockAgent); + expect(response.config.proxy).toBe(false); + }); + + it("wraps certificate errors in response interceptor", async () => { + const api = createApi(); + const certError = new AxiosError( + "self signed certificate", + "DEPTH_ZERO_SELF_SIGNED_CERT", + ); + mockAdapter.mockRejectedValueOnce(certError); + + const thrownError = await api + .getAxiosInstance() + .get("/api/v2/users/me") + .catch((e) => e); + + expect(thrownError).toBeInstanceOf(CertificateError); + expect(thrownError.message).toContain("Secure connection"); + expect(thrownError.x509Err).toBeDefined(); + }); + + it("applies headers in correct precedence order (command > config > axios default)", async () => { + const api = createApi(CODER_URL, AXIOS_TOKEN); + + // Test 1: Headers from config, default token from API creation + const response = await api.getAxiosInstance().get("/api/v2/users/me", { + headers: new AxiosHeaders({ + "X-Custom-Header": "from-config", + "X-Extra": "extra-value", + }), + }); + + expect(response.config.headers["X-Custom-Header"]).toBe("from-config"); + expect(response.config.headers["X-Extra"]).toBe("extra-value"); + expect(response.config.headers["Coder-Session-Token"]).toBe(AXIOS_TOKEN); + + // Test 2: Token from request options overrides default + const responseWithToken = await api + .getAxiosInstance() + .get("/api/v2/users/me", { + headers: new AxiosHeaders({ + "Coder-Session-Token": "from-options", + }), + }); + + expect(responseWithToken.config.headers["Coder-Session-Token"]).toBe( + "from-options", + ); + + // Test 3: Header command overrides everything + vi.mocked(getHeaders).mockResolvedValue({ + "Coder-Session-Token": "from-header-command", + }); + + const responseWithHeaderCommand = await api + .getAxiosInstance() + .get("/api/v2/users/me", { + headers: new AxiosHeaders({ + "Coder-Session-Token": "from-options", + }), + }); + + expect( + responseWithHeaderCommand.config.headers["Coder-Session-Token"], + ).toBe("from-header-command"); + }); + + it("logs requests and responses", async () => { + const api = createApi(); + + await api.getWorkspaces({}); + + expect(mockLogger.trace).toHaveBeenCalledWith( + expect.stringContaining("/api/v2/workspaces"), + ); + }); + + it("calculates request and response sizes in transforms", async () => { + const api = createApi(); + const response = await api + .getAxiosInstance() + .post("/api/v2/workspaces", { name: "test" }); + + expect((response.config as RequestConfigWithMeta).rawRequestSize).toBe( + 15, + ); + // We return the same data we sent in the mock adapter + expect((response.config as RequestConfigWithMeta).rawResponseSize).toBe( + 15, + ); + }); + }); + + describe("WebSocket Creation", () => { + const wsUrl = `wss://${CODER_URL.replace("https://", "")}/api/v2/workspacebuilds/${BUILD_ID}/logs?follow=true`; + + beforeEach(() => { + api = createApi(CODER_URL, AXIOS_TOKEN); + const mockWs = createMockWebSocket(wsUrl); + setupWebSocketMock(mockWs); + }); + + it("creates WebSocket with proper headers and configuration", async () => { + const mockAgent = new ProxyAgent(); + vi.mocked(getHeaders).mockResolvedValue({ + "X-Custom-Header": "custom-value", + }); + vi.mocked(createHttpAgent).mockResolvedValue(mockAgent); + + await api.watchBuildLogsByBuildId(BUILD_ID, []); + + expect(Ws).toHaveBeenCalledWith(wsUrl, undefined, { + agent: mockAgent, + followRedirects: true, + headers: { + "X-Custom-Header": "custom-value", + "Coder-Session-Token": AXIOS_TOKEN, + }, + }); + }); + + it("applies headers in correct precedence order (command > config > axios default)", async () => { + // Test 1: Default token from API creation + await api.watchBuildLogsByBuildId(BUILD_ID, []); + + expect(Ws).toHaveBeenCalledWith(wsUrl, undefined, { + agent: undefined, + followRedirects: true, + headers: { + "Coder-Session-Token": AXIOS_TOKEN, + }, + }); + + // Test 2: Token from config options overrides default + await api.watchBuildLogsByBuildId(BUILD_ID, [], { + headers: { + "X-Config-Header": "config-value", + "Coder-Session-Token": "from-config", + }, + }); + + expect(Ws).toHaveBeenCalledWith(wsUrl, undefined, { + agent: undefined, + followRedirects: true, + headers: { + "Coder-Session-Token": "from-config", + "X-Config-Header": "config-value", + }, + }); + + // Test 3: Header command overrides everything + vi.mocked(getHeaders).mockResolvedValue({ + "Coder-Session-Token": "from-header-command", + }); + + await api.watchBuildLogsByBuildId(BUILD_ID, [], { + headers: { + "Coder-Session-Token": "from-config", + }, + }); + + expect(Ws).toHaveBeenCalledWith(wsUrl, undefined, { + agent: undefined, + followRedirects: true, + headers: { + "Coder-Session-Token": "from-header-command", + }, + }); + }); + + it("logs WebSocket connections", async () => { + await api.watchBuildLogsByBuildId(BUILD_ID, []); + + expect(mockLogger.trace).toHaveBeenCalledWith( + expect.stringContaining(BUILD_ID), + ); + }); + + it("'watchBuildLogsByBuildId' includes after parameter for existing logs", async () => { + const jobLog: ProvisionerJobLog = { + created_at: new Date().toISOString(), + id: 1, + output: "log1", + log_source: "provisioner", + log_level: "info", + stage: "stage1", + }; + const existingLogs = [ + jobLog, + { ...jobLog, id: 20 }, + { ...jobLog, id: 5 }, + ]; + + await api.watchBuildLogsByBuildId(BUILD_ID, existingLogs); + + expect(Ws).toHaveBeenCalledWith( + expect.stringContaining("after=5"), + undefined, + expect.any(Object), + ); + }); + }); + + describe("SSE Fallback", () => { + beforeEach(() => { + api = createApi(); + const mockEventSource = createMockEventSource( + `${CODER_URL}/api/v2/workspaces/123/watch`, + ); + setupEventSourceMock(mockEventSource); + }); + + it("uses WebSocket when no errors occur", async () => { + const mockWs = createMockWebSocket( + `wss://${CODER_URL.replace("https://", "")}/api/v2/workspaceagents/${AGENT_ID}/watch-metadata`, + { + on: vi.fn((event, handler) => { + if (event === "open") { + setImmediate(() => handler()); + } + return mockWs as Ws; + }), + }, + ); + setupWebSocketMock(mockWs); + + const connection = await api.watchAgentMetadata(AGENT_ID); + + expect(connection).toBeInstanceOf(OneWayWebSocket); + expect(EventSource).not.toHaveBeenCalled(); + }); + + it("falls back to SSE when WebSocket creation fails", async () => { + vi.mocked(Ws).mockImplementation(() => { + throw new Error("WebSocket creation failed"); + }); + + const connection = await api.watchAgentMetadata(AGENT_ID); + + expect(connection).toBeInstanceOf(SseConnection); + expect(EventSource).toHaveBeenCalled(); + }); + + it("falls back to SSE on 404 error from WebSocket", async () => { + const mockWs = createMockWebSocket( + `wss://${CODER_URL.replace("https://", "")}/api/v2/test`, + { + on: vi.fn((event: string, handler: (e: unknown) => void) => { + if (event === "error") { + setImmediate(() => { + handler({ + error: new Error("404 Not Found"), + message: "404 Not Found", + }); + }); + } + return mockWs as Ws; + }), + }, + ); + setupWebSocketMock(mockWs); + + const connection = await api.watchAgentMetadata(AGENT_ID); + + expect(connection).toBeInstanceOf(SseConnection); + expect(EventSource).toHaveBeenCalled(); + }); + }); + + describe("Error Handling", () => { + it("throws error when no base URL is set", async () => { + const api = createApi(); + api.getAxiosInstance().defaults.baseURL = undefined; + + await expect(api.watchBuildLogsByBuildId(BUILD_ID, [])).rejects.toThrow( + "No base URL set on REST client", + ); + }); + }); +}); + +const mockAdapterImpl = vi.hoisted(() => (config: Record) => { + return Promise.resolve({ + data: config.data || "{}", + status: 200, + statusText: "OK", + headers: {}, + config, + }); +}); + +function createMockWebSocket( + url: string, + overrides?: Partial, +): Partial { + return { + url, + on: vi.fn(), + off: vi.fn(), + close: vi.fn(), + ...overrides, + }; +} + +function createMockEventSource(url: string): Partial { + return { + url, + readyState: EventSource.CONNECTING, + addEventListener: vi.fn((event, handler) => { + if (event === "open") { + setImmediate(() => handler(new Event("open"))); + } + }), + removeEventListener: vi.fn(), + close: vi.fn(), + }; +} + +function setupWebSocketMock(ws: Partial): void { + vi.mocked(Ws).mockImplementation(() => ws as Ws); +} + +function setupEventSourceMock(es: Partial): void { + vi.mocked(EventSource).mockImplementation(() => es as EventSource); +} diff --git a/test/unit/logging/wsLogger.test.ts b/test/unit/logging/eventStreamLogger.test.ts similarity index 50% rename from test/unit/logging/wsLogger.test.ts rename to test/unit/logging/eventStreamLogger.test.ts index 5bf9d5b1..352ccaac 100644 --- a/test/unit/logging/wsLogger.test.ts +++ b/test/unit/logging/eventStreamLogger.test.ts @@ -1,19 +1,23 @@ import { describe, expect, it } from "vitest"; -import { WsLogger } from "@/logging/wsLogger"; +import { EventStreamLogger } from "@/logging/eventStreamLogger"; import { createMockLogger } from "../../mocks/testHelpers"; -describe("WS Logger", () => { +describe("EventStreamLogger", () => { it("tracks message count and byte size", () => { const logger = createMockLogger(); - const wsLogger = new WsLogger(logger, "wss://example.com"); + const eventStreamLogger = new EventStreamLogger( + logger, + "wss://example.com", + "WS", + ); - wsLogger.logOpen(); - wsLogger.logMessage("hello"); - wsLogger.logMessage("world"); - wsLogger.logMessage(Buffer.from("test")); - wsLogger.logClose(); + eventStreamLogger.logOpen(); + eventStreamLogger.logMessage("hello"); + eventStreamLogger.logMessage("world"); + eventStreamLogger.logMessage(Buffer.from("test")); + eventStreamLogger.logClose(); expect(logger.trace).toHaveBeenCalledWith( expect.stringContaining("3 msgs"), @@ -23,12 +27,16 @@ describe("WS Logger", () => { it("handles unknown byte sizes with >= indicator", () => { const logger = createMockLogger(); - const wsLogger = new WsLogger(logger, "wss://example.com"); + const eventStreamLogger = new EventStreamLogger( + logger, + "wss://example.com", + "WS", + ); - wsLogger.logOpen(); - wsLogger.logMessage({ complex: "object" }); // Unknown size - no estimation - wsLogger.logMessage("known"); - wsLogger.logClose(); + eventStreamLogger.logOpen(); + eventStreamLogger.logMessage({ complex: "object" }); // Unknown size - no estimation + eventStreamLogger.logMessage("known"); + eventStreamLogger.logClose(); expect(logger.trace).toHaveBeenLastCalledWith( expect.stringContaining(">= 5 B"), @@ -37,22 +45,30 @@ describe("WS Logger", () => { it("handles close before open gracefully", () => { const logger = createMockLogger(); - const wsLogger = new WsLogger(logger, "wss://example.com"); + const eventStreamLogger = new EventStreamLogger( + logger, + "wss://example.com", + "WS", + ); // Closing without opening should not throw - expect(() => wsLogger.logClose()).not.toThrow(); + expect(() => eventStreamLogger.logClose()).not.toThrow(); expect(logger.trace).toHaveBeenCalled(); }); it("formats large message counts with compact notation", () => { const logger = createMockLogger(); - const wsLogger = new WsLogger(logger, "wss://example.com"); + const eventStreamLogger = new EventStreamLogger( + logger, + "wss://example.com", + "WS", + ); - wsLogger.logOpen(); + eventStreamLogger.logOpen(); for (let i = 0; i < 1100; i++) { - wsLogger.logMessage("x"); + eventStreamLogger.logMessage("x"); } - wsLogger.logClose(); + eventStreamLogger.logClose(); expect(logger.trace).toHaveBeenLastCalledWith( expect.stringMatching(/1[.,]1K\s*msgs/), @@ -61,10 +77,14 @@ describe("WS Logger", () => { it("logs errors with error object", () => { const logger = createMockLogger(); - const wsLogger = new WsLogger(logger, "wss://example.com"); + const eventStreamLogger = new EventStreamLogger( + logger, + "wss://example.com", + "WS", + ); const error = new Error("Connection failed"); - wsLogger.logError(error, "Failed to connect"); + eventStreamLogger.logError(error, "Failed to connect"); expect(logger.error).toHaveBeenCalledWith(expect.any(String), error); });