diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 5549ee1c..f4ac86f6 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -361,19 +361,26 @@ func (c *Connection) Call(ctx context.Context, method string, params any) *Async if err := c.write(ctx, call); err != nil { // Sending failed. We will never get a response, so deliver a fake one if it // wasn't already retired by the connection breaking. - c.updateInFlight(func(s *inFlightState) { - if s.outgoingCalls[ac.id] == ac { - delete(s.outgoingCalls, ac.id) - ac.retire(&Response{ID: id, Error: err}) - } else { - // ac was already retired by the readIncoming goroutine: - // perhaps our write raced with the Read side of the connection breaking. - } - }) + c.Retire(ac, err) } return ac } +// Retire stops tracking the call, and reports err as its terminal error. +// +// Retire is safe to call multiple times: if the call is already no longer +// tracked, Retire is a no op. +func (c *Connection) Retire(ac *AsyncCall, err error) { + c.updateInFlight(func(s *inFlightState) { + if s.outgoingCalls[ac.id] == ac { + delete(s.outgoingCalls, ac.id) + ac.retire(&Response{ID: ac.id, Error: err}) + } else { + // ac was already retired elsewhere. + } + }) +} + // Async, signals that the current jsonrpc2 request may be handled // asynchronously to subsequent requests, when ctx is the request context. // @@ -437,6 +444,9 @@ func (ac *AsyncCall) IsReady() bool { } // retire processes the response to the call. +// +// It is an error to call retire more than once: retire is guarded by the +// connection's outgoingCalls map. func (ac *AsyncCall) retire(response *Response) { select { case <-ac.ready: @@ -450,6 +460,9 @@ func (ac *AsyncCall) retire(response *Response) { // Await waits for (and decodes) the results of a Call. // The response will be unmarshaled from JSON into the result. +// +// If the call is cancelled due to context cancellation, the result is +// ctx.Err(). func (ac *AsyncCall) Await(ctx context.Context, result any) error { select { case <-ctx.Done(): diff --git a/mcp/streamable.go b/mcp/streamable.go index 21c73848..9e210abb 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -25,6 +25,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/internal/xcontext" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -1336,12 +1337,17 @@ const ( // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time. // It must be 1.0 or greater if MaxRetries is greater than 0. reconnectGrowFactor = 1.5 - // reconnectInitialDelay is the base delay for the first reconnect attempt. - reconnectInitialDelay = 1 * time.Second // reconnectMaxDelay caps the backoff delay, preventing it from growing indefinitely. reconnectMaxDelay = 30 * time.Second ) +var ( + // reconnectInitialDelay is the base delay for the first reconnect attempt. + // + // Mutable for testing. + reconnectInitialDelay = 1 * time.Second +) + // Connect implements the [Transport] interface. // // The resulting [Connection] writes messages via POST requests to the @@ -1364,7 +1370,20 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er // Create a new cancellable context that will manage the connection's lifecycle. // This is crucial for cleanly shutting down the background SSE listener by // cancelling its blocking network operations, which prevents hangs on exit. - connCtx, cancel := context.WithCancel(ctx) + // + // This context should be detached from the incoming context: the standalone + // SSE request should not break when the connection context is done. + // + // For example, consider that the user may want to wait at most 5s to connect + // to the server, and therefore uses a context with a 5s timeout when calling + // client.Connect. Let's suppose that Connect returns after 1s, and the user + // starts using the resulting session. If we didn't detach here, the session + // would break after 4s, when the background SSE stream is terminated. + // + // Instead, creating a cancellable context detached from the incoming context + // allows us to preserve context values (which may be necessary for auth + // middleware), yet only cancel the standalone stream when the connection is closed. + connCtx, cancel := context.WithCancel(xcontext.Detach(ctx)) conn := &streamableClientConn{ url: t.Endpoint, client: client, @@ -1383,8 +1402,8 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er type streamableClientConn struct { url string client *http.Client - ctx context.Context - cancel context.CancelFunc + ctx context.Context // connection context, detached from Connect + cancel context.CancelFunc // cancels ctx incoming chan jsonrpc.Message maxRetries int strict bool // from [StreamableClientTransport.strict] @@ -1447,9 +1466,13 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { } func (c *streamableClientConn) connectStandaloneSSE() { - resp, err := c.connectSSE("", 0) + resp, err := c.connectSSE(c.ctx, "", 0, true) if err != nil { - c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err)) + // If the client didn't cancel the request, and failure breaks the logical + // session. + if c.ctx.Err() == nil { + c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err)) + } return } @@ -1481,7 +1504,7 @@ func (c *streamableClientConn) connectStandaloneSSE() { c.fail(err) return } - go c.handleSSE(summary, resp, true, nil) + go c.handleSSE(c.ctx, summary, resp, true, nil) } // fail handles an asynchronous error while reading. @@ -1616,7 +1639,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e forCall = jsonReq } // TODO: should we cancel this logical SSE request if/when jsonReq is canceled? - go c.handleSSE(requestSummary, resp, false, forCall) + go c.handleSSE(ctx, requestSummary, resp, false, forCall) default: resp.Body.Close() @@ -1668,15 +1691,17 @@ func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Resp // // If forCall is set, it is the call that initiated the stream, and the // stream is complete when we receive its response. -func (c *streamableClientConn) handleSSE(requestSummary string, resp *http.Response, persistent bool, forCall *jsonrpc2.Request) { +func (c *streamableClientConn) handleSSE(ctx context.Context, requestSummary string, resp *http.Response, persistent bool, forCall *jsonrpc2.Request) { for { // Connection was successful. Continue the loop with the new response. - // TODO: we should set a reasonable limit on the number of times we'll try - // getting a response for a given request. + // + // TODO(#679): we should set a reasonable limit on the number of times + // we'll try getting a response for a given request, or enforce that we + // actually make progress. // // Eventually, if we don't get the response, we should stop trying and // fail the request. - lastEventID, reconnectDelay, clientClosed := c.processStream(requestSummary, resp, forCall) + lastEventID, reconnectDelay, clientClosed := c.processStream(ctx, requestSummary, resp, forCall) // If the connection was closed by the client, we're done. if clientClosed { @@ -1689,12 +1714,17 @@ func (c *streamableClientConn) handleSSE(requestSummary string, resp *http.Respo } // The stream was interrupted or ended by the server. Attempt to reconnect. - newResp, err := c.connectSSE(lastEventID, reconnectDelay) + newResp, err := c.connectSSE(ctx, lastEventID, reconnectDelay, false) if err != nil { - // All reconnection attempts failed: fail the connection. - c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err)) + // If the client didn't cancel this request, any failure to execute it + // breaks the logical MCP session. + if ctx.Err() == nil { + // All reconnection attempts failed: fail the connection. + c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err)) + } return } + resp = newResp if err := c.checkResponse(requestSummary, resp); err != nil { c.fail(err) @@ -1731,11 +1761,13 @@ func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.R // incoming channel. It returns the ID of the last processed event and a flag // indicating if the connection was closed by the client. If resp is nil, it // returns "", false. -func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, reconnectDelay time.Duration, clientClosed bool) { +func (c *streamableClientConn) processStream(ctx context.Context, requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, reconnectDelay time.Duration, clientClosed bool) { defer resp.Body.Close() for evt, err := range scanEvents(resp.Body) { if err != nil { - // TODO: we should differentiate EOF from other errors here. + if ctx.Err() != nil { + return "", 0, true // don't reconnect: client cancelled + } break } @@ -1768,6 +1800,7 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R return "", 0, true } } + case <-c.done: // The connection was closed by the client; exit gracefully. return "", 0, true @@ -1777,6 +1810,9 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R // // If the lastEventID is "", the stream is not retryable and we should // report a synthetic error for the call. + // + // Note that this is different from the cancellation case above, since the + // caller is still waiting for a response that will never come. if lastEventID == "" && forCall != nil { errmsg := &jsonrpc2.Response{ ID: forCall.ID, @@ -1800,12 +1836,20 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R // // reconnectDelay is the delay set by the server using the SSE retry field, or // 0. -func (c *streamableClientConn) connectSSE(lastEventID string, reconnectDelay time.Duration) (*http.Response, error) { +// +// If initial is set, this is the initial attempt. +// +// If connectSSE exits due to context cancellation, the result is (nil, ctx.Err()). +func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID string, reconnectDelay time.Duration, initial bool) (*http.Response, error) { var finalErr error - // If lastEventID is set, we've already connected successfully once, so - // consider that to be the first attempt. attempt := 0 - if lastEventID != "" { + if !initial { + // We've already connected successfully once, so delay subsequent + // reconnections. Otherwise, if the server returns 200 but terminates the + // connection, we'll reconnect as fast as we can, ad infinitum. + // + // TODO: we should consider also setting a limit on total attempts for one + // logical request. attempt = 1 } delay := calculateReconnectDelay(attempt) @@ -1816,16 +1860,14 @@ func (c *streamableClientConn) connectSSE(lastEventID string, reconnectDelay tim select { case <-c.done: return nil, fmt.Errorf("connection closed by client during reconnect") - case <-c.ctx.Done(): + + case <-ctx.Done(): // If the connection context is canceled, the request below will not // succeed anyway. - // - // TODO(#662): we should not be using the connection context for - // reconnection: we should instead be using the call context (from - // Write). - return nil, fmt.Errorf("connection context closed") + return nil, ctx.Err() + case <-time.After(delay): - req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil) if err != nil { return nil, err } diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 6d3d83b1..dcdda322 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -13,6 +13,7 @@ import ( "strings" "sync" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" @@ -23,16 +24,18 @@ type streamableRequestKey struct { httpMethod string // http method sessionID string // session ID header jsonrpcMethod string // jsonrpc method, or "" for non-requests + lastEventID string // Last-Event-ID header } type header map[string]string type streamableResponse struct { - header header // response headers - status int // or http.StatusOK - body string // or "" - optional bool // if set, request need not be sent - wantProtocolVersion string // if "", unchecked + header header // response headers + status int // or http.StatusOK + body string // or "" + optional bool // if set, request need not be sent + wantProtocolVersion string // if "", unchecked + done chan struct{} // if set, receive from this channel before terminating the request } type fakeResponses map[streamableRequestKey]*streamableResponse @@ -60,8 +63,9 @@ func (s *fakeStreamableServer) missingRequests() []streamableRequestKey { func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { key := streamableRequestKey{ - httpMethod: req.Method, - sessionID: req.Header.Get(sessionIDHeader), + httpMethod: req.Method, + sessionID: req.Header.Get(sessionIDHeader), + lastEventID: req.Header.Get("Last-Event-ID"), // TODO: extract this to a constant, like sessionIDHeader } if req.Method == http.MethodPost { body, err := io.ReadAll(req.Body) @@ -102,11 +106,17 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques status = http.StatusOK } w.WriteHeader(status) + w.(http.Flusher).Flush() // flush response headers if v := req.Header.Get(protocolVersionHeader); v != resp.wantProtocolVersion && resp.wantProtocolVersion != "" { s.t.Errorf("%v: bad protocol version header: got %q, want %q", key, v, resp.wantProtocolVersion) } w.Write([]byte(resp.body)) + w.(http.Flusher).Flush() // flush response + + if resp.done != nil { + <-resp.done + } } var ( @@ -140,24 +150,24 @@ func TestStreamableClientTransportLifecycle(t *testing.T) { fake := &fakeStreamableServer{ t: t, responses: fakeResponses{ - {"POST", "", methodInitialize}: { + {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json", sessionIDHeader: "123", }, body: jsonBody(t, initResp), }, - {"POST", "123", notificationInitialized}: { + {"POST", "123", notificationInitialized, ""}: { status: http.StatusAccepted, wantProtocolVersion: latestProtocolVersion, }, - {"GET", "123", ""}: { + {"GET", "123", "", ""}: { header: header{ "Content-Type": "text/event-stream", }, wantProtocolVersion: latestProtocolVersion, }, - {"DELETE", "123", ""}: {}, + {"DELETE", "123", "", ""}: {}, }, } @@ -191,21 +201,21 @@ func TestStreamableClientRedundantDelete(t *testing.T) { fake := &fakeStreamableServer{ t: t, responses: fakeResponses{ - {"POST", "", methodInitialize}: { + {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json", sessionIDHeader: "123", }, body: jsonBody(t, initResp), }, - {"POST", "123", notificationInitialized}: { + {"POST", "123", notificationInitialized, ""}: { status: http.StatusAccepted, wantProtocolVersion: latestProtocolVersion, }, - {"GET", "123", ""}: { + {"GET", "123", "", ""}: { status: http.StatusMethodNotAllowed, }, - {"POST", "123", methodListTools}: { + {"POST", "123", methodListTools, ""}: { status: http.StatusNotFound, }, }, @@ -251,25 +261,25 @@ func TestStreamableClientGETHandling(t *testing.T) { fake := &fakeStreamableServer{ t: t, responses: fakeResponses{ - {"POST", "", methodInitialize}: { + {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json; charset=utf-8", // should ignore the charset sessionIDHeader: "123", }, body: jsonBody(t, initResp), }, - {"POST", "123", notificationInitialized}: { + {"POST", "123", notificationInitialized, ""}: { status: http.StatusAccepted, wantProtocolVersion: latestProtocolVersion, }, - {"GET", "123", ""}: { + {"GET", "123", "", ""}: { header: header{ "Content-Type": "text/event-stream", }, status: test.status, wantProtocolVersion: latestProtocolVersion, }, - {"DELETE", "123", ""}: {optional: true}, + {"DELETE", "123", "", ""}: {optional: true}, }, } httpServer := httptest.NewServer(fake) @@ -320,25 +330,25 @@ func TestStreamableClientStrictness(t *testing.T) { fake := &fakeStreamableServer{ t: t, responses: fakeResponses{ - {"POST", "", methodInitialize}: { + {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "application/json", sessionIDHeader: "123", }, body: jsonBody(t, initResp), }, - {"POST", "123", notificationInitialized}: { + {"POST", "123", notificationInitialized, ""}: { status: test.initializedStatus, wantProtocolVersion: latestProtocolVersion, }, - {"GET", "123", ""}: { + {"GET", "123", "", ""}: { header: header{ "Content-Type": "text/event-stream", }, status: test.getStatus, wantProtocolVersion: latestProtocolVersion, }, - {"POST", "123", methodListTools}: { + {"POST", "123", methodListTools, ""}: { header: header{ "Content-Type": "application/json", sessionIDHeader: "123", @@ -346,7 +356,7 @@ func TestStreamableClientStrictness(t *testing.T) { body: jsonBody(t, resp(2, &ListToolsResult{Tools: []*Tool{}}, nil)), optional: true, }, - {"DELETE", "123", ""}: {optional: true}, + {"DELETE", "123", "", ""}: {optional: true}, }, } httpServer := httptest.NewServer(fake) @@ -379,14 +389,14 @@ func TestStreamableClientUnresumableRequest(t *testing.T) { fake := &fakeStreamableServer{ t: t, responses: fakeResponses{ - {"POST", "", methodInitialize}: { + {"POST", "", methodInitialize, ""}: { header: header{ "Content-Type": "text/event-stream", sessionIDHeader: "123", }, body: "", }, - {"DELETE", "123", ""}: {optional: true}, + {"DELETE", "123", "", ""}: {optional: true}, }, } httpServer := httptest.NewServer(fake) @@ -406,3 +416,142 @@ func TestStreamableClientUnresumableRequest(t *testing.T) { t.Errorf("Connect: got error %v, want containing %q", err, msg) } } + +func TestStreamableClientResumption_Cancelled(t *testing.T) { + // This test verifies that the resumed requests are closed when their context + // is cancelled (issue #662). + + // This test (unfortunately) relies on timing, so may have false positives. + // + // Set the reconnect initial delay to some small(ish) value so that the test + // doesn't take too long. But this value must be large enough that we mostly + // avoid races in the tests below, where one test cases is intended to be in + // between the initial attempt and first reconnection. + // + // For easier tuning (and debugging), factor out the tick size. + // + // TODO(#680): experiment with instead using synctest. + const tick = 10 * time.Millisecond + defer func(delay time.Duration) { + reconnectInitialDelay = delay + }(reconnectInitialDelay) + reconnectInitialDelay = 2 * tick + + // The setup: terminate a request stream and make the resumed request hang + // indefinitely. CallTool should still exit when its context is canceled. + // + // This should work whether we're handling the initial request, waiting to + // retry, or handling the retry. + // + // Furthermore, closing the client connection should not hang, because there + // should be no ongoing requests. + + tests := []struct { + label string + cancelAfter time.Duration + }{ + {"in process", 1 * tick}, // cancel while the request is being handled + // initial request terminates at 2 ticks (see below) + {"awaiting retry", 3 * tick}, // cancel in-between first and second attempt + // retry starts at 4 ticks (=2+2) + {"in retry", 5 * tick}, // cancel while second attempt is hanging + } + + for _, test := range tests { + t.Run(test.label, func(t *testing.T) { + ctx := context.Background() + + // done will be closed when the test exits: used to simulate requests that + // hang indefinitely. + initialRequestDone := make(chan struct{}) // closed below + allDone := make(chan struct{}) + + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized, ""}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", "", ""}: { + header: header{ + "Content-Type": "text/event-stream", + }, + status: http.StatusMethodNotAllowed, // don't allow the standalone stream + }, + {"POST", "123", methodCallTool, ""}: { + header: header{ + "Content-Type": "text/event-stream", + }, + status: http.StatusOK, + body: `id: 1 +data: { "jsonrpc": "2.0", "method": "notifications/message", "params": { "level": "error", "data": "bad" } } + +`, + done: initialRequestDone, + }, + {"POST", "123", methodListTools, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, resp(3, &ListToolsResult{Tools: []*Tool{}}, nil)), + }, + {"GET", "123", "", "1"}: { + header: header{ + "Content-Type": "text/event-stream", + }, + status: http.StatusOK, + done: allDone, // hang indefinitely + }, + {"POST", "123", notificationCancelled, ""}: {status: http.StatusAccepted}, + {"DELETE", "123", "", ""}: {optional: true}, + }, + } + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + defer close(allDone) // must be deferred *after* httpServer.Close, to avoid deadlock + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + cs, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() // ensure the session is closed, though we're also closing below + + // start the timer on the initial request + go func() { + <-time.After(2 * tick) + close(initialRequestDone) + }() + + // start the timer on the call cancellation + timeoutCtx, cancel := context.WithTimeout(ctx, test.cancelAfter) + defer cancel() + + go func() { + <-timeoutCtx.Done() + }() + + if _, err := cs.CallTool(timeoutCtx, &CallToolParams{ + Name: "tool", + }); err == nil { + t.Errorf("CallTool succeeded unexpectedly") + } + + // ...but cancellation should not break the session. + // Check that an arbitrary request succeeds. + if _, err := cs.ListTools(ctx, nil); err != nil { + t.Errorf("ListTools failed after cancellation") + } + }) + } +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 727a2837..d2234b8e 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1788,8 +1788,8 @@ func TestStreamableClientContextPropagation(t *testing.T) { cancel() select { case <-streamableConn.ctx.Done(): - case <-time.After(100 * time.Millisecond): - t.Error("Connection context was not cancelled when parent was cancelled") + t.Errorf("cancelling the connection context after successful connection broke the connection") + default: } } @@ -1945,7 +1945,7 @@ data: keepalive } // Process the stream - go conn.processStream("test", resp, testReq) + go conn.processStream(ctx, "test", resp, testReq) // Collect messages with timeout var messages []jsonrpc.Message diff --git a/mcp/transport.go b/mcp/transport.go index ec57b38c..25f1d5d0 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -69,7 +69,7 @@ type Connection interface { type clientConnection interface { Connection - // SessionUpdated is called whenever the client session state changes. + // sessionUpdated is called whenever the client session state changes. sessionUpdated(clientSessionState) } @@ -216,6 +216,18 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params Reason: ctx.Err().Error(), RequestID: call.ID().Raw(), }) + // By default, the jsonrpc2 library waits for graceful shutdown when the + // connection is closed, meaning it expects all outgoing and incoming + // requests to complete. However, for MCP this expectation is unrealistic, + // and can lead to hanging shutdown. For example, if a streamable client is + // killed, the server will not be able to detect this event, except via + // keepalive pings (if they are configured), and so outgoing calls may hang + // indefinitely. + // + // Therefore, we choose to eagerly retire calls, removing them from the + // outgoingCalls map, when the caller context is cancelled: if the caller + // will never receive the response, there's no need to track it. + conn.Retire(call, ctx.Err()) return errors.Join(ctx.Err(), err) case err != nil: return fmt.Errorf("calling %q: %w", method, err)