@@ -36,6 +36,10 @@ import {
3636 OneWayWebSocket ,
3737 type OneWayWebSocketInit ,
3838} from "../websocket/oneWayWebSocket" ;
39+ import {
40+ ReconnectingWebSocket ,
41+ type SocketFactory ,
42+ } from "../websocket/reconnectingWebSocket" ;
3943import { SseConnection } from "../websocket/sseConnection" ;
4044
4145import { createHttpAgent } from "./utils" ;
@@ -47,6 +51,10 @@ const coderSessionTokenHeader = "Coder-Session-Token";
4751 * and WebSocket methods for real-time functionality.
4852 */
4953export class CoderApi extends Api {
54+ private readonly reconnectingSockets = new Set <
55+ ReconnectingWebSocket < unknown >
56+ > ( ) ;
57+
5058 private constructor ( private readonly output : Logger ) {
5159 super ( ) ;
5260 }
@@ -70,6 +78,30 @@ export class CoderApi extends Api {
7078 return client ;
7179 }
7280
81+ setSessionToken = ( token : string ) : void => {
82+ const currentToken =
83+ this . getAxiosInstance ( ) . defaults . headers . common [ coderSessionTokenHeader ] ;
84+ this . getAxiosInstance ( ) . defaults . headers . common [ coderSessionTokenHeader ] =
85+ token ;
86+
87+ if ( currentToken !== token ) {
88+ for ( const socket of this . reconnectingSockets ) {
89+ socket . reconnect ( ) ;
90+ }
91+ }
92+ } ;
93+
94+ setHost = ( host : string | undefined ) : void => {
95+ const currentHost = this . getAxiosInstance ( ) . defaults . baseURL ;
96+ this . getAxiosInstance ( ) . defaults . baseURL = host ;
97+
98+ if ( currentHost !== host ) {
99+ for ( const socket of this . reconnectingSockets ) {
100+ socket . reconnect ( ) ;
101+ }
102+ }
103+ } ;
104+
73105 watchInboxNotifications = async (
74106 watchTemplates : string [ ] ,
75107 watchTargets : string [ ] ,
@@ -83,6 +115,7 @@ export class CoderApi extends Api {
83115 targets : watchTargets . join ( "," ) ,
84116 } ,
85117 options,
118+ enableRetry : true ,
86119 } ) ;
87120 } ;
88121
@@ -91,6 +124,7 @@ export class CoderApi extends Api {
91124 apiRoute : `/api/v2/workspaces/${ workspace . id } /watch-ws` ,
92125 fallbackApiRoute : `/api/v2/workspaces/${ workspace . id } /watch` ,
93126 options,
127+ enableRetry : true ,
94128 } ) ;
95129 } ;
96130
@@ -102,6 +136,7 @@ export class CoderApi extends Api {
102136 apiRoute : `/api/v2/workspaceagents/${ agentId } /watch-metadata-ws` ,
103137 fallbackApiRoute : `/api/v2/workspaceagents/${ agentId } /watch-metadata` ,
104138 options,
139+ enableRetry : true ,
105140 } ) ;
106141 } ;
107142
@@ -148,53 +183,73 @@ export class CoderApi extends Api {
148183 }
149184
150185 private async createWebSocket < TData = unknown > (
151- configs : Omit < OneWayWebSocketInit , "location" > ,
152- ) {
153- const baseUrlRaw = this . getAxiosInstance ( ) . defaults . baseURL ;
154- if ( ! baseUrlRaw ) {
155- throw new Error ( "No base URL set on REST client" ) ;
156- }
186+ configs : Omit < OneWayWebSocketInit , "location" > & { enableRetry ?: boolean } ,
187+ ) : Promise < UnidirectionalStream < TData > > {
188+ const { enableRetry, ...socketConfigs } = configs ;
189+
190+ const socketFactory : SocketFactory < TData > = async ( ) => {
191+ const baseUrlRaw = this . getAxiosInstance ( ) . defaults . baseURL ;
192+ if ( ! baseUrlRaw ) {
193+ throw new Error ( "No base URL set on REST client" ) ;
194+ }
195+
196+ const baseUrl = new URL ( baseUrlRaw ) ;
197+ const token = this . getAxiosInstance ( ) . defaults . headers . common [
198+ coderSessionTokenHeader
199+ ] as string | undefined ;
200+
201+ const headersFromCommand = await getHeaders (
202+ baseUrlRaw ,
203+ getHeaderCommand ( vscode . workspace . getConfiguration ( ) ) ,
204+ this . output ,
205+ ) ;
157206
158- const baseUrl = new URL ( baseUrlRaw ) ;
159- const token = this . getAxiosInstance ( ) . defaults . headers . common [
160- coderSessionTokenHeader
161- ] as string | undefined ;
207+ const httpAgent = await createHttpAgent (
208+ vscode . workspace . getConfiguration ( ) ,
209+ ) ;
162210
163- const headersFromCommand = await getHeaders (
164- baseUrlRaw ,
165- getHeaderCommand ( vscode . workspace . getConfiguration ( ) ) ,
166- this . output ,
167- ) ;
211+ /**
212+ * Similar to the REST client, we want to prioritize headers in this order (highest to lowest):
213+ * 1. Headers from the header command
214+ * 2. Any headers passed directly to this function
215+ * 3. Coder session token from the Api client (if set)
216+ */
217+ const headers = {
218+ ...( token ? { [ coderSessionTokenHeader ] : token } : { } ) ,
219+ ...configs . options ?. headers ,
220+ ...headersFromCommand ,
221+ } ;
168222
169- const httpAgent = await createHttpAgent (
170- vscode . workspace . getConfiguration ( ) ,
171- ) ;
223+ const webSocket = new OneWayWebSocket < TData > ( {
224+ location : baseUrl ,
225+ ...socketConfigs ,
226+ options : {
227+ ...configs . options ,
228+ agent : httpAgent ,
229+ followRedirects : true ,
230+ headers,
231+ } ,
232+ } ) ;
172233
173- /**
174- * Similar to the REST client, we want to prioritize headers in this order (highest to lowest):
175- * 1. Headers from the header command
176- * 2. Any headers passed directly to this function
177- * 3. Coder session token from the Api client (if set)
178- */
179- const headers = {
180- ...( token ? { [ coderSessionTokenHeader ] : token } : { } ) ,
181- ...configs . options ?. headers ,
182- ...headersFromCommand ,
234+ this . attachStreamLogger ( webSocket ) ;
235+ return webSocket ;
183236 } ;
184237
185- const webSocket = new OneWayWebSocket < TData > ( {
186- location : baseUrl ,
187- ... configs ,
188- options : {
189- ... configs . options ,
190- agent : httpAgent ,
191- followRedirects : true ,
192- headers ,
193- } ,
194- } ) ;
238+ if ( enableRetry ) {
239+ const reconnectingSocket = await ReconnectingWebSocket . create < TData > (
240+ socketFactory ,
241+ this . output ,
242+ configs . apiRoute ,
243+ ) ;
244+
245+ this . reconnectingSockets . add (
246+ reconnectingSocket as ReconnectingWebSocket < unknown > ,
247+ ) ;
195248
196- this . attachStreamLogger ( webSocket ) ;
197- return webSocket ;
249+ return reconnectingSocket ;
250+ } else {
251+ return socketFactory ( ) ;
252+ }
198253 }
199254
200255 private attachStreamLogger < TData > (
@@ -230,13 +285,15 @@ export class CoderApi extends Api {
230285 fallbackApiRoute : string ;
231286 searchParams ?: Record < string , string > | URLSearchParams ;
232287 options ?: ClientOptions ;
288+ enableRetry ?: boolean ;
233289 } ) : Promise < UnidirectionalStream < TData > > {
234- let webSocket : OneWayWebSocket < TData > ;
290+ let webSocket : UnidirectionalStream < TData > ;
235291 try {
236292 webSocket = await this . createWebSocket < TData > ( {
237293 apiRoute : configs . apiRoute ,
238294 searchParams : configs . searchParams ,
239295 options : configs . options ,
296+ enableRetry : configs . enableRetry ,
240297 } ) ;
241298 } catch {
242299 // Failed to create WebSocket, use SSE fallback
0 commit comments