diff --git a/src/__fixtures__/serverThatHangs.ts b/src/__fixtures__/serverThatHangs.ts new file mode 100644 index 000000000..82c244aa2 --- /dev/null +++ b/src/__fixtures__/serverThatHangs.ts @@ -0,0 +1,42 @@ +import { setInterval } from 'node:timers'; +import process from 'node:process'; +import { McpServer } from '../server/mcp.js'; +import { StdioServerTransport } from '../server/stdio.js'; + +const transport = new StdioServerTransport(); + +const server = new McpServer( + { + name: 'server-that-hangs', + title: 'Test Server that hangs', + version: '1.0.0' + }, + { + capabilities: { + logging: {} + } + } +); + +await server.connect(transport); + +// Keep process alive even after stdin closes +const keepAlive = setInterval(() => {}, 60_000); + +// Prevent transport close from exiting +transport.onclose = () => { + // Intentionally ignore - we want to test the signal handling +}; + +const doNotExitImmediately = async (signal: NodeJS.Signals) => { + await server.sendLoggingMessage({ + level: 'debug', + data: `received signal ${signal}` + }); + // Clear keepalive but delay exit to simulate slow shutdown + clearInterval(keepAlive); + setInterval(() => {}, 30_000); +}; + +process.on('SIGINT', doNotExitImmediately); +process.on('SIGTERM', doNotExitImmediately); diff --git a/src/__fixtures__/testServer.ts b/src/__fixtures__/testServer.ts new file mode 100644 index 000000000..6401d0f83 --- /dev/null +++ b/src/__fixtures__/testServer.ts @@ -0,0 +1,19 @@ +import { McpServer } from '../server/mcp.js'; +import { StdioServerTransport } from '../server/stdio.js'; + +const transport = new StdioServerTransport(); + +const server = new McpServer({ + name: 'test-server', + version: '1.0.0' +}); + +await server.connect(transport); + +const exit = async () => { + await server.close(); + process.exit(0); +}; + +process.on('SIGINT', exit); +process.on('SIGTERM', exit); diff --git a/src/shared/zodTestMatrix.ts b/src/__fixtures__/zodTestMatrix.ts similarity index 100% rename from src/shared/zodTestMatrix.ts rename to src/__fixtures__/zodTestMatrix.ts diff --git a/src/client/stdio.ts b/src/client/stdio.ts index d62a3aeb6..e488dcd24 100644 --- a/src/client/stdio.ts +++ b/src/client/stdio.ts @@ -91,7 +91,6 @@ export function getDefaultEnvironment(): Record { */ export class StdioClientTransport implements Transport { private _process?: ChildProcess; - private _abortController: AbortController = new AbortController(); private _readBuffer: ReadBuffer = new ReadBuffer(); private _serverParams: StdioServerParameters; private _stderrStream: PassThrough | null = null; @@ -126,18 +125,11 @@ export class StdioClientTransport implements Transport { }, stdio: ['pipe', 'pipe', this._serverParams.stderr ?? 'inherit'], shell: false, - signal: this._abortController.signal, windowsHide: process.platform === 'win32' && isElectron(), cwd: this._serverParams.cwd }); this._process.on('error', error => { - if (error.name === 'AbortError') { - // Expected when close() is called. - this.onclose?.(); - return; - } - reject(error); this.onerror?.(error); }); @@ -210,8 +202,43 @@ export class StdioClientTransport implements Transport { } async close(): Promise { - this._abortController.abort(); - this._process = undefined; + if (this._process) { + const processToClose = this._process; + this._process = undefined; + + const closePromise = new Promise(resolve => { + processToClose.once('close', () => { + resolve(); + }); + }); + + try { + processToClose.stdin?.end(); + } catch { + // ignore + } + + await Promise.race([closePromise, new Promise(resolve => setTimeout(resolve, 2_000).unref())]); + + if (processToClose.exitCode === null) { + try { + processToClose.kill('SIGTERM'); + } catch { + // ignore + } + + await Promise.race([closePromise, new Promise(resolve => setTimeout(resolve, 2_000).unref())]); + } + + if (processToClose.exitCode === null) { + try { + processToClose.kill('SIGKILL'); + } catch { + // ignore + } + } + } + this._readBuffer.clear(); } diff --git a/src/integration-tests/process-cleanup.test.ts b/src/integration-tests/process-cleanup.test.ts deleted file mode 100644 index e90ec7e24..000000000 --- a/src/integration-tests/process-cleanup.test.ts +++ /dev/null @@ -1,28 +0,0 @@ -import { Server } from '../server/index.js'; -import { StdioServerTransport } from '../server/stdio.js'; - -describe('Process cleanup', () => { - vi.setConfig({ testTimeout: 5000 }); // 5 second timeout - - it('should exit cleanly after closing transport', async () => { - const server = new Server( - { - name: 'test-server', - version: '1.0.0' - }, - { - capabilities: {} - } - ); - - const transport = new StdioServerTransport(); - await server.connect(transport); - - // Close the transport - await transport.close(); - - // If we reach here without hanging, the test passes - // The test runner will fail if the process hangs - expect(true).toBe(true); - }); -}); diff --git a/src/integration-tests/processCleanup.test.ts b/src/integration-tests/processCleanup.test.ts new file mode 100644 index 000000000..7579bebdc --- /dev/null +++ b/src/integration-tests/processCleanup.test.ts @@ -0,0 +1,113 @@ +import path from 'node:path'; +import { Readable, Writable } from 'node:stream'; +import { Client } from '../client/index.js'; +import { StdioClientTransport } from '../client/stdio.js'; +import { Server } from '../server/index.js'; +import { StdioServerTransport } from '../server/stdio.js'; +import { LoggingMessageNotificationSchema } from '../types.js'; + +const FIXTURES_DIR = path.resolve(__dirname, '../__fixtures__'); + +describe('Process cleanup', () => { + vi.setConfig({ testTimeout: 5000 }); // 5 second timeout + + it('server should exit cleanly after closing transport', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: {} + } + ); + + const mockReadable = new Readable({ + read() { + this.push(null); // signal EOF + } + }), + mockWritable = new Writable({ + write(chunk, encoding, callback) { + callback(); + } + }); + + // Attach mock streams to process for the server transport + const transport = new StdioServerTransport(mockReadable, mockWritable); + await server.connect(transport); + + // Close the transport + await transport.close(); + + // ensure a proper disposal mock streams + mockReadable.destroy(); + mockWritable.destroy(); + + // If we reach here without hanging, the test passes + // The test runner will fail if the process hangs + expect(true).toBe(true); + }); + + it('onclose should be called exactly once', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StdioClientTransport({ + command: 'node', + args: ['--import', 'tsx', 'testServer.ts'], + cwd: FIXTURES_DIR + }); + + await client.connect(transport); + + let onCloseWasCalled = 0; + client.onclose = () => { + onCloseWasCalled++; + }; + + await client.close(); + + // A short delay to allow the close event to propagate + await new Promise(resolve => setTimeout(resolve, 50)); + + expect(onCloseWasCalled).toBe(1); + }); + + it('should exit cleanly for a server that hangs', async () => { + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StdioClientTransport({ + command: 'node', + args: ['--import', 'tsx', 'serverThatHangs.ts'], + cwd: FIXTURES_DIR + }); + + await client.connect(transport); + await client.setLoggingLevel('debug'); + client.setNotificationHandler(LoggingMessageNotificationSchema, notification => { + console.debug('server log: ' + notification.params.data); + }); + const serverPid = transport.pid!; + + await client.close(); + + // A short delay to allow the close event to propagate + await new Promise(resolve => setTimeout(resolve, 50)); + + try { + process.kill(serverPid, 9); + throw new Error('Expected server to be dead but it is alive'); + } catch (err: unknown) { + // 'ESRCH' the process doesn't exist + if (err && typeof err === 'object' && 'code' in err && err.code === 'ESRCH') { + // success + } else throw err; + } + }); +}); diff --git a/src/integration-tests/stateManagementStreamableHttp.test.ts b/src/integration-tests/stateManagementStreamableHttp.test.ts index 3294df4d4..fe79ff9ee 100644 --- a/src/integration-tests/stateManagementStreamableHttp.test.ts +++ b/src/integration-tests/stateManagementStreamableHttp.test.ts @@ -12,7 +12,7 @@ import { ListPromptsResultSchema, LATEST_PROTOCOL_VERSION } from '../types.js'; -import { zodTestMatrix, type ZodMatrixEntry } from '../shared/zodTestMatrix.js'; +import { zodTestMatrix, type ZodMatrixEntry } from '../__fixtures__/zodTestMatrix.js'; describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const { z } = entry; diff --git a/src/integration-tests/taskResumability.test.ts b/src/integration-tests/taskResumability.test.ts index 3c357d171..bf0d4bc46 100644 --- a/src/integration-tests/taskResumability.test.ts +++ b/src/integration-tests/taskResumability.test.ts @@ -7,7 +7,7 @@ import { McpServer } from '../server/mcp.js'; import { StreamableHTTPServerTransport } from '../server/streamableHttp.js'; import { CallToolResultSchema, LoggingMessageNotificationSchema } from '../types.js'; import { InMemoryEventStore } from '../examples/shared/inMemoryEventStore.js'; -import { zodTestMatrix, type ZodMatrixEntry } from '../shared/zodTestMatrix.js'; +import { zodTestMatrix, type ZodMatrixEntry } from '../__fixtures__/zodTestMatrix.js'; describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const { z } = entry; diff --git a/src/server/completable.test.ts b/src/server/completable.test.ts index e0d2aba99..69dd67d02 100644 --- a/src/server/completable.test.ts +++ b/src/server/completable.test.ts @@ -1,5 +1,5 @@ import { completable, getCompleter } from './completable.js'; -import { zodTestMatrix, type ZodMatrixEntry } from '../shared/zodTestMatrix.js'; +import { zodTestMatrix, type ZodMatrixEntry } from '../__fixtures__/zodTestMatrix.js'; describe.each(zodTestMatrix)('completable with $zodVersionLabel', (entry: ZodMatrixEntry) => { const { z } = entry; diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 7a34a08bc..cfec318af 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -22,7 +22,7 @@ import { import { completable } from './completable.js'; import { McpServer, ResourceTemplate } from './mcp.js'; import { InMemoryTaskStore } from '../experimental/tasks/stores/in-memory.js'; -import { zodTestMatrix, type ZodMatrixEntry } from '../shared/zodTestMatrix.js'; +import { zodTestMatrix, type ZodMatrixEntry } from '../__fixtures__/zodTestMatrix.js'; function createLatch() { let latch = false; diff --git a/src/server/sse.test.ts b/src/server/sse.test.ts index 304f7e860..b95490c13 100644 --- a/src/server/sse.test.ts +++ b/src/server/sse.test.ts @@ -6,7 +6,7 @@ import { McpServer } from './mcp.js'; import { createServer, type Server } from 'node:http'; import { AddressInfo } from 'node:net'; import { CallToolResult, JSONRPCMessage } from '../types.js'; -import { zodTestMatrix, type ZodMatrixEntry } from '../shared/zodTestMatrix.js'; +import { zodTestMatrix, type ZodMatrixEntry } from '../__fixtures__/zodTestMatrix.js'; const createMockResponse = () => { const res = { diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index 80ee04d67..39c2e5805 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -5,7 +5,7 @@ import { EventStore, StreamableHTTPServerTransport, EventId, StreamId } from './ import { McpServer } from './mcp.js'; import { CallToolResult, JSONRPCMessage } from '../types.js'; import { AuthInfo } from './auth/types.js'; -import { zodTestMatrix, type ZodMatrixEntry } from '../shared/zodTestMatrix.js'; +import { zodTestMatrix, type ZodMatrixEntry } from '../__fixtures__/zodTestMatrix.js'; async function getFreePort() { return new Promise(res => { diff --git a/src/server/title.test.ts b/src/server/title.test.ts index 9eb99b992..2af3de3c0 100644 --- a/src/server/title.test.ts +++ b/src/server/title.test.ts @@ -2,7 +2,7 @@ import { Server } from './index.js'; import { Client } from '../client/index.js'; import { InMemoryTransport } from '../inMemory.js'; import { McpServer, ResourceTemplate } from './mcp.js'; -import { zodTestMatrix, type ZodMatrixEntry } from '../shared/zodTestMatrix.js'; +import { zodTestMatrix, type ZodMatrixEntry } from '../__fixtures__/zodTestMatrix.js'; describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { const { z } = entry; diff --git a/tsconfig.cjs.json b/tsconfig.cjs.json index 3b46f11c4..ed5f7fe3e 100644 --- a/tsconfig.cjs.json +++ b/tsconfig.cjs.json @@ -5,5 +5,5 @@ "moduleResolution": "node", "outDir": "./dist/cjs" }, - "exclude": ["**/*.test.ts", "src/__mocks__/**/*"] + "exclude": ["**/*.test.ts", "src/__mocks__/**/*", "src/__fixtures__/**/*"] } diff --git a/tsconfig.prod.json b/tsconfig.prod.json index fcf2e951c..a07311af7 100644 --- a/tsconfig.prod.json +++ b/tsconfig.prod.json @@ -3,5 +3,5 @@ "compilerOptions": { "outDir": "./dist/esm" }, - "exclude": ["**/*.test.ts", "src/__mocks__/**/*", "src/server/zodTestMatrix.ts"] + "exclude": ["**/*.test.ts", "src/__mocks__/**/*", "src/__fixtures__/**/*"] }