Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 131 additions & 27 deletions src/api/coderApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@ import {
} from "axios";
import { Api } from "coder/site/src/api/api";
import {
type ServerSentEvent,
type GetInboxNotificationResponse,
type ProvisionerJobLog,
type ServerSentEvent,
type Workspace,
type WorkspaceAgent,
} from "coder/site/src/api/typesGenerated";
import * as vscode from "vscode";
import { type ClientOptions } from "ws";
import { type ClientOptions, type CloseEvent, type ErrorEvent } from "ws";

import { CertificateError } from "../error";
import { getHeaderCommand, getHeaders } from "../headers";
import { EventStreamLogger } from "../logging/eventStreamLogger";
import {
createRequestMeta,
logRequest,
Expand All @@ -29,11 +30,12 @@ import {
HttpClientLogLevel,
} from "../logging/types";
import { sizeOf } from "../logging/utils";
import { WsLogger } from "../logging/wsLogger";
import { type UnidirectionalStream } from "../websocket/eventStreamConnection";
import {
OneWayWebSocket,
type OneWayWebSocketInit,
} from "../websocket/oneWayWebSocket";
import { SseConnection } from "../websocket/sseConnection";

import { createHttpAgent } from "./utils";

Expand Down Expand Up @@ -84,8 +86,9 @@ export class CoderApi extends Api {
};

watchWorkspace = async (workspace: Workspace, options?: ClientOptions) => {
return this.createWebSocket<ServerSentEvent>({
return this.createWebSocketWithFallback<ServerSentEvent>({
apiRoute: `/api/v2/workspaces/${workspace.id}/watch-ws`,
fallbackApiRoute: `/api/v2/workspaces/${workspace.id}/watch`,
options,
});
};
Expand All @@ -94,15 +97,17 @@ export class CoderApi extends Api {
agentId: WorkspaceAgent["id"],
options?: ClientOptions,
) => {
return this.createWebSocket<ServerSentEvent>({
return this.createWebSocketWithFallback<ServerSentEvent>({
apiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata-ws`,
fallbackApiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata`,
options,
});
};

watchBuildLogsByBuildId = async (
buildId: string,
logs: ProvisionerJobLog[],
options?: ClientOptions,
) => {
const searchParams = new URLSearchParams({ follow: "true" });
if (logs.length) {
Expand All @@ -112,6 +117,7 @@ export class CoderApi extends Api {
return this.createWebSocket<ProvisionerJobLog>({
apiRoute: `/api/v2/workspacebuilds/${buildId}/logs`,
searchParams,
options,
});
};

Expand All @@ -128,7 +134,7 @@ export class CoderApi extends Api {
coderSessionTokenHeader
] as string | undefined;

const headers = await getHeaders(
const headersFromCommand = await getHeaders(
baseUrlRaw,
getHeaderCommand(vscode.workspace.getConfiguration()),
this.output,
Expand All @@ -137,43 +143,141 @@ export class CoderApi extends Api {
const httpAgent = await createHttpAgent(
vscode.workspace.getConfiguration(),
);

const headers = {
...(token ? { [coderSessionTokenHeader]: token } : {}),
...configs.options?.headers,
...headersFromCommand,
};

const webSocket = new OneWayWebSocket<TData>({
location: baseUrl,
...configs,
options: {
...configs.options,
agent: httpAgent,
followRedirects: true,
headers: {
...(token ? { [coderSessionTokenHeader]: token } : {}),
...configs.options?.headers,
...headers,
},
...configs.options,
headers,
},
});

const wsUrl = new URL(webSocket.url);
const pathWithQuery = wsUrl.pathname + wsUrl.search;
const wsLogger = new WsLogger(this.output, pathWithQuery);
wsLogger.logConnecting();
this.attachStreamLogger(webSocket);
return webSocket;
}

webSocket.addEventListener("open", () => {
wsLogger.logOpen();
});
private attachStreamLogger<TData>(
connection: UnidirectionalStream<TData>,
): void {
const url = new URL(connection.url);
const logger = new EventStreamLogger(
this.output,
url.pathname + url.search,
url.protocol.startsWith("http") ? "SSE" : "WS",
);
logger.logConnecting();

webSocket.addEventListener("message", (event) => {
wsLogger.logMessage(event.sourceEvent.data);
});
connection.addEventListener("open", () => logger.logOpen());
connection.addEventListener("close", (event: CloseEvent) =>
logger.logClose(event.code, event.reason),
);
connection.addEventListener("error", (event: ErrorEvent) =>
logger.logError(event.error, event.message),
);
connection.addEventListener("message", (event) =>
logger.logMessage(event.sourceEvent.data),
);
}

webSocket.addEventListener("close", (event) => {
wsLogger.logClose(event.code, event.reason);
/**
* Create a WebSocket connection with SSE fallback on 404
*/
private async createWebSocketWithFallback<TData = unknown>(configs: {
apiRoute: string;
fallbackApiRoute: string;
searchParams?: Record<string, string> | URLSearchParams;
options?: ClientOptions;
}): Promise<UnidirectionalStream<TData>> {
let webSocket: OneWayWebSocket<TData>;
try {
webSocket = await this.createWebSocket<TData>({
apiRoute: configs.apiRoute,
searchParams: configs.searchParams,
options: configs.options,
});
} catch {
// Failed to create WebSocket, use SSE fallback
return this.createSseFallback<TData>(
configs.fallbackApiRoute,
configs.searchParams,
);
}

return this.waitForConnection(webSocket, () =>
this.createSseFallback<TData>(
configs.fallbackApiRoute,
configs.searchParams,
),
);
}

private waitForConnection<TData>(
connection: UnidirectionalStream<TData>,
onNotFound?: () => Promise<UnidirectionalStream<TData>>,
): Promise<UnidirectionalStream<TData>> {
return new Promise((resolve, reject) => {
const cleanup = () => {
connection.removeEventListener("open", handleOpen);
connection.removeEventListener("error", handleError);
};

const handleOpen = () => {
cleanup();
resolve(connection);
};

const handleError = (event: ErrorEvent) => {
cleanup();
const is404 =
event.message?.includes("404") ||
event.error?.message?.includes("404");

if (is404 && onNotFound) {
connection.close();
onNotFound().then(resolve).catch(reject);
} else {
reject(event.error || new Error(event.message));
}
};

connection.addEventListener("open", handleOpen);
connection.addEventListener("error", handleError);
});
}

/**
* Create SSE fallback connection
*/
private async createSseFallback<TData = unknown>(
apiRoute: string,
searchParams?: Record<string, string> | URLSearchParams,
): Promise<UnidirectionalStream<TData>> {
this.output.warn(`WebSocket failed, using SSE fallback: ${apiRoute}`);

const baseUrlRaw = this.getAxiosInstance().defaults.baseURL;
if (!baseUrlRaw) {
throw new Error("No base URL set on REST client");
}

webSocket.addEventListener("error", (event) => {
wsLogger.logError(event.error, event.message);
const baseUrl = new URL(baseUrlRaw);
const sseConnection = new SseConnection({
location: baseUrl,
apiRoute,
searchParams,
axiosInstance: this.getAxiosInstance(),
});

return webSocket;
this.attachStreamLogger(sseConnection);
return this.waitForConnection(sseConnection);
}
}

Expand Down
62 changes: 62 additions & 0 deletions src/api/streamingFetchAdapter.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import { type AxiosInstance } from "axios";
import { type FetchLikeInit, type FetchLikeResponse } from "eventsource";
import { type IncomingMessage } from "http";

/**
* Creates a fetch adapter using an Axios instance that returns streaming responses.
* This is used by EventSource to make authenticated SSE connections.
*/
export function createStreamingFetchAdapter(
axiosInstance: AxiosInstance,
): (url: string | URL, init?: FetchLikeInit) => Promise<FetchLikeResponse> {
return async (
url: string | URL,
init?: FetchLikeInit,
): Promise<FetchLikeResponse> => {
const urlStr = url.toString();

const response = await axiosInstance.request<IncomingMessage>({
url: urlStr,
signal: init?.signal,
headers: init?.headers,
responseType: "stream",
validateStatus: () => true, // Don't throw on any status code
});

const stream = new ReadableStream({
start(controller) {
response.data.on("data", (chunk: Buffer) => {
controller.enqueue(chunk);
});

response.data.on("end", () => {
controller.close();
});

response.data.on("error", (err: Error) => {
controller.error(err);
});
},

cancel() {
response.data.destroy();
return Promise.resolve();
},
});

return {
body: {
getReader: () => stream.getReader(),
},
url: urlStr,
status: response.status,
redirected: response.request?.res?.responseUrl !== urlStr,
headers: {
get: (name: string) => {
const value = response.headers[name.toLowerCase()];
return value === undefined ? null : String(value);
},
},
};
};
}
16 changes: 10 additions & 6 deletions src/logging/wsLogger.ts → src/logging/eventStreamLogger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,35 @@ const numFormatter = new Intl.NumberFormat("en", {
compactDisplay: "short",
});

export class WsLogger {
export class EventStreamLogger {
private readonly logger: Logger;
private readonly url: string;
private readonly id: string;
private readonly protocol: string;
private readonly startedAt: number;
private openedAt?: number;
private msgCount = 0;
private byteCount = 0;
private unknownByteCount = false;

constructor(logger: Logger, url: string) {
constructor(logger: Logger, url: string, protocol: "WS" | "SSE") {
this.logger = logger;
this.url = url;
this.protocol = protocol;
this.id = createRequestId();
this.startedAt = Date.now();
}

logConnecting(): void {
this.logger.trace(`→ WS ${shortId(this.id)} ${this.url}`);
this.logger.trace(`→ ${this.protocol} ${shortId(this.id)} ${this.url}`);
}

logOpen(): void {
this.openedAt = Date.now();
const time = formatTime(this.openedAt - this.startedAt);
this.logger.trace(`← WS ${shortId(this.id)} connected ${this.url} ${time}`);
this.logger.trace(
`← ${this.protocol} ${shortId(this.id)} connected ${this.url} ${time}`,
);
}

logMessage(data: unknown): void {
Expand All @@ -62,15 +66,15 @@ export class WsLogger {
const statsStr = ` [${stats.join(", ")}]`;

this.logger.trace(
`▣ WS ${shortId(this.id)} closed ${this.url}${codeStr}${reasonStr}${statsStr}`,
`▣ ${this.protocol} ${shortId(this.id)} closed ${this.url}${codeStr}${reasonStr}${statsStr}`,
);
}

logError(error: unknown, message: string): void {
const time = formatTime(Date.now() - this.startedAt);
const errorMsg = message || errToStr(error, "connection error");
this.logger.error(
`✗ WS ${shortId(this.id)} error ${this.url} ${time} - ${errorMsg}`,
`✗ ${this.protocol} ${shortId(this.id)} error ${this.url} ${time} - ${errorMsg}`,
error,
);
}
Expand Down
Loading