Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/fix-progress-monotonicity-timeout.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@modelcontextprotocol/core": patch
---

fix: enforce monotonic progress notifications
31 changes: 24 additions & 7 deletions packages/core/src/shared/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
private _notificationHandlers: Map<string, (notification: JSONRPCNotification) => Promise<void>> = new Map();
private _responseHandlers: Map<number, (response: JSONRPCResultResponse | Error) => void> = new Map();
private _progressHandlers: Map<number, ProgressCallback> = new Map();
private _progressValues: Map<number, number> = new Map();
private _timeoutInfo: Map<number, TimeoutInfo> = new Map();
private _pendingDebouncedNotifications = new Set<string>();

Expand Down Expand Up @@ -383,7 +384,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
request: (request, resultSchema, options) => this._requestWithSchema(request, resultSchema, options),
notification: (notification, options) => this.notification(notification, options),
reportError: error => this._onerror(error),
removeProgressHandler: token => this._progressHandlers.delete(token),
removeProgressHandler: token => this._removeProgressHandler(token),
registerHandler: (method, handler) => {
const schema = getRequestSchema(method as RequestMethod);
this._requestHandlers.set(method, (request, ctx) => {
Expand Down Expand Up @@ -460,6 +461,11 @@ export abstract class Protocol<ContextT extends BaseContext> {
}
}

private _removeProgressHandler(messageId: number): void {
this._progressHandlers.delete(messageId);
this._progressValues.delete(messageId);
}

/**
* Attaches to the given transport, starts it, and starts listening for messages.
*
Expand Down Expand Up @@ -506,6 +512,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
const responseHandlers = this._responseHandlers;
this._responseHandlers = new Map();
this._progressHandlers.clear();
this._progressValues.clear();
this._taskManager.onClose();
this._pendingDebouncedNotifications.clear();

Expand Down Expand Up @@ -702,14 +709,24 @@ export abstract class Protocol<ContextT extends BaseContext> {

const responseHandler = this._responseHandlers.get(messageId);
const timeoutInfo = this._timeoutInfo.get(messageId);
const lastProgress = this._progressValues.get(messageId);
if (lastProgress !== undefined && params.progress <= lastProgress) {
this._onerror(
new Error(
`Received a non-increasing progress notification for token ${progressToken}: ${params.progress} <= ${lastProgress}`
)
);
return;
}
this._progressValues.set(messageId, params.progress);

if (timeoutInfo && responseHandler && timeoutInfo.resetTimeoutOnProgress) {
try {
this._resetTimeout(messageId);
} catch (error) {
// Clean up if maxTotalTimeout was exceeded
this._responseHandlers.delete(messageId);
this._progressHandlers.delete(messageId);
this._removeProgressHandler(messageId);
this._cleanupTimeout(messageId);
responseHandler(error as Error);
return;
Expand Down Expand Up @@ -738,7 +755,7 @@ export abstract class Protocol<ContextT extends BaseContext> {

// Keep progress handler alive for CreateTaskResult responses
if (!preserveProgress) {
this._progressHandlers.delete(messageId);
this._removeProgressHandler(messageId);
}

if (isJSONRPCResultResponse(response)) {
Expand Down Expand Up @@ -890,7 +907,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
if (responseReceived) {
return;
}
this._progressHandlers.delete(messageId);
this._removeProgressHandler(messageId);

this._transport
?.send(
Expand Down Expand Up @@ -951,22 +968,22 @@ export abstract class Protocol<ContextT extends BaseContext> {
let outboundQueued = false;
try {
const taskResult = this._taskManager.processOutboundRequest(jsonrpcRequest, options, messageId, responseHandler, error => {
this._progressHandlers.delete(messageId);
this._removeProgressHandler(messageId);
reject(error);
});
if (taskResult.queued) {
outboundQueued = true;
}
} catch (error) {
this._progressHandlers.delete(messageId);
this._removeProgressHandler(messageId);
reject(error);
return;
}

if (!outboundQueued) {
// No related task or no module - send through transport normally
this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => {
this._progressHandlers.delete(messageId);
this._removeProgressHandler(messageId);
reject(error);
});
}
Expand Down
57 changes: 57 additions & 0 deletions packages/core/test/shared/protocol.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,63 @@ describe('protocol tests', () => {
await expect(requestPromise).resolves.toEqual({ result: 'success' });
});

test('should not reset timeout for non-increasing progress notifications', async () => {
await protocol.connect(transport);
const request = { method: 'example', params: {} };
const mockSchema: ZodType<{ result: string }> = z.object({
result: z.string()
});
const onErrorMock = vi.fn();
const onProgressMock = vi.fn();
protocol.onerror = onErrorMock;

const requestPromise = testRequest(protocol, request, mockSchema, {
timeout: 1000,
resetTimeoutOnProgress: true,
onprogress: onProgressMock
});

vi.advanceTimersByTime(800);
if (transport.onmessage) {
transport.onmessage({
jsonrpc: '2.0',
method: 'notifications/progress',
params: {
progressToken: 0,
progress: 50,
total: 100
}
});
}
await Promise.resolve();

expect(onProgressMock).toHaveBeenCalledOnce();
expect(onProgressMock).toHaveBeenCalledWith({
progress: 50,
total: 100
});

vi.advanceTimersByTime(800);
if (transport.onmessage) {
transport.onmessage({
jsonrpc: '2.0',
method: 'notifications/progress',
params: {
progressToken: 0,
progress: 25,
total: 100
}
});
}
await Promise.resolve();

expect(onErrorMock).toHaveBeenCalledWith(expect.objectContaining({ message: expect.stringContaining('non-increasing') }));
expect(onProgressMock).toHaveBeenCalledOnce();

vi.advanceTimersByTime(201);
await expect(requestPromise).rejects.toThrow('Request timed out');
});

test('should respect maxTotalTimeout', async () => {
await protocol.connect(transport);
const request = { method: 'example', params: {} };
Expand Down
Loading