diff --git a/.changeset/fix-progress-monotonicity-timeout.md b/.changeset/fix-progress-monotonicity-timeout.md new file mode 100644 index 000000000..2aa328afe --- /dev/null +++ b/.changeset/fix-progress-monotonicity-timeout.md @@ -0,0 +1,5 @@ +--- +"@modelcontextprotocol/core": patch +--- + +fix: enforce monotonic progress notifications diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index 361bd6fc7..7a623a3d5 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -316,6 +316,7 @@ export abstract class Protocol { private _notificationHandlers: Map Promise> = new Map(); private _responseHandlers: Map void> = new Map(); private _progressHandlers: Map = new Map(); + private _progressValues: Map = new Map(); private _timeoutInfo: Map = new Map(); private _pendingDebouncedNotifications = new Set(); @@ -383,7 +384,7 @@ export abstract class Protocol { request: (request, resultSchema, options) => this._requestWithSchema(request, resultSchema, options), notification: (notification, options) => this.notification(notification, options), reportError: error => this._onerror(error), - removeProgressHandler: token => this._progressHandlers.delete(token), + removeProgressHandler: token => this._removeProgressHandler(token), registerHandler: (method, handler) => { const schema = getRequestSchema(method as RequestMethod); this._requestHandlers.set(method, (request, ctx) => { @@ -460,6 +461,11 @@ export abstract class Protocol { } } + private _removeProgressHandler(messageId: number): void { + this._progressHandlers.delete(messageId); + this._progressValues.delete(messageId); + } + /** * Attaches to the given transport, starts it, and starts listening for messages. * @@ -506,6 +512,7 @@ export abstract class Protocol { const responseHandlers = this._responseHandlers; this._responseHandlers = new Map(); this._progressHandlers.clear(); + this._progressValues.clear(); this._taskManager.onClose(); this._pendingDebouncedNotifications.clear(); @@ -702,6 +709,16 @@ export abstract class Protocol { const responseHandler = this._responseHandlers.get(messageId); const timeoutInfo = this._timeoutInfo.get(messageId); + const lastProgress = this._progressValues.get(messageId); + if (lastProgress !== undefined && params.progress <= lastProgress) { + this._onerror( + new Error( + `Received a non-increasing progress notification for token ${progressToken}: ${params.progress} <= ${lastProgress}` + ) + ); + return; + } + this._progressValues.set(messageId, params.progress); if (timeoutInfo && responseHandler && timeoutInfo.resetTimeoutOnProgress) { try { @@ -709,7 +726,7 @@ export abstract class Protocol { } catch (error) { // Clean up if maxTotalTimeout was exceeded this._responseHandlers.delete(messageId); - this._progressHandlers.delete(messageId); + this._removeProgressHandler(messageId); this._cleanupTimeout(messageId); responseHandler(error as Error); return; @@ -738,7 +755,7 @@ export abstract class Protocol { // Keep progress handler alive for CreateTaskResult responses if (!preserveProgress) { - this._progressHandlers.delete(messageId); + this._removeProgressHandler(messageId); } if (isJSONRPCResultResponse(response)) { @@ -890,7 +907,7 @@ export abstract class Protocol { if (responseReceived) { return; } - this._progressHandlers.delete(messageId); + this._removeProgressHandler(messageId); this._transport ?.send( @@ -951,14 +968,14 @@ export abstract class Protocol { let outboundQueued = false; try { const taskResult = this._taskManager.processOutboundRequest(jsonrpcRequest, options, messageId, responseHandler, error => { - this._progressHandlers.delete(messageId); + this._removeProgressHandler(messageId); reject(error); }); if (taskResult.queued) { outboundQueued = true; } } catch (error) { - this._progressHandlers.delete(messageId); + this._removeProgressHandler(messageId); reject(error); return; } @@ -966,7 +983,7 @@ export abstract class Protocol { if (!outboundQueued) { // No related task or no module - send through transport normally this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { - this._progressHandlers.delete(messageId); + this._removeProgressHandler(messageId); reject(error); }); } diff --git a/packages/core/test/shared/protocol.test.ts b/packages/core/test/shared/protocol.test.ts index 619e09376..d167a931f 100644 --- a/packages/core/test/shared/protocol.test.ts +++ b/packages/core/test/shared/protocol.test.ts @@ -556,6 +556,63 @@ describe('protocol tests', () => { await expect(requestPromise).resolves.toEqual({ result: 'success' }); }); + test('should not reset timeout for non-increasing progress notifications', async () => { + await protocol.connect(transport); + const request = { method: 'example', params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string() + }); + const onErrorMock = vi.fn(); + const onProgressMock = vi.fn(); + protocol.onerror = onErrorMock; + + const requestPromise = testRequest(protocol, request, mockSchema, { + timeout: 1000, + resetTimeoutOnProgress: true, + onprogress: onProgressMock + }); + + vi.advanceTimersByTime(800); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken: 0, + progress: 50, + total: 100 + } + }); + } + await Promise.resolve(); + + expect(onProgressMock).toHaveBeenCalledOnce(); + expect(onProgressMock).toHaveBeenCalledWith({ + progress: 50, + total: 100 + }); + + vi.advanceTimersByTime(800); + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: '2.0', + method: 'notifications/progress', + params: { + progressToken: 0, + progress: 25, + total: 100 + } + }); + } + await Promise.resolve(); + + expect(onErrorMock).toHaveBeenCalledWith(expect.objectContaining({ message: expect.stringContaining('non-increasing') })); + expect(onProgressMock).toHaveBeenCalledOnce(); + + vi.advanceTimersByTime(201); + await expect(requestPromise).rejects.toThrow('Request timed out'); + }); + test('should respect maxTotalTimeout', async () => { await protocol.connect(transport); const request = { method: 'example', params: {} };