From a75c670e0a5392d45bf36a7fd62ee4ec82685901 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Tue, 9 Sep 2025 14:36:19 +0000 Subject: [PATCH] mcp: improve error messages from Wait for streamable clients Previously, the error message received from ClientSession.Wait would only report the closeErr, which would often be nil even if the client transport was broken. Wait should return the reason the session terminated, if abnormal. I'm not sure of the exact semantics of this, but surely returning nil is less useful than returning a meaningful non-nil error. We can refine our handling of errors once we have more feedback. Also add a test for client termination on HTTP server shutdown, described in #265. This should work as long as (1) the session is stateful (with a hanging GET), or (2) the session is stateless but the client has a keepalive ping. Also: don't send DELETE if the session was terminated with 404; +test. Fixes #265 --- internal/jsonrpc2/conn.go | 27 +++++++++++-- mcp/streamable.go | 70 +++++++++++++++++++++++++-------- mcp/streamable_client_test.go | 52 +++++++++++++++++++++++- mcp/streamable_test.go | 74 +++++++++++++++++++++++++++++++++++ 4 files changed, 202 insertions(+), 21 deletions(-) diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 49902b00..5549ee1c 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -483,10 +483,32 @@ func (c *Connection) Cancel(id ID) { // Wait blocks until the connection is fully closed, but does not close it. func (c *Connection) Wait() error { + return c.wait(true) +} + +// wait for the connection to close, and aggregates the most cause of its +// termination, if abnormal. +// +// The fromWait argument allows this logic to be shared with Close, where we +// only want to expose the closeErr. +// +// (Previously, Wait also only returned the closeErr, which was misleading if +// the connection was broken for another reason). +func (c *Connection) wait(fromWait bool) error { var err error <-c.done c.updateInFlight(func(s *inFlightState) { - err = s.closeErr + if fromWait { + if !errors.Is(s.readErr, io.EOF) { + err = s.readErr + } + if err == nil && !errors.Is(s.writeErr, io.EOF) { + err = s.writeErr + } + } + if err == nil { + err = s.closeErr + } }) return err } @@ -502,8 +524,7 @@ func (c *Connection) Close() error { // Stop handling new requests, and interrupt the reader (by closing the // connection) as soon as the active requests finish. c.updateInFlight(func(s *inFlightState) { s.connClosing = true }) - - return c.Wait() + return c.wait(false) } // readIncoming collects inbound messages from the reader and delivers them, either responding diff --git a/mcp/streamable.go b/mcp/streamable.go index 1eef9a74..3e3cc6f8 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1119,6 +1119,17 @@ type streamableClientConn struct { sessionID string } +// errSessionMissing distinguishes if the session is known to not be present on +// the server (see [streamableClientConn.fail]). +// +// TODO(rfindley): should we expose this error value (and its corresponding +// API) to the user? +// +// The spec says that if the server returns 404, clients should reestablish +// a session. For now, we delegate that to the user, but do they need a way to +// differentiate a 'NotFound' error from other errors? +var errSessionMissing = errors.New("session not found") + var _ clientConnection = (*streamableClientConn)(nil) func (c *streamableClientConn) sessionUpdated(state clientSessionState) { @@ -1146,6 +1157,10 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { // // If err is non-nil, it is terminal, and subsequent (or pending) Reads will // fail. +// +// If err wraps errSessionMissing, the failure indicates that the session is no +// longer present on the server, and no final DELETE will be performed when +// closing the connection. func (c *streamableClientConn) fail(err error) { if err != nil { c.failOnce.Do(func() { @@ -1193,9 +1208,19 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return err } + var requestSummary string + switch msg := msg.(type) { + case *jsonrpc.Request: + requestSummary = fmt.Sprintf("sending %q", msg.Method) + case *jsonrpc.Response: + requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID) + default: + panic("unreachable") + } + data, err := jsonrpc.EncodeMessage(msg) if err != nil { - return err + return fmt.Errorf("%s: %v", requestSummary, err) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(data)) @@ -1208,9 +1233,21 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e resp, err := c.client.Do(req) if err != nil { + return fmt.Errorf("%s: %v", requestSummary, err) + } + + // Section 2.5.3: "The server MAY terminate the session at any time, after + // which it MUST respond to requests containing that session ID with HTTP + // 404 Not Found." + if resp.StatusCode == http.StatusNotFound { + // Fail the session immediately, rather than relying on jsonrpc2 to fail + // (and close) it, because we want the call to Close to know that this + // session is missing (and therefore not send the DELETE). + err := fmt.Errorf("%s: failed to send: %w", requestSummary, errSessionMissing) + c.fail(err) + resp.Body.Close() return err } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { resp.Body.Close() return fmt.Errorf("broken session: %v", resp.Status) @@ -1233,16 +1270,6 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return nil } - var requestSummary string - switch msg := msg.(type) { - case *jsonrpc.Request: - requestSummary = fmt.Sprintf("sending %q", msg.Method) - case *jsonrpc.Response: - requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID) - default: - panic("unreachable") - } - switch ct := resp.Header.Get("Content-Type"); ct { case "application/json": go c.handleJSON(requestSummary, resp) @@ -1333,6 +1360,11 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt resp.Body.Close() return } + // (see equivalent handling in [streamableClientConn.Write]). + if resp.StatusCode == http.StatusNotFound { + c.fail(fmt.Errorf("%s: failed to reconnect: %w", requestSummary, errSessionMissing)) + return + } if resp.StatusCode < 200 || resp.StatusCode >= 300 { resp.Body.Close() c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, http.StatusText(resp.StatusCode))) @@ -1423,13 +1455,17 @@ func (c *streamableClientConn) Close() error { c.cancel() close(c.done) - req, err := http.NewRequest(http.MethodDelete, c.url, nil) - if err != nil { - c.closeErr = err + if errors.Is(c.failure(), errSessionMissing) { + // If the session is missing, no need to delete it. } else { - c.setMCPHeaders(req) - if _, err := c.client.Do(req); err != nil { + req, err := http.NewRequest(http.MethodDelete, c.url, nil) + if err != nil { c.closeErr = err + } else { + c.setMCPHeaders(req) + if _, err := c.client.Do(req); err != nil { + c.closeErr = err + } } } }) diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index fe87b21c..001d3a64 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -29,7 +29,7 @@ type streamableRequestKey struct { type header map[string]string type streamableResponse struct { - header header + header header // response headers status int // or http.StatusOK body string // or "" optional bool // if set, request need not be sent @@ -187,6 +187,56 @@ func TestStreamableClientTransportLifecycle(t *testing.T) { } } +func TestStreamableClientRedundantDelete(t *testing.T) { + ctx := context.Background() + + // The lifecycle test verifies various behavior of the streamable client + // initialization: + // - check that it can handle application/json responses + // - check that it sends the negotiated protocol version + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", ""}: { + status: http.StatusMethodNotAllowed, + optional: true, + }, + {"POST", "123", methodListTools}: { + status: http.StatusNotFound, + }, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + _, err = session.ListTools(ctx, nil) + if err == nil { + t.Errorf("Listing tools: got nil error, want non-nil") + } + _ = session.Wait() // must not hang + if missing := fake.missingRequests(); len(missing) > 0 { + t.Errorf("did not receive expected requests: %v", missing) + } +} + func TestStreamableClientGETHandling(t *testing.T) { ctx := context.Background() diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index a85fbec0..f0da3dc9 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -190,6 +190,80 @@ func TestStreamableTransports(t *testing.T) { } } +func TestStreamableServerShutdown(t *testing.T) { + ctx := context.Background() + + // This test checks that closing the streamable HTTP server actually results + // in client session termination, provided one of following holds: + // 1. The server is stateful, and therefore the hanging GET fails the connection. + // 2. The server is stateless, and the client uses a KeepAlive. + tests := []struct { + name string + stateless, keepalive bool + }{ + {"stateful", false, false}, + {"stateless with keepalive", true, true}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + server := NewServer(testImpl, nil) + // Add a tool, just so we can check things are working. + AddTool(server, &Tool{Name: "greet"}, sayHi) + + handler := NewStreamableHTTPHandler( + func(req *http.Request) *Server { return server }, + &StreamableHTTPOptions{Stateless: test.stateless}) + + // When we shut down the server, we need to explicitly close ongoing + // connections. Otherwise, the hanging GET may never terminate. + httpServer := httptest.NewUnstartedServer(handler) + httpServer.Config.RegisterOnShutdown(func() { + for session := range server.Sessions() { + session.Close() + } + }) + httpServer.Start() + defer httpServer.Close() + + // Connect and run a tool. + var opts ClientOptions + if test.keepalive { + opts.KeepAlive = 50 * time.Millisecond + } + client := NewClient(testImpl, &opts) + clientSession, err := client.Connect(ctx, &StreamableClientTransport{ + Endpoint: httpServer.URL, + MaxRetries: -1, // avoid slow tests during exponential retries + }, nil) + if err != nil { + t.Fatal(err) + } + + params := &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"name": "foo"}, + } + // Verify that we can call a tool. + if _, err := clientSession.CallTool(ctx, params); err != nil { + t.Fatalf("CallTool() failed: %v", err) + } + + // Shut down the server. Sessions should terminate. + go func() { + if err := httpServer.Config.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Errorf("closing http server: %v", err) + } + }() + + // Wait may return an error (after all, the connection failed), but it + // should not hang. + t.Log("Client waiting") + _ = clientSession.Wait() + }) + } +} + // TestClientReplay verifies that the client can recover from a mid-stream // network failure and receive replayed messages (if replay is configured). It // uses a proxy that is killed and restarted to simulate a recoverable network