diff --git a/mcp/streamable.go b/mcp/streamable.go index e3d80bc3..9ae20c02 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -849,26 +849,15 @@ func (c *streamableServerConn) Close() error { // endpoint serving the streamable HTTP transport defined by the 2025-03-26 // version of the spec. type StreamableClientTransport struct { - Endpoint string - HTTPClient *http.Client - ReconnectOptions *StreamableReconnectOptions -} - -// StreamableReconnectOptions defines parameters for client reconnect attempts. -type StreamableReconnectOptions struct { + Endpoint string + HTTPClient *http.Client // MaxRetries is the maximum number of times to attempt a reconnect before giving up. - // A value of 0 or less means never retry. + // It defaults to 5. To disable retries, use a negative number. MaxRetries int } -// DefaultReconnectOptions provides sensible defaults for reconnect logic. -var DefaultReconnectOptions = &StreamableReconnectOptions{ - MaxRetries: 5, -} - // These settings are not (yet) exposed to the user in -// StreamableReconnectOptions. Since they're invisible, keep them const rather -// than requiring the user to start from DefaultReconnectOptions and mutate. +// StreamableClientTransport. const ( // reconnectGrowFactor is the multiplicative factor by which the delay increases after each attempt. // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time. @@ -887,8 +876,10 @@ const ( type StreamableClientTransportOptions struct { // HTTPClient is the client to use for making HTTP requests. If nil, // http.DefaultClient is used. - HTTPClient *http.Client - ReconnectOptions *StreamableReconnectOptions + HTTPClient *http.Client + // MaxRetries is the maximum number of times to attempt a reconnect before giving up. + // It defaults to 5. To disable retries, use a negative number. + MaxRetries int } // NewStreamableClientTransport returns a new client transport that connects to @@ -901,7 +892,7 @@ func NewStreamableClientTransport(url string, opts *StreamableClientTransportOpt t := &StreamableClientTransport{Endpoint: url} if opts != nil { t.HTTPClient = opts.HTTPClient - t.ReconnectOptions = opts.ReconnectOptions + t.MaxRetries = opts.MaxRetries } return t } @@ -919,34 +910,36 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er if client == nil { client = http.DefaultClient } - reconnOpts := t.ReconnectOptions - if reconnOpts == nil { - reconnOpts = DefaultReconnectOptions + maxRetries := t.MaxRetries + if maxRetries == 0 { + maxRetries = 5 + } else if maxRetries < 0 { + maxRetries = 0 } // Create a new cancellable context that will manage the connection's lifecycle. // This is crucial for cleanly shutting down the background SSE listener by // cancelling its blocking network operations, which prevents hangs on exit. connCtx, cancel := context.WithCancel(context.Background()) conn := &streamableClientConn{ - url: t.Endpoint, - client: client, - incoming: make(chan jsonrpc.Message, 10), - done: make(chan struct{}), - ReconnectOptions: reconnOpts, - ctx: connCtx, - cancel: cancel, - failed: make(chan struct{}), + url: t.Endpoint, + client: client, + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + maxRetries: maxRetries, + ctx: connCtx, + cancel: cancel, + failed: make(chan struct{}), } return conn, nil } type streamableClientConn struct { - url string - ReconnectOptions *StreamableReconnectOptions - client *http.Client - ctx context.Context - cancel context.CancelFunc - incoming chan jsonrpc.Message + url string + client *http.Client + ctx context.Context + cancel context.CancelFunc + incoming chan jsonrpc.Message + maxRetries int // Guard calls to Close, as it may be called multiple times. closeOnce sync.Once @@ -1222,7 +1215,7 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er attempt = 1 } - for ; attempt <= c.ReconnectOptions.MaxRetries; attempt++ { + for ; attempt <= c.maxRetries; attempt++ { select { case <-c.done: return nil, fmt.Errorf("connection closed by client during reconnect") @@ -1244,9 +1237,9 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er } // If the loop completes, all retries have failed. if finalErr != nil { - return nil, fmt.Errorf("connection failed after %d attempts: %w", c.ReconnectOptions.MaxRetries, finalErr) + return nil, fmt.Errorf("connection failed after %d attempts: %w", c.maxRetries, finalErr) } - return nil, fmt.Errorf("connection failed after %d attempts", c.ReconnectOptions.MaxRetries) + return nil, fmt.Errorf("connection failed after %d attempts", c.maxRetries) } // isResumable checks if an HTTP response indicates a valid SSE stream that can be processed. diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 11600fbc..9c7f5f9c 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -193,8 +193,8 @@ func TestStreamableTransports(t *testing.T) { // outage. func TestClientReplay(t *testing.T) { for _, test := range []clientReplayTest{ - {"default", nil, true}, - {"no retries", &StreamableReconnectOptions{}, false}, + {"default", 0, true}, + {"no retries", -1, false}, } { t.Run(test.name, func(t *testing.T) { testClientReplay(t, test) @@ -204,7 +204,7 @@ func TestClientReplay(t *testing.T) { type clientReplayTest struct { name string - options *StreamableReconnectOptions + maxRetries int wantRecovered bool } @@ -258,8 +258,8 @@ func testClientReplay(t *testing.T, test clientReplayTest) { }, }) clientSession, err := client.Connect(ctx, &StreamableClientTransport{ - Endpoint: proxy.URL, - ReconnectOptions: test.options, + Endpoint: proxy.URL, + MaxRetries: test.maxRetries, }, nil) if err != nil { t.Fatalf("client.Connect() failed: %v", err)