Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions internal/jsonrpc2/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,19 +361,26 @@ func (c *Connection) Call(ctx context.Context, method string, params any) *Async
if err := c.write(ctx, call); err != nil {
// Sending failed. We will never get a response, so deliver a fake one if it
// wasn't already retired by the connection breaking.
c.updateInFlight(func(s *inFlightState) {
if s.outgoingCalls[ac.id] == ac {
delete(s.outgoingCalls, ac.id)
ac.retire(&Response{ID: id, Error: err})
} else {
// ac was already retired by the readIncoming goroutine:
// perhaps our write raced with the Read side of the connection breaking.
}
})
c.Retire(ac, err)
}
return ac
}

// Retire stops tracking the call, and reports err as its terminal error.
//
// Retire is safe to call multiple times: if the call is already no longer
// tracked, Retire is a no op.
func (c *Connection) Retire(ac *AsyncCall, err error) {
c.updateInFlight(func(s *inFlightState) {
if s.outgoingCalls[ac.id] == ac {
delete(s.outgoingCalls, ac.id)
ac.retire(&Response{ID: ac.id, Error: err})
} else {
// ac was already retired elsewhere.
}
})
}

// Async, signals that the current jsonrpc2 request may be handled
// asynchronously to subsequent requests, when ctx is the request context.
//
Expand Down Expand Up @@ -437,6 +444,9 @@ func (ac *AsyncCall) IsReady() bool {
}

