Skip to content

Commit f9b1f25

Browse files
authored
Add SSE fallback to some one way WS connections (#623)
Add SSE fallback to some WS connections: * `/api/v2/workspaces/${workspace.id}/watch-ws` -> `/api/v2/workspaceagents/${agentId}/watch-metadata` * `/api/v2/workspaceagents/${agentId}/watch-metadata-ws` -> `/api/v2/workspaceagents/${agentId}/watch-metadata` Restored the previous code regarding `createStreamingFetchAdapter` to stream in SSE events. * Implemented a unified interface for WS and SSE to be similar to the `OneWayWebSocket`. * Added unified logging for WS and SSE. * Fixed issue with headers order precedence * Add tests for `CoderApi` Closes #620
1 parent 5165ade commit f9b1f25

File tree

10 files changed

+1005
-116
lines changed

10 files changed

+1005
-116
lines changed

src/api/coderApi.ts

Lines changed: 144 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,18 @@ import {
66
} from "axios";
77
import { Api } from "coder/site/src/api/api";
88
import {
9+
type ServerSentEvent,
910
type GetInboxNotificationResponse,
1011
type ProvisionerJobLog,
11-
type ServerSentEvent,
1212
type Workspace,
1313
type WorkspaceAgent,
1414
} from "coder/site/src/api/typesGenerated";
1515
import * as vscode from "vscode";
16-
import { type ClientOptions } from "ws";
16+
import { type ClientOptions, type CloseEvent, type ErrorEvent } from "ws";
1717

1818
import { CertificateError } from "../error";
1919
import { getHeaderCommand, getHeaders } from "../headers";
20+
import { EventStreamLogger } from "../logging/eventStreamLogger";
2021
import {
2122
createRequestMeta,
2223
logRequest,
@@ -29,11 +30,12 @@ import {
2930
HttpClientLogLevel,
3031
} from "../logging/types";
3132
import { sizeOf } from "../logging/utils";
32-
import { WsLogger } from "../logging/wsLogger";
33+
import { type UnidirectionalStream } from "../websocket/eventStreamConnection";
3334
import {
3435
OneWayWebSocket,
3536
type OneWayWebSocketInit,
3637
} from "../websocket/oneWayWebSocket";
38+
import { SseConnection } from "../websocket/sseConnection";
3739

3840
import { createHttpAgent } from "./utils";
3941

@@ -84,8 +86,9 @@ export class CoderApi extends Api {
8486
};
8587

8688
watchWorkspace = async (workspace: Workspace, options?: ClientOptions) => {
87-
return this.createWebSocket<ServerSentEvent>({
89+
return this.createWebSocketWithFallback<ServerSentEvent>({
8890
apiRoute: `/api/v2/workspaces/${workspace.id}/watch-ws`,
91+
fallbackApiRoute: `/api/v2/workspaces/${workspace.id}/watch`,
8992
options,
9093
});
9194
};
@@ -94,15 +97,17 @@ export class CoderApi extends Api {
9497
agentId: WorkspaceAgent["id"],
9598
options?: ClientOptions,
9699
) => {
97-
return this.createWebSocket<ServerSentEvent>({
100+
return this.createWebSocketWithFallback<ServerSentEvent>({
98101
apiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata-ws`,
102+
fallbackApiRoute: `/api/v2/workspaceagents/${agentId}/watch-metadata`,
99103
options,
100104
});
101105
};
102106

103107
watchBuildLogsByBuildId = async (
104108
buildId: string,
105109
logs: ProvisionerJobLog[],
110+
options?: ClientOptions,
106111
) => {
107112
const searchParams = new URLSearchParams({ follow: "true" });
108113
if (logs.length) {
@@ -112,6 +117,7 @@ export class CoderApi extends Api {
112117
return this.createWebSocket<ProvisionerJobLog>({
113118
apiRoute: `/api/v2/workspacebuilds/${buildId}/logs`,
114119
searchParams,
120+
options,
115121
});
116122
};
117123

@@ -128,7 +134,7 @@ export class CoderApi extends Api {
128134
coderSessionTokenHeader
129135
] as string | undefined;
130136

131-
const headers = await getHeaders(
137+
const headersFromCommand = await getHeaders(
132138
baseUrlRaw,
133139
getHeaderCommand(vscode.workspace.getConfiguration()),
134140
this.output,
@@ -137,43 +143,154 @@ export class CoderApi extends Api {
137143
const httpAgent = await createHttpAgent(
138144
vscode.workspace.getConfiguration(),
139145
);
146+
147+
/**
148+
* Similar to the REST client, we want to prioritize headers in this order (highest to lowest):
149+
* 1. Headers from the header command
150+
* 2. Any headers passed directly to this function
151+
* 3. Coder session token from the Api client (if set)
152+
*/
153+
const headers = {
154+
...(token ? { [coderSessionTokenHeader]: token } : {}),
155+
...configs.options?.headers,
156+
...headersFromCommand,
157+
};
158+
140159
const webSocket = new OneWayWebSocket<TData>({
141160
location: baseUrl,
142161
...configs,
143162
options: {
163+
...configs.options,
144164
agent: httpAgent,
145165
followRedirects: true,
146-
headers: {
147-
...(token ? { [coderSessionTokenHeader]: token } : {}),
148-
...configs.options?.headers,
149-
...headers,
150-
},
151-
...configs.options,
166+
headers,
152167
},
153168
});
154169

155-
const wsUrl = new URL(webSocket.url);
156-
const pathWithQuery = wsUrl.pathname + wsUrl.search;
157-
const wsLogger = new WsLogger(this.output, pathWithQuery);
158-
wsLogger.logConnecting();
170+
this.attachStreamLogger(webSocket);
171+
return webSocket;
172+
}
159173

160-
webSocket.addEventListener("open", () => {
161-
wsLogger.logOpen();
162-
});
174+
private attachStreamLogger<TData>(
175+
connection: UnidirectionalStream<TData>,
176+
): void {
177+
const url = new URL(connection.url);
178+
const logger = new EventStreamLogger(
179+
this.output,
180+
url.pathname + url.search,
181+
url.protocol.startsWith("http") ? "SSE" : "WS",
182+
);
183+
logger.logConnecting();
163184

164-
webSocket.addEventListener("message", (event) => {
165-
wsLogger.logMessage(event.sourceEvent.data);
166-
});
185+
connection.addEventListener("open", () => logger.logOpen());
186+
connection.addEventListener("close", (event: CloseEvent) =>
187+
logger.logClose(event.code, event.reason),
188+
);
189+
connection.addEventListener("error", (event: ErrorEvent) =>
190+
logger.logError(event.error, event.message),
191+
);
192+
connection.addEventListener("message", (event) =>
193+
logger.logMessage(event.sourceEvent.data),
194+
);
195+
}
167196

168-
webSocket.addEventListener("close", (event) => {
169-
wsLogger.logClose(event.code, event.reason);
197+
/**
198+
* Create a WebSocket connection with SSE fallback on 404.
199+
*
200+
* Note: The fallback on SSE ignores all passed client options except the headers.
201+
*/
202+
private async createWebSocketWithFallback<TData = unknown>(configs: {
203+
apiRoute: string;
204+
fallbackApiRoute: string;
205+
searchParams?: Record<string, string> | URLSearchParams;
206+
options?: ClientOptions;
207+
}): Promise<UnidirectionalStream<TData>> {
208+
let webSocket: OneWayWebSocket<TData>;
209+
try {
210+
webSocket = await this.createWebSocket<TData>({
211+
apiRoute: configs.apiRoute,
212+
searchParams: configs.searchParams,
213+
options: configs.options,
214+
});
215+
} catch {
216+
// Failed to create WebSocket, use SSE fallback
217+
return this.createSseFallback<TData>(
218+
configs.fallbackApiRoute,
219+
configs.searchParams,
220+
configs.options?.headers,
221+
);
222+
}
223+
224+
return this.waitForConnection(webSocket, () =>
225+
this.createSseFallback<TData>(
226+
configs.fallbackApiRoute,
227+
configs.searchParams,
228+
configs.options?.headers,
229+
),
230+
);
231+
}
232+
233+
private waitForConnection<TData>(
234+
connection: UnidirectionalStream<TData>,
235+
onNotFound?: () => Promise<UnidirectionalStream<TData>>,
236+
): Promise<UnidirectionalStream<TData>> {
237+
return new Promise((resolve, reject) => {
238+
const cleanup = () => {
239+
connection.removeEventListener("open", handleOpen);
240+
connection.removeEventListener("error", handleError);
241+
};
242+
243+
const handleOpen = () => {
244+
cleanup();
245+
resolve(connection);
246+
};
247+
248+
const handleError = (event: ErrorEvent) => {
249+
cleanup();
250+
const is404 =
251+
event.message?.includes("404") ||
252+
event.error?.message?.includes("404");
253+
254+
if (is404 && onNotFound) {
255+
connection.close();
256+
onNotFound().then(resolve).catch(reject);
257+
} else {
258+
reject(event.error || new Error(event.message));
259+
}
260+
};
261+
262+
connection.addEventListener("open", handleOpen);
263+
connection.addEventListener("error", handleError);
170264
});
265+
}
266+
267+
/**
268+
* Create SSE fallback connection
269+
*/
270+
private async createSseFallback<TData = unknown>(
271+
apiRoute: string,
272+
searchParams?: Record<string, string> | URLSearchParams,
273+
optionsHeaders?: Record<string, string>,
274+
): Promise<UnidirectionalStream<TData>> {
275+
this.output.warn(`WebSocket failed, using SSE fallback: ${apiRoute}`);
276+
277+
const baseUrlRaw = this.getAxiosInstance().defaults.baseURL;
278+
if (!baseUrlRaw) {
279+
throw new Error("No base URL set on REST client");
280+
}
171281

172-
webSocket.addEventListener("error", (event) => {
173-
wsLogger.logError(event.error, event.message);
282+
const baseUrl = new URL(baseUrlRaw);
283+
const sseConnection = new SseConnection({
284+
location: baseUrl,
285+
apiRoute,
286+
searchParams,
287+
axiosInstance: this.getAxiosInstance(),
288+
optionsHeaders: optionsHeaders,
289+
logger: this.output,
174290
});
175291

176-
return webSocket;
292+
this.attachStreamLogger(sseConnection);
293+
return this.waitForConnection(sseConnection);
177294
}
178295
}
179296

src/api/streamingFetchAdapter.ts

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import { type AxiosInstance } from "axios";
2+
import { type FetchLikeInit, type FetchLikeResponse } from "eventsource";
3+
import { type IncomingMessage } from "http";
4+
5+
/**
6+
* Creates a fetch adapter using an Axios instance that returns streaming responses.
7+
* This is used by EventSource to make authenticated SSE connections.
8+
*/
9+
export function createStreamingFetchAdapter(
10+
axiosInstance: AxiosInstance,
11+
configHeaders?: Record<string, string>,
12+
): (url: string | URL, init?: FetchLikeInit) => Promise<FetchLikeResponse> {
13+
return async (
14+
url: string | URL,
15+
init?: FetchLikeInit,
16+
): Promise<FetchLikeResponse> => {
17+
const urlStr = url.toString();
18+
19+
const response = await axiosInstance.request<IncomingMessage>({
20+
url: urlStr,
21+
signal: init?.signal,
22+
headers: { ...init?.headers, ...configHeaders },
23+
responseType: "stream",
24+
validateStatus: () => true, // Don't throw on any status code
25+
});
26+
27+
const stream = new ReadableStream({
28+
start(controller) {
29+
response.data.on("data", (chunk: Buffer) => {
30+
try {
31+
controller.enqueue(chunk);
32+
} catch {
33+
// Stream already closed or errored, ignore
34+
}
35+
});
36+
37+
response.data.on("end", () => {
38+
try {
39+
controller.close();
40+
} catch {
41+
// Stream already closed, ignore
42+
}
43+
});
44+
45+
response.data.on("error", (err: Error) => {
46+
controller.error(err);
47+
});
48+
},
49+
50+
cancel() {
51+
response.data.destroy();
52+
return Promise.resolve();
53+
},
54+
});
55+
56+
return {
57+
body: {
58+
getReader: () => stream.getReader(),
59+
},
60+
url: urlStr,
61+
status: response.status,
62+
redirected: response.request?.res?.responseUrl !== urlStr,
63+
headers: {
64+
get: (name: string) => {
65+
const value = response.headers[name.toLowerCase()];
66+
return value === undefined ? null : String(value);
67+
},
68+
},
69+
};
70+
};
71+
}

src/logging/wsLogger.ts renamed to src/logging/eventStreamLogger.ts

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,35 @@ const numFormatter = new Intl.NumberFormat("en", {
1212
compactDisplay: "short",
1313
});
1414

15-
export class WsLogger {
15+
export class EventStreamLogger {
1616
private readonly logger: Logger;
1717
private readonly url: string;
1818
private readonly id: string;
19+
private readonly protocol: string;
1920
private readonly startedAt: number;
2021
private openedAt?: number;
2122
private msgCount = 0;
2223
private byteCount = 0;
2324
private unknownByteCount = false;
2425

25-
constructor(logger: Logger, url: string) {
26+
constructor(logger: Logger, url: string, protocol: "WS" | "SSE") {
2627
this.logger = logger;
2728
this.url = url;
29+
this.protocol = protocol;
2830
this.id = createRequestId();
2931
this.startedAt = Date.now();
3032
}
3133

3234
logConnecting(): void {
33-
this.logger.trace(`→ WS ${shortId(this.id)} ${this.url}`);
35+
this.logger.trace(`→ ${this.protocol} ${shortId(this.id)} ${this.url}`);
3436
}
3537

3638
logOpen(): void {
3739
this.openedAt = Date.now();
3840
const time = formatTime(this.openedAt - this.startedAt);
39-
this.logger.trace(`← WS ${shortId(this.id)} connected ${this.url} ${time}`);
41+
this.logger.trace(
42+
`← ${this.protocol} ${shortId(this.id)} connected ${this.url} ${time}`,
43+
);
4044
}
4145

4246
logMessage(data: unknown): void {
@@ -62,15 +66,15 @@ export class WsLogger {
6266
const statsStr = ` [${stats.join(", ")}]`;
6367

6468
this.logger.trace(
65-
`▣ WS ${shortId(this.id)} closed ${this.url}${codeStr}${reasonStr}${statsStr}`,
69+
`▣ ${this.protocol} ${shortId(this.id)} closed ${this.url}${codeStr}${reasonStr}${statsStr}`,
6670
);
6771
}
6872

6973
logError(error: unknown, message: string): void {
7074
const time = formatTime(Date.now() - this.startedAt);
7175
const errorMsg = message || errToStr(error, "connection error");
7276
this.logger.error(
73-
`✗ WS ${shortId(this.id)} error ${this.url} ${time} - ${errorMsg}`,
77+
`✗ ${this.protocol} ${shortId(this.id)} error ${this.url} ${time} - ${errorMsg}`,
7478
error,
7579
);
7680
}

0 commit comments

Comments
 (0)