From 8c85ee70ec90a82b03ef323086a336b3cd967c3a Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 4 Sep 2025 15:34:39 +0000 Subject: [PATCH 1/2] mcp: refactor the streamable client test to be more flexible Use a fake streamable server to facilitate testing client behavior. For this commit, just update the existing test (moved to a new file for isolation). Subsequent CLs will add more tests. Improve one client error message that occurred while debuging tests. For #393 --- mcp/streamable.go | 9 +- mcp/streamable_client_test.go | 185 ++++++++++++++++++++++++++++++++++ mcp/streamable_test.go | 71 ------------- 3 files changed, 193 insertions(+), 72 deletions(-) create mode 100644 mcp/streamable_client_test.go diff --git a/mcp/streamable.go b/mcp/streamable.go index e7777eb0..7a407538 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1234,7 +1234,14 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e default: resp.Body.Close() - return fmt.Errorf("unsupported content type %q", ct) + switch msg := msg.(type) { + case *jsonrpc.Request: + return fmt.Errorf("unsupported content type %q when sending %q (status: %d)", ct, msg.Method, resp.StatusCode) + case *jsonrpc.Response: + return fmt.Errorf("unsupported content type %q when sending jsonrpc response #%d (status: %d)", ct, msg.ID, resp.StatusCode) + default: + panic("unreachable") + } } return nil } diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go new file mode 100644 index 00000000..ee89df5a --- /dev/null +++ b/mcp/streamable_client_test.go @@ -0,0 +1,185 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +type streamableRequestKey struct { + httpMethod string // http method + sessionID string // session ID header + jsonrpcMethod string // jsonrpc method, or "" for non-requests +} + +type header map[string]string + +type streamableResponse struct { + header header + status int // or http.StatusOK + body string // or "" + optional bool // if set, request need not be sent + wantProtocolVersion string // if "", unchecked + callback func() // if set, called after the request is handled +} + +type fakeResponses map[streamableRequestKey]*streamableResponse + +type fakeStreamableServer struct { + t *testing.T + responses fakeResponses + + callMu sync.Mutex + calls map[streamableRequestKey]int +} + +func (s *fakeStreamableServer) missingRequests() []streamableRequestKey { + s.callMu.Lock() + defer s.callMu.Unlock() + + var unused []streamableRequestKey + for k, resp := range s.responses { + if s.calls[k] == 0 && !resp.optional { + unused = append(unused, k) + } + } + return unused +} + +func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + key := streamableRequestKey{ + httpMethod: req.Method, + sessionID: req.Header.Get(sessionIDHeader), + } + if req.Method == http.MethodPost { + body, err := io.ReadAll(req.Body) + if err != nil { + s.t.Errorf("failed to read body: %v", err) + http.Error(w, "failed to read body", http.StatusInternalServerError) + return + } + msg, err := jsonrpc.DecodeMessage(body) + if err != nil { + s.t.Errorf("invalid body: %v", err) + http.Error(w, "invalid body", http.StatusInternalServerError) + return + } + if r, ok := msg.(*jsonrpc.Request); ok { + key.jsonrpcMethod = r.Method + } + } + + s.callMu.Lock() + if s.calls == nil { + s.calls = make(map[streamableRequestKey]int) + } + s.calls[key]++ + s.callMu.Unlock() + + resp, ok := s.responses[key] + if !ok { + s.t.Errorf("missing response for %v", key) + http.Error(w, "no response", http.StatusInternalServerError) + return + } + if resp.callback != nil { + defer resp.callback() + } + for k, v := range resp.header { + w.Header().Set(k, v) + } + status := resp.status + if status == 0 { + status = http.StatusOK + } + w.WriteHeader(status) + + if v := req.Header.Get(protocolVersionHeader); v != resp.wantProtocolVersion && resp.wantProtocolVersion != "" { + s.t.Errorf("%v: bad protocol version header: got %q, want %q", key, v, resp.wantProtocolVersion) + } + w.Write([]byte(resp.body)) +} + +var ( + initResult = &InitializeResult{ + Capabilities: &ServerCapabilities{ + Completions: &CompletionCapabilities{}, + Logging: &LoggingCapabilities{}, + Tools: &ToolCapabilities{ListChanged: true}, + }, + ProtocolVersion: latestProtocolVersion, + ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, + } + initResp = resp(1, initResult, nil) +) + +func jsonBody(t *testing.T, msg jsonrpc2.Message) string { + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + t.Fatalf("encoding failed: %v", err) + } + return string(data) +} + +func TestStreamableClientTransportLifecycle(t *testing.T) { + ctx := context.Background() + + // The lifecycle test verifies various behavior of the streamable client + // initialization: + // - check that it can handle application/json responses + // - check that it sends the negotiated protocol version + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", ""}: { + header: header{ + "Content-Type": "text/event-stream", + }, + optional: true, + wantProtocolVersion: latestProtocolVersion, + }, + {"DELETE", "123", ""}: {}, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + if err := session.Close(); err != nil { + t.Errorf("closing session: %v", err) + } + if missing := fake.missingRequests(); len(missing) > 0 { + t.Errorf("did not receive expected requests: %v", missing) + } + if diff := cmp.Diff(initResult, session.state.InitializeResult); diff != "" { + t.Errorf("mismatch (-want, +got):\n%s", diff) + } +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 0d171d83..2963a04d 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1035,77 +1035,6 @@ func mustMarshal(v any) json.RawMessage { return data } -func TestStreamableClientTransport(t *testing.T) { - // This test verifies various behavior of the streamable client transport: - // - check that it can handle application/json responses - // - check that it sends the negotiated protocol version - // - // TODO(rfindley): make this test more comprehensive, similar to - // [TestStreamableServerTransport]. - ctx := context.Background() - resp := func(id int64, result any, err error) *jsonrpc.Response { - return &jsonrpc.Response{ - ID: jsonrpc2.Int64ID(id), - Result: mustMarshal(result), - Error: err, - } - } - initResult := &InitializeResult{ - Capabilities: &ServerCapabilities{ - Completions: &CompletionCapabilities{}, - Logging: &LoggingCapabilities{}, - Tools: &ToolCapabilities{ListChanged: true}, - }, - ProtocolVersion: latestProtocolVersion, - ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, - } - initResp := resp(1, initResult, nil) - - var reqN atomic.Int32 // request count - serverHandler := func(w http.ResponseWriter, r *http.Request) { - rN := reqN.Add(1) - - // TODO(rfindley): if the status code is NoContent or Accepted, we should - // probably be tolerant of when the content type is not application/json. - w.Header().Set("Content-Type", "application/json") - if rN == 1 { - data, err := jsonrpc2.EncodeMessage(initResp) - if err != nil { - t.Errorf("encoding failed: %v", err) - } - w.Header().Set("Mcp-Session-Id", "123") - w.Write(data) - } else { - if v := r.Header.Get(protocolVersionHeader); v != latestProtocolVersion { - t.Errorf("bad protocol version header: got %q, want %q", v, latestProtocolVersion) - } - } - } - - httpServer := httptest.NewServer(http.HandlerFunc(serverHandler)) - defer httpServer.Close() - - transport := &StreamableClientTransport{Endpoint: httpServer.URL} - client := NewClient(testImpl, nil) - session, err := client.Connect(ctx, transport, nil) - if err != nil { - t.Fatalf("client.Connect() failed: %v", err) - } - if err := session.Close(); err != nil { - t.Errorf("closing session: %v", err) - } - - if got, want := reqN.Load(), int32(3); got < want { - // Expect at least 3 requests: initialize, initialized, and DELETE. - // We may or may not observe the GET, depending on timing. - t.Errorf("unexpected number of requests: got %d, want at least %d", got, want) - } - - if diff := cmp.Diff(initResult, session.state.InitializeResult); diff != "" { - t.Errorf("mismatch (-want, +got):\n%s", diff) - } -} - func TestEventID(t *testing.T) { tests := []struct { sid StreamID From c2bc0de1fe6eb4ca9f398e78f6ba76fcbc83776b Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 4 Sep 2025 17:16:14 +0000 Subject: [PATCH 2/2] mcp: systematically improve streamable client errors The streamable client connection can break for a variety of reasons, asynchronously to the client's request. Decorate these failures with additional context to clarify why they occurred. Add a test for the failure message of #393. Fixes #393 --- mcp/streamable.go | 43 +++++++++-------- mcp/streamable_client_test.go | 90 +++++++++++++++++++++++++++++++++++ mcp/transport.go | 2 +- 3 files changed, 114 insertions(+), 21 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 7a407538..25efe31a 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1130,7 +1130,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { // ยง 2.5: A server using the Streamable HTTP transport MAY assign a session // ID at initialization time, by including it in an Mcp-Session-Id header // on the HTTP response containing the InitializeResult. - go c.handleSSE(nil, true, nil) + go c.handleSSE("hanging GET", nil, true, nil) } // fail handles an asynchronous error while reading. @@ -1224,24 +1224,27 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return nil } + var requestSummary string + switch msg := msg.(type) { + case *jsonrpc.Request: + requestSummary = fmt.Sprintf("sending %q", msg.Method) + case *jsonrpc.Response: + requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID) + default: + panic("unreachable") + } + switch ct := resp.Header.Get("Content-Type"); ct { case "application/json": - go c.handleJSON(resp) + go c.handleJSON(requestSummary, resp) case "text/event-stream": jsonReq, _ := msg.(*jsonrpc.Request) - go c.handleSSE(resp, false, jsonReq) + go c.handleSSE(requestSummary, resp, false, jsonReq) default: resp.Body.Close() - switch msg := msg.(type) { - case *jsonrpc.Request: - return fmt.Errorf("unsupported content type %q when sending %q (status: %d)", ct, msg.Method, resp.StatusCode) - case *jsonrpc.Response: - return fmt.Errorf("unsupported content type %q when sending jsonrpc response #%d (status: %d)", ct, msg.ID, resp.StatusCode) - default: - panic("unreachable") - } + return fmt.Errorf("%s: unsupported content type %q", requestSummary, ct) } return nil } @@ -1265,16 +1268,16 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) { } } -func (c *streamableClientConn) handleJSON(resp *http.Response) { +func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Response) { body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - c.fail(err) + c.fail(fmt.Errorf("%s: failed to read body: %v", requestSummary, err)) return } msg, err := jsonrpc.DecodeMessage(body) if err != nil { - c.fail(fmt.Errorf("failed to decode response: %v", err)) + c.fail(fmt.Errorf("%s: failed to decode response: %v", requestSummary, err)) return } select { @@ -1289,12 +1292,12 @@ func (c *streamableClientConn) handleJSON(resp *http.Response) { // // If forReq is set, it is the request that initiated the stream, and the // stream is complete when we receive its response. -func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool, forReq *jsonrpc2.Request) { +func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *http.Response, persistent bool, forReq *jsonrpc2.Request) { resp := initialResp var lastEventID string for { if resp != nil { - eventID, clientClosed := c.processStream(resp, forReq) + eventID, clientClosed := c.processStream(requestSummary, resp, forReq) lastEventID = eventID // If the connection was closed by the client, we're done. @@ -1312,7 +1315,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent newResp, err := c.reconnect(lastEventID) if err != nil { // All reconnection attempts failed: fail the connection. - c.fail(err) + c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, err)) return } resp = newResp @@ -1323,7 +1326,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent } if resp.StatusCode < 200 || resp.StatusCode >= 300 { resp.Body.Close() - c.fail(fmt.Errorf("failed to reconnect: %v", http.StatusText(resp.StatusCode))) + c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, http.StatusText(resp.StatusCode))) return } // Reconnection was successful. Continue the loop with the new response. @@ -1334,7 +1337,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent // incoming channel. It returns the ID of the last processed event and a flag // indicating if the connection was closed by the client. If resp is nil, it // returns "", false. -func (c *streamableClientConn) processStream(resp *http.Response, forReq *jsonrpc.Request) (lastEventID string, clientClosed bool) { +func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forReq *jsonrpc.Request) (lastEventID string, clientClosed bool) { defer resp.Body.Close() for evt, err := range scanEvents(resp.Body) { if err != nil { @@ -1347,7 +1350,7 @@ func (c *streamableClientConn) processStream(resp *http.Response, forReq *jsonrp msg, err := jsonrpc.DecodeMessage(evt.Data) if err != nil { - c.fail(fmt.Errorf("failed to decode event: %v", err)) + c.fail(fmt.Errorf("%s: failed to decode event: %v", requestSummary, err)) return "", true } diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index ee89df5a..fe87b21c 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -6,11 +6,14 @@ package mcp import ( "context" + "fmt" "io" "net/http" "net/http/httptest" + "strings" "sync" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" @@ -183,3 +186,90 @@ func TestStreamableClientTransportLifecycle(t *testing.T) { t.Errorf("mismatch (-want, +got):\n%s", diff) } } + +func TestStreamableClientGETHandling(t *testing.T) { + ctx := context.Background() + + tests := []struct { + status int + wantErrorContaining string + }{ + {http.StatusOK, ""}, + {http.StatusMethodNotAllowed, ""}, + {http.StatusBadRequest, "hanging GET"}, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("status=%d", test.status), func(t *testing.T) { + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", ""}: { + header: header{ + "Content-Type": "text/event-stream", + }, + status: test.status, + wantProtocolVersion: latestProtocolVersion, + }, + {"POST", "123", methodListTools}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, resp(2, &ListToolsResult{Tools: []*Tool{}}, nil)), + optional: true, + }, + {"DELETE", "123", ""}: {optional: true}, + }, + } + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + + // wait for all required requests to be handled, with exponential + // backoff. + start := time.Now() + delay := 1 * time.Millisecond + for range 10 { + if len(fake.missingRequests()) == 0 { + break + } + time.Sleep(delay) + delay *= 2 + } + if missing := fake.missingRequests(); len(missing) > 0 { + t.Errorf("did not receive expected requests after %s: %v", time.Since(start), missing) + } + + _, err = session.ListTools(ctx, nil) + if (err != nil) != (test.wantErrorContaining != "") { + t.Errorf("After initialization, got error %v, want %v", err, test.wantErrorContaining) + } else if err != nil { + if !strings.Contains(err.Error(), test.wantErrorContaining) { + t.Errorf("After initialization, got error %s, want containing %q", err, test.wantErrorContaining) + } + } + + if err := session.Close(); err != nil { + t.Errorf("closing session: %v", err) + } + }) + } +} diff --git a/mcp/transport.go b/mcp/transport.go index fac640a6..5c7ca130 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -194,7 +194,7 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params err := call.Await(ctx, result) switch { case errors.Is(err, jsonrpc2.ErrClientClosing), errors.Is(err, jsonrpc2.ErrServerClosing): - return fmt.Errorf("calling %q: %w", method, ErrConnectionClosed) + return fmt.Errorf("%w: calling %q: %v", ErrConnectionClosed, method, err) case ctx.Err() != nil: // Notify the peer of cancellation. err := conn.Notify(xcontext.Detach(ctx), notificationCancelled, &CancelledParams{