diff --git a/Dockerfile b/Dockerfile index 30b7eed3..8b463a20 100644 --- a/Dockerfile +++ b/Dockerfile @@ -59,8 +59,11 @@ FROM base AS runner ARG UV_VERSION=latest ARG OPENCODE_VERSION=latest +# Bump TOOLS_CACHEBUST (e.g. via --build-arg) to force a fresh uv/opencode +# install without invalidating the rest of the build cache. +ARG TOOLS_CACHEBUST=0 -RUN echo "Installing uv=${UV_VERSION} opencode=${OPENCODE_VERSION}" && \ +RUN echo "Installing uv=${UV_VERSION} opencode=${OPENCODE_VERSION} (cachebust=${TOOLS_CACHEBUST})" && \ curl -LsSf https://astral.sh/uv/install.sh | UV_NO_MODIFY_PATH=1 sh && \ mv /root/.local/bin/uv /usr/local/bin/uv && \ mv /root/.local/bin/uvx /usr/local/bin/uvx && \ diff --git a/backend/src/routes/sse-writer.ts b/backend/src/routes/sse-writer.ts new file mode 100644 index 00000000..6da4303c --- /dev/null +++ b/backend/src/routes/sse-writer.ts @@ -0,0 +1,49 @@ +export interface QueuedSSEWriterInput { + write: (chunk: Uint8Array) => Promise | void + onError: (error: unknown) => void +} + +export interface QueuedSSEWriter { + writeSSE: (event: string, data: string) => void + writeFrame: (frame: Uint8Array) => void + close: () => void +} + +const sharedEncoder = new TextEncoder() + +export function encodeSSEFrame(event: string, data: string): Uint8Array { + const head = event ? `event: ${event}\n` : '' + return sharedEncoder.encode(`${head}data: ${data}\n\n`) +} + +export function createQueuedSSEWriter(input: QueuedSSEWriterInput): QueuedSSEWriter { + let chain = Promise.resolve() + let closed = false + + const writeFrame = (frame: Uint8Array) => { + if (closed) return + + chain = chain + .then(async () => { + if (closed) return + await input.write(frame) + }) + .catch((error) => { + if (!closed) { + closed = true + input.onError(error) + } + }) + } + + const writeSSE = (event: string, data: string) => { + if (closed) return + writeFrame(encodeSSEFrame(event, data)) + } + + const close = () => { + closed = true + } + + return { writeSSE, writeFrame, close } +} diff --git a/backend/src/routes/sse.ts b/backend/src/routes/sse.ts index 636e3473..e1af25cf 100644 --- a/backend/src/routes/sse.ts +++ b/backend/src/routes/sse.ts @@ -4,6 +4,7 @@ import { sseAggregator } from '../services/sse-aggregator' import { SSESubscribeSchema, SSEVisibilitySchema } from '@opencode-manager/shared/schemas' import { logger } from '../utils/logger' import { DEFAULTS } from '@opencode-manager/shared/config' +import { createQueuedSSEWriter } from './sse-writer' const { HEARTBEAT_INTERVAL_MS } = DEFAULTS.SSE @@ -21,42 +22,39 @@ export function createSSERoutes() { c.header('X-Accel-Buffering', 'no') return stream(c, async (writer) => { - const encoder = new TextEncoder() - const writeSSE = (event: string, data: string) => { - const lines = [] - if (event) lines.push(`event: ${event}`) - lines.push(`data: ${data}`) - lines.push('') - lines.push('') - writer.write(encoder.encode(lines.join('\n'))) - } - - const cleanup = sseAggregator.addClient( + let cleanup: () => void = () => {} + + const queuedWriter = createQueuedSSEWriter({ + write: (chunk) => writer.write(chunk), + onError: (error) => { + logger.error(`SSE write failed for ${clientId}:`, error) + clearInterval(heartbeatInterval) + cleanup() + }, + }) + + const heartbeatInterval = setInterval(() => { + queuedWriter.writeSSE('heartbeat', JSON.stringify({ timestamp: Date.now() })) + }, HEARTBEAT_INTERVAL_MS) + + cleanup = sseAggregator.addClient( clientId, (event, data) => { - writeSSE(event, data) + queuedWriter.writeSSE(event, data) + }, + (frame) => { + queuedWriter.writeFrame(frame) }, directories ) - const heartbeatInterval = setInterval(() => { - try { - writeSSE('heartbeat', JSON.stringify({ timestamp: Date.now() })) - } catch { - clearInterval(heartbeatInterval) - } - }, HEARTBEAT_INTERVAL_MS) - writer.onAbort(() => { + queuedWriter.close() clearInterval(heartbeatInterval) cleanup() }) - try { - writeSSE('connected', JSON.stringify({ clientId, directories, ...sseAggregator.getConnectionStatus() })) - } catch (err) { - logger.error(`Failed to send SSE connected event for ${clientId}:`, err) - } + queuedWriter.writeSSE('connected', JSON.stringify({ clientId, directories, ...sseAggregator.getConnectionStatus() })) await new Promise(() => {}) }) diff --git a/backend/src/services/assistant-mode.ts b/backend/src/services/assistant-mode.ts index 3052e02c..ca562592 100644 --- a/backend/src/services/assistant-mode.ts +++ b/backend/src/services/assistant-mode.ts @@ -1090,7 +1090,7 @@ export async function warmAssistantWorkspace(deps: { db: deps.db, apiBaseUrl: deps.apiBaseUrl, }) - await deps.openCodeClient.getJson('/session', { + await deps.openCodeClient.getJson('/api/session?limit=1&order=desc', { directory: status.directory, signal: AbortSignal.timeout(ASSISTANT_WARMUP_OPENCODE_TIMEOUT_MS), }) diff --git a/backend/src/services/opencode-single-server.ts b/backend/src/services/opencode-single-server.ts index 49827ef3..8105886e 100644 --- a/backend/src/services/opencode-single-server.ts +++ b/backend/src/services/opencode-single-server.ts @@ -29,6 +29,7 @@ import { writeFileContent } from './file-operations' const MIN_OPENCODE_VERSION = '1.0.137' const MAX_STDERR_SIZE = 10240 const HEALTH_CHECK_TIMEOUT_MS = 3000 +const PLUGIN_INSTALL_TIMEOUT_MS = 120000 const DEPRECATED_PLUGIN_PACKAGES = ['opencode-openai-codex-auth', 'opencode-copilot-auth'] type StartupValidationIssue = { @@ -585,19 +586,24 @@ class OpenCodeServerManager { await fs.mkdir(installDir, { recursive: true }) if (!await fs.access(path.join(installDir, 'package.json')).then(() => true).catch(() => false)) { - const init = spawnSync('bun', ['init', '-y'], { cwd: installDir, encoding: 'utf8' }) + const init = spawnSync('bun', ['init', '-y'], { cwd: installDir, encoding: 'utf8', timeout: 30000 }) if (init.status !== 0) { logger.warn(`Failed to initialize OpenCode plugin cache for ${plugin}: ${init.stderr || init.stdout}`) continue } } - const result = spawnSync('bun', ['add', '--ignore-scripts', installSpec], { cwd: installDir, encoding: 'utf8' }) + const result = spawnSync('bun', ['add', '--ignore-scripts', installSpec], { cwd: installDir, encoding: 'utf8', timeout: PLUGIN_INSTALL_TIMEOUT_MS }) if (result.status === 0) { logger.info(`Installed OpenCode plugin: ${plugin}`) continue } + if (result.error) { + logger.warn(`Failed to install OpenCode plugin ${plugin}: ${result.error.message}`) + continue + } + logger.warn(`Failed to install OpenCode plugin ${plugin}: ${result.stderr || result.stdout}`) } } diff --git a/backend/src/services/sse-aggregator.ts b/backend/src/services/sse-aggregator.ts index b763495b..a1ee8ad1 100644 --- a/backend/src/services/sse-aggregator.ts +++ b/backend/src/services/sse-aggregator.ts @@ -3,13 +3,16 @@ import { logger } from '../utils/logger' import { ENV } from '@opencode-manager/shared/config/env' import { DEFAULTS } from '@opencode-manager/shared/config' import { getOpenCodeBasicAuthHeader, type OpenCodePasswordResolver } from './opencode/auth' +import { encodeSSEFrame } from '../routes/sse-writer' type SSEClientCallback = (event: string, data: string) => void +type SSEClientFrameWriter = (frame: Uint8Array) => void type SSEEventListener = (directory: string, event: SSEEvent) => void interface SSEClient { id: string callback: SSEClientCallback + writeFrame: SSEClientFrameWriter directories: Set visible: boolean activeSessionId: string | null @@ -95,10 +98,11 @@ class SSEAggregator { void this.connectUpstream() } - addClient(id: string, callback: SSEClientCallback, directories: string[]): () => void { + addClient(id: string, callback: SSEClientCallback, writeFrame: SSEClientFrameWriter, directories: string[]): () => void { const client: SSEClient = { id, callback, + writeFrame, directories: new Set(directories), visible: false, activeSessionId: null @@ -211,7 +215,7 @@ class SSEAggregator { if (!client || !client.directories.has(directory)) return for (const item of items) { - const payload = JSON.stringify({ type, properties: item, directory }) + const payload = JSON.stringify({ directory, payload: { type, properties: item } }) try { client.callback('message', payload) } catch (error) { @@ -316,11 +320,13 @@ class SSEAggregator { try { listener(directory, parsed) } catch { /* ignore listener errors */ } }) - const clientData = JSON.stringify({ ...parsed, directory }) + let frame: Uint8Array | undefined + const getFrame = (): Uint8Array => (frame ??= encodeSSEFrame('message', data)) + this.clients.forEach((client) => { if (client.directories.has(directory)) { try { - client.callback('message', clientData) + client.writeFrame(getFrame()) } catch (error) { logger.error(`Failed to send to client ${client.id}:`, error) } @@ -478,8 +484,10 @@ export const sseAggregator = SSEAggregator.getInstance() export function broadcastSSHHostKeyRequest(data: Record): void { const event = JSON.stringify({ - type: 'ssh.host-key-request', - properties: data, + payload: { + type: 'ssh.host-key-request', + properties: data, + }, }) sseAggregator.broadcastToAll('message', event) } diff --git a/backend/test/routes/sse-writer.test.ts b/backend/test/routes/sse-writer.test.ts new file mode 100644 index 00000000..724a0b9a --- /dev/null +++ b/backend/test/routes/sse-writer.test.ts @@ -0,0 +1,88 @@ +import { describe, it, expect, vi } from 'vitest' +import { createQueuedSSEWriter, encodeSSEFrame } from '../../src/routes/sse-writer' + +describe('encodeSSEFrame', () => { + const decoder = new TextDecoder() + + it('encodes an event frame', () => { + expect(decoder.decode(encodeSSEFrame('message', '{"n":1}'))).toBe('event: message\ndata: {"n":1}\n\n') + }) + + it('omits the event line when event is empty', () => { + expect(decoder.decode(encodeSSEFrame('', '{"n":1}'))).toBe('data: {"n":1}\n\n') + }) +}) + +describe('createQueuedSSEWriter', () => { + describe('writeFrame', () => { + it('writes a pre-encoded frame through the serialized chain', async () => { + const writes: Uint8Array[] = [] + const write = vi.fn((chunk: Uint8Array) => { writes.push(chunk) }) + const onError = vi.fn() + + const writer = createQueuedSSEWriter({ write, onError }) + const frame = encodeSSEFrame('message', '{"shared":true}') + writer.writeFrame(frame) + + await vi.waitFor(() => expect(write).toHaveBeenCalledTimes(1)) + expect(writes[0]).toBe(frame) + expect(onError).not.toHaveBeenCalled() + }) + }) + + describe('serializes frames in enqueue order', () => { + it('should not execute second write until first resolves', async () => { + let firstWriteResolve!: () => void + const writes: Uint8Array[] = [] + const write = vi.fn((chunk: Uint8Array) => { + writes.push(chunk) + if (writes.length === 1) { + return new Promise((resolve) => { + firstWriteResolve = resolve + }) + } + }) + const onError = vi.fn() + + const writer = createQueuedSSEWriter({ write, onError }) + + writer.writeSSE('message', '{"n":1}') + writer.writeSSE('message', '{"n":2}') + + await new Promise((resolve) => setTimeout(resolve, 0)) + expect(write).toHaveBeenCalledTimes(1) + expect(onError).not.toHaveBeenCalled() + + firstWriteResolve() + await new Promise((resolve) => setTimeout(resolve, 0)) + expect(write).toHaveBeenCalledTimes(2) + + const decoder = new TextDecoder() + expect(decoder.decode(writes[0])).toBe('event: message\ndata: {"n":1}\n\n') + expect(decoder.decode(writes[1])).toBe('event: message\ndata: {"n":2}\n\n') + }) + }) + + describe('stops writing after a write failure', () => { + it('should call onError and skip subsequent writes', async () => { + const write = vi.fn().mockRejectedValueOnce(new Error('write failed')) + const onError = vi.fn() + + const writer = createQueuedSSEWriter({ write, onError }) + + writer.writeSSE('message', '{"n":1}') + + await vi.waitFor(() => { + expect(onError).toHaveBeenCalledTimes(1) + }) + expect(onError).toHaveBeenCalledWith(new Error('write failed')) + expect(write).toHaveBeenCalledTimes(1) + + writer.writeSSE('message', '{"n":2}') + + await vi.waitFor(() => { + expect(write).toHaveBeenCalledTimes(1) + }) + }) + }) +}) diff --git a/backend/test/services/assistant-mode.test.ts b/backend/test/services/assistant-mode.test.ts index a35c7b24..08266773 100644 --- a/backend/test/services/assistant-mode.test.ts +++ b/backend/test/services/assistant-mode.test.ts @@ -620,7 +620,7 @@ describe('warmAssistantWorkspace', () => { }) afterEach(async () => { await ws.cleanup() }) - it('provisions the workspace and triggers a directory-scoped session request', async () => { + it('provisions the workspace and triggers a bounded directory-scoped session request', async () => { const getJsonCalls: Array<{ path: string; directory?: string }> = [] const client = { getJson: async (requestPath: string, opts?: { directory?: string }) => { @@ -634,7 +634,7 @@ describe('warmAssistantWorkspace', () => { const opencodeJson = await readFile(path.join(ws.assistantDir, 'opencode.json'), 'utf8') expect(JSON.parse(opencodeJson).default_agent).toBe('assistant') expect(getJsonCalls).toHaveLength(1) - expect(getJsonCalls[0]?.path).toBe('/session') + expect(getJsonCalls[0]?.path).toBe('/api/session?limit=1&order=desc') expect(getJsonCalls[0]?.directory).toBe(ws.assistantDir) }) diff --git a/backend/test/services/opencode-single-server.test.ts b/backend/test/services/opencode-single-server.test.ts index 6429d809..100e1a41 100644 --- a/backend/test/services/opencode-single-server.test.ts +++ b/backend/test/services/opencode-single-server.test.ts @@ -17,6 +17,8 @@ const spawnMock = vi.hoisted(() => vi.fn(() => ({ on: vi.fn(), }))) +const spawnSyncMock = vi.hoisted(() => vi.fn()) + vi.mock('bun:sqlite', () => ({ Database: vi.fn(), })) @@ -62,6 +64,7 @@ vi.mock('fs', () => ({ vi.mock('child_process', () => ({ execSync: vi.fn(), spawn: spawnMock, + spawnSync: spawnSyncMock, })) vi.mock('../../src/services/opencode/config-recovery', () => ({ @@ -73,7 +76,7 @@ vi.mock('../../src/services/opencode/client', () => ({ })) import { promises as fs } from 'fs' -import { execSync } from 'child_process' +import { execSync, spawnSync } from 'child_process' import { ConfigReloadError } from '../../src/services/opencode-single-server' import { encryptSecret } from '../../src/utils/crypto' import { ENV } from '@opencode-manager/shared/config/env' @@ -89,6 +92,7 @@ vi.mock('../../src/utils/logger', () => ({ const mkdirMock = fs.mkdir as any const accessMock = fs.access as any const execSyncMock = execSync as any +const childSpawnSyncMock = spawnSync as any // Reset singleton before any tests run to clear any polluted state from previous test files beforeAll(async () => { @@ -413,3 +417,40 @@ describe('OpenCodeServerManager - checkHealth', () => { expect(aborted).toBe(true) }, 5000) }) + +describe('OpenCodeServerManager - configured plugin install', () => { + beforeEach(async () => { + const { OpenCodeServerManager } = await import('../../src/services/opencode-single-server') + OpenCodeServerManager.resetInstance() + vi.clearAllMocks() + }) + + afterEach(async () => { + const { OpenCodeServerManager } = await import('../../src/services/opencode-single-server') + OpenCodeServerManager.resetInstance() + vi.clearAllMocks() + }) + + it('bounds first-run plugin installation with a timeout', async () => { + const { opencodeServerManager } = await import('../../src/services/opencode-single-server') + const { logger } = await import('../../src/utils/logger') + + accessMock.mockImplementation((filePath: string) => { + const error = new Error('Not found') as NodeJS.ErrnoException + error.code = 'ENOENT' + return filePath.includes('package.json') ? Promise.reject(error) : Promise.resolve() + }) + childSpawnSyncMock + .mockReturnValueOnce({ status: 0, stdout: '', stderr: '' }) + .mockReturnValueOnce({ status: null, stdout: '', stderr: '', error: new Error('spawnSync bun ETIMEDOUT') }) + + await (opencodeServerManager as any).installConfiguredPlugins(['test-plugin']) + + expect(childSpawnSyncMock).toHaveBeenCalledWith( + 'bun', + ['add', '--ignore-scripts', 'test-plugin@latest'], + expect.objectContaining({ timeout: 120000 }), + ) + expect(logger.warn).toHaveBeenCalledWith('Failed to install OpenCode plugin test-plugin: spawnSync bun ETIMEDOUT') + }) +}) diff --git a/backend/test/services/sse-aggregator.test.ts b/backend/test/services/sse-aggregator.test.ts index 0e6ac900..4d6a3fd7 100644 --- a/backend/test/services/sse-aggregator.test.ts +++ b/backend/test/services/sse-aggregator.test.ts @@ -23,10 +23,15 @@ interface CapturedEvent { function createCapturingClient() { const events: CapturedEvent[] = [] + const frames: string[] = [] + const decoder = new TextDecoder() const callback = (event: string, data: string) => { events.push({ event, data }) } - return { callback, events } + const writeFrame = (frame: Uint8Array) => { + frames.push(decoder.decode(frame)) + } + return { callback, writeFrame, events, frames } } function makeFetcher(map: Record): PendingActionsFetcher { @@ -69,29 +74,29 @@ describe('SSEAggregator pending replay on connect', () => { }) sseAggregator.setPendingActionsFetcher(fetcher) - const { callback, events } = createCapturingClient() - sseAggregator.addClient('client-1', callback, ['/repo/a', '/repo/b']) + const { callback, writeFrame, events } = createCapturingClient() + sseAggregator.addClient('client-1', callback, writeFrame, ['/repo/a', '/repo/b']) await flushReplay() expect(events).toHaveLength(4) - const parsed = events.map(e => JSON.parse(e.data) as { type: string; properties: { id: string }; directory: string }) + const parsed = events.map(e => JSON.parse(e.data) as { directory: string; payload: { type: string; properties: { id: string } } }) - expect(parsed.filter(p => p.type === 'permission.asked' && p.directory === '/repo/a').map(p => p.properties.id)).toEqual([ + expect(parsed.filter(p => p.payload.type === 'permission.asked' && p.directory === '/repo/a').map(p => p.payload.properties.id)).toEqual([ 'perm-1', 'perm-2', ]) - expect(parsed.filter(p => p.type === 'question.asked' && p.directory === '/repo/a').map(p => p.properties.id)).toEqual(['q-1']) - expect(parsed.filter(p => p.type === 'permission.asked' && p.directory === '/repo/b').map(p => p.properties.id)).toEqual([ + expect(parsed.filter(p => p.payload.type === 'question.asked' && p.directory === '/repo/a').map(p => p.payload.properties.id)).toEqual(['q-1']) + expect(parsed.filter(p => p.payload.type === 'permission.asked' && p.directory === '/repo/b').map(p => p.payload.properties.id)).toEqual([ 'perm-3', ]) - expect(parsed.filter(p => p.type === 'question.asked' && p.directory === '/repo/b')).toHaveLength(0) + expect(parsed.filter(p => p.payload.type === 'question.asked' && p.directory === '/repo/b')).toHaveLength(0) }) it('does not replay when no fetcher is configured', async () => { - const { callback, events } = createCapturingClient() - sseAggregator.addClient('client-2', callback, ['/repo/a']) + const { callback, writeFrame, events } = createCapturingClient() + sseAggregator.addClient('client-2', callback, writeFrame, ['/repo/a']) await flushReplay() @@ -107,8 +112,8 @@ describe('SSEAggregator pending replay on connect', () => { const clientA = createCapturingClient() const clientB = createCapturingClient() - sseAggregator.addClient('a', clientA.callback, ['/repo/a']) - sseAggregator.addClient('b', clientB.callback, []) + sseAggregator.addClient('a', clientA.callback, clientA.writeFrame, ['/repo/a']) + sseAggregator.addClient('b', clientB.callback, clientB.writeFrame, []) await flushReplay() @@ -123,8 +128,8 @@ describe('SSEAggregator pending replay on connect', () => { }) sseAggregator.setPendingActionsFetcher(fetcher) - const { callback, events } = createCapturingClient() - sseAggregator.addClient('client-3', callback, ['/repo/a']) + const { callback, writeFrame, events } = createCapturingClient() + sseAggregator.addClient('client-3', callback, writeFrame, ['/repo/a']) await flushReplay() const initialCount = events.length @@ -134,11 +139,11 @@ describe('SSEAggregator pending replay on connect', () => { await flushReplay() const newEvents = events.slice(initialCount) - const parsed = newEvents.map(e => JSON.parse(e.data) as { type: string; directory: string; properties: { id: string } }) + const parsed = newEvents.map(e => JSON.parse(e.data) as { directory: string; payload: { type: string; properties: { id: string } } }) expect(parsed).toHaveLength(1) const [first] = parsed expect(first?.directory).toBe('/repo/b') - expect(first?.properties.id).toBe('perm-2') + expect(first?.payload.properties.id).toBe('perm-2') }) it('survives upstream fetch failures for one directory and still replays the others', async () => { @@ -155,15 +160,15 @@ describe('SSEAggregator pending replay on connect', () => { } sseAggregator.setPendingActionsFetcher(fetcher) - const { callback, events } = createCapturingClient() - sseAggregator.addClient('client-4', callback, ['/repo/broken', '/repo/ok']) + const { callback, writeFrame, events } = createCapturingClient() + sseAggregator.addClient('client-4', callback, writeFrame, ['/repo/broken', '/repo/ok']) await flushReplay() - const parsed = events.map(e => JSON.parse(e.data) as { directory: string; properties: { id: string } }) + const parsed = events.map(e => JSON.parse(e.data) as { directory: string; payload: { properties: { id: string } } }) expect(parsed).toHaveLength(1) const [first] = parsed expect(first?.directory).toBe('/repo/ok') - expect(first?.properties.id).toBe('perm-ok') + expect(first?.payload.properties.id).toBe('perm-ok') }) it('does not deliver replay events to a client that no longer subscribes to that directory', async () => { @@ -180,8 +185,8 @@ describe('SSEAggregator pending replay on connect', () => { } sseAggregator.setPendingActionsFetcher(fetcher) - const { callback, events } = createCapturingClient() - sseAggregator.addClient('client-5', callback, ['/repo/a']) + const { callback, writeFrame, events } = createCapturingClient() + sseAggregator.addClient('client-5', callback, writeFrame, ['/repo/a']) sseAggregator.removeDirectories('client-5', ['/repo/a']) resolvePermissions([{ id: 'late', sessionID: 's' }]) diff --git a/docker-compose.yml b/docker-compose.yml index 0be7cb50..00a02334 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,6 +3,8 @@ services: build: context: . dockerfile: Dockerfile + args: + TOOLS_CACHEBUST: ${TOOLS_CACHEBUST:-0} container_name: opencode-manager ports: - "5003:5003" diff --git a/frontend/src/components/message/MessageThread.tsx b/frontend/src/components/message/MessageThread.tsx index 4d102d03..97f34362 100644 --- a/frontend/src/components/message/MessageThread.tsx +++ b/frontend/src/components/message/MessageThread.tsx @@ -135,8 +135,7 @@ const findLastMessageByRole = ( interface MessageRowProps { msgWithParts: MessageWithParts - index: number - messages: MessageWithParts[] + nextAssistantMessage: MessageWithParts | undefined pendingAssistantId: string | undefined lastUserMessageId: string | undefined isSessionBusy: boolean @@ -157,8 +156,7 @@ interface MessageRowProps { const MessageRow = memo(function MessageRow({ msgWithParts, - index, - messages, + nextAssistantMessage, pendingAssistantId, lastUserMessageId, isSessionBusy, @@ -183,7 +181,6 @@ const MessageRow = memo(function MessageRow({ const isLastUserMessage = msg.role === 'user' && msg.id === lastUserMessageId const messageTextContent = getMessageTextContent(parts) - const nextAssistantMessage = messages.slice(index + 1).find(m => m.info.role === 'assistant') const nextAssistantMsg = nextAssistantMessage?.info const isUserBeforeAssistant = msg.role === 'user' && nextAssistantMessage const canEditUserMessage = isLastUserMessage && isUserBeforeAssistant && !isSessionBusy @@ -352,6 +349,20 @@ export const MessageThread = memo(function MessageThread({ return findLastMessageByRole(messages, 'user') }, [messages]) + const nextAssistantByMessageId = useMemo(() => { + const map = new Map() + if (!messages) return map + let next: MessageWithParts | undefined + for (let i = messages.length - 1; i >= 0; i--) { + const msg = messages[i] + map.set(msg.info.id, next) + if (msg.info.role === 'assistant') { + next = msg + } + } + return map + }, [messages]) + const isSessionBusy = !!pendingAssistantId || isSessionInRetry(sessionStatus) const setSessionTodos = useSessionTodos((state) => state.setTodos) @@ -415,12 +426,11 @@ export const MessageThread = memo(function MessageThread({ return (
- {messages.map((msgWithParts, index) => ( + {messages.map((msgWithParts) => ( { contentVersion: newMessages.length, onScrollStateChange, }) + }) + + act(() => { vi.advanceTimersByTime(100) }) @@ -162,6 +165,7 @@ describe('useAutoScroll', () => { contentVersion: messages.length + 1, onScrollStateChange, }) + vi.advanceTimersByTime(100) }) expect(containerHarness.getScrollTop()).toBe(containerHarness.div.scrollHeight - containerHarness.div.clientHeight) @@ -231,6 +235,29 @@ describe('useAutoScroll', () => { expect(containerHarness.getScrollTop()).toBe(userPosition) }) + it('cancels pending bottom scroll when user wheel-scrolls up before pending frames complete', () => { + const messages = [createMessage('1', 'user'), createMessage('2', 'assistant')] + const { renderResult, containerHarness } = setupHook(messages) + + act(() => { + renderResult.result.current.scrollToBottom() + }) + + const userPosition = 150 + act(() => { + containerHarness.div.dispatchEvent( + new WheelEvent('wheel', { + deltaY: -50, + bubbles: true, + }) + ) + containerHarness.setScrollTop(userPosition) + vi.runOnlyPendingTimers() + }) + + expect(containerHarness.getScrollTop()).toBe(userPosition) + }) + it('does not show scroll button on tiny upward drag from bottom', () => { const messages = [createMessage('1', 'user')] const { containerHarness, onScrollStateChange } = setupHook(messages) @@ -431,6 +458,10 @@ describe('useAutoScroll', () => { }) }) + act(() => { + vi.advanceTimersByTime(100) + }) + expect(containerHarness.getScrollTop()).toBe(containerHarness.div.scrollHeight - containerHarness.div.clientHeight) }) @@ -459,6 +490,10 @@ describe('useAutoScroll', () => { }) }) + act(() => { + vi.advanceTimersByTime(100) + }) + expect(containerHarness.getScrollTop()).toBe(containerHarness.div.scrollHeight - containerHarness.div.clientHeight) }) }) diff --git a/frontend/src/hooks/useAutoScroll.ts b/frontend/src/hooks/useAutoScroll.ts index 202bb91c..df2f9ab1 100644 --- a/frontend/src/hooks/useAutoScroll.ts +++ b/frontend/src/hooks/useAutoScroll.ts @@ -46,17 +46,13 @@ export function useAutoScroll({ isScrollButtonVisibleRef.current = false const scrollRequestId = scrollRequestIdRef.current + 1 scrollRequestIdRef.current = scrollRequestId - const scroll = () => { - if (!containerRef?.current) return - containerRef.current.scrollTop = containerRef.current.scrollHeight - } - scroll() let frameCount = 0 const scrollAfterLayout = () => { if (scrollRequestIdRef.current !== scrollRequestId) return + if (!containerRef?.current) return + containerRef.current.scrollTop = containerRef.current.scrollHeight frameCount += 1 - scroll() if (frameCount < SCROLL_TO_BOTTOM_FRAME_COUNT) { requestAnimationFrame(scrollAfterLayout) } diff --git a/frontend/src/hooks/useSSE.test.tsx b/frontend/src/hooks/useSSE.test.tsx index 72ad8953..e1c9795c 100644 --- a/frontend/src/hooks/useSSE.test.tsx +++ b/frontend/src/hooks/useSSE.test.tsx @@ -4,6 +4,7 @@ import type { ReactNode } from 'react' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { useSSE } from './useSSE' import { useSessionStatus } from '../stores/sessionStatusStore' +import type { Part, MessageWithParts } from '@/api/types' const mocks = vi.hoisted(() => ({ getSessionStatuses: vi.fn(), @@ -274,4 +275,181 @@ describe('useSSE', () => { unmount() }) + + it('routes streamed part deltas to the event directory in multi-directory subscriptions', async () => { + const origRAF = window.requestAnimationFrame + window.requestAnimationFrame = ((cb: FrameRequestCallback) => { + cb(0) + return 0 + }) as typeof window.requestAnimationFrame + + try { + const queryClient = new QueryClient({ + defaultOptions: { queries: { retry: false } }, + }) + + // Seed both directory caches before rendering the hook + queryClient.setQueryData( + ['opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo-a'], + [{ + ...assistantMessage('session-1', 'message-1'), + parts: [textPart('session-1', 'message-1', 'part-1', 'A')], + }], + ) + queryClient.setQueryData( + ['opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo-b'], + [{ + ...assistantMessage('session-1', 'message-1'), + parts: [textPart('session-1', 'message-1', 'part-1', 'B')], + }], + ) + + const wrapper = ({ children }: { children: ReactNode }) => ( + {children} + ) + + // Use stable reference to avoid re-render loop with inline array + const directories = ['/repo-a', '/repo-b'] + const { result, unmount } = renderHook( + () => useSSE('http://localhost:5551', directories, 'session-1'), + { wrapper }, + ) + + await waitFor(() => { + expect(MockEventSource.instances.length).toBeGreaterThanOrEqual(1) + }) + + const eventSource = MockEventSource.instances[MockEventSource.instances.length - 1] + + act(() => { + eventSource.emit('connected', { clientId: 'client-1' }) + }) + + await waitFor(() => expect(result.current.isConnected).toBe(true)) + + act(() => { + eventSource.emit('message', { + type: 'message.part.delta', + directory: '/repo-b', + properties: { + sessionID: 'session-1', + messageID: 'message-1', + partID: 'part-1', + field: 'text', + delta: ' + streamed', + }, + }) + }) + + const repoBData = queryClient.getQueryData([ + 'opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo-b', + ]) + expect(repoBData![0].parts[0]).toHaveProperty('text', 'B + streamed') + + const repoAData = queryClient.getQueryData([ + 'opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo-a', + ]) + expect(repoAData![0].parts[0]).toHaveProperty('text', 'A') + + unmount() + } finally { + window.requestAnimationFrame = origRAF + } + }) + + it('processes part deltas when directory transitions from undefined to a real value', async () => { + const origRAF = window.requestAnimationFrame + window.requestAnimationFrame = ((cb: FrameRequestCallback) => { + cb(0) + return 0 + }) as typeof window.requestAnimationFrame + + try { + const queryClient = new QueryClient({ + defaultOptions: { queries: { retry: false } }, + }) + const wrapper = ({ children }: { children: ReactNode }) => ( + {children} + ) + + // Seed cache with an empty part + queryClient.setQueryData( + ['opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo'], + [{ + ...assistantMessage('session-1', 'message-1'), + parts: [textPart('session-1', 'message-1', 'part-1', '')], + }], + ) + + // Initial render with directory=undefined — batcher should be created eagerly + const { rerender, unmount } = renderHook( + ({ directory }) => useSSE('http://localhost:5551', directory, 'session-1'), + { wrapper, initialProps: { directory: undefined as string | undefined } }, + ) + + // No SSE subscription yet because directoriesList is empty + expect(MockEventSource.instances).toHaveLength(0) + + // Re-render with a real directory to start the SSE subscription + rerender({ directory: '/repo' }) + + await waitFor(() => expect(MockEventSource.instances).toHaveLength(1)) + const eventSource = MockEventSource.instances[0] + + act(() => { + eventSource.emit('connected', { clientId: 'client-1' }) + }) + + // Emit a part delta — the batcher was created on the initial mount, + // so it should process the event even though directory was undefined at mount time + act(() => { + eventSource.emit('message', { + type: 'message.part.delta', + directory: '/repo', + properties: { + sessionID: 'session-1', + messageID: 'message-1', + partID: 'part-1', + field: 'text', + delta: 'streamed content', + }, + }) + }) + + await waitFor(() => { + const data = queryClient.getQueryData([ + 'opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo', + ]) + expect(data![0].parts[0]).toHaveProperty('text', 'streamed content') + }) + + unmount() + } finally { + window.requestAnimationFrame = origRAF + } + }) }) + +function assistantMessage(sessionID: string, messageID: string): MessageWithParts { + return { + info: { + id: messageID, + sessionID, + role: 'assistant', + time: { created: Date.now() }, + parentID: '', + modelID: 'test-model', + providerID: 'test-provider', + mode: 'test', + agent: 'test-agent', + path: { cwd: '/test', root: '/test' }, + cost: 0, + tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } }, + }, + parts: [], + } +} + +function textPart(sessionID: string, messageID: string, partID: string, text: string): Part { + return { id: partID, sessionID, messageID, type: 'text', text } as Part +} diff --git a/frontend/src/hooks/useSSE.ts b/frontend/src/hooks/useSSE.ts index 020d13f7..661acc59 100644 --- a/frontend/src/hooks/useSSE.ts +++ b/frontend/src/hooks/useSSE.ts @@ -73,19 +73,21 @@ export const useSSE = (opcodeUrl: string | null | undefined, directory?: string const batcherRef = useRef | null>(null) useEffect(() => { - if (!opcodeUrl || !primaryDirectory) { + if (!opcodeUrl) { batcherRef.current?.destroy() batcherRef.current = null return } - batcherRef.current = createPartsBatcher(queryClient, opcodeUrl, primaryDirectory) + if (!batcherRef.current) { + batcherRef.current = createPartsBatcher(queryClient, opcodeUrl) + } return () => { batcherRef.current?.destroy() batcherRef.current = null } - }, [queryClient, opcodeUrl, primaryDirectory]) + }, [queryClient, opcodeUrl]) const resolveCacheDirectory = useCallback( (eventDirectory: string | undefined): string | undefined => { @@ -134,14 +136,14 @@ export const useSSE = (opcodeUrl: string | null | undefined, directory?: string case 'messagev2.part.updated': { if (!('part' in event.properties)) break const { part } = event.properties - batcherRef.current?.queuePartUpdate(part.sessionID, part) + batcherRef.current?.queuePartUpdate(part.sessionID, part, cacheDirectory) break } case 'message.part.delta': { if (!('sessionID' in event.properties && 'messageID' in event.properties && 'partID' in event.properties && 'field' in event.properties && 'delta' in event.properties)) break const { sessionID, messageID, partID, field, delta } = event.properties - batcherRef.current?.queuePartDelta(sessionID, messageID, partID, field, delta) + batcherRef.current?.queuePartDelta(sessionID, messageID, partID, field, delta, cacheDirectory) break } @@ -166,15 +168,15 @@ export const useSSE = (opcodeUrl: string | null | undefined, directory?: string ? currentData.filter(msgWithParts => !msgWithParts.info.id.startsWith('optimistic_')) : currentData queryClient.setQueryData(messagesQueryKey, [...filteredData, { info, parts: [] }]) - return + } else { + const updated = currentData.map(msgWithParts => { + if (msgWithParts.info.id !== info.id) return msgWithParts + return { ...msgWithParts, info: { ...info } } + }) + queryClient.setQueryData(messagesQueryKey, updated) } - const updated = currentData.map(msgWithParts => { - if (msgWithParts.info.id !== info.id) return msgWithParts - return { ...msgWithParts, info: { ...info } } - }) - - queryClient.setQueryData(messagesQueryKey, updated) + batcherRef.current?.flush({ sessionID, directory: cacheDirectory }) break } @@ -200,7 +202,7 @@ export const useSSE = (opcodeUrl: string | null | undefined, directory?: string const { sessionID, messageID, partID } = event.properties - batcherRef.current?.queuePartRemoval(sessionID, messageID, partID) + batcherRef.current?.queuePartRemoval(sessionID, messageID, partID, cacheDirectory) break } @@ -224,11 +226,14 @@ export const useSSE = (opcodeUrl: string | null | undefined, directory?: string setSessionStatus(sessionID, { type: 'idle' }) - batcherRef.current?.flush() + batcherRef.current?.flush({ sessionID, directory: cacheDirectory }) const messagesQueryKey = ['opencode', 'messages', opcodeUrl, sessionID, cacheDirectory] const currentData = queryClient.getQueryData(messagesQueryKey) - if (!currentData) break + if (!currentData) { + queryClient.invalidateQueries({ queryKey: messagesQueryKey }) + break + } const now = Date.now() const updated = currentData.map(msgWithParts => { diff --git a/frontend/src/lib/opencode-event-stream/__tests__/openCodeEventStream.test.ts b/frontend/src/lib/opencode-event-stream/__tests__/openCodeEventStream.test.ts index 8ae73320..69848f46 100644 --- a/frontend/src/lib/opencode-event-stream/__tests__/openCodeEventStream.test.ts +++ b/frontend/src/lib/opencode-event-stream/__tests__/openCodeEventStream.test.ts @@ -19,7 +19,7 @@ describe('OpenCodeEventStream', () => { stream.subscribeGlobalMonitor({ directories: ['/repo'], onEvent }) transport.openConnection() transport.connected() - transport.message({ type: 'permission.asked', properties: { sessionID: 'session-1' }, directory: '/repo' }) + transport.message({ directory: '/repo', payload: { type: 'permission.asked', properties: { sessionID: 'session-1' } } }) expect(onEvent).toHaveBeenCalledWith({ type: 'permission.asked', @@ -79,6 +79,56 @@ describe('OpenCodeEventStream', () => { }) }) + it('opens the initial EventSource URL with first subscriber directories', () => { + const transport = new TestEventStreamTransport() + const stream = new OpenCodeEventStream({ transport }) + + stream.subscribeGlobalMonitor({ directories: ['/repo'], onEvent: vi.fn() }) + + expect(transport.openedUrls).toHaveLength(1) + const url = transport.openedUrls[0] + expect(url).toContain('/api/sse/stream') + expect(url).toContain(encodeURIComponent('/repo')) + }) + + it('opens the initial EventSource URL with all first subscriber directories', () => { + const transport = new TestEventStreamTransport() + const stream = new OpenCodeEventStream({ transport }) + + stream.subscribeGlobalMonitor({ directories: ['/repo-a', '/repo-b'], onEvent: vi.fn() }) + + expect(transport.openedUrls).toHaveLength(1) + const url = transport.openedUrls[0] + expect(url).toContain('/api/sse/stream') + expect(url).toContain(encodeURIComponent('/repo-a')) + expect(url).toContain(encodeURIComponent('/repo-b')) + }) + + it('marks health unhealthy when backend connected event reports no upstream connection', () => { + const transport = new TestEventStreamTransport() + const stream = new OpenCodeEventStream({ transport }) + const healthStates: EventStreamHealthState[] = [] + + stream.subscribeGlobalMonitor({ + directories: [], + onEvent: vi.fn(), + onHealthChange: (health) => healthStates.push(health), + }) + + transport.openConnection() + transport.connectedPayload({ clientId: 'client-1', connected: 0, total: 1 }) + + const unhealthyState = healthStates.at(-1) + expect(unhealthyState?.isConnected).toBe(true) + expect(unhealthyState?.isHealthy).toBe(false) + + transport.connectedPayload({ clientId: 'client-1', connected: 1, total: 1 }) + + const healthyState = healthStates.at(-1) + expect(healthyState?.isConnected).toBe(true) + expect(healthyState?.isHealthy).toBe(true) + }) + it('reports visibility through the transport adapter', async () => { const transport = new TestEventStreamTransport() const stream = new OpenCodeEventStream({ transport }) diff --git a/frontend/src/lib/opencode-event-stream/openCodeEventStream.ts b/frontend/src/lib/opencode-event-stream/openCodeEventStream.ts index 62c6d4c2..b9946f68 100644 --- a/frontend/src/lib/opencode-event-stream/openCodeEventStream.ts +++ b/frontend/src/lib/opencode-event-stream/openCodeEventStream.ts @@ -35,6 +35,8 @@ export class OpenCodeEventStream { private clientId: string | null = null private lastEventAt: number | null = null private watchdogTimer: ReturnType | null = null + private upstreamConnectedCount: number | null = null + private upstreamTotalCount: number | null = null constructor(options: OpenCodeEventStreamOptions = {}) { this.transport = options.transport ?? createBrowserEventStreamTransport() @@ -46,8 +48,7 @@ export class OpenCodeEventStream { onStatusChange?: EventStreamStatusHandler onHealthChange?: (state: EventStreamHealthState) => void }): GlobalMonitorSubscription { - const id = this.addSubscriber(input.onEvent, input.onStatusChange, input.onHealthChange) - this.updateSubscriberDirectories(id, input.directories) + const id = this.addSubscriber(input.onEvent, input.onStatusChange, input.onHealthChange, input.directories) return { updateDirectories: (directories) => this.updateSubscriberDirectories(id, directories), @@ -65,14 +66,31 @@ export class OpenCodeEventStream { onEvent: OpenCodeEventHandler, onStatusChange?: EventStreamStatusHandler, onHealthChange?: (state: EventStreamHealthState) => void, + directories: string[] = [], ): string { const id = `sub_${++this.subscriberIdCounter}` + + const initialDirectories = new Set(directories.filter(Boolean)) + const newDirectories: string[] = [] + + for (const directory of initialDirectories) { + const currentCount = this.directoryRefCounts.get(directory) ?? 0 + this.directoryRefCounts.set(directory, currentCount + 1) + if (currentCount === 0) { + if (this.clientId && this.connected) { + newDirectories.push(directory) + } else { + this.pendingDirectories.add(directory) + } + } + } + this.subscribers.set(id, { id, onEvent, onStatusChange, onHealthChange, - directories: new Set(), + directories: initialDirectories, }) onStatusChange?.(this.connected) @@ -80,6 +98,8 @@ export class OpenCodeEventStream { if (this.subscribers.size === 1) { this.connect() + } else if (newDirectories.length > 0) { + void this.subscribeToRemoteDirectories(newDirectories) } return id @@ -186,6 +206,8 @@ export class OpenCodeEventStream { private handleError(): void { this.connected = false this.clientId = null + this.upstreamConnectedCount = null + this.upstreamTotalCount = null this.stopWatchdog() this.lastEventAt = null @@ -205,7 +227,7 @@ export class OpenCodeEventStream { private handleMessage(data: string): void { try { this.markActivity() - this.broadcast(JSON.parse(data)) + this.broadcast(flattenEventEnvelope(JSON.parse(data))) } catch { this.markActivity() } @@ -213,10 +235,16 @@ export class OpenCodeEventStream { private handleConnected(data: string): void { try { - const parsed = JSON.parse(data) as { clientId?: unknown } + const parsed = JSON.parse(data) as { clientId?: unknown; connected?: unknown; total?: unknown } if (typeof parsed.clientId === 'string') { this.clientId = parsed.clientId } + if (typeof parsed.connected === 'number') { + this.upstreamConnectedCount = parsed.connected + } + if (typeof parsed.total === 'number') { + this.upstreamTotalCount = parsed.total + } } catch { this.clientId = null } @@ -239,6 +267,8 @@ export class OpenCodeEventStream { this.connection = null this.connected = false this.clientId = null + this.upstreamConnectedCount = null + this.upstreamTotalCount = null this.lastEventAt = null this.pendingDirectories.clear() this.notifyHealth() @@ -263,6 +293,8 @@ export class OpenCodeEventStream { this.connection = null this.connected = false this.clientId = null + this.upstreamConnectedCount = null + this.upstreamTotalCount = null this.lastEventAt = null this.pendingDirectories = new Set(this.directoryRefCounts.keys()) this.notifyStatusChange(false) @@ -308,9 +340,10 @@ export class OpenCodeEventStream { private buildHealth(): EventStreamHealthState { const isStalled = this.connected && this.lastEventAt != null && Date.now() - this.lastEventAt > STALL_THRESHOLD_MS + const upstreamDisconnected = this.upstreamTotalCount != null && this.upstreamTotalCount > 0 && this.upstreamConnectedCount === 0 return { isConnected: this.connected, - isHealthy: this.connected && this.lastEventAt != null && !isStalled, + isHealthy: this.connected && this.lastEventAt != null && !isStalled && !upstreamDisconnected, lastEventAt: this.lastEventAt, isStalled, } @@ -381,4 +414,18 @@ export class OpenCodeEventStream { } } +function flattenEventEnvelope(parsed: unknown): unknown { + if ( + parsed !== null && + typeof parsed === 'object' && + 'payload' in parsed && + (parsed as { payload: unknown }).payload !== null && + typeof (parsed as { payload: unknown }).payload === 'object' + ) { + const { payload, directory } = parsed as { payload: object; directory?: unknown } + return { ...payload, directory } + } + return parsed +} + export const openCodeEventStream = new OpenCodeEventStream() diff --git a/frontend/src/lib/opencode-event-stream/testTransport.ts b/frontend/src/lib/opencode-event-stream/testTransport.ts index d7e7511e..68086192 100644 --- a/frontend/src/lib/opencode-event-stream/testTransport.ts +++ b/frontend/src/lib/opencode-event-stream/testTransport.ts @@ -2,11 +2,13 @@ import type { EventStreamConnection, EventStreamTransport, EventStreamTransportH export class TestEventStreamTransport implements EventStreamTransport { readonly posts: Array<{ path: string; body: unknown }> = [] + readonly openedUrls: string[] = [] closeCount = 0 private handlers: EventStreamTransportHandlers | null = null private connection: EventStreamConnection | null = null - open(_url: string, handlers: EventStreamTransportHandlers): EventStreamConnection { + open(url: string, handlers: EventStreamTransportHandlers): EventStreamConnection { + this.openedUrls.push(url) this.handlers = handlers this.connection = { close: () => { @@ -29,7 +31,11 @@ export class TestEventStreamTransport implements EventStreamTransport { } connected(clientId = 'test-client'): void { - this.handlers?.onConnected(JSON.stringify({ clientId })) + this.connectedPayload({ clientId }) + } + + connectedPayload(payload: unknown): void { + this.handlers?.onConnected(JSON.stringify(payload)) } message(data: unknown): void { diff --git a/frontend/src/lib/partsBatcher.test.ts b/frontend/src/lib/partsBatcher.test.ts new file mode 100644 index 00000000..ee0356ca --- /dev/null +++ b/frontend/src/lib/partsBatcher.test.ts @@ -0,0 +1,250 @@ +import { QueryClient } from '@tanstack/react-query' +import { describe, it, expect, vi } from 'vitest' +import { createPartsBatcher } from './partsBatcher' +import type { Part, MessageWithParts } from '@/api/types' + +function assistantMessage(sessionID: string, messageID: string): MessageWithParts { + return { + info: { + id: messageID, + sessionID, + role: 'assistant', + time: { created: Date.now() }, + parentID: '', + modelID: 'test-model', + providerID: 'test-provider', + mode: 'test', + agent: 'test-agent', + path: { cwd: '/test', root: '/test' }, + cost: 0, + tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } }, + }, + parts: [], + } +} + +function textPart(sessionID: string, messageID: string, partID: string, text: string): Part { + return { id: partID, sessionID, messageID, type: 'text', text } as Part +} + +function createManyCachedMessages(count: number, sessionID: string): MessageWithParts[] { + const messages: MessageWithParts[] = [] + for (let i = 0; i < count; i++) { + const msg = assistantMessage(sessionID, `msg-${i}`) + msg.parts = [textPart(sessionID, `msg-${i}`, `part-${i}`, `base text ${i}`)] + messages.push(msg) + } + return messages +} + +describe('createPartsBatcher', () => { + it('invalidates when part deltas arrive before message cache exists and applies a later authoritative upsert', () => { + const queryClient = new QueryClient() + const invalidateSpy = vi.spyOn(queryClient, 'invalidateQueries') + const batcher = createPartsBatcher(queryClient, 'http://localhost:5551') + + batcher.queuePartDelta('session-1', 'message-1', 'part-1', 'text', 'Hello', '/repo') + batcher.flush() + + expect( + queryClient.getQueryData(['opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo']), + ).toBeUndefined() + + expect(invalidateSpy).toHaveBeenCalledWith({ + queryKey: ['opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo'], + }) + + queryClient.setQueryData( + ['opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo'], + [assistantMessage('session-1', 'message-1')], + ) + + batcher.queuePartUpdate('session-1', textPart('session-1', 'message-1', 'part-1', 'Hello world'), '/repo') + batcher.flush() + + const data = queryClient.getQueryData([ + 'opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo', + ]) + expect(data).toHaveLength(1) + expect(data![0].parts).toHaveLength(1) + expect(data![0].parts[0]).toHaveProperty('text', 'Hello world') + }) + + it('does not replay stale deltas after authoritative upsert resolves the part', () => { + const queryClient = new QueryClient() + const batcher = createPartsBatcher(queryClient, 'http://localhost:5551') + + batcher.queuePartDelta('session-1', 'message-1', 'part-1', 'text', 'stale delta ', '/repo') + batcher.flush() + + queryClient.setQueryData( + ['opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo'], + [assistantMessage('session-1', 'message-1')], + ) + + batcher.queuePartUpdate('session-1', textPart('session-1', 'message-1', 'part-1', 'authoritative text'), '/repo') + batcher.flush() + + const data = queryClient.getQueryData([ + 'opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo', + ]) + expect(data![0].parts).toHaveLength(1) + expect(data![0].parts[0]).toHaveProperty('text', 'authoritative text') + + batcher.flush() + const data2 = queryClient.getQueryData([ + 'opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo', + ]) + expect(data2![0].parts[0]).toHaveProperty('text', 'authoritative text') + }) + + it('does not replay unapplied deltas onto refetched authoritative data', () => { + const queryClient = new QueryClient() + const invalidateSpy = vi.spyOn(queryClient, 'invalidateQueries') + const batcher = createPartsBatcher(queryClient, 'http://localhost:5551') + + queryClient.setQueryData( + ['opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo'], + [assistantMessage('session-1', 'message-1')], + ) + + batcher.queuePartDelta('session-1', 'message-1', 'part-1', 'text', ' stale', '/repo') + batcher.flush() + + expect(invalidateSpy).toHaveBeenCalledWith({ + queryKey: ['opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo'], + }) + + queryClient.setQueryData( + ['opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo'], + [{ ...assistantMessage('session-1', 'message-1'), parts: [textPart('session-1', 'message-1', 'part-1', 'fresh')] }], + ) + + batcher.flush() + + const data = queryClient.getQueryData([ + 'opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo', + ]) + expect(data![0].parts[0]).toHaveProperty('text', 'fresh') + }) + + it('applies deltas queued after an authoritative upsert in the same batch', () => { + const queryClient = new QueryClient() + const batcher = createPartsBatcher(queryClient, 'http://localhost:5551') + + queryClient.setQueryData( + ['opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo'], + [assistantMessage('session-1', 'message-1')], + ) + + batcher.queuePartUpdate('session-1', textPart('session-1', 'message-1', 'part-1', 'snapshot'), '/repo') + batcher.queuePartDelta('session-1', 'message-1', 'part-1', 'text', ' later', '/repo') + batcher.flush() + + const data = queryClient.getQueryData([ + 'opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo', + ]) + expect(data![0].parts).toHaveLength(1) + expect(data![0].parts[0]).toHaveProperty('text', 'snapshot later') + }) + + it('applies many queued part deltas with one cache write and no invalidation storm', () => { + const queryClient = new QueryClient() + const batcher = createPartsBatcher(queryClient, 'http://localhost:5551') + + const sessionID = 'session-1' + const directory = '/repo' + const messageCount = 1000 + const deltaCount = 500 + + const messages = createManyCachedMessages(messageCount, sessionID) + queryClient.setQueryData( + ['opencode', 'messages', 'http://localhost:5551', sessionID, directory], + messages, + ) + + const setQueryDataSpy = vi.spyOn(queryClient, 'setQueryData') + const invalidateSpy = vi.spyOn(queryClient, 'invalidateQueries') + + for (let i = 0; i < deltaCount; i++) { + batcher.queuePartDelta(sessionID, `msg-${i}`, `part-${i}`, 'text', ` delta ${i}`, directory) + } + + batcher.flush() + + expect(invalidateSpy).not.toHaveBeenCalled() + + const data = queryClient.getQueryData([ + 'opencode', 'messages', 'http://localhost:5551', sessionID, directory, + ]) + expect(data).toHaveLength(messageCount) + + for (let i = 0; i < deltaCount; i++) { + expect(data![i].parts[0]).toHaveProperty('text', `base text ${i} delta ${i}`) + } + + for (let i = deltaCount; i < messageCount; i++) { + expect(data![i].parts[0]).toHaveProperty('text', `base text ${i}`) + } + + const setQueryDataCalls = setQueryDataSpy.mock.calls.filter( + ([key]) => JSON.stringify(key).includes('opencode'), + ) + expect(setQueryDataCalls.length).toBe(1) + }) + + it('does not apply same-batch deltas for removed parts to shifted parts', () => { + const queryClient = new QueryClient() + const batcher = createPartsBatcher(queryClient, 'http://localhost:5551') + + queryClient.setQueryData( + ['opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo'], + [{ + ...assistantMessage('session-1', 'message-1'), + parts: [ + textPart('session-1', 'message-1', 'part-1', 'first'), + textPart('session-1', 'message-1', 'part-2', 'second'), + ], + }], + ) + + batcher.queuePartRemoval('session-1', 'message-1', 'part-1', '/repo') + batcher.queuePartDelta('session-1', 'message-1', 'part-1', 'text', ' stale', '/repo') + batcher.flush() + + const data = queryClient.getQueryData([ + 'opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo', + ]) + + expect(data![0].parts).toHaveLength(1) + expect(data![0].parts[0]).toHaveProperty('id', 'part-2') + expect(data![0].parts[0]).toHaveProperty('text', 'second') + }) + + it('applies deltas to the directory they were queued for', () => { + const queryClient = new QueryClient() + const batcher = createPartsBatcher(queryClient, 'http://localhost:5551') + + queryClient.setQueryData( + ['opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo-a'], + [assistantMessage('session-1', 'message-1')], + ) + queryClient.setQueryData( + ['opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo-b'], + [{ ...assistantMessage('session-1', 'message-1'), parts: [textPart('session-1', 'message-1', 'part-1', 'B')] }], + ) + + batcher.queuePartDelta('session-1', 'message-1', 'part-1', 'text', ' + chunk', '/repo-b') + batcher.flush() + + const repoBData = queryClient.getQueryData([ + 'opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo-b', + ]) + expect(repoBData![0].parts[0]).toHaveProperty('text', 'B + chunk') + + const repoAData = queryClient.getQueryData([ + 'opencode', 'messages', 'http://localhost:5551', 'session-1', '/repo-a', + ]) + expect(repoAData![0].parts).toHaveLength(0) + }) +}) diff --git a/frontend/src/lib/partsBatcher.ts b/frontend/src/lib/partsBatcher.ts index 8c879ed0..f14590da 100644 --- a/frontend/src/lib/partsBatcher.ts +++ b/frontend/src/lib/partsBatcher.ts @@ -2,10 +2,10 @@ import type { QueryClient } from '@tanstack/react-query' import type { Part, MessageWithParts } from '@/api/types' interface PartsBatcher { - queuePartUpdate: (sessionID: string, part: Part) => void - queuePartDelta: (sessionID: string, messageID: string, partID: string, field: string, delta: string) => void - queuePartRemoval: (sessionID: string, messageID: string, partID: string) => void - flush: () => void + queuePartUpdate: (sessionID: string, part: Part, directory?: string) => void + queuePartDelta: (sessionID: string, messageID: string, partID: string, field: string, delta: string, directory?: string) => void + queuePartRemoval: (sessionID: string, messageID: string, partID: string, directory?: string) => void + flush: (target?: { sessionID?: string; directory?: string }) => void destroy: () => void } @@ -14,12 +14,21 @@ type PartOperation = | { type: 'delta'; messageID: string; partID: string; field: string; delta: string } | { type: 'remove'; messageID: string; partID: string } +type OperationGroup = { + sessionID: string + directory?: string + operations: PartOperation[] +} + +function groupKey(sessionID: string, directory?: string): string { + return `${directory ?? ''}\0${sessionID}` +} + export function createPartsBatcher( queryClient: QueryClient, opcodeUrl: string, - directory?: string ): PartsBatcher { - const pendingOperations = new Map() + const pendingOperations = new Map() let pendingFrameId: number | null = null const scheduleFlush = () => { @@ -30,78 +39,180 @@ export function createPartsBatcher( }) } - const flush = () => { + const flush = (target?: { sessionID?: string; directory?: string }) => { if (pendingOperations.size === 0) return - for (const [sessionID, operations] of pendingOperations.entries()) { + const groupsToDelete: string[] = [] + const invalidatedGroupKeys = new Set() + + for (const [key, group] of pendingOperations.entries()) { + if (target) { + if (target.sessionID !== undefined && group.sessionID !== target.sessionID) continue + if (target.directory !== undefined && group.directory !== target.directory) continue + } + + const { sessionID, directory } = group const queryKey = ['opencode', 'messages', opcodeUrl, sessionID, directory] const currentData = queryClient.getQueryData(queryKey) - if (!currentData) continue + if (!currentData) { + if (!invalidatedGroupKeys.has(key)) { + invalidatedGroupKeys.add(key) + queryClient.invalidateQueries({ queryKey }) + } + groupsToDelete.push(key) + continue + } - let updatedData = [...currentData] + let updatedData = currentData + let dataMutated = false + const unapplied: PartOperation[] = [] + const supersededPartIDs = new Set() - for (const operation of operations) { - updatedData = updatedData.map((msgWithParts) => { - if (operation.type === 'upsert') { - if (msgWithParts.info.id !== operation.part.messageID) return msgWithParts + const messageIdxById = new Map() + for (let i = 0; i < currentData.length; i++) { + messageIdxById.set(currentData[i].info.id, i) + } - const existingIdx = msgWithParts.parts.findIndex((part) => part.id === operation.part.id) - const parts = [...msgWithParts.parts] - if (existingIdx >= 0) { - parts[existingIdx] = operation.part - } else { - parts.push(operation.part) - } + const partIdxCache = new Map>() - return { ...msgWithParts, parts } + const ensurePartIdx = (msgIdx: number, parts: Part[]): Map => { + let cache = partIdxCache.get(msgIdx) + if (!cache) { + cache = new Map() + for (let i = 0; i < parts.length; i++) { + cache.set(parts[i].id, i) } + partIdxCache.set(msgIdx, cache) + } + return cache + } - if (msgWithParts.info.id !== operation.messageID) return msgWithParts - - if (operation.type === 'remove') { - return { - ...msgWithParts, - parts: msgWithParts.parts.filter((part) => part.id !== operation.partID), - } + for (const operation of group.operations) { + if (operation.type === 'upsert') { + const msgIdx = messageIdxById.get(operation.part.messageID) + if (msgIdx === undefined) { + unapplied.push(operation) + continue + } + if (!dataMutated) { + updatedData = [...currentData] + dataMutated = true + } + const msg = updatedData[msgIdx] + const pIdx = ensurePartIdx(msgIdx, msg.parts) + const existingPartIdx = pIdx.get(operation.part.id) + let nextParts: Part[] + if (existingPartIdx !== undefined) { + nextParts = [...msg.parts] + nextParts[existingPartIdx] = operation.part + } else { + nextParts = [...msg.parts, operation.part] + pIdx.set(operation.part.id, nextParts.length - 1) + } + updatedData[msgIdx] = { ...msg, parts: nextParts } + supersededPartIDs.add(operation.part.id) + continue + } + + if (operation.type === 'remove') { + const msgIdx = messageIdxById.get(operation.messageID) + if (msgIdx === undefined) { + unapplied.push(operation) + continue + } + if (!dataMutated) { + updatedData = [...currentData] + dataMutated = true + } + const msg = updatedData[msgIdx] + const pIdx = ensurePartIdx(msgIdx, msg.parts) + if (pIdx.get(operation.partID) === undefined) { + unapplied.push(operation) + continue } + const nextParts = msg.parts.filter((part) => part.id !== operation.partID) + updatedData[msgIdx] = { ...msg, parts: nextParts } + partIdxCache.delete(msgIdx) + supersededPartIDs.add(operation.partID) + continue + } + + const msgIdx = messageIdxById.get(operation.messageID) + if (msgIdx === undefined) { + unapplied.push(operation) + continue + } + if (!dataMutated) { + updatedData = [...currentData] + dataMutated = true + } + const msg = updatedData[msgIdx] + const pIdx = ensurePartIdx(msgIdx, msg.parts) + const pIdxResult = pIdx.get(operation.partID) + if (pIdxResult === undefined) { + unapplied.push(operation) + continue + } + const targetPart = msg.parts[pIdxResult] + if (!targetPart) { + unapplied.push(operation) + continue + } + const nextParts = [...msg.parts] + const currentValue = (targetPart as Record)[operation.field] + const nextValue = `${typeof currentValue === 'string' ? currentValue : ''}${operation.delta}` + nextParts[pIdxResult] = { ...targetPart, [operation.field]: nextValue } as Part + updatedData[msgIdx] = { ...msg, parts: nextParts } + } - return { - ...msgWithParts, - parts: msgWithParts.parts.map((part) => { - if (part.id !== operation.partID) return part + if (dataMutated) { + queryClient.setQueryData(queryKey, updatedData) + } - const currentValue = (part as Record)[operation.field] - const nextValue = `${typeof currentValue === 'string' ? currentValue : ''}${operation.delta}` - return { ...part, [operation.field]: nextValue } as Part - }), - } - }) + const filteredUnapplied = unapplied.filter((op) => { + if (op.type === 'delta' || op.type === 'remove') { + return !supersededPartIDs.has(op.partID) + } + return true + }) + + if (filteredUnapplied.length > 0) { + if (!invalidatedGroupKeys.has(key)) { + invalidatedGroupKeys.add(key) + queryClient.invalidateQueries({ queryKey }) + } } - queryClient.setQueryData(queryKey, updatedData) + groupsToDelete.push(key) } - pendingOperations.clear() + for (const key of groupsToDelete) { + pendingOperations.delete(key) + } } - const queueOperation = (sessionID: string, operation: PartOperation) => { - const operations = pendingOperations.get(sessionID) ?? [] - operations.push(operation) - pendingOperations.set(sessionID, operations) + const queueOperation = (sessionID: string, operation: PartOperation, directory?: string) => { + const key = groupKey(sessionID, directory) + let group = pendingOperations.get(key) + if (!group) { + group = { sessionID, directory, operations: [] } + pendingOperations.set(key, group) + } + group.operations.push(operation) scheduleFlush() } - const queuePartUpdate = (sessionID: string, part: Part) => { - queueOperation(sessionID, { type: 'upsert', part }) + const queuePartUpdate = (sessionID: string, part: Part, directory?: string) => { + queueOperation(sessionID, { type: 'upsert', part }, directory) } - const queuePartDelta = (sessionID: string, messageID: string, partID: string, field: string, delta: string) => { - queueOperation(sessionID, { type: 'delta', messageID, partID, field, delta }) + const queuePartDelta = (sessionID: string, messageID: string, partID: string, field: string, delta: string, directory?: string) => { + queueOperation(sessionID, { type: 'delta', messageID, partID, field, delta }, directory) } - const queuePartRemoval = (sessionID: string, messageID: string, partID: string) => { - queueOperation(sessionID, { type: 'remove', messageID, partID }) + const queuePartRemoval = (sessionID: string, messageID: string, partID: string, directory?: string) => { + queueOperation(sessionID, { type: 'remove', messageID, partID }, directory) } const destroy = () => { diff --git a/scripts/docker-upgrade.sh b/scripts/docker-upgrade.sh index a630828b..a2f8ffec 100755 --- a/scripts/docker-upgrade.sh +++ b/scripts/docker-upgrade.sh @@ -13,9 +13,35 @@ fi NO_PULL=false SHOW_LOGS=false +MODE="cached" # cached | tools | full + +usage() { + cat <