// retire processes the response to the call.
//
// It is an error to call retire more than once: retire is guarded by the
// connection's outgoingCalls map.
func (ac *AsyncCall) retire(response *Response) {
select {
case <-ac.ready:
Expand All @@ -450,6 +460,9 @@ func (ac *AsyncCall) retire(response *Response) {

// Await waits for (and decodes) the results of a Call.
// The response will be unmarshaled from JSON into the result.
//
// If the call is cancelled due to context cancellation, the result is
// ctx.Err().
func (ac *AsyncCall) Await(ctx context.Context, result any) error {
select {
case <-ctx.Done():
Expand Down
100 changes: 71 additions & 29 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/modelcontextprotocol/go-sdk/auth"
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
"github.com/modelcontextprotocol/go-sdk/internal/xcontext"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
)

Expand Down Expand Up @@ -1336,12 +1337,17 @@ const (
// A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
// It must be 1.0 or greater if MaxRetries is greater than 0.
reconnectGrowFactor = 1.5
// reconnectInitialDelay is the base delay for the first reconnect attempt.
reconnectInitialDelay = 1 * time.Second
// reconnectMaxDelay caps the backoff delay, preventing it from growing indefinitely.
reconnectMaxDelay = 30 * time.Second
)

var (
// reconnectInitialDelay is the base delay for the first reconnect attempt.
//
// Mutable for testing.
reconnectInitialDelay = 1 * time.Second
)

// Connect implements the [Transport] interface.
//
// The resulting [Connection] writes messages via POST requests to the
Expand All @@ -1364,7 +1370,20 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
// Create a new cancellable context that will manage the connection's lifecycle.
// This is crucial for cleanly shutting down the background SSE listener by
// cancelling its blocking network operations, which prevents hangs on exit.
connCtx, cancel := context.WithCancel(ctx)
//
// This context should be detached from the incoming context: the standalone
// SSE request should not break when the connection context is done.
//
// For example, consider that the user may want to wait at most 5s to connect
// to the server, and therefore uses a context with a 5s timeout when calling
// client.Connect. Let's suppose that Connect returns after 1s, and the user
// starts using the resulting session. If we didn't detach here, the session
// would break after 4s, when the background SSE stream is terminated.
//
// Instead, creating a cancellable context detached from the incoming context
// allows us to preserve context values (which may be necessary for auth
// middleware), yet only cancel the standalone stream when the connection is closed.
connCtx, cancel := context.WithCancel(xcontext.Detach(ctx))
conn := &streamableClientConn{
url: t.Endpoint,
client: client,
Expand All @@ -1383,8 +1402,8 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
type streamableClientConn struct {
url string
client *http.Client
ctx context.Context
cancel context.CancelFunc
ctx context.Context // connection context, detached from Connect
cancel context.CancelFunc // cancels ctx
incoming chan jsonrpc.Message
maxRetries int
strict bool // from [StreamableClientTransport.strict]
Expand Down Expand Up @@ -1447,9 +1466,13 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
}

func (c *streamableClientConn) connectStandaloneSSE() {
resp, err := c.connectSSE("", 0)
resp, err := c.connectSSE(c.ctx, "", 0, true)
if err != nil {
c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err))
// If the client didn't cancel the request, and failure breaks the logical
// session.
if c.ctx.Err() == nil {
c.fail(fmt.Errorf("standalone SSE request failed (session ID: %v): %v", c.sessionID, err))
}
return
}

Expand Down Expand Up @@ -1481,7 +1504,7 @@ func (c *streamableClientConn) connectStandaloneSSE() {
c.fail(err)
return
}
go c.handleSSE(summary, resp, true, nil)
go c.handleSSE(c.ctx, summary, resp, true, nil)
}

// fail handles an asynchronous error while reading.
Expand Down Expand Up @@ -1616,7 +1639,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
forCall = jsonReq
}
// TODO: should we cancel this logical SSE request if/when jsonReq is canceled?
go c.handleSSE(requestSummary, resp, false, forCall)
go c.handleSSE(ctx, requestSummary, resp, false, forCall)

default:
resp.Body.Close()
Expand Down Expand Up @@ -1668,15 +1691,17 @@ func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Resp
//
// If forCall is set, it is the call that initiated the stream, and the
// stream is complete when we receive its response.
func (c *streamableClientConn) handleSSE(requestSummary string, resp *http.Response, persistent bool, forCall *jsonrpc2.Request) {
func (c *streamableClientConn) handleSSE(ctx context.Context, requestSummary string, resp *http.Response, persistent bool, forCall *jsonrpc2.Request) {
for {
// Connection was successful. Continue the loop with the new response.
// TODO: we should set a reasonable limit on the number of times we'll try
// getting a response for a given request.
//
// TODO(#679): we should set a reasonable limit on the number of times
// we'll try getting a response for a given request, or enforce that we
// actually make progress.
//
// Eventually, if we don't get the response, we should stop trying and
// fail the request.
lastEventID, reconnectDelay, clientClosed := c.processStream(requestSummary, resp, forCall)
lastEventID, reconnectDelay, clientClosed := c.processStream(ctx, requestSummary, resp, forCall)

// If the connection was closed by the client, we're done.
if clientClosed {
Expand All @@ -1689,12 +1714,17 @@ func (c *streamableClientConn) handleSSE(requestSummary string, resp *http.Respo
}

// The stream was interrupted or ended by the server. Attempt to reconnect.
newResp, err := c.connectSSE(lastEventID, reconnectDelay)
newResp, err := c.connectSSE(ctx, lastEventID, reconnectDelay, false)
if err != nil {
// All reconnection attempts failed: fail the connection.
c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err))
// If the client didn't cancel this request, any failure to execute it
// breaks the logical MCP session.
if ctx.Err() == nil {
// All reconnection attempts failed: fail the connection.
c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err))
}
return
}

resp = newResp
if err := c.checkResponse(requestSummary, resp); err != nil {
c.fail(err)
Expand Down Expand Up @@ -1731,11 +1761,13 @@ func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.R
// incoming channel. It returns the ID of the last processed event and a flag
// indicating if the connection was closed by the client. If resp is nil, it
// returns "", false.
func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, reconnectDelay time.Duration, clientClosed bool) {
func (c *streamableClientConn) processStream(ctx context.Context, requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, reconnectDelay time.Duration, clientClosed bool) {
defer resp.Body.Close()
for evt, err := range scanEvents(resp.Body) {
if err != nil {
// TODO: we should differentiate EOF from other errors here.
if ctx.Err() != nil {
return "", 0, true // don't reconnect: client cancelled
}
break
}

Expand Down Expand Up @@ -1768,6 +1800,7 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
return "", 0, true
}
}

case <-c.done:
// The connection was closed by the client; exit gracefully.
return "", 0, true
Expand All @@ -1777,6 +1810,9 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
//
// If the lastEventID is "", the stream is not retryable and we should
// report a synthetic error for the call.
//
// Note that this is different from the cancellation case above, since the
// caller is still waiting for a response that will never come.
if lastEventID == "" && forCall != nil {
errmsg := &jsonrpc2.Response{
ID: forCall.ID,
Expand All @@ -1800,12 +1836,20 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
//
// reconnectDelay is the delay set by the server using the SSE retry field, or
// 0.
func (c *streamableClientConn) connectSSE(lastEventID string, reconnectDelay time.Duration) (*http.Response, error) {
//
// If initial is set, this is the initial attempt.
//
// If connectSSE exits due to context cancellation, the result is (nil, ctx.Err()).
func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID string, reconnectDelay time.Duration, initial bool) (*http.Response, error) {
var finalErr error
// If lastEventID is set, we've already connected successfully once, so
// consider that to be the first attempt.
attempt := 0
if lastEventID != "" {
if !initial {
// We've already connected successfully once, so delay subsequent
// reconnections. Otherwise, if the server returns 200 but terminates the
// connection, we'll reconnect as fast as we can, ad infinitum.
//
// TODO: we should consider also setting a limit on total attempts for one
// logical request.
attempt = 1
}
delay := calculateReconnectDelay(attempt)
Expand All @@ -1816,16 +1860,14 @@ func (c *streamableClientConn) connectSSE(lastEventID string, reconnectDelay tim
select {
case <-c.done:
return nil, fmt.Errorf("connection closed by client during reconnect")
case <-c.ctx.Done():

case <-ctx.Done():
// If the connection context is canceled, the request below will not
// succeed anyway.
//
// TODO(#662): we should not be using the connection context for
// reconnection: we should instead be using the call context (from
// Write).
return nil, fmt.Errorf("connection context closed")
return nil, ctx.Err()

case <-time.After(delay):
req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.url, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil)
if err != nil {
return nil, err
}
Expand Down
Loading