diff --git a/mcp/streamable.go b/mcp/streamable.go index 7f5ce21b..e3d80bc3 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -329,6 +329,7 @@ func (c *streamableServerConn) SessionID() string { // A stream is a single logical stream of SSE events within a server session. // A stream begins with a client request, or with a client GET that has // no Last-Event-ID header. +// // A stream ends only when its session ends; we cannot determine its end otherwise, // since a client may send a GET with a Last-Event-ID that references the stream // at any time. @@ -529,6 +530,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } c.mu.Unlock() stream.signal.Store(signalChanPtr()) + defer stream.signal.Store(nil) } // Publish incoming messages. @@ -857,27 +859,27 @@ type StreamableReconnectOptions struct { // MaxRetries is the maximum number of times to attempt a reconnect before giving up. // A value of 0 or less means never retry. MaxRetries int - - // growFactor is the multiplicative factor by which the delay increases after each attempt. - // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time. - // It must be 1.0 or greater if MaxRetries is greater than 0. - growFactor float64 - - // initialDelay is the base delay for the first reconnect attempt. - initialDelay time.Duration - - // maxDelay caps the backoff delay, preventing it from growing indefinitely. - maxDelay time.Duration } // DefaultReconnectOptions provides sensible defaults for reconnect logic. var DefaultReconnectOptions = &StreamableReconnectOptions{ - MaxRetries: 5, - growFactor: 1.5, - initialDelay: 1 * time.Second, - maxDelay: 30 * time.Second, + MaxRetries: 5, } +// These settings are not (yet) exposed to the user in +// StreamableReconnectOptions. Since they're invisible, keep them const rather +// than requiring the user to start from DefaultReconnectOptions and mutate. +const ( + // reconnectGrowFactor is the multiplicative factor by which the delay increases after each attempt. + // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time. + // It must be 1.0 or greater if MaxRetries is greater than 0. + reconnectGrowFactor = 1.5 + // reconnectInitialDelay is the base delay for the first reconnect attempt. + reconnectInitialDelay = 1 * time.Second + // reconnectMaxDelay caps the backoff delay, preventing it from growing indefinitely. + reconnectMaxDelay = 30 * time.Second +) + // StreamableClientTransportOptions provides options for the // [NewStreamableClientTransport] constructor. // @@ -928,7 +930,7 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er conn := &streamableClientConn{ url: t.Endpoint, client: client, - incoming: make(chan []byte, 100), + incoming: make(chan jsonrpc.Message, 10), done: make(chan struct{}), ReconnectOptions: reconnOpts, ctx: connCtx, @@ -944,7 +946,7 @@ type streamableClientConn struct { client *http.Client ctx context.Context cancel context.CancelFunc - incoming chan []byte + incoming chan jsonrpc.Message // Guard calls to Close, as it may be called multiple times. closeOnce sync.Once @@ -988,7 +990,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { // ยง 2.5: A server using the Streamable HTTP transport MAY assign a session // ID at initialization time, by including it in an Mcp-Session-Id header // on the HTTP response containing the InitializeResult. - go c.handleSSE(nil, true) + go c.handleSSE(nil, true, nil) } // fail handles an asynchronous error while reading. @@ -1031,8 +1033,8 @@ func (c *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error return nil, c.failure() case <-c.done: return nil, io.EOF - case data := <-c.incoming: - return jsonrpc2.DecodeMessage(data) + case msg := <-c.incoming: + return msg, nil } } @@ -1042,7 +1044,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return err } - data, err := jsonrpc2.EncodeMessage(msg) + data, err := jsonrpc.EncodeMessage(msg) if err != nil { return err } @@ -1088,7 +1090,8 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e go c.handleJSON(resp) case "text/event-stream": - go c.handleSSE(resp, false) + jsonReq, _ := msg.(*jsonrpc.Request) + go c.handleSSE(resp, false, jsonReq) default: resp.Body.Close() @@ -1116,8 +1119,13 @@ func (c *streamableClientConn) handleJSON(resp *http.Response) { c.fail(err) return } + msg, err := jsonrpc.DecodeMessage(body) + if err != nil { + c.fail(fmt.Errorf("failed to decode response: %v", err)) + return + } select { - case c.incoming <- body: + case c.incoming <- msg: case <-c.done: // The connection was closed by the client; exit gracefully. } @@ -1125,21 +1133,26 @@ func (c *streamableClientConn) handleJSON(resp *http.Response) { // handleSSE manages the lifecycle of an SSE connection. It can be either // persistent (for the main GET listener) or temporary (for a POST response). -func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool) { +// +// If forReq is set, it is the request that initiated the stream, and the +// stream is complete when we receive its response. +func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool, forReq *jsonrpc2.Request) { resp := initialResp var lastEventID string for { - eventID, clientClosed := c.processStream(resp) - lastEventID = eventID + if resp != nil { + eventID, clientClosed := c.processStream(resp, forReq) + lastEventID = eventID - // If the connection was closed by the client, we're done. - if clientClosed { - return - } - // If the stream has ended, then do not reconnect if the stream is - // temporary (POST initiated SSE). - if lastEventID == "" && !persistent { - return + // If the connection was closed by the client, we're done. + if clientClosed { + return + } + // If the stream has ended, then do not reconnect if the stream is + // temporary (POST initiated SSE). + if lastEventID == "" && !persistent { + return + } } // The stream was interrupted or ended by the server. Attempt to reconnect. @@ -1159,12 +1172,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent // incoming channel. It returns the ID of the last processed event and a flag // indicating if the connection was closed by the client. If resp is nil, it // returns "", false. -func (c *streamableClientConn) processStream(resp *http.Response) (lastEventID string, clientClosed bool) { - if resp == nil { - // TODO(rfindley): avoid this special handling. - return "", false - } - +func (c *streamableClientConn) processStream(resp *http.Response, forReq *jsonrpc.Request) (lastEventID string, clientClosed bool) { defer resp.Body.Close() for evt, err := range scanEvents(resp.Body) { if err != nil { @@ -1175,8 +1183,21 @@ func (c *streamableClientConn) processStream(resp *http.Response) (lastEventID s lastEventID = evt.ID } + msg, err := jsonrpc.DecodeMessage(evt.Data) + if err != nil { + c.fail(fmt.Errorf("failed to decode event: %v", err)) + return "", true + } + select { - case c.incoming <- evt.Data: + case c.incoming <- msg: + if jsonResp, ok := msg.(*jsonrpc.Response); ok && forReq != nil { + // TODO: we should never get a response when forReq is nil (the hanging GET). + // We should detect this case, and eliminate the 'persistent' flag arguments. + if jsonResp.ID == forReq.ID { + return "", true + } + } case <-c.done: // The connection was closed by the client; exit gracefully. return "", true @@ -1192,11 +1213,20 @@ func (c *streamableClientConn) processStream(resp *http.Response) (lastEventID s func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) { var finalErr error - for attempt := 0; attempt < c.ReconnectOptions.MaxRetries; attempt++ { + // We can reach the 'reconnect' path through the hanging GET, in which case + // lastEventID will be "". + // + // In this case, we need an initial attempt. + attempt := 0 + if lastEventID != "" { + attempt = 1 + } + + for ; attempt <= c.ReconnectOptions.MaxRetries; attempt++ { select { case <-c.done: return nil, fmt.Errorf("connection closed by client during reconnect") - case <-time.After(calculateReconnectDelay(c.ReconnectOptions, attempt)): + case <-time.After(calculateReconnectDelay(attempt)): resp, err := c.establishSSE(lastEventID) if err != nil { finalErr = err // Store the error and try again. @@ -1267,11 +1297,11 @@ func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response, } // calculateReconnectDelay calculates a delay using exponential backoff with full jitter. -func calculateReconnectDelay(opts *StreamableReconnectOptions, attempt int) time.Duration { +func calculateReconnectDelay(attempt int) time.Duration { // Calculate the exponential backoff using the grow factor. - backoffDuration := time.Duration(float64(opts.initialDelay) * math.Pow(opts.growFactor, float64(attempt))) + backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt))) // Cap the backoffDuration at maxDelay. - backoffDuration = min(backoffDuration, opts.maxDelay) + backoffDuration = min(backoffDuration, reconnectMaxDelay) // Use a full jitter using backoffDuration jitter := rand.N(backoffDuration) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 55aadb6a..11600fbc 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -16,6 +16,7 @@ import ( "net/http/httptest" "net/http/httputil" "net/url" + "sort" "strings" "sync" "sync/atomic" @@ -186,12 +187,30 @@ func TestStreamableTransports(t *testing.T) { } } -// TestClientReplay verifies that the client can recover from a -// mid-stream network failure and receive replayed messages. It uses a proxy -// that is killed and restarted to simulate a recoverable network outage. +// TestClientReplay verifies that the client can recover from a mid-stream +// network failure and receive replayed messages (if replay is configured). It +// uses a proxy that is killed and restarted to simulate a recoverable network +// outage. func TestClientReplay(t *testing.T) { + for _, test := range []clientReplayTest{ + {"default", nil, true}, + {"no retries", &StreamableReconnectOptions{}, false}, + } { + t.Run(test.name, func(t *testing.T) { + testClientReplay(t, test) + }) + } +} + +type clientReplayTest struct { + name string + options *StreamableReconnectOptions + wantRecovered bool +} + +func testClientReplay(t *testing.T, test clientReplayTest) { notifications := make(chan string) - // 1. Configure the real MCP server. + // Configure the real MCP server. server := NewServer(testImpl, nil) // Use a channel to synchronize the server's message sending with the test's @@ -200,23 +219,24 @@ func TestClientReplay(t *testing.T) { serverClosed := make(chan struct{}) server.AddTool(&Tool{Name: "multiMessageTool", InputSchema: &jsonschema.Schema{}}, func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) { - go func() { - bgCtx := context.Background() - // Send the first two messages immediately. - req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg1"}) - req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg2"}) - - // Signal the test that it can now kill the proxy. - close(serverReadyToKillProxy) - <-serverClosed - - // These messages should be queued for replay by the server after - // the client's connection drops. - req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg3"}) - req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg4"}) - }() - return &CallToolResult{}, nil + // Send one message to the request context, and another to a background + // context (which will end up on the hanging GET). + + bgCtx := context.Background() + req.Session.NotifyProgress(ctx, &ProgressNotificationParams{Message: "msg1"}) + req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg2"}) + + // Signal the test that it can now kill the proxy. + close(serverReadyToKillProxy) + <-serverClosed + + // These messages should be queued for replay by the server after + // the client's connection drops. + req.Session.NotifyProgress(ctx, &ProgressNotificationParams{Message: "msg3"}) + req.Session.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg4"}) + return new(CallToolResult), nil }) + realServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)) defer realServer.Close() realServerURL, err := url.Parse(realServer.URL) @@ -224,12 +244,12 @@ func TestClientReplay(t *testing.T) { t.Fatalf("Failed to parse real server URL: %v", err) } - // 2. Configure a proxy that sits between the client and the real server. + // Configure a proxy that sits between the client and the real server. proxyHandler := httputil.NewSingleHostReverseProxy(realServerURL) proxy := httptest.NewServer(proxyHandler) proxyAddr := proxy.Listener.Addr().String() // Get the address to restart it later. - // 3. Configure the client to connect to the proxy with default options. + // Configure the client to connect to the proxy with default options. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() client := NewClient(testImpl, &ClientOptions{ @@ -237,20 +257,24 @@ func TestClientReplay(t *testing.T) { notifications <- req.Params.Message }, }) - clientSession, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: proxy.URL}, nil) + clientSession, err := client.Connect(ctx, &StreamableClientTransport{ + Endpoint: proxy.URL, + ReconnectOptions: test.options, + }, nil) if err != nil { t.Fatalf("client.Connect() failed: %v", err) } defer clientSession.Close() - clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"}) - - // 4. Read and verify messages until the server signals it's ready for the proxy kill. - receivedNotifications := readNotifications(t, ctx, notifications, 2) - wantReceived := []string{"msg1", "msg2"} - if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" { - t.Errorf("Received notifications mismatch (-want +got):\n%s", diff) - } + var ( + wg sync.WaitGroup + callErr error + ) + wg.Add(1) + go func() { + defer wg.Done() + _, callErr = clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"}) + }() select { case <-serverReadyToKillProxy: @@ -259,32 +283,50 @@ func TestClientReplay(t *testing.T) { t.Fatalf("Context timed out before server was ready to kill proxy") } - // 5. Simulate a total network failure by closing the proxy. + // We should always get the first two notifications. + msgs := readNotifications(t, ctx, notifications, 2) + sort.Strings(msgs) // notifications may arrive in either order + want := []string{"msg1", "msg2"} + if diff := cmp.Diff(want, msgs); diff != "" { + t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff) + } + + // Simulate a total network failure by closing the proxy. t.Log("--- Killing proxy to simulate network failure ---") proxy.CloseClientConnections() proxy.Close() close(serverClosed) - // 6. Simulate network recovery by restarting the proxy on the same address. + // Simulate network recovery by restarting the proxy on the same address. t.Logf("--- Restarting proxy on %s ---", proxyAddr) listener, err := net.Listen("tcp", proxyAddr) if err != nil { t.Fatalf("Failed to listen on proxy address: %v", err) } + restartedProxy := &http.Server{Handler: proxyHandler} go restartedProxy.Serve(listener) defer restartedProxy.Close() - // 7. Continue reading from the same connection object. - // Its internal logic should successfully retry, reconnect to the new proxy, - // and receive the replayed messages. - recoveredNotifications := readNotifications(t, ctx, notifications, 2) + wg.Wait() - // 8. Verify the correct messages were received on the recovered connection. - wantRecovered := []string{"msg3", "msg4"} - - if diff := cmp.Diff(wantRecovered, recoveredNotifications); diff != "" { - t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff) + if test.wantRecovered { + // If we've recovered, we should get all 4 notifications and the tool call + // should have succeeded. + msgs := readNotifications(t, ctx, notifications, 2) + sort.Strings(msgs) + want := []string{"msg3", "msg4"} + if diff := cmp.Diff(want, msgs); diff != "" { + t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff) + } + if callErr != nil { + t.Errorf("CallTool failed unexpectedly: %v", err) + } + } else { + // Otherwise, the call should fail. + if callErr == nil { + t.Errorf("CallTool succeeded unexpectedly") + } } }