diff --git a/src/browser/contexts/API.test.tsx b/src/browser/contexts/API.test.tsx new file mode 100644 index 0000000000..08bcc31989 --- /dev/null +++ b/src/browser/contexts/API.test.tsx @@ -0,0 +1,267 @@ +import { act, cleanup, render, waitFor } from "@testing-library/react"; +import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test"; +import { GlobalWindow } from "happy-dom"; + +// Mock WebSocket that we can control +class MockWebSocket { + static instances: MockWebSocket[] = []; + url: string; + readyState = 0; // CONNECTING + eventListeners = new Map void>>(); + + constructor(url: string) { + this.url = url; + MockWebSocket.instances.push(this); + } + + addEventListener(event: string, handler: (event?: unknown) => void) { + const handlers = this.eventListeners.get(event) ?? []; + handlers.push(handler); + this.eventListeners.set(event, handlers); + } + + close() { + this.readyState = 3; // CLOSED + } + + // Test helpers + simulateOpen() { + this.readyState = 1; // OPEN + this.eventListeners.get("open")?.forEach((h) => h()); + } + + simulateClose(code: number) { + this.readyState = 3; + this.eventListeners.get("close")?.forEach((h) => h({ code })); + } + + simulateError() { + this.eventListeners.get("error")?.forEach((h) => h()); + } + + static reset() { + MockWebSocket.instances = []; + } + + static lastInstance(): MockWebSocket | undefined { + return MockWebSocket.instances[MockWebSocket.instances.length - 1]; + } +} + +// Mock orpc client +void mock.module("@/common/orpc/client", () => ({ + createClient: () => ({ + general: { + ping: () => Promise.resolve("pong"), + }, + }), +})); + +void mock.module("@orpc/client/websocket", () => ({ + RPCLink: class {}, +})); + +void mock.module("@orpc/client/message-port", () => ({ + RPCLink: class {}, +})); + +void mock.module("@/browser/components/AuthTokenModal", () => ({ + getStoredAuthToken: () => null, + // eslint-disable-next-line @typescript-eslint/no-empty-function + clearStoredAuthToken: () => {}, +})); + +// Import the real API module types (not the mocked version) +import type { UseAPIResult as _UseAPIResult, APIProvider as APIProviderType } from "./API"; + +// IMPORTANT: Other test files mock @/browser/contexts/API with a fake APIProvider. +// Module mocks leak between test files in bun (https://github.com/oven-sh/bun/issues/12823). +// The query string creates a distinct module cache key, bypassing any mocked version. +/* eslint-disable @typescript-eslint/no-require-imports, @typescript-eslint/no-unsafe-assignment */ +const RealAPIModule: { + APIProvider: typeof APIProviderType; + useAPI: () => _UseAPIResult; +} = require("./API?real=1"); +/* eslint-enable @typescript-eslint/no-require-imports, @typescript-eslint/no-unsafe-assignment */ +const { APIProvider, useAPI } = RealAPIModule; +type UseAPIResult = _UseAPIResult; + +// Test component to observe API state +function APIStateObserver(props: { onState: (state: UseAPIResult) => void }) { + const apiState = useAPI(); + props.onState(apiState); + return null; +} + +// Factory that creates MockWebSocket instances (injected via prop) +const createMockWebSocket = (url: string) => new MockWebSocket(url) as unknown as WebSocket; + +describe("API reconnection", () => { + beforeEach(() => { + // Minimal DOM setup required by @testing-library/react + const happyWindow = new GlobalWindow(); + globalThis.window = happyWindow as unknown as Window & typeof globalThis; + globalThis.document = happyWindow.document as unknown as Document; + MockWebSocket.reset(); + }); + + afterEach(() => { + cleanup(); + MockWebSocket.reset(); + globalThis.window = undefined as unknown as Window & typeof globalThis; + globalThis.document = undefined as unknown as Document; + }); + + test("reconnects on close without showing auth_required when previously connected", async () => { + const states: string[] = []; + + render( + + states.push(s.status)} /> + + ); + + const ws1 = MockWebSocket.lastInstance(); + expect(ws1).toBeDefined(); + + // Simulate successful connection (open + ping success) + await act(async () => { + ws1!.simulateOpen(); + // Wait for ping promise to resolve + await new Promise((r) => setTimeout(r, 10)); + }); + + expect(states).toContain("connected"); + + // Simulate server restart (close code 1006 = abnormal closure) + act(() => { + ws1!.simulateClose(1006); + }); + + // Should be "reconnecting", NOT "auth_required" + await waitFor(() => { + expect(states).toContain("reconnecting"); + }); + + expect(states.filter((s) => s === "auth_required")).toHaveLength(0); + + // New WebSocket should be created for reconnect attempt (after delay) + await waitFor(() => { + expect(MockWebSocket.instances.length).toBeGreaterThan(1); + }); + }); + + test("shows auth_required on close with auth error codes (4401)", async () => { + const states: string[] = []; + + render( + + states.push(s.status)} /> + + ); + + const ws1 = MockWebSocket.lastInstance(); + expect(ws1).toBeDefined(); + + await act(async () => { + ws1!.simulateOpen(); + await new Promise((r) => setTimeout(r, 10)); + }); + + expect(states).toContain("connected"); + + act(() => { + ws1!.simulateClose(4401); + }); + + await waitFor(() => { + expect(states).toContain("auth_required"); + }); + }); + + test("shows auth_required on close with auth error codes (1008)", async () => { + const states: string[] = []; + + render( + + states.push(s.status)} /> + + ); + + const ws1 = MockWebSocket.lastInstance(); + expect(ws1).toBeDefined(); + + await act(async () => { + ws1!.simulateOpen(); + await new Promise((r) => setTimeout(r, 10)); + }); + + expect(states).toContain("connected"); + + act(() => { + ws1!.simulateClose(1008); + }); + + await waitFor(() => { + expect(states).toContain("auth_required"); + }); + }); + + test("shows auth_required on first connection failure without token", async () => { + const states: string[] = []; + + render( + + states.push(s.status)} /> + + ); + + const ws1 = MockWebSocket.lastInstance(); + expect(ws1).toBeDefined(); + + // First connection fails - browser fires error then close + act(() => { + ws1!.simulateError(); + ws1!.simulateClose(1006); + }); + + await waitFor(() => { + expect(states).toContain("auth_required"); + }); + + expect(states.filter((s) => s === "reconnecting")).toHaveLength(0); + }); + + test("reconnects on connection loss when previously connected", async () => { + const states: string[] = []; + + render( + + states.push(s.status)} /> + + ); + + const ws1 = MockWebSocket.lastInstance(); + expect(ws1).toBeDefined(); + + await act(async () => { + ws1!.simulateOpen(); + await new Promise((r) => setTimeout(r, 10)); + }); + + expect(states).toContain("connected"); + + // Connection lost after being connected + act(() => { + ws1!.simulateError(); + ws1!.simulateClose(1006); + }); + + await waitFor(() => { + expect(states).toContain("reconnecting"); + }); + + const authRequiredAfterConnected = states.slice(states.indexOf("connected") + 1); + expect(authRequiredAfterConnected.filter((s) => s === "auth_required")).toHaveLength(0); + }); +}); diff --git a/src/browser/contexts/API.tsx b/src/browser/contexts/API.tsx index 8dfe030530..4a920c2073 100644 --- a/src/browser/contexts/API.tsx +++ b/src/browser/contexts/API.tsx @@ -20,6 +20,7 @@ export type { APIClient }; export type APIState = | { status: "connecting"; api: null; error: null } | { status: "connected"; api: APIClient; error: null } + | { status: "reconnecting"; api: null; error: null; attempt: number } | { status: "auth_required"; api: null; error: string | null } | { status: "error"; api: null; error: string }; @@ -35,15 +36,23 @@ export type UseAPIResult = APIState & APIStateMethods; type ConnectionState = | { status: "connecting" } | { status: "connected"; client: APIClient; cleanup: () => void } + | { status: "reconnecting"; attempt: number } | { status: "auth_required"; error?: string } | { status: "error"; error: string }; +// Reconnection constants +const MAX_RECONNECT_ATTEMPTS = 10; +const BASE_DELAY_MS = 100; +const MAX_DELAY_MS = 10000; + const APIContext = createContext(null); interface APIProviderProps { children: React.ReactNode; /** Optional pre-created client. If provided, skips internal connection setup. */ client?: APIClient; + /** WebSocket factory for testing. Defaults to native WebSocket constructor. */ + createWebSocket?: (url: string) => WebSocket; } function getApiBase(): string { @@ -65,7 +74,10 @@ function createElectronClient(): { client: APIClient; cleanup: () => void } { }; } -function createBrowserClient(authToken: string | null): { +function createBrowserClient( + authToken: string | null, + createWebSocket: (url: string) => WebSocket +): { client: APIClient; cleanup: () => void; ws: WebSocket; @@ -77,7 +89,7 @@ function createBrowserClient(authToken: string | null): { ? `${WS_BASE}/orpc/ws?token=${encodeURIComponent(authToken)}` : `${WS_BASE}/orpc/ws`; - const ws = new WebSocket(wsUrl); + const ws = createWebSocket(wsUrl); const link = new WebSocketLink({ websocket: ws }); return { @@ -102,6 +114,15 @@ export const APIProvider = (props: APIProviderProps) => { }); const cleanupRef = useRef<(() => void) | null>(null); + const hasConnectedRef = useRef(false); + const reconnectAttemptRef = useRef(0); + const reconnectTimeoutRef = useRef | null>(null); + const scheduleReconnectRef = useRef<(() => void) | null>(null); + + const wsFactory = useMemo( + () => props.createWebSocket ?? ((url: string) => new WebSocket(url)), + [props.createWebSocket] + ); const connect = useCallback( (token: string | null) => { @@ -112,7 +133,8 @@ export const APIProvider = (props: APIProviderProps) => { return; } - if (window.api) { + // Skip Electron detection if custom WebSocket factory provided (for testing) + if (!props.createWebSocket && window.api) { const { client, cleanup } = createElectronClient(); window.__ORPC_CLIENT__ = client; cleanupRef.current = cleanup; @@ -121,12 +143,14 @@ export const APIProvider = (props: APIProviderProps) => { } setState({ status: "connecting" }); - const { client, cleanup, ws } = createBrowserClient(token); + const { client, cleanup, ws } = createBrowserClient(token, wsFactory); ws.addEventListener("open", () => { client.general .ping("auth-check") .then(() => { + hasConnectedRef.current = true; + reconnectAttemptRef.current = 0; window.__ORPC_CLIENT__ = client; cleanupRef.current = cleanup; setState({ status: "connected", client, cleanup }); @@ -149,31 +173,70 @@ export const APIProvider = (props: APIProviderProps) => { }); }); + // Note: Browser fires 'error' before 'close', so we handle reconnection + // only in 'close' to avoid double-scheduling. The 'error' event just + // signals that something went wrong; 'close' provides the final state. ws.addEventListener("error", () => { - cleanup(); - if (token) { - clearStoredAuthToken(); - setState({ status: "auth_required", error: "Connection failed - invalid token?" }); - } else { - setState({ status: "auth_required" }); - } + // Error occurred - close event will follow and handle reconnection + // We don't call cleanup() here since close handler will do it }); ws.addEventListener("close", (event) => { + cleanup(); + + // Auth-specific close codes if (event.code === 1008 || event.code === 4401) { - cleanup(); clearStoredAuthToken(); + hasConnectedRef.current = false; // Reset - need fresh auth setState({ status: "auth_required", error: "Authentication required" }); + return; + } + + // If we were previously connected, try to reconnect + if (hasConnectedRef.current) { + scheduleReconnectRef.current?.(); + return; + } + + // First connection failed - check if auth might be needed + if (token) { + clearStoredAuthToken(); + setState({ status: "auth_required", error: "Connection failed - invalid token?" }); + } else { + setState({ status: "auth_required" }); } }); }, - [props.client] + [props.client, props.createWebSocket, wsFactory] ); + // Schedule reconnection with exponential backoff + const scheduleReconnect = useCallback(() => { + const attempt = reconnectAttemptRef.current; + if (attempt >= MAX_RECONNECT_ATTEMPTS) { + setState({ status: "error", error: "Connection lost. Please refresh the page." }); + return; + } + + const delay = Math.min(BASE_DELAY_MS * Math.pow(2, attempt), MAX_DELAY_MS); + reconnectAttemptRef.current = attempt + 1; + setState({ status: "reconnecting", attempt: attempt + 1 }); + + reconnectTimeoutRef.current = setTimeout(() => { + connect(authToken); + }, delay); + }, [authToken, connect]); + + // Keep ref in sync with latest scheduleReconnect + scheduleReconnectRef.current = scheduleReconnect; + useEffect(() => { connect(authToken); return () => { cleanupRef.current?.(); + if (reconnectTimeoutRef.current) { + clearTimeout(reconnectTimeoutRef.current); + } }; // eslint-disable-next-line react-hooks/exhaustive-deps }, []); @@ -198,6 +261,8 @@ export const APIProvider = (props: APIProviderProps) => { return { status: "connecting", api: null, error: null, ...base }; case "connected": return { status: "connected", api: state.client, error: null, ...base }; + case "reconnecting": + return { status: "reconnecting", api: null, error: null, attempt: state.attempt, ...base }; case "auth_required": return { status: "auth_required", api: null, error: state.error ?? null, ...base }; case "error":