From 67a6b52c5204c47d9feb53b6528c34559ff040c8 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Mon, 1 Dec 2025 21:30:43 +0000 Subject: [PATCH 1/2] mcp: better handling for streamable context cancellation After walking through our handling of streamable client context cancellation (due to encountering a shutdown deadlock), I think I've settled on a more coherent strategy for handling call cancellation: - In our call handler, retire the request if the call exits due to cancellation: the caller will never see the actual result anyway. - In connectSSE, use the actual request context (the same context used in Write) for the client request, so that it terminates when the context is cancelled. Thread through the initialization context for the standalone SSE request. Also, a couple minor improvements: - Use a detached context for the background context of the client connection. We want to preserve context values (see #513), but it is not right to cancel the connection after Connect has already returned, if the context times out. - Don't use Last-Event-ID != "" as the signal for whether the connectSSE call is initial: if the standalone SSE stream disconnects without an event ID, we'll still reconnect it, and don't want to do so without a delay. + tests, updating the streamable client connection test harness to accomodate the new aspects being exercised. Fixes #662 --- internal/jsonrpc2/conn.go | 31 ++++-- mcp/streamable.go | 84 +++++++++----- mcp/streamable_client_test.go | 199 +++++++++++++++++++++++++++++----- mcp/streamable_test.go | 6 +- mcp/transport.go | 14 ++- 5 files changed, 268 insertions(+), 66 deletions(-) diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 5549ee1c..f4ac86f6 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -361,19 +361,26 @@ func (c *Connection) Call(ctx context.Context, method string, params any) *Async if err := c.write(ctx, call); err != nil { // Sending failed. We will never get a response, so deliver a fake one if it // wasn't already retired by the connection breaking. - c.updateInFlight(func(s *inFlightState) { - if s.outgoingCalls[ac.id] == ac { - delete(s.outgoingCalls, ac.id) - ac.retire(&Response{ID: id, Error: err}) - } else { - // ac was already retired by the readIncoming goroutine: - // perhaps our write raced with the Read side of the connection breaking. - } - }) + c.Retire(ac, err) } return ac } +// Retire stops tracking the call, and reports err as its terminal error. +// +// Retire is safe to call multiple times: if the call is already no longer +// tracked, Retire is a no op. +func (c *Connection) Retire(ac *AsyncCall, err error) { + c.updateInFlight(func(s *inFlightState) { + if s.outgoingCalls[ac.id] == ac { + delete(s.outgoingCalls, ac.id) + ac.retire(&Response{ID: ac.id, Error: err}) + } else { + // ac was already retired elsewhere. + } + }) +} + // Async, signals that the current jsonrpc2 request may be handled // asynchronously to subsequent requests, when ctx is the request context. // @@ -437,6 +444,9 @@ func (ac *AsyncCall) IsReady() bool { } // retire processes the response to the call. +// +// It is an error to call retire more than once: retire is guarded by the +// connection's outgoingCalls map. func (ac *AsyncCall) retire(response *Response) { select { case <-ac.ready: @@ -450,6 +460,9 @@ func (ac *AsyncCall) retire(response *Response) { // Await waits for (and decodes) the results of a Call. // The response will be unmarshaled from JSON into the result. +// +// If the call is cancelled due to context cancellation, the result is +// ctx.Err(). func (ac *AsyncCall) Await(ctx context.Context, result any) error { select { case <-ctx.Done(): diff --git a/mcp/streamable.go b/mcp/streamable.go index 21c73848..cb8d4b44 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -25,6 +25,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/internal/xcontext" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -1336,12 +1337,17 @@ const ( // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time. // It must be 1.0 or greater if MaxRetries is greater than 0. reconnectGrowFactor = 1.5 - // reconnectInitialDelay is the base delay for the first reconnect attempt. - reconnectInitialDelay = 1 * time.Second // reconnectMaxDelay caps the backoff delay, preventing it from growing indefinitely. reconnectMaxDelay = 30 * time.Second ) +var ( + // reconnectInitialDelay is the base delay for the first reconnect attempt. + // + // Mutable for testing. + reconnectInitialDelay = 1 * time.Second +) + // Connect implements the [Transport] interface. // // The resulting [Connection] writes messages via POST requests to the @@ -1364,7 +1370,10 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er // Create a new cancellable context that will manage the connection's lifecycle. // This is crucial for cleanly shutting down the background SSE listener by // cancelling its blocking network operations, which prevents hangs on exit. - connCtx, cancel := context.WithCancel(ctx) + // + // This context should be detached, to decouple the standalone SSE from the + // call to Connect. + connCtx, cancel := context.WithCancel(xcontext.Detach(ctx)) conn := &streamableClientConn{ url: t.Endpoint, client: client, @@ -1383,8 +1392,8 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er type streamableClientConn struct { url string client *http.Client - ctx context.Context - cancel context.CancelFunc + ctx context.Context // connection context, detached from Connect + cancel context.CancelFunc // cancels ctx incoming chan jsonrpc.Message maxRetries int strict bool // from [StreamableClientTransport.strict] @@ -1447,9 +1456,13 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { } func (c *streamableClientConn) connectStandaloneSSE() { - resp, err := c.connectSSE("", 0) + resp, err := c.connectSSE(c.ctx, "", 0, true) if err != nil { - c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err)) + // If the client didn't cancel the request, and failure breaks the logical + // session. + if c.ctx.Err() == nil { + c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err)) + } return } @@ -1481,7 +1494,7 @@ func (c *streamableClientConn) connectStandaloneSSE() { c.fail(err) return } - go c.handleSSE(summary, resp, true, nil) + go c.handleSSE(c.ctx, summary, resp, true, nil) } // fail handles an asynchronous error while reading. @@ -1616,7 +1629,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e forCall = jsonReq } // TODO: should we cancel this logical SSE request if/when jsonReq is canceled? - go c.handleSSE(requestSummary, resp, false, forCall) + go c.handleSSE(ctx, requestSummary, resp, false, forCall) default: resp.Body.Close() @@ -1668,7 +1681,7 @@ func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Resp // // If forCall is set, it is the call that initiated the stream, and the // stream is complete when we receive its response. -func (c *streamableClientConn) handleSSE(requestSummary string, resp *http.Response, persistent bool, forCall *jsonrpc2.Request) { +func (c *streamableClientConn) handleSSE(ctx context.Context, requestSummary string, resp *http.Response, persistent bool, forCall *jsonrpc2.Request) { for { // Connection was successful. Continue the loop with the new response. // TODO: we should set a reasonable limit on the number of times we'll try @@ -1676,7 +1689,7 @@ func (c *streamableClientConn) handleSSE(requestSummary string, resp *http.Respo // // Eventually, if we don't get the response, we should stop trying and // fail the request. - lastEventID, reconnectDelay, clientClosed := c.processStream(requestSummary, resp, forCall) + lastEventID, reconnectDelay, clientClosed := c.processStream(ctx, requestSummary, resp, forCall) // If the connection was closed by the client, we're done. if clientClosed { @@ -1689,12 +1702,17 @@ func (c *streamableClientConn) handleSSE(requestSummary string, resp *http.Respo } // The stream was interrupted or ended by the server. Attempt to reconnect. - newResp, err := c.connectSSE(lastEventID, reconnectDelay) + newResp, err := c.connectSSE(ctx, lastEventID, reconnectDelay, false) if err != nil { - // All reconnection attempts failed: fail the connection. - c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err)) + // If the client didn't cancel this request, any failure to execute it + // breaks the logical MCP session. + if ctx.Err() == nil { + // All reconnection attempts failed: fail the connection. + c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err)) + } return } + resp = newResp if err := c.checkResponse(requestSummary, resp); err != nil { c.fail(err) @@ -1731,11 +1749,13 @@ func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.R // incoming channel. It returns the ID of the last processed event and a flag // indicating if the connection was closed by the client. If resp is nil, it // returns "", false. -func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, reconnectDelay time.Duration, clientClosed bool) { +func (c *streamableClientConn) processStream(ctx context.Context, requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, reconnectDelay time.Duration, clientClosed bool) { defer resp.Body.Close() for evt, err := range scanEvents(resp.Body) { if err != nil { - // TODO: we should differentiate EOF from other errors here. + if ctx.Err() != nil { + return "", 0, true // don't reconnect: client cancelled + } break } @@ -1768,6 +1788,7 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R return "", 0, true } } + case <-c.done: // The connection was closed by the client; exit gracefully. return "", 0, true @@ -1777,6 +1798,9 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R // // If the lastEventID is "", the stream is not retryable and we should // report a synthetic error for the call. + // + // Note that this is different from the cancellation case above, since the + // caller is still waiting for a response that will never come. if lastEventID == "" && forCall != nil { errmsg := &jsonrpc2.Response{ ID: forCall.ID, @@ -1800,12 +1824,20 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R // // reconnectDelay is the delay set by the server using the SSE retry field, or // 0. -func (c *streamableClientConn) connectSSE(lastEventID string, reconnectDelay time.Duration) (*http.Response, error) { +// +// If initial is set, this is the initial attempt. +// +// If connectSSE exits due to context cancellation, the result is (nil, ctx.Err()). +func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID string, reconnectDelay time.Duration, initial bool) (*http.Response, error) { var finalErr error - // If lastEventID is set, we've already connected successfully once, so - // consider that to be the first attempt. attempt := 0 - if lastEventID != "" { + if !initial { + // We've already connected successfully once, so delay subsequent + // reconnections. Otherwise, if the server returns 200 but terminates the + // connection, we'll reconnect as fast as we can, ad infinitum. + // + // TODO: we should consider also setting a limit on total attempts for one + // logical request. attempt = 1 } delay := calculateReconnectDelay(attempt) @@ -1816,16 +1848,14 @@ func (c *streamableClientConn) connectSSE(lastEventID string, reconnectDelay tim select { case <-c.done: return nil, fmt.Errorf("connection closed by client during reconnect") - case <-c.ctx.Done(): + + case <-ctx.Done(): // If the connection context is canceled, the request below will not // succeed anyway. - // - // TODO(#662): we should not be using the connection context for - // reconnection: we should instead be using the call context (from - // Write). - return nil, fmt.Errorf("connection context closed") + return nil, ctx.Err() + case <-time.After(delay): - req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil) if err != nil { return nil, err } diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 6d3d83b1..174e2d46 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -13,6 +13,7 @@ import ( "strings" "sync" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" @@ -23,16 +24,18 @@ type streamableRequestKey struct { httpMethod string // http method sessionID string // session ID header jsonrpcMethod string // jsonrpc method, or "" for non-requests + lastEventID string // Last-Event-ID header } type header map[string]string type streamableResponse struct { - header header // response headers - status int // or http.StatusOK - body string // or "" - optional bool // if set, request need not be sent - wantProtocolVersion string // if "", unchecked + header header // response headers + status int // or http.StatusOK + body string // or "" + optional bool // if set, request need not be sent + wantProtocolVersion string // if "", unchecked + done chan struct{} // if set, receive from this channel before terminating the request } type fakeResponses map[streamableRequestKey]*streamableResponse @@ -60,8 +63,9 @@ func (s *fakeStreamableServer) missingRequests() []streamableRequestKey { func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { key := streamableRequestKey{ - httpMethod: req.Method, - sessionID: req.Header.Get(sessionIDHeader), + httpMethod: req.Method, + sessionID: req.Header.Get(sessionIDHeader), + lastEventID: req.Header.Get("Last-Event-ID"), // TODO: extract this to a constant, like sessionIDHeader } if req.Method == http.MethodPost { body, err := io.ReadAll(req.Body) @@ -102,11 +106,17 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques status = http.StatusOK } w.WriteHeader(status) + w.(http.Flusher).Flush() // flush response headers if v := req.Header.Get(protocolVersionHeader); v != resp.wantProtocolVersion && resp.wantProtocolVersion != "" { s.t.Errorf("%v: bad protocol version header: got %q, want %q", key, v, resp.wantProtocolVersion) } w.Write([]byte(resp.body)) + w.(http.Flusher).Flush() // flush response + + if resp.done != nil { + <-resp.done + } } var ( @@ -140,24 +150,24 @@ func TestStreamableClientTransportLifecycle(t *testing.T) { fake := &fakeStreamableServer{ t: t, responses: fakeResponses{ - {"POST", "", methodInitialize}: { + {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json", sessionIDHeader: "123", }, body: jsonBody(t, initResp), }, - {"POST", "123", notificationInitialized}: { + {"POST", "123", notificationInitialized, ""}: { status: http.StatusAccepted, wantProtocolVersion: latestProtocolVersion, }, - {"GET", "123", ""}: { + {"GET", "123", "", ""}: { header: header{ "Content-Type": "text/event-stream", }, wantProtocolVersion: latestProtocolVersion, }, - {"DELETE", "123", ""}: {}, + {"DELETE", "123", "", ""}: {}, }, } @@ -191,21 +201,21 @@ func TestStreamableClientRedundantDelete(t *testing.T) { fake := &fakeStreamableServer{ t: t, responses: fakeResponses{ - {"POST", "", methodInitialize}: { + {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json", sessionIDHeader: "123", }, body: jsonBody(t, initResp), }, - {"POST", "123", notificationInitialized}: { + {"POST", "123", notificationInitialized, ""}: { status: http.StatusAccepted, wantProtocolVersion: latestProtocolVersion, }, - {"GET", "123", ""}: { + {"GET", "123", "", ""}: { status: http.StatusMethodNotAllowed, }, - {"POST", "123", methodListTools}: { + {"POST", "123", methodListTools, ""}: { status: http.StatusNotFound, }, }, @@ -251,25 +261,25 @@ func TestStreamableClientGETHandling(t *testing.T) { fake := &fakeStreamableServer{ t: t, responses: fakeResponses{ - {"POST", "", methodInitialize}: { + {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json; charset=utf-8", // should ignore the charset sessionIDHeader: "123", }, body: jsonBody(t, initResp), }, - {"POST", "123", notificationInitialized}: { + {"POST", "123", notificationInitialized, ""}: { status: http.StatusAccepted, wantProtocolVersion: latestProtocolVersion, }, - {"GET", "123", ""}: { + {"GET", "123", "", ""}: { header: header{ "Content-Type": "text/event-stream", }, status: test.status, wantProtocolVersion: latestProtocolVersion, }, - {"DELETE", "123", ""}: {optional: true}, + {"DELETE", "123", "", ""}: {optional: true}, }, } httpServer := httptest.NewServer(fake) @@ -320,25 +330,25 @@ func TestStreamableClientStrictness(t *testing.T) { fake := &fakeStreamableServer{ t: t, responses: fakeResponses{ - {"POST", "", methodInitialize}: { + {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json", sessionIDHeader: "123", }, body: jsonBody(t, initResp), }, - {"POST", "123", notificationInitialized}: { + {"POST", "123", notificationInitialized, ""}: { status: test.initializedStatus, wantProtocolVersion: latestProtocolVersion, }, - {"GET", "123", ""}: { + {"GET", "123", "", ""}: { header: header{ "Content-Type": "text/event-stream", }, status: test.getStatus, wantProtocolVersion: latestProtocolVersion, }, - {"POST", "123", methodListTools}: { + {"POST", "123", methodListTools, ""}: { header: header{ "Content-Type": "application/json", sessionIDHeader: "123", @@ -346,7 +356,7 @@ func TestStreamableClientStrictness(t *testing.T) { body: jsonBody(t, resp(2, &ListToolsResult{Tools: []*Tool{}}, nil)), optional: true, }, - {"DELETE", "123", ""}: {optional: true}, + {"DELETE", "123", "", ""}: {optional: true}, }, } httpServer := httptest.NewServer(fake) @@ -379,14 +389,14 @@ func TestStreamableClientUnresumableRequest(t *testing.T) { fake := &fakeStreamableServer{ t: t, responses: fakeResponses{ - {"POST", "", methodInitialize}: { + {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "text/event-stream", sessionIDHeader: "123", }, body: "", }, - {"DELETE", "123", ""}: {optional: true}, + {"DELETE", "123", "", ""}: {optional: true}, }, } httpServer := httptest.NewServer(fake) @@ -406,3 +416,140 @@ func TestStreamableClientUnresumableRequest(t *testing.T) { t.Errorf("Connect: got error %v, want containing %q", err, msg) } } + +func TestStreamableClientResumption_Cancelled(t *testing.T) { + // This test verifies that the resumed requests are closed when their context + // is cancelled (issue #662). + + // This test (unfortunately) relies on timing, so may have false positives. + // + // Set the reconnect initial delay to some small(ish) value so that the test + // doesn't take too long. But this value must be large enough that we mostly + // avoid races in the tests below, where one test cases is intended to be in + // between the initial attempt and first reconnection. + // + // For easier tuning (and debugging), factor out the tick size. + const tick = 10 * time.Millisecond + defer func(delay time.Duration) { + reconnectInitialDelay = delay + }(reconnectInitialDelay) + reconnectInitialDelay = 2 * tick + + // The setup: terminate a request stream and make the resumed request hang + // indefinitely. CallTool should still exit when its context is canceled. + // + // This should work whether we're handling the initial request, waiting to + // retry, or handling the retry. + // + // Furthermore, closing the client connection should not hang, because there + // should be no ongoing requests. + + tests := []struct { + label string + cancelAfter time.Duration + }{ + {"in process", 1 * tick}, // cancel while the request is being handled + // initial request terminates at 2 ticks (see below) + {"awaiting retry", 3 * tick}, // cancel in-between first and second attempt + // retry starts at 4 ticks (=2+2) + {"in retry", 5 * tick}, // cancel while second attempt is hanging + } + + for _, test := range tests { + t.Run(test.label, func(t *testing.T) { + ctx := context.Background() + + // done will be closed when the test exits: used to simulate requests that + // hang indefinitely. + initialRequestDone := make(chan struct{}) // closed below + allDone := make(chan struct{}) + + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized, ""}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", "", ""}: { + header: header{ + "Content-Type": "text/event-stream", + }, + status: http.StatusMethodNotAllowed, // don't allow the standalone stream + }, + {"POST", "123", methodCallTool, ""}: { + header: header{ + "Content-Type": "text/event-stream", + }, + status: http.StatusOK, + body: `id: 1 +data: { "jsonrpc": "2.0", "method": "notifications/message", "params": { "level": "error", "data": "bad" } } + +`, + done: initialRequestDone, + }, + {"POST", "123", methodListTools, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, resp(3, &ListToolsResult{Tools: []*Tool{}}, nil)), + }, + {"GET", "123", "", "1"}: { + header: header{ + "Content-Type": "text/event-stream", + }, + status: http.StatusOK, + done: allDone, // hang indefinitely + }, + {"POST", "123", notificationCancelled, ""}: {status: http.StatusAccepted}, + {"DELETE", "123", "", ""}: {optional: true}, + }, + } + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + defer close(allDone) // must be deferred *after* httpServer.Close, to avoid deadlock + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + cs, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() // ensure the session is closed, though we're also closing below + + // start the timer on the initial request + go func() { + <-time.After(2 * tick) + close(initialRequestDone) + }() + + // start the timer on the call cancellation + timeoutCtx, cancel := context.WithTimeout(ctx, test.cancelAfter) + defer cancel() + + go func() { + <-timeoutCtx.Done() + }() + + if _, err := cs.CallTool(timeoutCtx, &CallToolParams{ + Name: "tool", + }); err == nil { + t.Errorf("CallTool succeeded unexpectedly") + } + + // ...but cancellation should not break the session. + // Check that an arbitrary request succeeds. + if _, err := cs.ListTools(ctx, nil); err != nil { + t.Errorf("ListTools failed after cancellation") + } + }) + } +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 727a2837..d2234b8e 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1788,8 +1788,8 @@ func TestStreamableClientContextPropagation(t *testing.T) { cancel() select { case <-streamableConn.ctx.Done(): - case <-time.After(100 * time.Millisecond): - t.Error("Connection context was not cancelled when parent was cancelled") + t.Errorf("cancelling the connection context after successful connection broke the connection") + default: } } @@ -1945,7 +1945,7 @@ data: keepalive } // Process the stream - go conn.processStream("test", resp, testReq) + go conn.processStream(ctx, "test", resp, testReq) // Collect messages with timeout var messages []jsonrpc.Message diff --git a/mcp/transport.go b/mcp/transport.go index ec57b38c..25f1d5d0 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -69,7 +69,7 @@ type Connection interface { type clientConnection interface { Connection - // SessionUpdated is called whenever the client session state changes. + // sessionUpdated is called whenever the client session state changes. sessionUpdated(clientSessionState) } @@ -216,6 +216,18 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params Reason: ctx.Err().Error(), RequestID: call.ID().Raw(), }) + // By default, the jsonrpc2 library waits for graceful shutdown when the + // connection is closed, meaning it expects all outgoing and incoming + // requests to complete. However, for MCP this expectation is unrealistic, + // and can lead to hanging shutdown. For example, if a streamable client is + // killed, the server will not be able to detect this event, except via + // keepalive pings (if they are configured), and so outgoing calls may hang + // indefinitely. + // + // Therefore, we choose to eagerly retire calls, removing them from the + // outgoingCalls map, when the caller context is cancelled: if the caller + // will never receive the response, there's no need to track it. + conn.Retire(call, ctx.Err()) return errors.Join(ctx.Err(), err) case err != nil: return fmt.Errorf("calling %q: %w", method, err) From fac0f5aebe5e6d0d50e12d13cd3428f04a799b88 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 3 Dec 2025 16:43:31 +0000 Subject: [PATCH 2/2] address review comments --- mcp/streamable.go | 20 ++++++++++++++++---- mcp/streamable_client_test.go | 2 ++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index cb8d4b44..9e210abb 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1371,8 +1371,18 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er // This is crucial for cleanly shutting down the background SSE listener by // cancelling its blocking network operations, which prevents hangs on exit. // - // This context should be detached, to decouple the standalone SSE from the - // call to Connect. + // This context should be detached from the incoming context: the standalone + // SSE request should not break when the connection context is done. + // + // For example, consider that the user may want to wait at most 5s to connect + // to the server, and therefore uses a context with a 5s timeout when calling + // client.Connect. Let's suppose that Connect returns after 1s, and the user + // starts using the resulting session. If we didn't detach here, the session + // would break after 4s, when the background SSE stream is terminated. + // + // Instead, creating a cancellable context detached from the incoming context + // allows us to preserve context values (which may be necessary for auth + // middleware), yet only cancel the standalone stream when the connection is closed. connCtx, cancel := context.WithCancel(xcontext.Detach(ctx)) conn := &streamableClientConn{ url: t.Endpoint, @@ -1684,8 +1694,10 @@ func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Resp func (c *streamableClientConn) handleSSE(ctx context.Context, requestSummary string, resp *http.Response, persistent bool, forCall *jsonrpc2.Request) { for { // Connection was successful. Continue the loop with the new response. - // TODO: we should set a reasonable limit on the number of times we'll try - // getting a response for a given request. + // + // TODO(#679): we should set a reasonable limit on the number of times + // we'll try getting a response for a given request, or enforce that we + // actually make progress. // // Eventually, if we don't get the response, we should stop trying and // fail the request. diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 174e2d46..dcdda322 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -429,6 +429,8 @@ func TestStreamableClientResumption_Cancelled(t *testing.T) { // between the initial attempt and first reconnection. // // For easier tuning (and debugging), factor out the tick size. + // + // TODO(#680): experiment with instead using synctest. const tick = 10 * time.Millisecond defer func(delay time.Duration) { reconnectInitialDelay = delay