From 069ab41774432201ec85692e01e48f09655d11b3 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 13 Aug 2025 15:37:56 +0000 Subject: [PATCH] mcp: fix cancellation for HTTP transport In #202, I added the checkRequest helper to validate incoming requests, and invoked it in the stremable transports to preemptively reject invalid HTTP requests, so that a jsonrpc error could be translated to an HTTP error. However, this introduced a bug: since cancellation was handled in the jsonrpc2 layer, we never had to validate it in the mcp layer, and therefore never added methodInfo. As a result, it was reported as an invalid request in the http layer. Add a test, and a fix. The simplest fix was to create stubs that are placeholders for cancellation. This was discovered in the course of investigating #285. --- mcp/client.go | 10 +++++++++ mcp/mcp_test.go | 5 ++--- mcp/server.go | 10 +++++++++ mcp/streamable.go | 20 ++++++++--------- mcp/streamable_test.go | 51 ++++++++++++++++++++++++++++++++---------- mcp/transport.go | 2 +- 6 files changed, 72 insertions(+), 26 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index b0db1d64..88eea7da 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -305,6 +305,7 @@ var clientMethodInfos = map[string]methodInfo{ methodPing: newClientMethodInfo(clientSessionMethod((*ClientSession).ping), missingParamsOK), methodListRoots: newClientMethodInfo(clientMethod((*Client).listRoots), missingParamsOK), methodCreateMessage: newClientMethodInfo(clientMethod((*Client).createMessage), 0), + notificationCancelled: newClientMethodInfo(clientSessionMethod((*ClientSession).cancel), notification|missingParamsOK), notificationToolListChanged: newClientMethodInfo(clientMethod((*Client).callToolChangedHandler), notification|missingParamsOK), notificationPromptListChanged: newClientMethodInfo(clientMethod((*Client).callPromptChangedHandler), notification|missingParamsOK), notificationResourceListChanged: newClientMethodInfo(clientMethod((*Client).callResourceChangedHandler), notification|missingParamsOK), @@ -344,6 +345,15 @@ func (*ClientSession) ping(context.Context, *PingParams) (*emptyResult, error) { return &emptyResult{}, nil } +// cancel is a placeholder: cancellation is handled the jsonrpc2 package. +// +// It should never be invoked in practice because cancellation is preempted, +// but having its signature here facilitates the construction of methodInfo +// that can be used to validate incoming cancellation notifications. +func (*ClientSession) cancel(context.Context, *CancelledParams) (Result, error) { + return nil, nil +} + func newClientRequest[P Params](cs *ClientSession, params P) *ClientRequest[P] { return &ClientRequest[P]{Session: cs, Params: params} } diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 58b0377e..66ad7e0e 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -646,8 +646,7 @@ func TestCancellation(t *testing.T) { start = make(chan struct{}) cancelled = make(chan struct{}, 1) // don't block the request ) - - slowRequest := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) { + slowRequest := func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { start <- struct{}{} select { case <-ctx.Done(): @@ -658,7 +657,7 @@ func TestCancellation(t *testing.T) { return nil, nil } _, cs := basicConnection(t, func(s *Server) { - s.AddTool(&Tool{Name: "slow", InputSchema: &jsonschema.Schema{}}, slowRequest) + AddTool(s, &Tool{Name: "slow"}, slowRequest) }) defer cs.Close() diff --git a/mcp/server.go b/mcp/server.go index e39372dc..5b7538a1 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -760,6 +760,7 @@ var serverMethodInfos = map[string]methodInfo{ methodSetLevel: newServerMethodInfo(serverSessionMethod((*ServerSession).setLevel), 0), methodSubscribe: newServerMethodInfo(serverMethod((*Server).subscribe), 0), methodUnsubscribe: newServerMethodInfo(serverMethod((*Server).unsubscribe), 0), + notificationCancelled: newServerMethodInfo(serverSessionMethod((*ServerSession).cancel), notification|missingParamsOK), notificationInitialized: newServerMethodInfo(serverSessionMethod((*ServerSession).initialized), notification|missingParamsOK), notificationRootsListChanged: newServerMethodInfo(serverMethod((*Server).callRootsListChangedHandler), notification|missingParamsOK), notificationProgress: newServerMethodInfo(serverSessionMethod((*ServerSession).callProgressNotificationHandler), notification), @@ -838,6 +839,15 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error return &emptyResult{}, nil } +// cancel is a placeholder: cancellation is handled the jsonrpc2 package. +// +// It should never be invoked in practice because cancellation is preempted, +// but having its signature here facilitates the construction of methodInfo +// that can be used to validate incoming cancellation notifications. +func (ss *ServerSession) cancel(context.Context, *CancelledParams) (Result, error) { + return nil, nil +} + func (ss *ServerSession) setLevel(_ context.Context, params *SetLevelParams) (*emptyResult, error) { ss.updateState(func(state *ServerSessionState) { state.LogLevel = params.Level diff --git a/mcp/streamable.go b/mcp/streamable.go index 048b99aa..7a9efead 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -115,12 +115,12 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } - var session *StreamableServerTransport + var transport *StreamableServerTransport if id := req.Header.Get(sessionIDHeader); id != "" { h.mu.Lock() - session, _ = h.transports[id] + transport, _ = h.transports[id] h.mu.Unlock() - if session == nil { + if transport == nil { http.Error(w, "session not found", http.StatusNotFound) return } @@ -129,22 +129,22 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // TODO(rfindley): simplify the locking so that each request has only one // critical section. if req.Method == http.MethodDelete { - if session == nil { + if transport == nil { // => Mcp-Session-Id was not set; else we'd have returned NotFound above. http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) return } h.mu.Lock() - delete(h.transports, session.SessionID) + delete(h.transports, transport.SessionID) h.mu.Unlock() - session.connection.Close() + transport.connection.Close() w.WriteHeader(http.StatusNoContent) return } switch req.Method { case http.MethodPost, http.MethodGet: - if req.Method == http.MethodGet && session == nil { + if req.Method == http.MethodGet && transport == nil { http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed) return } @@ -154,7 +154,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } - if session == nil { + if transport == nil { server := h.getServer(req) if server == nil { // The getServer argument to NewStreamableHTTPHandler returned nil. @@ -191,10 +191,10 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque h.transports[s.SessionID] = s h.mu.Unlock() } - session = s + transport = s } - session.ServeHTTP(w, req) + transport.ServeHTTP(w, req) } // StreamableServerTransportOptions configures the stramable server transport. diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 25dd224e..ca5e5a5c 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -37,9 +37,26 @@ func TestStreamableTransports(t *testing.T) { for _, useJSON := range []bool{false, true} { t.Run(fmt.Sprintf("JSONResponse=%v", useJSON), func(t *testing.T) { - // 1. Create a server with a simple "greet" tool. + // Create a server with some simple tools. server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) + // The "hang" tool checks that context cancellation is propagated. + // It hangs until the context is cancelled. + var ( + start = make(chan struct{}) + cancelled = make(chan struct{}, 1) // don't block the request + ) + hang := func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + start <- struct{}{} + select { + case <-ctx.Done(): + cancelled <- struct{}{} + case <-time.After(5 * time.Second): + return nil, nil + } + return nil, nil + } + AddTool(server, &Tool{Name: "hang"}, hang) AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { // Test that we can make sampling requests during tool handling. // @@ -60,7 +77,7 @@ func TestStreamableTransports(t *testing.T) { return &CallToolResultFor[any]{}, nil }) - // 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a + // Start an httptest.Server with the StreamableHTTPHandler, wrapped in a // cookie-checking middleware. handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, &StreamableHTTPOptions{ jsonResponse: useJSON, @@ -84,7 +101,7 @@ func TestStreamableTransports(t *testing.T) { })) defer httpServer.Close() - // 3. Create a client and connect it to the server using our StreamableClientTransport. + // Create a client and connect it to the server using our StreamableClientTransport. // Check that all requests honor a custom client. jar, err := cookiejar.New(nil) if err != nil { @@ -117,10 +134,13 @@ func TestStreamableTransports(t *testing.T) { if g, w := session.mcpConn.(*streamableClientConn).initializedResult.ProtocolVersion, latestProtocolVersion; g != w { t.Fatalf("got protocol version %q, want %q", g, w) } - // 4. The client calls the "greet" tool. + + // Verify the behavior of various tools. + + // The "greet" tool should just work. params := &CallToolParams{ Name: "greet", - Arguments: map[string]any{"name": "streamy"}, + Arguments: map[string]any{"name": "foo"}, } got, err := session.CallTool(ctx, params) if err != nil { @@ -132,19 +152,26 @@ func TestStreamableTransports(t *testing.T) { if g, w := lastHeader.Get(protocolVersionHeader), latestProtocolVersion; g != w { t.Errorf("got protocol version header %q, want %q", g, w) } - - // 5. Verify that the correct response is received. want := &CallToolResult{ - Content: []Content{ - &TextContent{Text: "hi streamy"}, - }, + Content: []Content{&TextContent{Text: "hi foo"}}, } if diff := cmp.Diff(want, got); diff != "" { t.Errorf("CallTool() returned unexpected content (-want +got):\n%s", diff) } - // 6. Run the "sampling" tool and verify that the streamable server can - // call tools. + // The "hang" tool should be cancellable. + ctx2, cancel := context.WithCancel(context.Background()) + go session.CallTool(ctx2, &CallToolParams{Name: "hang"}) + <-start + cancel() + select { + case <-cancelled: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for cancellation") + } + + // The "sampling" tool should be able to issue sampling requests during + // tool operation. result, err := session.CallTool(ctx, &CallToolParams{ Name: "sample", Arguments: map[string]any{}, diff --git a/mcp/transport.go b/mcp/transport.go index 76b79986..6d25de33 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -171,7 +171,7 @@ type canceller struct { // Preempt implements [jsonrpc2.Preempter]. func (c *canceller) Preempt(ctx context.Context, req *jsonrpc.Request) (result any, err error) { - if req.Method == "notifications/cancelled" { + if req.Method == notificationCancelled { var params CancelledParams if err := json.Unmarshal(req.Params, ¶ms); err != nil { return nil, err