diff --git a/src/main/core/conversations/conversation-session-supervisor.test.ts b/src/main/core/conversations/conversation-session-supervisor.test.ts new file mode 100644 index 0000000000..863a3a7867 --- /dev/null +++ b/src/main/core/conversations/conversation-session-supervisor.test.ts @@ -0,0 +1,82 @@ +import { describe, expect, it, vi } from 'vitest'; +import type { Pty } from '@main/core/pty/pty'; +import { + CONVERSATION_REPLACEMENT_SUSTAINED_MS, + ConversationSessionSupervisor, +} from './conversation-session-supervisor'; + +function fakePty(): Pty { + return { + write: vi.fn(), + resize: vi.fn(), + kill: vi.fn(), + onData: vi.fn(), + onExit: vi.fn(), + }; +} + +describe('ConversationSessionSupervisor', () => { + it('rejects and kills a spawn that returns after explicit stop invalidated its generation', () => { + const supervisor = new ConversationSessionSupervisor(); + const token = supervisor.beginStart('session-1'); + expect(token).toBeDefined(); + + supervisor.stop('session-1'); + + const pty = fakePty(); + expect(supervisor.acceptSpawn('session-1', token!, pty)).toBe(false); + }); + + it('allows one replacement inside a failure window and then fails until sustained running resets it', () => { + vi.useFakeTimers(); + try { + const supervisor = new ConversationSessionSupervisor(); + const first = fakePty(); + const second = fakePty(); + const firstToken = supervisor.beginStart('session-1'); + expect(supervisor.acceptSpawn('session-1', firstToken!, first)).toBe(true); + + expect(supervisor.handleExit('session-1', first)).toEqual({ kind: 'replace' }); + const secondToken = supervisor.beginStart('session-1', { + requireDesired: true, + }); + expect(supervisor.acceptSpawn('session-1', secondToken!, second)).toBe(true); + expect(supervisor.handleExit('session-1', second)).toEqual({ kind: 'failed' }); + + const third = fakePty(); + const thirdToken = supervisor.beginStart('session-2'); + expect(supervisor.acceptSpawn('session-2', thirdToken!, third)).toBe(true); + expect(supervisor.handleExit('session-2', third)).toEqual({ kind: 'replace' }); + const fourth = fakePty(); + const fourthToken = supervisor.beginStart('session-2', { + requireDesired: true, + }); + expect(supervisor.acceptSpawn('session-2', fourthToken!, fourth)).toBe(true); + vi.advanceTimersByTime(CONVERSATION_REPLACEMENT_SUSTAINED_MS); + expect(supervisor.handleExit('session-2', fourth)).toEqual({ kind: 'replace' }); + } finally { + vi.useRealTimers(); + } + }); + + it('clears the replacement failure window after a spawn failure', () => { + const supervisor = new ConversationSessionSupervisor(); + const first = fakePty(); + const firstToken = supervisor.beginStart('session-1'); + expect(supervisor.acceptSpawn('session-1', firstToken!, first)).toBe(true); + + expect(supervisor.handleExit('session-1', first)).toEqual({ kind: 'replace' }); + const failedToken = supervisor.beginStart('session-1', { + requireDesired: true, + }); + expect(failedToken).toBeDefined(); + supervisor.failSpawn('session-1', failedToken!); + + const retry = fakePty(); + const retryToken = supervisor.beginStart('session-1', { + requireDesired: true, + }); + expect(supervisor.acceptSpawn('session-1', retryToken!, retry)).toBe(true); + expect(supervisor.handleExit('session-1', retry)).toEqual({ kind: 'replace' }); + }); +}); diff --git a/src/main/core/conversations/conversation-session-supervisor.ts b/src/main/core/conversations/conversation-session-supervisor.ts new file mode 100644 index 0000000000..95b3f7ddf0 --- /dev/null +++ b/src/main/core/conversations/conversation-session-supervisor.ts @@ -0,0 +1,133 @@ +import type { Pty } from '@main/core/pty/pty'; + +export type ConversationSpawnToken = { + generation: number; +}; + +type ConversationRuntime = { + desired: boolean; + pty?: Pty; + spawnInFlightGeneration?: number; + replacementGeneration: number; + replacementAttemptedInWindow: boolean; + stableTimer?: ReturnType; +}; + +export const CONVERSATION_REPLACEMENT_SUSTAINED_MS = 5_000; + +export type ExitDecision = + | { kind: 'stale' } + | { kind: 'stopped' } + | { kind: 'replace' } + | { kind: 'failed' }; + +export class ConversationSessionSupervisor { + private runtimes = new Map(); + + beginStart( + sessionId: string, + options: { requireDesired?: boolean } = {} + ): ConversationSpawnToken | undefined { + const runtime = this.getOrCreateRuntime(sessionId); + if (runtime.pty || runtime.spawnInFlightGeneration !== undefined) return undefined; + if (options.requireDesired === true && !runtime.desired) return undefined; + + runtime.desired = true; + runtime.replacementGeneration += 1; + runtime.spawnInFlightGeneration = runtime.replacementGeneration; + + return { generation: runtime.replacementGeneration }; + } + + acceptSpawn(sessionId: string, token: ConversationSpawnToken, pty: Pty): boolean { + const runtime = this.runtimes.get(sessionId); + if (!runtime || runtime.spawnInFlightGeneration !== token.generation) return false; + + runtime.spawnInFlightGeneration = undefined; + if (!runtime.desired || runtime.replacementGeneration !== token.generation) return false; + + runtime.pty = pty; + this.armStableTimer(runtime); + return true; + } + + failSpawn(sessionId: string, token: ConversationSpawnToken): void { + const runtime = this.runtimes.get(sessionId); + if (!runtime || runtime.spawnInFlightGeneration !== token.generation) return; + runtime.spawnInFlightGeneration = undefined; + runtime.replacementAttemptedInWindow = false; + } + + stop(sessionId: string): Pty | undefined { + const runtime = this.runtimes.get(sessionId); + if (!runtime) return undefined; + + runtime.desired = false; + runtime.replacementGeneration += 1; + runtime.spawnInFlightGeneration = undefined; + this.clearStableTimer(runtime); + + const pty = runtime.pty; + runtime.pty = undefined; + runtime.replacementAttemptedInWindow = false; + return pty; + } + + isDesired(sessionId: string): boolean { + return this.runtimes.get(sessionId)?.desired === true; + } + + handleExit(sessionId: string, pty: Pty): ExitDecision { + const runtime = this.runtimes.get(sessionId); + if (!runtime || runtime.pty !== pty) return { kind: 'stale' }; + + runtime.pty = undefined; + runtime.spawnInFlightGeneration = undefined; + this.clearStableTimer(runtime); + + if (!runtime.desired) return { kind: 'stopped' }; + + if (runtime.replacementAttemptedInWindow) { + runtime.desired = false; + runtime.replacementAttemptedInWindow = false; + return { kind: 'failed' }; + } + + runtime.replacementAttemptedInWindow = true; + return { kind: 'replace' }; + } + + forget(sessionId: string): void { + const runtime = this.runtimes.get(sessionId); + if (runtime) this.clearStableTimer(runtime); + this.runtimes.delete(sessionId); + } + + private getOrCreateRuntime(sessionId: string): ConversationRuntime { + let runtime = this.runtimes.get(sessionId); + if (!runtime) { + runtime = { + desired: false, + replacementGeneration: 0, + replacementAttemptedInWindow: false, + }; + this.runtimes.set(sessionId, runtime); + } + return runtime; + } + + private armStableTimer(runtime: ConversationRuntime): void { + this.clearStableTimer(runtime); + runtime.stableTimer = setTimeout(() => { + runtime.replacementAttemptedInWindow = false; + runtime.stableTimer = undefined; + }, CONVERSATION_REPLACEMENT_SUSTAINED_MS); + } + + private clearStableTimer(runtime: ConversationRuntime): void { + if (runtime.stableTimer !== undefined) { + clearTimeout(runtime.stableTimer); + runtime.stableTimer = undefined; + } + } +} diff --git a/src/main/core/conversations/impl/conversation-provider-respawn.test.ts b/src/main/core/conversations/impl/conversation-provider-respawn.test.ts index 2aba199074..1e9d6682b9 100644 --- a/src/main/core/conversations/impl/conversation-provider-respawn.test.ts +++ b/src/main/core/conversations/impl/conversation-provider-respawn.test.ts @@ -1,7 +1,9 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'; import type { Pty, PtyExitInfo } from '@main/core/pty/pty'; +import { ptySessionRegistry } from '@main/core/pty/pty-session-registry'; import type { Conversation } from '@shared/conversations'; import { agentSessionExitedChannel } from '@shared/events/agentEvents'; +import { ptyExitChannel } from '@shared/events/ptyEvents'; import { makePtySessionId } from '@shared/ptySessionId'; import { LocalConversationProvider } from './local-conversation'; import { SshConversationProvider } from './ssh-conversation'; @@ -83,9 +85,9 @@ vi.mock('@main/core/settings/settings-service', () => ({ })); const { events } = await import('@main/lib/events'); +const { buildAgentSessionCommand } = await import('./agent-command'); type RespawnState = { - respawnCounts: Map; knownSessionIds: Set; sessions: Map; }; @@ -111,7 +113,10 @@ function sshProvider( { tmux = false, ctx = {} as never, - }: { tmux?: boolean; ctx?: ConstructorParameters[0]['ctx'] } = {} + }: { + tmux?: boolean; + ctx?: ConstructorParameters[0]['ctx']; + } = {} ) { return new SshConversationProvider({ projectId: 'project-1', @@ -152,6 +157,8 @@ describe('conversation provider respawn state', () => { spawnLocalPty.mockReset(); openSsh2Pty.mockReset(); vi.mocked(events.emit).mockClear(); + vi.mocked(buildAgentSessionCommand).mockClear(); + ptySessionRegistry.unregister('project-1:task-1:conversation-1'); }); it('passes global editor variables to local agent sessions', async () => { @@ -175,73 +182,149 @@ describe('conversation provider respawn state', () => { } }); - it('preserves resume mode when a local resumed session respawns within budget', async () => { + it('replaces a local conversation after clean exit by resuming the same provider session', async () => { vi.useFakeTimers(); try { - const exitHandlers: Array<(info: PtyExitInfo) => void> = []; - spawnLocalPty.mockReturnValue(fakePty(exitHandlers)); + const exitHandlers: Array void>> = []; + spawnLocalPty.mockImplementation(() => { + const handlers: Array<(info: PtyExitInfo) => void> = []; + exitHandlers.push(handlers); + return fakePty(handlers); + }); const provider = localProvider(); const size = { cols: 100, rows: 40 }; const initialPrompt = 'continue'; - const item = { ...conversation(), providerSessionId: undefined }; + const item = conversation(); await provider.startSession(item, size, true, initialPrompt); - const respawn = vi.spyOn(provider, 'startSession').mockResolvedValue(undefined); - - for (const handler of exitHandlers) handler({ exitCode: 1 }); + for (const handler of exitHandlers[0] ?? []) handler({ exitCode: 0 }); await vi.advanceTimersByTimeAsync(500); - expect(respawn).toHaveBeenCalledWith(item, size, true, initialPrompt); + expect(spawnLocalPty).toHaveBeenCalledTimes(2); + expect(buildAgentSessionCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ isResuming: true }) + ); } finally { vi.useRealTimers(); } }); - it('preserves resume mode when an SSH resumed session respawns within budget', async () => { + it('replaces an SSH conversation after clean exit by resuming the same provider session', async () => { vi.useFakeTimers(); try { - const exitHandlers: Array<(info: PtyExitInfo) => void> = []; - openSsh2Pty.mockResolvedValue({ success: true, data: fakePty(exitHandlers) }); + const exitHandlers: Array void>> = []; + openSsh2Pty.mockImplementation(() => { + const handlers: Array<(info: PtyExitInfo) => void> = []; + exitHandlers.push(handlers); + return Promise.resolve({ success: true, data: fakePty(handlers) }); + }); const provider = sshProvider(); const size = { cols: 100, rows: 40 }; const initialPrompt = 'continue'; - const item = { ...conversation(), providerSessionId: undefined }; + const item = conversation(); await provider.startSession(item, size, true, initialPrompt); - const respawn = vi.spyOn(provider, 'startSession').mockResolvedValue(undefined); - - for (const handler of exitHandlers) handler({ exitCode: 1 }); + for (const handler of exitHandlers[0] ?? []) handler({ exitCode: 0 }); await vi.advanceTimersByTimeAsync(500); - expect(respawn).toHaveBeenCalledWith(item, size, true, initialPrompt); + expect(openSsh2Pty).toHaveBeenCalledTimes(2); + expect(buildAgentSessionCommand).toHaveBeenLastCalledWith( + expect.objectContaining({ isResuming: true }) + ); } finally { vi.useRealTimers(); } }); - it('preserves resume mode on immediate exit within budget', async () => { + it('emits PTY exit when a local conversation unregisters before the registry exit handler runs', async () => { + const exitHandlers: Array<(info: PtyExitInfo) => void> = []; + const exitInfo = { exitCode: 0 }; + spawnLocalPty.mockReturnValue(fakePty(exitHandlers)); + const provider = localProvider(); + const item = conversation(); + const sessionId = makePtySessionId(item.projectId, item.taskId, item.id); + + await provider.startSession(item); + vi.mocked(events.emit).mockClear(); + for (const handler of exitHandlers) handler(exitInfo); + + expect(events.emit).toHaveBeenCalledWith(ptyExitChannel, exitInfo, sessionId); + }); + + it('emits PTY exit when an SSH conversation unregisters before the registry exit handler runs', async () => { + const exitHandlers: Array<(info: PtyExitInfo) => void> = []; + const exitInfo = { exitCode: 0 }; + openSsh2Pty.mockResolvedValue({ + success: true, + data: fakePty(exitHandlers), + }); + const provider = sshProvider(); + const item = conversation(); + const sessionId = makePtySessionId(item.projectId, item.taskId, item.id); + + await provider.startSession(item); + vi.mocked(events.emit).mockClear(); + for (const handler of exitHandlers) handler(exitInfo); + + expect(events.emit).toHaveBeenCalledWith(ptyExitChannel, exitInfo, sessionId); + }); + + it('uses the last observed terminal size when replacing a local conversation', async () => { vi.useFakeTimers(); try { - const exitHandlers: Array<(info: PtyExitInfo) => void> = []; - spawnLocalPty.mockReturnValue(fakePty(exitHandlers)); + const exitHandlers: Array void>> = []; + spawnLocalPty.mockImplementation(() => { + const handlers: Array<(info: PtyExitInfo) => void> = []; + exitHandlers.push(handlers); + return fakePty(handlers); + }); const provider = localProvider(); - const size = { cols: 100, rows: 40 }; - const initialPrompt = 'continue'; const item = conversation(); + const sessionId = makePtySessionId(item.projectId, item.taskId, item.id); - await provider.startSession(item, size, true, initialPrompt); - const respawn = vi.spyOn(provider, 'startSession').mockResolvedValue(undefined); + await provider.startSession(item, { cols: 100, rows: 40 }, true); + ptySessionRegistry.resize(sessionId, 68, 42); + for (const handler of exitHandlers[0] ?? []) handler({ exitCode: 0 }); + await vi.advanceTimersByTimeAsync(500); + + expect(spawnLocalPty).toHaveBeenNthCalledWith( + 2, + expect.objectContaining({ cols: 68, rows: 42 }) + ); + } finally { + vi.useRealTimers(); + } + }); + + it('uses the last observed terminal size when replacing an SSH conversation', async () => { + vi.useFakeTimers(); + try { + const exitHandlers: Array void>> = []; + openSsh2Pty.mockImplementation(() => { + const handlers: Array<(info: PtyExitInfo) => void> = []; + exitHandlers.push(handlers); + return Promise.resolve({ success: true, data: fakePty(handlers) }); + }); + const provider = sshProvider(); + const item = conversation(); + const sessionId = makePtySessionId(item.projectId, item.taskId, item.id); - for (const handler of exitHandlers) handler({ exitCode: 1 }); + await provider.startSession(item, { cols: 100, rows: 40 }, true); + ptySessionRegistry.resize(sessionId, 68, 42); + for (const handler of exitHandlers[0] ?? []) handler({ exitCode: 0 }); await vi.advanceTimersByTimeAsync(500); - expect(respawn).toHaveBeenCalledWith(item, size, true, initialPrompt); + expect(openSsh2Pty).toHaveBeenNthCalledWith( + 2, + expect.anything(), + expect.objectContaining({ cols: 68, rows: 42 }) + ); } finally { vi.useRealTimers(); } }); - it('falls back to fresh local session after resume exceeds respawn budget', async () => { + it('does not loop if the replacement exits inside the failure window', async () => { vi.useFakeTimers(); try { const exitHandlers: Array void>> = []; @@ -251,57 +334,112 @@ describe('conversation provider respawn state', () => { return fakePty(handlers); }); const provider = localProvider(); - const startSession = vi.spyOn(provider, 'startSession'); - const size = { cols: 100, rows: 40 }; - const initialPrompt = 'continue'; const item = conversation(); - await provider.startSession(item, size, true, initialPrompt); + await provider.startSession(item); for (const handler of exitHandlers[0] ?? []) handler({ exitCode: 1 }); await vi.advanceTimersByTimeAsync(500); - expect(startSession).toHaveBeenLastCalledWith(item, size, true, initialPrompt); + expect(spawnLocalPty).toHaveBeenCalledTimes(2); for (const handler of exitHandlers[1] ?? []) handler({ exitCode: 1 }); await vi.advanceTimersByTimeAsync(500); - expect(startSession).toHaveBeenLastCalledWith(item, size, true, initialPrompt); - for (const handler of exitHandlers[2] ?? []) handler({ exitCode: 1 }); - await vi.advanceTimersByTimeAsync(500); - expect(startSession).toHaveBeenLastCalledWith(item, size, false, initialPrompt); + expect(spawnLocalPty).toHaveBeenCalledTimes(2); } finally { vi.useRealTimers(); } }); - it('falls back to fresh SSH session after resume exceeds respawn budget', async () => { + it('does not loop when local replacement spawn fails', async () => { vi.useFakeTimers(); try { - const exitHandlers: Array void>> = []; - openSsh2Pty.mockImplementation(() => { - const handlers: Array<(info: PtyExitInfo) => void> = []; - exitHandlers.push(handlers); - return Promise.resolve({ success: true, data: fakePty(handlers) }); + const exitHandlers: Array<(info: PtyExitInfo) => void> = []; + spawnLocalPty.mockReturnValueOnce(fakePty(exitHandlers)).mockImplementationOnce(() => { + throw new Error('spawn failed'); }); - const provider = sshProvider(); - const startSession = vi.spyOn(provider, 'startSession'); - const size = { cols: 100, rows: 40 }; - const initialPrompt = 'continue'; + const provider = localProvider(); const item = conversation(); - await provider.startSession(item, size, true, initialPrompt); + await provider.startSession(item); + for (const handler of exitHandlers) handler({ exitCode: 0 }); + await vi.advanceTimersByTimeAsync(500); - for (const handler of exitHandlers[0] ?? []) handler({ exitCode: 1 }); + expect(spawnLocalPty).toHaveBeenCalledTimes(2); + } finally { + vi.useRealTimers(); + } + }); + + it('does not start a delayed local replacement after explicit stop', async () => { + vi.useFakeTimers(); + try { + const exitHandlers: Array<(info: PtyExitInfo) => void> = []; + spawnLocalPty.mockReturnValue(fakePty(exitHandlers)); + const provider = localProvider(); + const item = conversation(); + + await provider.startSession(item); + for (const handler of exitHandlers) handler({ exitCode: 0 }); + await provider.stopSession(item.id); await vi.advanceTimersByTimeAsync(500); - expect(startSession).toHaveBeenLastCalledWith(item, size, true, initialPrompt); - for (const handler of exitHandlers[1] ?? []) handler({ exitCode: 1 }); + expect(spawnLocalPty).toHaveBeenCalledTimes(1); + } finally { + vi.useRealTimers(); + } + }); + + it('does not replace a local tmux attachment after it exits', async () => { + vi.useFakeTimers(); + try { + const exitHandlers: Array<(info: PtyExitInfo) => void> = []; + spawnLocalPty.mockReturnValue(fakePty(exitHandlers)); + const provider = localProvider({ tmux: true }); + const item = conversation(); + + await provider.startSession(item); + vi.mocked(events.emit).mockClear(); + for (const handler of exitHandlers) handler({ exitCode: 0 }); await vi.advanceTimersByTimeAsync(500); - expect(startSession).toHaveBeenLastCalledWith(item, size, true, initialPrompt); - for (const handler of exitHandlers[2] ?? []) handler({ exitCode: 1 }); + expect(spawnLocalPty).toHaveBeenCalledTimes(1); + expect(events.emit).toHaveBeenCalledWith( + agentSessionExitedChannel, + expect.objectContaining({ + conversationId: item.id, + taskId: item.taskId, + }) + ); + } finally { + vi.useRealTimers(); + } + }); + + it('does not replace an SSH tmux attachment after it exits', async () => { + vi.useFakeTimers(); + try { + const exitHandlers: Array<(info: PtyExitInfo) => void> = []; + openSsh2Pty.mockResolvedValue({ + success: true, + data: fakePty(exitHandlers), + }); + const provider = sshProvider(undefined, { tmux: true }); + const item = conversation(); + + await provider.startSession(item); + vi.mocked(events.emit).mockClear(); + for (const handler of exitHandlers) handler({ exitCode: 0 }); await vi.advanceTimersByTimeAsync(500); - expect(startSession).toHaveBeenLastCalledWith(item, size, false, initialPrompt); + + expect(openSsh2Pty).toHaveBeenCalledTimes(1); + expect(events.emit).toHaveBeenCalledWith( + agentSessionExitedChannel, + expect.objectContaining({ + conversationId: item.id, + taskId: item.taskId, + }) + ); } finally { vi.useRealTimers(); } @@ -313,8 +451,14 @@ describe('conversation provider respawn state', () => { const firstExitHandlers: Array<(info: PtyExitInfo) => void> = []; const secondExitHandlers: Array<(info: PtyExitInfo) => void> = []; openSsh2Pty - .mockResolvedValueOnce({ success: true, data: fakePty(firstExitHandlers) }) - .mockResolvedValueOnce({ success: true, data: fakePty(secondExitHandlers) }); + .mockResolvedValueOnce({ + success: true, + data: fakePty(firstExitHandlers), + }) + .mockResolvedValueOnce({ + success: true, + data: fakePty(secondExitHandlers), + }); const proxy = { getRemoteShellProfile: vi.fn(async () => ({})), refreshRemoteShellProfile: vi.fn(async () => ({})), @@ -328,10 +472,7 @@ describe('conversation provider respawn state', () => { expect(proxy.refreshRemoteShellProfile).toHaveBeenCalledTimes(1); expect(openSsh2Pty).toHaveBeenCalledTimes(2); - expect(events.emit).not.toHaveBeenCalledWith( - agentSessionExitedChannel, - expect.objectContaining({ exitCode: 127 }) - ); + expect(events.emit).not.toHaveBeenCalledWith(agentSessionExitedChannel, expect.anything()); for (const handler of secondExitHandlers) handler({ exitCode: 127 }); await vi.advanceTimersByTimeAsync(500); @@ -340,31 +481,41 @@ describe('conversation provider respawn state', () => { expect(openSsh2Pty).toHaveBeenCalledTimes(2); expect(events.emit).toHaveBeenCalledWith( agentSessionExitedChannel, - expect.objectContaining({ exitCode: 127 }) + expect.objectContaining({ + conversationId: item.id, + taskId: item.taskId, + }) ); } finally { vi.useRealTimers(); } }); - it('clears local respawn counts when explicitly stopping a session', async () => { - const provider = localProvider(); - const sessionId = makePtySessionId('project-1', 'task-1', 'conversation-1'); - (provider as unknown as RespawnState).respawnCounts.set(sessionId, 3); - - await provider.stopSession('conversation-1'); - - expect((provider as unknown as RespawnState).respawnCounts.has(sessionId)).toBe(false); - }); - - it('clears SSH respawn counts when explicitly stopping a session', async () => { - const provider = sshProvider(); - const sessionId = makePtySessionId('project-1', 'task-1', 'conversation-1'); - (provider as unknown as RespawnState).respawnCounts.set(sessionId, 3); + it('does not refresh the SSH shell profile after explicit stop cancels a missing-command retry', async () => { + vi.useFakeTimers(); + try { + const exitHandlers: Array<(info: PtyExitInfo) => void> = []; + openSsh2Pty.mockResolvedValue({ + success: true, + data: fakePty(exitHandlers), + }); + const proxy = { + getRemoteShellProfile: vi.fn(async () => ({})), + refreshRemoteShellProfile: vi.fn(async () => ({})), + }; + const provider = sshProvider(proxy); + const item = conversation(); - await provider.stopSession('conversation-1'); + await provider.startSession(item); + for (const handler of exitHandlers) handler({ exitCode: 127 }); + await provider.stopSession(item.id); + await vi.advanceTimersByTimeAsync(500); - expect((provider as unknown as RespawnState).respawnCounts.has(sessionId)).toBe(false); + expect(proxy.refreshRemoteShellProfile).not.toHaveBeenCalled(); + expect(openSsh2Pty).toHaveBeenCalledTimes(1); + } finally { + vi.useRealTimers(); + } }); it('detaches local tmux conversations without killing the tmux session', async () => { @@ -385,10 +536,7 @@ describe('conversation provider respawn state', () => { expect(pty.kill).toHaveBeenCalledTimes(1); expect(ctx.exec).not.toHaveBeenCalled(); - expect(events.emit).not.toHaveBeenCalledWith( - agentSessionExitedChannel, - expect.objectContaining({ sessionId }) - ); + expect(events.emit).not.toHaveBeenCalledWith(agentSessionExitedChannel, expect.anything()); expect((provider as unknown as RespawnState).knownSessionIds.has(sessionId)).toBe(true); }); @@ -410,10 +558,7 @@ describe('conversation provider respawn state', () => { expect(pty.kill).toHaveBeenCalledTimes(1); expect(ctx.exec).not.toHaveBeenCalled(); - expect(events.emit).not.toHaveBeenCalledWith( - agentSessionExitedChannel, - expect.objectContaining({ sessionId }) - ); + expect(events.emit).not.toHaveBeenCalledWith(agentSessionExitedChannel, expect.anything()); expect((provider as unknown as RespawnState).knownSessionIds.has(sessionId)).toBe(true); }); @@ -441,7 +586,10 @@ describe('conversation provider respawn state', () => { it('kills tmux when explicitly stopping a detached SSH conversation', async () => { const exitHandlers: Array<(info: PtyExitInfo) => void> = []; - openSsh2Pty.mockResolvedValue({ success: true, data: fakePty(exitHandlers) }); + openSsh2Pty.mockResolvedValue({ + success: true, + data: fakePty(exitHandlers), + }); const ctx = { exec: vi.fn(async () => ({ stdout: '', stderr: '' })), }; @@ -478,10 +626,7 @@ describe('conversation provider respawn state', () => { for (const handler of firstExitHandlers) handler({ exitCode: 0 }); expect((provider as unknown as RespawnState).sessions.get(sessionId)).toBe(secondPty); - expect(events.emit).not.toHaveBeenCalledWith( - agentSessionExitedChannel, - expect.objectContaining({ sessionId }) - ); + expect(events.emit).not.toHaveBeenCalledWith(agentSessionExitedChannel, expect.anything()); }); it('ignores stale SSH attach exits after a tmux conversation is rehydrated', async () => { @@ -503,9 +648,6 @@ describe('conversation provider respawn state', () => { for (const handler of firstExitHandlers) handler({ exitCode: 0 }); expect((provider as unknown as RespawnState).sessions.get(sessionId)).toBe(secondPty); - expect(events.emit).not.toHaveBeenCalledWith( - agentSessionExitedChannel, - expect.objectContaining({ sessionId }) - ); + expect(events.emit).not.toHaveBeenCalledWith(agentSessionExitedChannel, expect.anything()); }); }); diff --git a/src/main/core/conversations/impl/local-conversation.ts b/src/main/core/conversations/impl/local-conversation.ts index 683b8d9f87..bc4dcfc4a1 100644 --- a/src/main/core/conversations/impl/local-conversation.ts +++ b/src/main/core/conversations/impl/local-conversation.ts @@ -3,10 +3,11 @@ import { agentHookService } from '@main/core/agent-hooks/agent-hook-service'; import { wireAgentClassifier } from '@main/core/agent-hooks/classifier-wiring'; import { claudeTrustService } from '@main/core/agent-hooks/claude-trust-service'; import { HookConfigWriter } from '@main/core/agent-hooks/hook-config'; +import { ConversationSessionSupervisor } from '@main/core/conversations/conversation-session-supervisor'; +import { resolveAgentSessionCommandArgs } from '@main/core/conversations/resolve-agent-session-command'; import type { ConversationProvider } from '@main/core/conversations/types'; import type { IExecutionContext } from '@main/core/execution-context/types'; import { LocalFileSystem } from '@main/core/fs/impl/local-fs'; -import { isUnexpectedPtyExit } from '@main/core/pty/exit-classification'; import { spawnLocalPty } from '@main/core/pty/local-pty'; import type { Pty } from '@main/core/pty/pty'; import { buildAgentEnv } from '@main/core/pty/pty-env'; @@ -23,7 +24,6 @@ import type { Conversation } from '@shared/conversations'; import { agentSessionExitedChannel } from '@shared/events/agentEvents'; import { makePtyId } from '@shared/ptyId'; import { makePtySessionId } from '@shared/ptySessionId'; -import { resolveAgentSessionCommandArgs } from '../resolve-agent-session-command'; import { buildAgentSessionCommand } from './agent-command'; import { syncGrokThemeWithAppTheme } from './grok-theme-config'; import { scheduleInitialPromptInjection } from './keystroke-injection'; @@ -31,13 +31,12 @@ import { resolveProviderEnv } from './provider-env'; const DEFAULT_COLS = 80; const DEFAULT_ROWS = 24; -const MAX_RESPAWNS = 2; +const RESPAWN_DELAY_MS = 500; export class LocalConversationProvider implements ConversationProvider { private sessions = new Map(); private knownSessionIds = new Set(); - private respawnCounts = new Map(); - private suppressedExitPtys = new WeakSet(); + private supervisor = new ConversationSessionSupervisor(); private readonly projectId: string; private readonly taskPath: string; private readonly taskId: string; @@ -80,9 +79,22 @@ export class LocalConversationProvider implements ConversationProvider { async startSession( conversation: Conversation, - initialSize: { cols: number; rows: number } = { cols: DEFAULT_COLS, rows: DEFAULT_ROWS }, + initialSize: { cols: number; rows: number } = { + cols: DEFAULT_COLS, + rows: DEFAULT_ROWS, + }, isResuming: boolean = false, initialPrompt?: string + ): Promise { + return this.startSessionInternal(conversation, initialSize, isResuming, initialPrompt, false); + } + + private async startSessionInternal( + conversation: Conversation, + initialSize: { cols: number; rows: number }, + isResuming: boolean, + initialPrompt: string | undefined, + requireDesired: boolean ): Promise { const sessionId = makePtySessionId( conversation.projectId, @@ -90,154 +102,169 @@ export class LocalConversationProvider implements ConversationProvider { conversation.id ); this.knownSessionIds.add(sessionId); - if (this.sessions.has(sessionId)) return; - await claudeTrustService.maybeAutoTrustLocal({ - providerId: conversation.providerId, - cwd: this.taskPath, - homedir: homedir(), - }); - const hooksAvailable = await this.prepareHookConfig(conversation.providerId); + const spawnSize = ptySessionRegistry.getLastSize(sessionId) ?? initialSize; + const spawnToken = this.supervisor.beginStart(sessionId, { requireDesired }); + if (!spawnToken) return; - const providerConfig = await providerOverrideSettings.getItem(conversation.providerId); - const providerDef = getProvider(conversation.providerId); - const agentSession = resolveAgentSessionCommandArgs(conversation, isResuming); - const { command, args } = buildAgentSessionCommand({ - providerId: conversation.providerId, - providerConfig, - autoApprove: conversation.autoApprove, - sessionId: conversation.id, - providerSessionId: conversation.providerSessionId, - isResuming: agentSession.isResuming, - initialPrompt, - }); - const providerEnv = resolveProviderEnv(providerConfig, { - providerId: conversation.providerId, - autoApprove: conversation.autoApprove, - }); - if (conversation.providerId === 'grok') { - await syncGrokThemeWithAppTheme({ env: providerEnv }); - } - - const tmuxSessionName = this.tmux ? makeTmuxSessionName(sessionId) : undefined; - - const resolved = resolveLocalPtySpawn({ - platform: process.platform, - env: process.env, - intent: { - kind: 'run-command', + try { + await claudeTrustService.maybeAutoTrustLocal({ + providerId: conversation.providerId, cwd: this.taskPath, - command: { kind: 'argv', command, args }, - shellSetup: this.shellSetup, - tmuxSessionName, - }, - }); + homedir: homedir(), + }); + const hooksAvailable = await this.prepareHookConfig(conversation.providerId); - logLocalPtySpawnWarnings('LocalConversationProvider', resolved.warnings, { - conversationId: conversation.id, - sessionId, - }); + const providerConfig = await providerOverrideSettings.getItem(conversation.providerId); + const providerDef = getProvider(conversation.providerId); + const agentSession = resolveAgentSessionCommandArgs(conversation, isResuming); + const { command, args } = buildAgentSessionCommand({ + providerId: conversation.providerId, + providerConfig, + autoApprove: conversation.autoApprove, + sessionId: agentSession.sessionId, + providerSessionId: conversation.providerSessionId, + isResuming: agentSession.isResuming, + initialPrompt, + }); + const providerEnv = resolveProviderEnv(providerConfig, { + providerId: conversation.providerId, + autoApprove: conversation.autoApprove, + }); + if (conversation.providerId === 'grok') { + await syncGrokThemeWithAppTheme({ env: providerEnv }); + } - const ptyId = makePtyId(conversation.providerId, conversation.id); - const port = agentHookService.getPort(); - const token = agentHookService.getToken(); - const hookActive = port > 0; - const ampHooksAvailable = - hookActive && - conversation.providerId === 'amp' && - providerDef?.supportsHooks && - hooksAvailable; - const pty = spawnLocalPty({ - id: sessionId, - command: resolved.command, - args: resolved.args, - cwd: resolved.cwd, - env: { - ...buildAgentEnv({ - hook: port > 0 ? { port, ptyId, token } : undefined, - providerVars: providerEnv, - }), - ...this.taskEnvVars, - ...(ampHooksAvailable && !this.taskEnvVars['PLUGINS'] ? { PLUGINS: 'all' } : {}), - }, - cols: initialSize.cols, - rows: initialSize.rows, - }); + const tmuxSessionName = this.tmux ? makeTmuxSessionName(sessionId) : undefined; - /* - * Codex hooks can be skipped by the CLI in some live-session edge cases. - * Amp hooks only cover lifecycle events today. Keep the output classifier - * active as a fallback so the UI can leave "working" and catch prompts. - */ - const useHooksOnly = - hookActive && - providerDef?.supportsHooks && - hooksAvailable && - conversation.providerId !== 'codex' && - conversation.providerId !== 'amp'; + const resolved = resolveLocalPtySpawn({ + platform: process.platform, + env: process.env, + intent: { + kind: 'run-command', + cwd: this.taskPath, + command: { kind: 'argv', command, args }, + shellSetup: this.shellSetup, + tmuxSessionName, + }, + }); - if (!useHooksOnly) { - wireAgentClassifier({ - pty, - providerId: conversation.providerId, - projectId: conversation.projectId, - taskId: conversation.taskId, + logLocalPtySpawnWarnings('LocalConversationProvider', resolved.warnings, { conversationId: conversation.id, + sessionId, }); - } - pty.onExit(({ exitCode, signal }) => { - const currentPty = this.sessions.get(sessionId); - if (currentPty !== undefined && currentPty !== pty) return; + const ptyId = makePtyId(conversation.providerId, conversation.id); + const port = agentHookService.getPort(); + const token = agentHookService.getToken(); + const hookActive = port > 0; + const ampHooksAvailable = + hookActive && + conversation.providerId === 'amp' && + providerDef?.supportsHooks && + hooksAvailable; + const pty = spawnLocalPty({ + id: sessionId, + command: resolved.command, + args: resolved.args, + cwd: resolved.cwd, + env: { + ...buildAgentEnv({ + hook: port > 0 ? { port, ptyId, token } : undefined, + providerVars: providerEnv, + }), + ...this.taskEnvVars, + ...(ampHooksAvailable && !this.taskEnvVars['PLUGINS'] ? { PLUGINS: 'all' } : {}), + }, + cols: spawnSize.cols, + rows: spawnSize.rows, + }); - ptySessionRegistry.unregister(sessionId); - const shouldRespawn = currentPty === pty && isUnexpectedPtyExit({ exitCode, signal }); - this.sessions.delete(sessionId); - const suppressExitEvent = this.suppressedExitPtys.has(pty); - if (!suppressExitEvent) { - events.emit(agentSessionExitedChannel, { - sessionId, + /* + * Codex hooks can be skipped by the CLI in some live-session edge cases. + * Amp hooks only cover lifecycle events today. Keep the output classifier + * active as a fallback so the UI can leave "working" and catch prompts. + */ + const useHooksOnly = + hookActive && + providerDef?.supportsHooks && + hooksAvailable && + conversation.providerId !== 'codex' && + conversation.providerId !== 'amp'; + + if (!useHooksOnly) { + wireAgentClassifier({ + pty, + providerId: conversation.providerId, projectId: conversation.projectId, - conversationId: conversation.id, taskId: conversation.taskId, - exitCode, + conversationId: conversation.id, }); } - if (shouldRespawn && !this.tmux) { - const count = (this.respawnCounts.get(sessionId) ?? 0) + 1; - this.respawnCounts.set(sessionId, count); - if (count > MAX_RESPAWNS && !isResuming) { - log.error('LocalConversationProvider: respawn limit reached, giving up', { - conversationId: conversation.id, - }); - this.respawnCounts.delete(sessionId); + pty.onExit((info) => { + const decision = this.supervisor.handleExit(sessionId, pty); + if (decision.kind === 'stale') return; + const replacementSize = ptySessionRegistry.getLastSize(sessionId) ?? spawnSize; + + ptySessionRegistry.unregister(sessionId, { pty, exitInfo: info }); + this.sessions.delete(sessionId); + if (decision.kind === 'stopped') return; + + events.emit(agentSessionExitedChannel, { + conversationId: conversation.id, + taskId: conversation.taskId, + }); + + if (decision.kind === 'failed') { return; } - const resumeNext = isResuming && count <= MAX_RESPAWNS; - setTimeout(() => { - this.startSession(conversation, initialSize, resumeNext, initialPrompt).catch((e) => { - log.error('LocalConversationProvider: respawn failed', { - conversationId: conversation.id, - error: String(e), - }); + if (this.tmux) { + return; + } + + if (this.supervisor.isDesired(sessionId)) { + this.scheduleReplacement({ + conversation, + initialSize: replacementSize, }); - }, 500); + } + }); + + if (!this.supervisor.acceptSpawn(sessionId, spawnToken, pty)) { + try { + pty.kill(); + } catch {} + if (ptySessionRegistry.get(sessionId) === pty) { + ptySessionRegistry.unregister(sessionId); + } + return; } - }); - ptySessionRegistry.register(sessionId, pty, { - metadata: { providerId: conversation.providerId, title: conversation.title }, - }); - this.sessions.set(sessionId, pty); - scheduleInitialPromptInjection({ pty, conversation, initialPrompt, isResuming }); - telemetryService.capture('agent_run_started', { - provider: conversation.providerId, - project_id: conversation.projectId, - task_id: conversation.taskId, - conversation_id: conversation.id, - }); + ptySessionRegistry.register(sessionId, pty, { + metadata: { + providerId: conversation.providerId, + title: conversation.title, + }, + }); + this.sessions.set(sessionId, pty); + scheduleInitialPromptInjection({ + pty, + conversation, + initialPrompt, + isResuming: agentSession.isResuming, + }); + telemetryService.capture('agent_run_started', { + provider: conversation.providerId, + project_id: conversation.projectId, + task_id: conversation.taskId, + conversation_id: conversation.id, + }); + } catch (error) { + this.supervisor.failSpawn(sessionId, spawnToken); + throw error; + } } private async prepareHookConfig(providerId: Conversation['providerId']): Promise { @@ -252,7 +279,10 @@ export class LocalConversationProvider implements ConversationProvider { const hooksAvailable = await this.hookConfigWriter.writeForProvider(providerId, { writeGitIgnoreEntries, }); - this.preparedHookProviders.set(providerId, { writeGitIgnoreEntries, hooksAvailable }); + this.preparedHookProviders.set(providerId, { + writeGitIgnoreEntries, + hooksAvailable, + }); return hooksAvailable; } catch (error) { log.warn('LocalConversationProvider: failed to prepare hook config', { @@ -265,38 +295,50 @@ export class LocalConversationProvider implements ConversationProvider { } private detachPty(sessionId: string): void { - this.respawnCounts.delete(sessionId); - const pty = this.sessions.get(sessionId); + const pty = this.supervisor.stop(sessionId) ?? this.sessions.get(sessionId); this.sessions.delete(sessionId); ptySessionRegistry.unregister(sessionId); if (pty) { try { pty.kill(); } catch (e) { - log.warn('LocalAgentProvider: error killing PTY', { sessionId, error: String(e) }); + log.warn('LocalAgentProvider: error killing PTY', { + sessionId, + error: String(e), + }); } } } async detachSession(conversationId: string): Promise { const sessionId = makePtySessionId(this.projectId, this.taskId, conversationId); - const pty = this.sessions.get(sessionId); - if (this.tmux && pty) { - this.suppressedExitPtys.add(pty); - } this.detachPty(sessionId); if (!this.tmux) { this.knownSessionIds.delete(sessionId); + this.supervisor.forget(sessionId); } } async stopSession(conversationId: string): Promise { const sessionId = makePtySessionId(this.projectId, this.taskId, conversationId); this.knownSessionIds.delete(sessionId); - this.detachPty(sessionId); + const pty = this.supervisor.stop(sessionId) ?? this.sessions.get(sessionId); + this.sessions.delete(sessionId); + ptySessionRegistry.unregister(sessionId); + if (pty) { + try { + pty.kill(); + } catch (e) { + log.warn('LocalAgentProvider: error killing PTY', { + sessionId, + error: String(e), + }); + } + } if (this.tmux) { await killTmuxSession(this.ctx, makeTmuxSessionName(sessionId)); } + this.supervisor.forget(sessionId); } async destroyAll(): Promise { @@ -305,20 +347,37 @@ export class LocalConversationProvider implements ConversationProvider { if (this.tmux) { await Promise.all(sessionIds.map((id) => killTmuxSession(this.ctx, makeTmuxSessionName(id)))); } + for (const sessionId of sessionIds) { + this.supervisor.forget(sessionId); + } this.knownSessionIds.clear(); } async detachAll(): Promise { for (const [sessionId, pty] of this.sessions) { - if (this.tmux) { - this.suppressedExitPtys.add(pty); - } + this.supervisor.stop(sessionId); try { pty.kill(); } catch {} ptySessionRegistry.unregister(sessionId); } this.sessions.clear(); - this.respawnCounts.clear(); + } + + private scheduleReplacement({ + conversation, + initialSize, + }: { + conversation: Conversation; + initialSize: { cols: number; rows: number }; + }): void { + setTimeout(() => { + this.startSessionInternal(conversation, initialSize, true, undefined, true).catch((e) => { + log.error('LocalConversationProvider: replacement failed', { + conversationId: conversation.id, + error: String(e), + }); + }); + }, RESPAWN_DELAY_MS); } } diff --git a/src/main/core/conversations/impl/ssh-conversation.ts b/src/main/core/conversations/impl/ssh-conversation.ts index deb63a713b..64b87b9cc5 100644 --- a/src/main/core/conversations/impl/ssh-conversation.ts +++ b/src/main/core/conversations/impl/ssh-conversation.ts @@ -1,9 +1,10 @@ import { wireAgentClassifier } from '@main/core/agent-hooks/classifier-wiring'; import { claudeTrustService } from '@main/core/agent-hooks/claude-trust-service'; +import { ConversationSessionSupervisor } from '@main/core/conversations/conversation-session-supervisor'; +import { resolveAgentSessionCommandArgs } from '@main/core/conversations/resolve-agent-session-command'; import type { ConversationProvider } from '@main/core/conversations/types'; import type { IExecutionContext } from '@main/core/execution-context/types'; import { SshFileSystem } from '@main/core/fs/impl/ssh-fs'; -import { isUnexpectedPtyExit } from '@main/core/pty/exit-classification'; import type { Pty } from '@main/core/pty/pty'; import { ptySessionRegistry } from '@main/core/pty/pty-session-registry'; import { resolveSshCommand } from '@main/core/pty/spawn-utils'; @@ -18,20 +19,18 @@ import type { AgentSessionConfig } from '@shared/agent-session'; import type { Conversation } from '@shared/conversations'; import { agentSessionExitedChannel } from '@shared/events/agentEvents'; import { makePtySessionId } from '@shared/ptySessionId'; -import { resolveAgentSessionCommandArgs } from '../resolve-agent-session-command'; import { buildAgentSessionCommand } from './agent-command'; import { scheduleInitialPromptInjection } from './keystroke-injection'; import { resolveProviderEnv } from './provider-env'; const DEFAULT_COLS = 80; const DEFAULT_ROWS = 24; -const MAX_RESPAWNS = 2; +const RESPAWN_DELAY_MS = 500; export class SshConversationProvider implements ConversationProvider { private sessions = new Map(); private knownSessionIds = new Set(); - private respawnCounts = new Map(); - private suppressedExitPtys = new WeakSet(); + private supervisor = new ConversationSessionSupervisor(); private readonly projectId: string; private readonly taskPath: string; private readonly taskId: string; @@ -72,11 +71,14 @@ export class SshConversationProvider implements ConversationProvider { async startSession( conversation: Conversation, - initialSize: { cols: number; rows: number } = { cols: DEFAULT_COLS, rows: DEFAULT_ROWS }, + initialSize: { cols: number; rows: number } = { + cols: DEFAULT_COLS, + rows: DEFAULT_ROWS, + }, isResuming: boolean = false, initialPrompt?: string ): Promise { - return this.startSessionInternal(conversation, initialSize, isResuming, initialPrompt, { + return this.startSessionInternal(conversation, initialSize, isResuming, initialPrompt, false, { shellRefreshRetried: false, }); } @@ -86,6 +88,7 @@ export class SshConversationProvider implements ConversationProvider { initialSize: { cols: number; rows: number }, isResuming: boolean, initialPrompt: string | undefined, + requireDesired: boolean, options: { shellRefreshRetried: boolean } ): Promise { const sessionId = makePtySessionId( @@ -95,193 +98,216 @@ export class SshConversationProvider implements ConversationProvider { ); this.knownSessionIds.add(sessionId); - if (this.sessions.has(sessionId)) return; + const spawnSize = ptySessionRegistry.getLastSize(sessionId) ?? initialSize; + const spawnToken = this.supervisor.beginStart(sessionId, { requireDesired }); + if (!spawnToken) return; - await claudeTrustService.maybeAutoTrustSsh({ - providerId: conversation.providerId, - cwd: this.taskPath, - ctx: this.ctx, - remoteFs: new SshFileSystem(this.proxy, '/'), - }); + try { + await claudeTrustService.maybeAutoTrustSsh({ + providerId: conversation.providerId, + cwd: this.taskPath, + ctx: this.ctx, + remoteFs: new SshFileSystem(this.proxy, '/'), + }); - const providerConfig = await providerOverrideSettings.getItem(conversation.providerId); - const agentSession = resolveAgentSessionCommandArgs(conversation, isResuming, { - requireProviderSessionId: false, - }); - const { command, args } = buildAgentSessionCommand({ - providerId: conversation.providerId, - providerConfig, - autoApprove: conversation.autoApprove, - sessionId: conversation.id, - providerSessionId: conversation.providerSessionId, - isResuming: agentSession.isResuming, - initialPrompt, - }); - const providerEnv = resolveProviderEnv(providerConfig, { - providerId: conversation.providerId, - autoApprove: conversation.autoApprove, - }); + const providerConfig = await providerOverrideSettings.getItem(conversation.providerId); + const agentSession = resolveAgentSessionCommandArgs(conversation, isResuming, { + requireProviderSessionId: false, + }); + const { command, args } = buildAgentSessionCommand({ + providerId: conversation.providerId, + providerConfig, + autoApprove: conversation.autoApprove, + sessionId: agentSession.sessionId, + providerSessionId: conversation.providerSessionId, + isResuming: agentSession.isResuming, + initialPrompt, + }); + const providerEnv = resolveProviderEnv(providerConfig, { + providerId: conversation.providerId, + autoApprove: conversation.autoApprove, + }); - const tmuxSessionName = this.tmux ? makeTmuxSessionName(sessionId) : undefined; + const tmuxSessionName = this.tmux ? makeTmuxSessionName(sessionId) : undefined; - const cfg: AgentSessionConfig = { - taskId: this.taskId, - conversationId: conversation.id, - providerId: conversation.providerId, - command, - args, - cwd: this.taskPath, - shellSetup: this.shellSetup, - tmuxSessionName, - autoApprove: conversation.autoApprove ?? false, - resume: isResuming, - }; + const cfg: AgentSessionConfig = { + taskId: this.taskId, + conversationId: conversation.id, + providerId: conversation.providerId, + command, + args, + cwd: this.taskPath, + shellSetup: this.shellSetup, + tmuxSessionName, + autoApprove: conversation.autoApprove ?? false, + resume: agentSession.isResuming, + }; - const profile = await this.proxy.getRemoteShellProfile(); - const sshCommand = resolveSshCommand( - 'agent', - cfg, - { ...providerEnv, ...this.taskEnvVars }, - profile - ); + const profile = await this.proxy.getRemoteShellProfile(); + const sshCommand = resolveSshCommand( + 'agent', + cfg, + { ...providerEnv, ...this.taskEnvVars }, + profile + ); - const result = await openSsh2Pty(this.proxy, { - id: sessionId, - command: sshCommand, - cols: initialSize.cols, - rows: initialSize.rows, - }); + const result = await openSsh2Pty(this.proxy, { + id: sessionId, + command: sshCommand, + cols: spawnSize.cols, + rows: spawnSize.rows, + }); + + if (!result.success) { + log.error('SshConversationProvider: failed to open SSH channel', { + sessionId, + error: result.error.message, + }); + throw new Error(result.error.message); + } - if (!result.success) { - log.error('SshConversationProvider: failed to open SSH channel', { - sessionId, - error: result.error.message, + const pty = result.data; + + // hooks not supported yet, rely on classifier for visual indicator + wireAgentClassifier({ + pty, + providerId: conversation.providerId, + projectId: conversation.projectId, + taskId: conversation.taskId, + conversationId: conversation.id, }); - throw new Error(result.error.message); - } - const pty = result.data; + pty.onExit((info) => { + const { exitCode } = info; + const decision = this.supervisor.handleExit(sessionId, pty); + if (decision.kind === 'stale') return; + const replacementSize = ptySessionRegistry.getLastSize(sessionId) ?? spawnSize; - // hooks not supported yet, rely on classifier for visual indicator - wireAgentClassifier({ - pty, - providerId: conversation.providerId, - projectId: conversation.projectId, - taskId: conversation.taskId, - conversationId: conversation.id, - }); + ptySessionRegistry.unregister(sessionId, { pty, exitInfo: info }); + this.sessions.delete(sessionId); + if (decision.kind === 'stopped') return; - pty.onExit(({ exitCode, signal }) => { - const currentPty = this.sessions.get(sessionId); - if (currentPty !== undefined && currentPty !== pty) return; + if (decision.kind === 'failed') { + events.emit(agentSessionExitedChannel, { + conversationId: conversation.id, + taskId: conversation.taskId, + }); + return; + } - ptySessionRegistry.unregister(sessionId); - const sessionWasActive = this.sessions.has(sessionId); - const shouldRetryAfterShellRefresh = - sessionWasActive && !this.tmux && !options.shellRefreshRetried && exitCode === 127; - const shouldRespawn = - sessionWasActive && exitCode !== 127 && isUnexpectedPtyExit({ exitCode, signal }); - this.sessions.delete(sessionId); - if (shouldRetryAfterShellRefresh) { - setTimeout(() => { - this.proxy - .refreshRemoteShellProfile() - .then(() => - this.startSessionInternal(conversation, initialSize, isResuming, initialPrompt, { - shellRefreshRetried: true, - }) - ) - .catch((e) => { - log.error('SshConversationProvider: shell refresh retry failed', { - conversationId: conversation.id, - error: String(e), - }); - }); - }, 500); - return; - } + if (this.tmux) { + events.emit(agentSessionExitedChannel, { + conversationId: conversation.id, + taskId: conversation.taskId, + }); + return; + } + + if (!options.shellRefreshRetried && exitCode === 127) { + this.scheduleShellRefreshRetry({ + conversation, + sessionId, + initialSize: replacementSize, + isResuming, + initialPrompt, + }); + return; + } - const suppressExitEvent = this.suppressedExitPtys.has(pty); - if (!suppressExitEvent) { events.emit(agentSessionExitedChannel, { - sessionId, - projectId: conversation.projectId, conversationId: conversation.id, taskId: conversation.taskId, - exitCode, }); - } - if (shouldRespawn && !this.tmux) { - const count = (this.respawnCounts.get(sessionId) ?? 0) + 1; - this.respawnCounts.set(sessionId, count); - - if (count > MAX_RESPAWNS && !isResuming) { - log.error('SshConversationProvider: respawn limit reached, giving up', { - conversationId: conversation.id, + if (this.supervisor.isDesired(sessionId)) { + this.scheduleReplacement({ + conversation, + initialSize: replacementSize, }); - this.respawnCounts.delete(sessionId); - return; } + }); - const resumeNext = isResuming && count <= MAX_RESPAWNS; - setTimeout(() => { - this.startSession(conversation, initialSize, resumeNext, initialPrompt).catch((e) => { - log.error('SshConversationProvider: respawn failed', { - conversationId: conversation.id, - error: String(e), - }); - }); - }, 500); + if (!this.supervisor.acceptSpawn(sessionId, spawnToken, pty)) { + try { + pty.kill(); + } catch {} + if (ptySessionRegistry.get(sessionId) === pty) { + ptySessionRegistry.unregister(sessionId); + } + return; } - }); - ptySessionRegistry.register(sessionId, pty, { - metadata: { providerId: conversation.providerId, title: conversation.title, isRemote: true }, - }); - this.sessions.set(sessionId, pty); - scheduleInitialPromptInjection({ pty, conversation, initialPrompt, isResuming }); - telemetryService.capture('agent_run_started', { - provider: conversation.providerId, - project_id: conversation.projectId, - task_id: conversation.taskId, - conversation_id: conversation.id, - }); + ptySessionRegistry.register(sessionId, pty, { + metadata: { + providerId: conversation.providerId, + title: conversation.title, + isRemote: true, + }, + }); + this.sessions.set(sessionId, pty); + scheduleInitialPromptInjection({ + pty, + conversation, + initialPrompt, + isResuming: agentSession.isResuming, + }); + telemetryService.capture('agent_run_started', { + provider: conversation.providerId, + project_id: conversation.projectId, + task_id: conversation.taskId, + conversation_id: conversation.id, + }); + } catch (error) { + this.supervisor.failSpawn(sessionId, spawnToken); + throw error; + } } private detachPty(sessionId: string): void { - this.respawnCounts.delete(sessionId); - const pty = this.sessions.get(sessionId); + const pty = this.supervisor.stop(sessionId) ?? this.sessions.get(sessionId); this.sessions.delete(sessionId); ptySessionRegistry.unregister(sessionId); if (pty) { try { pty.kill(); } catch (e) { - log.warn('SshAgentProvider: error killing PTY', { sessionId, error: String(e) }); + log.warn('SshAgentProvider: error killing PTY', { + sessionId, + error: String(e), + }); } } } async detachSession(conversationId: string): Promise { const sessionId = makePtySessionId(this.projectId, this.taskId, conversationId); - const pty = this.sessions.get(sessionId); - if (this.tmux && pty) { - this.suppressedExitPtys.add(pty); - } this.detachPty(sessionId); if (!this.tmux) { this.knownSessionIds.delete(sessionId); + this.supervisor.forget(sessionId); } } async stopSession(conversationId: string): Promise { const sessionId = makePtySessionId(this.projectId, this.taskId, conversationId); this.knownSessionIds.delete(sessionId); - this.detachPty(sessionId); + const pty = this.supervisor.stop(sessionId) ?? this.sessions.get(sessionId); + this.sessions.delete(sessionId); + ptySessionRegistry.unregister(sessionId); + if (pty) { + try { + pty.kill(); + } catch (e) { + log.warn('SshAgentProvider: error killing PTY', { + sessionId, + error: String(e), + }); + } + } if (this.tmux) { await killTmuxSession(this.ctx, makeTmuxSessionName(sessionId)); } + this.supervisor.forget(sessionId); } async destroyAll(): Promise { @@ -290,20 +316,76 @@ export class SshConversationProvider implements ConversationProvider { if (this.tmux) { await Promise.all(sessionIds.map((id) => killTmuxSession(this.ctx, makeTmuxSessionName(id)))); } + for (const sessionId of sessionIds) { + this.supervisor.forget(sessionId); + } this.knownSessionIds.clear(); } async detachAll(): Promise { for (const [sessionId, pty] of this.sessions) { - if (this.tmux) { - this.suppressedExitPtys.add(pty); - } + this.supervisor.stop(sessionId); try { pty.kill(); } catch {} ptySessionRegistry.unregister(sessionId); } this.sessions.clear(); - this.respawnCounts.clear(); + } + + private scheduleShellRefreshRetry({ + conversation, + sessionId, + initialSize, + isResuming, + initialPrompt, + }: { + conversation: Conversation; + sessionId: string; + initialSize: { cols: number; rows: number }; + isResuming: boolean; + initialPrompt: string | undefined; + }): void { + setTimeout(() => { + if (!this.supervisor.isDesired(sessionId)) return; + this.proxy + .refreshRemoteShellProfile() + .then(() => { + if (!this.supervisor.isDesired(sessionId)) return; + return this.startSessionInternal( + conversation, + initialSize, + isResuming, + initialPrompt, + true, + { shellRefreshRetried: true } + ); + }) + .catch((e) => { + log.error('SshConversationProvider: shell refresh retry failed', { + conversationId: conversation.id, + error: String(e), + }); + }); + }, RESPAWN_DELAY_MS); + } + + private scheduleReplacement({ + conversation, + initialSize, + }: { + conversation: Conversation; + initialSize: { cols: number; rows: number }; + }): void { + setTimeout(() => { + this.startSessionInternal(conversation, initialSize, true, undefined, true, { + shellRefreshRetried: false, + }).catch((e) => { + log.error('SshConversationProvider: replacement failed', { + conversationId: conversation.id, + error: String(e), + }); + }); + }, RESPAWN_DELAY_MS); } } diff --git a/src/main/core/conversations/resolve-agent-session-command.test.ts b/src/main/core/conversations/resolve-agent-session-command.test.ts index 59658f6fa3..0b04993378 100644 --- a/src/main/core/conversations/resolve-agent-session-command.test.ts +++ b/src/main/core/conversations/resolve-agent-session-command.test.ts @@ -1,5 +1,7 @@ import { describe, expect, it } from 'vitest'; +import { getProvider } from '@shared/agent-provider-registry'; import type { Conversation } from '@shared/conversations'; +import { buildAgentSessionCommand } from './impl/agent-command'; import { resolveAgentSessionCommandArgs } from './resolve-agent-session-command'; function makeConversation(overrides: Partial = {}): Conversation { @@ -52,4 +54,49 @@ describe('resolveAgentSessionCommandArgs', () => { ) ).toEqual({ sessionId: 'conv-1', isResuming: true }); }); + + it('builds a Claude replacement resume command from the logical conversation id', () => { + const conversation = makeConversation({ + id: '6fac6620-9fa8-4604-b7e0-1fe361589104', + providerId: 'claude', + }); + const spawnPlan = resolveAgentSessionCommandArgs(conversation, true); + + expect( + buildAgentSessionCommand({ + providerId: conversation.providerId, + providerConfig: getProvider(conversation.providerId), + autoApprove: false, + sessionId: spawnPlan.sessionId, + providerSessionId: conversation.providerSessionId, + isResuming: spawnPlan.isResuming, + }) + ).toEqual({ + command: 'claude', + args: ['--resume', conversation.id], + }); + }); + + it('builds a Codex replacement resume command from the stored provider session id', () => { + const conversation = makeConversation({ + id: '6fac6620-9fa8-4604-b7e0-1fe361589104', + providerId: 'codex', + providerSessionId: 'provider-session-1', + }); + const spawnPlan = resolveAgentSessionCommandArgs(conversation, true); + + expect( + buildAgentSessionCommand({ + providerId: conversation.providerId, + providerConfig: getProvider(conversation.providerId), + autoApprove: false, + sessionId: spawnPlan.sessionId, + providerSessionId: conversation.providerSessionId, + isResuming: spawnPlan.isResuming, + }) + ).toEqual({ + command: 'codex', + args: ['resume', 'provider-session-1'], + }); + }); }); diff --git a/src/main/core/pty/controller.ts b/src/main/core/pty/controller.ts index 13e278140a..1683983ccd 100644 --- a/src/main/core/pty/controller.ts +++ b/src/main/core/pty/controller.ts @@ -43,9 +43,8 @@ export const ptyController = createRPCController({ /** Resize a PTY session to the given terminal dimensions. */ resize: (sessionId: string, cols: number, rows: number) => { - const pty = ptySessionRegistry.get(sessionId); - if (!pty) return err({ type: 'not_found' as const }); - pty.resize(cols, rows); + const resized = ptySessionRegistry.resize(sessionId, cols, rows); + if (!resized) return err({ type: 'not_found' as const }); return ok(); }, diff --git a/src/main/core/pty/local-pty.integration.test.ts b/src/main/core/pty/local-pty.integration.test.ts new file mode 100644 index 0000000000..0201d4821f --- /dev/null +++ b/src/main/core/pty/local-pty.integration.test.ts @@ -0,0 +1,81 @@ +import { mkdtemp, rm } from 'node:fs/promises'; +import { tmpdir } from 'node:os'; +import { join } from 'node:path'; +import { afterEach, describe, expect, it } from 'vitest'; +import { spawnLocalPty } from './local-pty'; +import type { PtyExitInfo } from './pty'; + +const tempDirs: string[] = []; + +async function tempCwd(): Promise { + const dir = await mkdtemp(join(tmpdir(), 'emdash-local-pty-')); + tempDirs.push(dir); + return dir; +} + +function waitForExit( + register: (handler: (info: PtyExitInfo) => void) => void +): Promise { + return new Promise((resolve) => register(resolve)); +} + +function waitForData( + register: (handler: (data: string) => void) => void, + pattern: RegExp +): Promise { + return new Promise((resolve) => { + let buffer = ''; + register((data) => { + buffer += data; + if (pattern.test(buffer)) resolve(buffer); + }); + }); +} + +describe('spawnLocalPty integration', () => { + afterEach(async () => { + await Promise.all(tempDirs.splice(0).map((dir) => rm(dir, { recursive: true, force: true }))); + }); + + it('runs a real local PTY process, streams output, resizes, and reports exit', async () => { + const cwd = await tempCwd(); + const pty = spawnLocalPty({ + id: 'local-integration', + command: process.execPath, + args: [ + '-e', + 'setTimeout(() => { process.stdout.write(`ready ${process.cwd()}\\n`); process.exit(7); }, 25);', + ], + cwd, + env: { ...process.env, TERM: 'xterm-256color' } as Record, + cols: 80, + rows: 24, + }); + + const data = await waitForData((handler) => pty.onData(handler), /ready/); + pty.resize(120, 50); + const exit = await waitForExit((handler) => pty.onExit(handler)); + + expect(data).toContain('ready'); + expect(data).toContain(cwd); + expect(exit.exitCode).toBe(7); + }); + + it('can kill a real long-running local PTY process', async () => { + const pty = spawnLocalPty({ + id: 'local-kill-integration', + command: process.execPath, + args: ['-e', 'setInterval(() => {}, 1000);'], + cwd: await tempCwd(), + env: { ...process.env, TERM: 'xterm-256color' } as Record, + cols: 80, + rows: 24, + }); + + const exitPromise = waitForExit((handler) => pty.onExit(handler)); + pty.kill(); + const exit = await exitPromise; + + expect(exit.exitCode ?? exit.signal).toBeDefined(); + }); +}); diff --git a/src/main/core/pty/pty-session-registry.test.ts b/src/main/core/pty/pty-session-registry.test.ts new file mode 100644 index 0000000000..2ef7d29b0a --- /dev/null +++ b/src/main/core/pty/pty-session-registry.test.ts @@ -0,0 +1,164 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest'; +import { ptyStartedChannel } from '@shared/events/appEvents'; +import { ptyDataChannel, ptyExitChannel } from '@shared/events/ptyEvents'; +import type { Pty, PtyExitInfo } from './pty'; +import { PtySessionRegistry } from './pty-session-registry'; + +vi.mock('@main/lib/events', () => ({ + events: { + emit: vi.fn(), + on: vi.fn(() => () => {}), + }, +})); + +const { events } = await import('@main/lib/events'); + +function fakePty(): Pty & { + emitData(data: string): void; + emitExit(info: PtyExitInfo): void; +} { + const dataHandlers: Array<(data: string) => void> = []; + const exitHandlers: Array<(info: PtyExitInfo) => void> = []; + return { + write: vi.fn(), + resize: vi.fn(), + kill: vi.fn(), + onData: vi.fn((handler) => dataHandlers.push(handler)), + onExit: vi.fn((handler) => exitHandlers.push(handler)), + emitData(data: string) { + for (const handler of dataHandlers) handler(data); + }, + emitExit(info: PtyExitInfo) { + for (const handler of exitHandlers) handler(info); + }, + }; +} + +describe('PtySessionRegistry', () => { + beforeEach(() => { + vi.mocked(events.emit).mockClear(); + vi.mocked(events.on).mockClear(); + }); + + it('ignores stale data and exit cleanup from a replaced PTY', () => { + const registry = new PtySessionRegistry(); + const first = fakePty(); + const second = fakePty(); + + registry.register('session-1', first); + registry.register('session-1', second); + + first.emitData('old output'); + first.emitExit({ exitCode: 0 }); + + expect(registry.get('session-1')).toBe(second); + expect(events.emit).not.toHaveBeenCalledWith( + expect.objectContaining({ name: 'pty:data' }), + 'old output', + 'session-1' + ); + + second.emitExit({ exitCode: 0 }); + + expect(registry.get('session-1')).toBeUndefined(); + }); + + it('does not flush buffered output from an old PTY after replacement', async () => { + vi.useFakeTimers(); + try { + const registry = new PtySessionRegistry(); + const first = fakePty(); + const second = fakePty(); + + registry.register('session-1', first); + first.emitData('old buffered output'); + registry.register('session-1', second); + vi.mocked(events.emit).mockClear(); + + await vi.advanceTimersByTimeAsync(16); + + expect(events.emit).not.toHaveBeenCalledWith( + expect.objectContaining({ name: 'pty:data' }), + 'old buffered output', + 'session-1' + ); + expect(registry.get('session-1')).toBe(second); + } finally { + vi.useRealTimers(); + } + }); + + it('flushes buffered output when unregistering the current PTY before the flush timer fires', () => { + const registry = new PtySessionRegistry(); + const pty = fakePty(); + + registry.register('session-1', pty); + pty.emitData('final output'); + registry.unregister('session-1'); + + expect(events.emit).toHaveBeenCalledWith(ptyDataChannel, 'final output', 'session-1'); + }); + + it('emits exit when unregistering the current PTY with exit info', () => { + const registry = new PtySessionRegistry(); + const pty = fakePty(); + const exitInfo = { exitCode: 0 }; + + registry.register('session-1', pty); + registry.unregister('session-1', { pty, exitInfo }); + + expect(events.emit).toHaveBeenCalledWith(ptyExitChannel, exitInfo, 'session-1'); + }); + + it('does not emit exit or unregister when unregister is called for a stale PTY', () => { + const registry = new PtySessionRegistry(); + const first = fakePty(); + const second = fakePty(); + const exitInfo = { exitCode: 0 }; + + registry.register('session-1', first); + registry.register('session-1', second); + vi.mocked(events.emit).mockClear(); + + registry.unregister('session-1', { pty: first, exitInfo }); + + expect(registry.get('session-1')).toBe(second); + expect(events.emit).not.toHaveBeenCalledWith(ptyExitChannel, exitInfo, 'session-1'); + }); + + it('records resize dimensions before forwarding to the current PTY', () => { + const registry = new PtySessionRegistry(); + const pty = fakePty(); + + registry.register('session-1', pty); + const resized = registry.resize('session-1', 120, 50); + + expect(resized).toBe(true); + expect(pty.resize).toHaveBeenCalledWith(120, 50); + expect(registry.getLastSize('session-1')).toEqual({ cols: 120, rows: 50 }); + }); + + it('clears last observed size when preserving output after exit', () => { + const registry = new PtySessionRegistry(); + const pty = fakePty(); + + registry.register('session-1', pty, { preserveBufferOnExit: true }); + registry.resize('session-1', 120, 50); + pty.emitExit({ exitCode: 0 }); + + expect(registry.get('session-1')).toBeUndefined(); + expect(registry.getLastSize('session-1')).toBeUndefined(); + }); + + it('emits a monotonically increasing epoch for every registered PTY', () => { + const registry = new PtySessionRegistry(); + + registry.register('session-1', fakePty()); + registry.register('session-1', fakePty()); + registry.register('session-2', fakePty()); + + expect(events.emit).toHaveBeenCalledWith(ptyStartedChannel, { id: 'session-1', epoch: 1 }); + expect(events.emit).toHaveBeenCalledWith(ptyStartedChannel, { id: 'session-1', epoch: 2 }); + expect(events.emit).toHaveBeenCalledWith(ptyStartedChannel, { id: 'session-2', epoch: 3 }); + }); +}); diff --git a/src/main/core/pty/pty-session-registry.ts b/src/main/core/pty/pty-session-registry.ts index 908f982861..f447722031 100644 --- a/src/main/core/pty/pty-session-registry.ts +++ b/src/main/core/pty/pty-session-registry.ts @@ -1,7 +1,8 @@ import { events } from '@main/lib/events'; import type { AgentProviderId } from '@shared/agent-provider-registry'; +import { ptyStartedChannel } from '@shared/events/appEvents'; import { ptyDataChannel, ptyExitChannel, ptyInputChannel } from '@shared/events/ptyEvents'; -import type { Pty } from './pty'; +import type { Pty, PtyExitInfo } from './pty'; export interface PtySessionMetadata { providerId?: AgentProviderId; @@ -18,6 +19,9 @@ export class PtySessionRegistry { private ringBuffers: Map = new Map(); private activeConsumers: Set = new Set(); private metadata: Map = new Map(); + private lastSizes: Map = new Map(); + private pendingFlushes: Map void> = new Map(); + private epoch = 0; register( sessionId: string, @@ -27,25 +31,37 @@ export class PtySessionRegistry { const preserveBufferOnExit = options?.preserveBufferOnExit ?? false; // Clear any stale ring buffer and consumer from a previous PTY at this sessionId (respawn) + this.ptyInputSubscriptions.get(sessionId)?.(); + this.ptyInputSubscriptions.delete(sessionId); + this.pendingFlushes.delete(sessionId); this.ringBuffers.delete(sessionId); this.activeConsumers.delete(sessionId); this.metadata.delete(sessionId); if (options?.metadata) this.metadata.set(sessionId, options.metadata); this.ptyMap.set(sessionId, pty); + this.epoch += 1; + const epoch = this.epoch; let buffer = ''; let flushTimer: ReturnType | null = null; const flush = () => { + if (this.ptyMap.get(sessionId) !== pty) { + buffer = ''; + flushTimer = null; + return; + } if (buffer) { events.emit(ptyDataChannel, buffer, sessionId); buffer = ''; } flushTimer = null; }; + this.pendingFlushes.set(sessionId, flush); pty.onData((data) => { + if (this.ptyMap.get(sessionId) !== pty) return; buffer += data; if (!flushTimer) { flushTimer = setTimeout(flush, FLUSH_INTERVAL_MS); @@ -57,6 +73,9 @@ export class PtySessionRegistry { }); pty.onExit((info) => { + const isCurrentPty = this.ptyMap.get(sessionId) === pty; + if (!isCurrentPty) return; + // Flush any buffered output before emitting exit if (flushTimer !== null) { clearTimeout(flushTimer); @@ -68,6 +87,8 @@ export class PtySessionRegistry { this.ptyMap.delete(sessionId); this.ptyInputSubscriptions.get(sessionId)?.(); this.ptyInputSubscriptions.delete(sessionId); + this.pendingFlushes.delete(sessionId); + this.lastSizes.delete(sessionId); } else { this.unregister(sessionId); } @@ -82,15 +103,23 @@ export class PtySessionRegistry { ); this.ptyInputSubscriptions.set(sessionId, off); + events.emit(ptyStartedChannel, { id: sessionId, epoch }); } - unregister(sessionId: string): void { + unregister(sessionId: string, options: { pty?: Pty; exitInfo?: PtyExitInfo } = {}): void { + if (options.pty !== undefined && this.ptyMap.get(sessionId) !== options.pty) return; + this.pendingFlushes.get(sessionId)?.(); + if (options.exitInfo !== undefined) { + events.emit(ptyExitChannel, options.exitInfo, sessionId); + } this.ptyMap.delete(sessionId); this.ptyInputSubscriptions.get(sessionId)?.(); this.ptyInputSubscriptions.delete(sessionId); + this.pendingFlushes.delete(sessionId); this.ringBuffers.delete(sessionId); this.activeConsumers.delete(sessionId); this.metadata.delete(sessionId); + this.lastSizes.delete(sessionId); } get(sessionId: string): Pty | undefined { @@ -121,6 +150,18 @@ export class PtySessionRegistry { return this.metadata.get(sessionId); } + resize(sessionId: string, cols: number, rows: number): boolean { + const pty = this.ptyMap.get(sessionId); + if (!pty) return false; + this.lastSizes.set(sessionId, { cols, rows }); + pty.resize(cols, rows); + return true; + } + + getLastSize(sessionId: string): { cols: number; rows: number } | undefined { + return this.lastSizes.get(sessionId); + } + /** Active PTYs with local OS PID; SSH entries have `pid: undefined`. */ listActiveSessions(): Array<{ sessionId: string; diff --git a/src/main/core/pty/ssh2-pty.test.ts b/src/main/core/pty/ssh2-pty.test.ts new file mode 100644 index 0000000000..3de79a095b --- /dev/null +++ b/src/main/core/pty/ssh2-pty.test.ts @@ -0,0 +1,45 @@ +import { EventEmitter } from 'node:events'; +import { describe, expect, it, vi } from 'vitest'; +import { Ssh2PtySession } from './ssh2-pty'; + +class FakeClientChannel extends EventEmitter { + writes: string[] = []; + windows: Array<{ rows: number; cols: number; height: number; width: number }> = []; + closed = false; + + write(data: string): boolean { + this.writes.push(data); + return true; + } + + setWindow(rows: number, cols: number, height: number, width: number): void { + this.windows.push({ rows, cols, height, width }); + } + + close(): void { + this.closed = true; + this.emit('close', 0, undefined); + } +} + +describe('Ssh2PtySession', () => { + it('wraps SSH channel data, input, resize, close, and exit semantics', () => { + const channel = new FakeClientChannel(); + const session = new Ssh2PtySession('ssh-session', channel as never); + const dataHandler = vi.fn(); + const exitHandler = vi.fn(); + + session.onData(dataHandler); + session.onExit(exitHandler); + session.write('hello'); + session.resize(132, 43); + channel.emit('data', Buffer.from('remote output')); + session.kill(); + + expect(channel.writes).toEqual(['hello']); + expect(channel.windows).toEqual([{ rows: 43, cols: 132, height: 0, width: 0 }]); + expect(dataHandler).toHaveBeenCalledWith('remote output'); + expect(channel.closed).toBe(true); + expect(exitHandler).toHaveBeenCalledWith({ exitCode: 0, signal: undefined }); + }); +}); diff --git a/src/main/core/terminals/impl/local-terminal-provider.ts b/src/main/core/terminals/impl/local-terminal-provider.ts index d73f109843..6f8e623fde 100644 --- a/src/main/core/terminals/impl/local-terminal-provider.ts +++ b/src/main/core/terminals/impl/local-terminal-provider.ts @@ -173,14 +173,15 @@ export class LocalTerminalProvider implements TerminalProvider { wireTerminalDevServerWatcher({ pty, scopeId: this.scopeId, terminalId: terminal.id }); } - pty.onExit(({ exitCode, signal }) => { + pty.onExit((info) => { + const { exitCode, signal } = info; const shouldRespawn = policy.respawnOnExit && this.sessions.has(sessionId) && isUnexpectedPtyExit({ exitCode, signal }); this.sessions.delete(sessionId); if (!policy.preserveBufferOnExit) { - ptySessionRegistry.unregister(sessionId); + ptySessionRegistry.unregister(sessionId, { pty, exitInfo: info }); } if (shouldRespawn && !this.tmux) { const count = (this.respawnCounts.get(sessionId) ?? 0) + 1; diff --git a/src/main/core/terminals/impl/ssh-terminal-provider.ts b/src/main/core/terminals/impl/ssh-terminal-provider.ts index 9dfb7c3316..9672de8719 100644 --- a/src/main/core/terminals/impl/ssh-terminal-provider.ts +++ b/src/main/core/terminals/impl/ssh-terminal-provider.ts @@ -192,14 +192,15 @@ export class SshTerminalProvider implements TerminalProvider { }); } - pty.onExit(({ exitCode, signal }) => { + pty.onExit((info) => { + const { exitCode, signal } = info; const shouldRespawn = policy.respawnOnExit && this.sessions.has(sessionId) && isUnexpectedPtyExit({ exitCode, signal }); this.sessions.delete(sessionId); if (!policy.preserveBufferOnExit) { - ptySessionRegistry.unregister(sessionId); + ptySessionRegistry.unregister(sessionId, { pty, exitInfo: info }); } if (shouldRespawn && !this.tmux) { const count = (this.respawnCounts.get(sessionId) ?? 0) + 1; diff --git a/src/renderer/features/tasks/conversations/conversation-manager.ts b/src/renderer/features/tasks/conversations/conversation-manager.ts index 324827602f..2380379060 100644 --- a/src/renderer/features/tasks/conversations/conversation-manager.ts +++ b/src/renderer/features/tasks/conversations/conversation-manager.ts @@ -218,7 +218,7 @@ export class ConversationManagerStore implements IDisposable { try { await rpc.conversations.deleteConversation(this.projectId, this.taskId, conversationId); - session?.dispose(); + session?.destroy(); } catch (err) { runInAction(() => { this.conversations.set(conversationId, store); @@ -257,7 +257,7 @@ export class ConversationManagerStore implements IDisposable { this.offConversationChanges?.(); this.offConversationChanges = null; for (const session of this.sessions.values()) { - session.dispose(); + session.destroy(); } } diff --git a/src/renderer/features/tasks/stores/lifecycle-scripts.test.ts b/src/renderer/features/tasks/stores/lifecycle-scripts.test.ts index 8e0ec915e1..abd5209339 100644 --- a/src/renderer/features/tasks/stores/lifecycle-scripts.test.ts +++ b/src/renderer/features/tasks/stores/lifecycle-scripts.test.ts @@ -38,6 +38,7 @@ vi.mock('@renderer/lib/pty/pty-session', () => ({ connect = vi.fn(async () => {}); dispose = vi.fn(); + destroy = vi.fn(); }, })); diff --git a/src/renderer/features/tasks/stores/lifecycle-scripts.ts b/src/renderer/features/tasks/stores/lifecycle-scripts.ts index 9f8caf59fb..48f3f2d9a6 100644 --- a/src/renderer/features/tasks/stores/lifecycle-scripts.ts +++ b/src/renderer/features/tasks/stores/lifecycle-scripts.ts @@ -75,7 +75,7 @@ export class LifecycleScriptStore { dispose() { this.offStatus?.(); this.offStatus = null; - this.session.dispose(); + this.session.destroy(); } } diff --git a/src/renderer/features/tasks/terminals/terminal-manager.ts b/src/renderer/features/tasks/terminals/terminal-manager.ts index b8077f105b..37bddb63c0 100644 --- a/src/renderer/features/tasks/terminals/terminal-manager.ts +++ b/src/renderer/features/tasks/terminals/terminal-manager.ts @@ -59,7 +59,7 @@ export class TerminalManagerStore implements IDisposable { // Remove stale entries. const staleIds = Array.from(this.terminals.keys()).filter((id) => !incomingIds.has(id)); for (const id of staleIds) { - this.sessions.get(id)?.dispose(); + this.sessions.get(id)?.destroy(); this.sessions.delete(id); this.terminals.delete(id); } @@ -99,7 +99,7 @@ export class TerminalManagerStore implements IDisposable { return terminal; } catch (err) { runInAction(() => { - this.sessions.get(params.id)?.dispose(); + this.sessions.get(params.id)?.destroy(); this.sessions.delete(params.id); this.terminals.delete(params.id); }); @@ -137,7 +137,7 @@ export class TerminalManagerStore implements IDisposable { taskId: this.taskId, terminalId, }); - session?.dispose(); + session?.destroy(); } catch (err) { runInAction(() => { this.terminals.set(terminalId, store); @@ -160,7 +160,7 @@ export class TerminalManagerStore implements IDisposable { dispose(): void { this._disposeReaction(); for (const session of this.sessions.values()) { - session.dispose(); + session.destroy(); } this.list.dispose(); } diff --git a/src/renderer/lib/pty/pty-session.test.ts b/src/renderer/lib/pty/pty-session.test.ts index caefd8d1c1..c434835fe9 100644 --- a/src/renderer/lib/pty/pty-session.test.ts +++ b/src/renderer/lib/pty/pty-session.test.ts @@ -1,12 +1,23 @@ -import { describe, expect, it, vi } from 'vitest'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; +import { events } from '@renderer/lib/ipc'; +import { ptyStartedChannel } from '@shared/events/appEvents'; import { PtySession } from './pty-session'; const frontendConnect = vi.hoisted(() => vi.fn()); const frontendDispose = vi.hoisted(() => vi.fn()); +const frontendInstances = vi.hoisted(() => [] as Array<{ sessionId: string }>); + +vi.mock('@renderer/lib/ipc', () => ({ + events: { + on: vi.fn(), + }, +})); vi.mock('@renderer/lib/pty/pty', () => ({ FrontendPty: class { - constructor(readonly sessionId: string) {} + constructor(readonly sessionId: string) { + frontendInstances.push(this); + } connect = frontendConnect; dispose = frontendDispose; @@ -23,7 +34,22 @@ function deferred() { return { promise, resolve, reject }; } +function ptyStartedListeners() { + return vi + .mocked(events.on) + .mock.calls.filter(([channel]) => channel === ptyStartedChannel) + .map(([, listener]) => listener as (event: { id: string; epoch: number }) => void); +} + describe('PtySession', () => { + beforeEach(() => { + frontendConnect.mockReset(); + frontendDispose.mockReset(); + frontendInstances.length = 0; + vi.mocked(events.on).mockReset(); + vi.mocked(events.on).mockReturnValue(() => {}); + }); + it('does not mark the session ready when disposed while connect is in flight', async () => { const connect = deferred(); frontendConnect.mockReturnValue(connect.promise); @@ -47,4 +73,90 @@ describe('PtySession', () => { expect(session.pty).toBeNull(); expect(frontendDispose).toHaveBeenCalledTimes(1); }); + + it('unsubscribes from backend start events only when destroyed', () => { + const offPtyStarted = vi.fn(); + vi.mocked(events.on).mockReturnValue(offPtyStarted); + const session = new PtySession('session-1'); + + session.dispose(); + expect(offPtyStarted).not.toHaveBeenCalled(); + + session.destroy(); + expect(offPtyStarted).toHaveBeenCalledTimes(1); + }); + + it('recreates the frontend PTY when the backend starts a newer epoch for the session', async () => { + frontendConnect.mockResolvedValue(undefined); + const session = new PtySession('session-1'); + + await session.connect(); + const initialPty = session.pty; + expect(initialPty).not.toBeNull(); + + for (const listener of ptyStartedListeners()) { + listener({ id: 'session-1', epoch: 2 }); + } + await Promise.resolve(); + + expect(frontendDispose).toHaveBeenCalledTimes(1); + expect(frontendInstances).toHaveLength(2); + expect(session.pty).not.toBe(initialPty); + expect(session.status).toBe('ready'); + }); + + it('does not recreate for the first backend epoch that arrives during initial connect', async () => { + const connect = deferred(); + frontendConnect.mockReturnValue(connect.promise); + + const session = new PtySession('session-1'); + const connectPromise = session.connect(); + await Promise.resolve(); + const initialPty = session.pty; + + for (const listener of ptyStartedListeners()) { + listener({ id: 'session-1', epoch: 42 }); + } + await Promise.resolve(); + + expect(frontendDispose).not.toHaveBeenCalled(); + expect(frontendInstances).toHaveLength(1); + expect(session.pty).toBe(initialPty); + + connect.resolve(); + await connectPromise; + expect(session.status).toBe('ready'); + }); + + it('lets backend replacement win over an in-flight connect for an older frontend PTY', async () => { + const firstConnect = deferred(); + const secondConnect = deferred(); + frontendConnect + .mockReturnValueOnce(firstConnect.promise) + .mockReturnValueOnce(secondConnect.promise); + + const session = new PtySession('session-1'); + const firstConnectPromise = session.connect(); + await Promise.resolve(); + const initialPty = session.pty; + expect(initialPty).not.toBeNull(); + expect(session.status).toBe('connecting'); + + for (const listener of ptyStartedListeners()) listener({ id: 'session-1', epoch: 42 }); + for (const listener of ptyStartedListeners()) listener({ id: 'session-1', epoch: 43 }); + await Promise.resolve(); + const replacementPty = session.pty; + expect(replacementPty).not.toBe(initialPty); + expect(frontendDispose).toHaveBeenCalledTimes(1); + + firstConnect.resolve(); + await firstConnectPromise; + expect(session.pty).toBe(replacementPty); + expect(session.status).toBe('connecting'); + + secondConnect.resolve(); + await Promise.resolve(); + expect(session.pty).toBe(replacementPty); + expect(session.status).toBe('ready'); + }); }); diff --git a/src/renderer/lib/pty/pty-session.ts b/src/renderer/lib/pty/pty-session.ts index 34be87ed93..cb565503b8 100644 --- a/src/renderer/lib/pty/pty-session.ts +++ b/src/renderer/lib/pty/pty-session.ts @@ -1,5 +1,7 @@ import { makeAutoObservable, onBecomeObserved, runInAction } from 'mobx'; +import { events } from '@renderer/lib/ipc'; import { FrontendPty } from '@renderer/lib/pty/pty'; +import { ptyStartedChannel } from '@shared/events/appEvents'; export type PtySessionStatus = 'disconnected' | 'connecting' | 'ready'; @@ -8,6 +10,8 @@ export class PtySession { status: PtySessionStatus = 'disconnected'; private connectPromise: Promise | null = null; private version = 0; + private lastSeenEpoch = 0; + private offPtyStarted: (() => void) | null = null; constructor( readonly sessionId: string, @@ -18,6 +22,10 @@ export class PtySession { makeAutoObservable(this, { pty: false, }); + this.offPtyStarted = events.on(ptyStartedChannel, (event) => { + if (event.id !== this.sessionId) return; + void this.handleBackendStarted(event.epoch); + }); // Lazy connect: auto-connects the first time any observer reads status. // Sessions are created at data-load time without connecting; this fires // when the session is first rendered as the active conversation or terminal. @@ -36,12 +44,13 @@ export class PtySession { if (version !== this.version) return; if (this.pty) return; const pty = new FrontendPty(this.sessionId, undefined, this.onOpenFile, this.onOpenExternal); - this.pty = pty; runInAction(() => { + this.pty = pty; this.status = 'connecting'; }); await pty.connect(); if (version !== this.version || this.pty !== pty) return; + if (this.lastSeenEpoch === 0) this.lastSeenEpoch = 1; runInAction(() => { this.status = 'ready'; }); @@ -60,4 +69,51 @@ export class PtySession { this.status = 'disconnected'; }); } + + destroy() { + this.dispose(); + this.offPtyStarted?.(); + this.offPtyStarted = null; + } + + private async handleBackendStarted(epoch: number): Promise { + if (epoch <= this.lastSeenEpoch) return; + if (this.lastSeenEpoch === 0 && (this.status === 'connecting' || this.pty === null)) { + this.lastSeenEpoch = epoch; + return; + } + if (!this.pty && this.status === 'disconnected') { + this.lastSeenEpoch = epoch; + return; + } + + this.lastSeenEpoch = epoch; + this.version++; + this.connectPromise = null; + this.pty?.dispose(); + + const version = this.version; + const pty = new FrontendPty(this.sessionId, undefined, this.onOpenFile, this.onOpenExternal); + runInAction(() => { + this.pty = pty; + this.status = 'connecting'; + }); + + try { + await pty.connect(); + if (version === this.version && this.pty === pty) { + runInAction(() => { + this.status = 'ready'; + }); + } + } catch { + if (version === this.version && this.pty === pty) { + pty.dispose(); + runInAction(() => { + this.pty = null; + this.status = 'disconnected'; + }); + } + } + } } diff --git a/src/shared/events/agentEvents.ts b/src/shared/events/agentEvents.ts index 0e14bfde50..4adbe13f3b 100644 --- a/src/shared/events/agentEvents.ts +++ b/src/shared/events/agentEvents.ts @@ -45,12 +45,8 @@ export type SoundEvent = 'needs_attention' | 'task_complete'; export const agentEventChannel = defineEvent('agent:event'); export interface AgentSessionExited { - /** PTY session ID (= conversationId for agent sessions). */ - projectId: string; - sessionId: string; conversationId: string; taskId: string; - exitCode: number | undefined; } /** Emitted when an agent PTY session exits. Topic = taskId. */ diff --git a/src/shared/events/appEvents.ts b/src/shared/events/appEvents.ts index 3f51727965..f314abf759 100644 --- a/src/shared/events/appEvents.ts +++ b/src/shared/events/appEvents.ts @@ -28,6 +28,7 @@ export const notificationFocusTaskChannel = defineEvent<{ export const ptyStartedChannel = defineEvent<{ id: string; + epoch: number; }>('pty:started'); export type PlanEvent = {