From da43fe549c6ddde2f33434c50db4a7313c62dfe9 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Mon, 18 Aug 2025 18:15:06 +0000 Subject: [PATCH] mcp: implement a concurrency model for calls Implement the concurrency model described in #26: notifications are synchronous, but calls are asynchronous (except for 'initialize'). To achieve this, implement jsonrpc2.Async(ctx) to signal asynchronous handling. This is simpler to use than returning ErrAsyncResponse and calling Respond, and since this is an internal detail we don't need to worry too much about whether it's idiomatic. Add tests that verify both features, for both client and server. Also: - replace req.ID.IsValid with req.IsCall - remove the methodHandler type as we can just use MethodHandler Fixes #26 --- internal/jsonrpc2/conn.go | 82 ++++++++++-------- internal/jsonrpc2/jsonrpc2.go | 7 -- internal/jsonrpc2/jsonrpc2_test.go | 16 ++-- mcp/client.go | 7 +- mcp/conformance_test.go | 6 +- mcp/content.go | 2 +- mcp/mcp_test.go | 133 ++++++++++++++++++++++++++--- mcp/server.go | 13 ++- mcp/shared.go | 24 ++---- mcp/streamable.go | 2 +- mcp/streamable_test.go | 2 +- 11 files changed, 208 insertions(+), 86 deletions(-) diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 6f48c9ba..963350e7 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -374,6 +374,46 @@ func (c *Connection) Call(ctx context.Context, method string, params any) *Async return ac } +// Async, signals that the current jsonrpc2 request may be handled +// asynchronously to subsequent requests, when ctx is the request context. +// +// Async must be called at most once on each request's context (and its +// descendants). +func Async(ctx context.Context) { + if r, ok := ctx.Value(asyncKey).(*releaser); ok { + r.release(false) + } +} + +type asyncKeyType struct{} + +var asyncKey = asyncKeyType{} + +// A releaser implements concurrency safe 'releasing' of async requests. (A +// request is released when it is allowed to run concurrent with other +// requests, via a call to [Async].) +type releaser struct { + mu sync.Mutex + ch chan struct{} + released bool +} + +// release closes the associated channel. If soft is set, multiple calls to +// release are allowed. +func (r *releaser) release(soft bool) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.released { + if !soft { + panic("jsonrpc2.Async called multiple times") + } + } else { + close(r.ch) + r.released = true + } +} + type AsyncCall struct { id ID ready chan struct{} // closed after response has been set @@ -425,28 +465,6 @@ func (ac *AsyncCall) Await(ctx context.Context, result any) error { return json.Unmarshal(ac.response.Result, result) } -// Respond delivers a response to an incoming Call. -// -// Respond must be called exactly once for any message for which a handler -// returns ErrAsyncResponse. It must not be called for any other message. -func (c *Connection) Respond(id ID, result any, err error) error { - var req *incomingRequest - c.updateInFlight(func(s *inFlightState) { - req = s.incomingByID[id] - }) - if req == nil { - return c.internalErrorf("Request not found for ID %v", id) - } - - if err == ErrAsyncResponse { - // Respond is supposed to supply the asynchronous response, so it would be - // confusing to call Respond with an error that promises to call Respond - // again. - err = c.internalErrorf("Respond called with ErrAsyncResponse for %q", req.Method) - } - return c.processResult("Respond", req, result, err) -} - // Cancel cancels the Context passed to the Handle call for the inbound message // with the given ID. // @@ -576,11 +594,6 @@ func (c *Connection) acceptRequest(ctx context.Context, msg *Request, preempter if preempter != nil { result, err := preempter.Preempt(req.ctx, req.Request) - if req.IsCall() && errors.Is(err, ErrAsyncResponse) { - // This request will remain in flight until Respond is called for it. - return - } - if !errors.Is(err, ErrNotHandled) { c.processResult("Preempt", req, result, err) return @@ -655,19 +668,20 @@ func (c *Connection) handleAsync() { continue } - result, err := c.handler.Handle(req.ctx, req.Request) - c.processResult(c.handler, req, result, err) + releaser := &releaser{ch: make(chan struct{})} + ctx := context.WithValue(req.ctx, asyncKey, releaser) + go func() { + defer releaser.release(true) + result, err := c.handler.Handle(ctx, req.Request) + c.processResult(c.handler, req, result, err) + }() + <-releaser.ch } } // processResult processes the result of a request and, if appropriate, sends a response. func (c *Connection) processResult(from any, req *incomingRequest, result any, err error) error { switch err { - case ErrAsyncResponse: - if !req.IsCall() { - return c.internalErrorf("%#v returned ErrAsyncResponse for a %q Request without an ID", from, req.Method) - } - return nil // This request is still in flight, so don't record the result yet. case ErrNotHandled, ErrMethodNotFound: // Add detail describing the unhandled method. err = fmt.Errorf("%w: %q", ErrMethodNotFound, req.Method) diff --git a/internal/jsonrpc2/jsonrpc2.go b/internal/jsonrpc2/jsonrpc2.go index b9c320c8..234e6ee3 100644 --- a/internal/jsonrpc2/jsonrpc2.go +++ b/internal/jsonrpc2/jsonrpc2.go @@ -22,13 +22,6 @@ var ( // If a Handler returns ErrNotHandled, the server replies with // ErrMethodNotFound. ErrNotHandled = errors.New("JSON RPC not handled") - - // ErrAsyncResponse is returned from a handler to indicate it will generate a - // response asynchronously. - // - // ErrAsyncResponse must not be returned for notifications, - // which do not receive responses. - ErrAsyncResponse = errors.New("JSON RPC asynchronous response") ) // Preempter handles messages on a connection before they are queued to the main diff --git a/internal/jsonrpc2/jsonrpc2_test.go b/internal/jsonrpc2/jsonrpc2_test.go index 16a5039b..8c79300c 100644 --- a/internal/jsonrpc2/jsonrpc2_test.go +++ b/internal/jsonrpc2/jsonrpc2_test.go @@ -371,16 +371,14 @@ func (h *handler) Handle(ctx context.Context, req *jsonrpc2.Request) (any, error if err := json.Unmarshal(req.Params, &name); err != nil { return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err) } + jsonrpc2.Async(ctx) waitFor := h.waiter(name) - go func() { - select { - case <-waitFor: - h.conn.Respond(req.ID, true, nil) - case <-ctx.Done(): - h.conn.Respond(req.ID, nil, ctx.Err()) - } - }() - return nil, jsonrpc2.ErrAsyncResponse + select { + case <-waitFor: + return true, nil + case <-ctx.Done(): + return nil, ctx.Err() + } default: return nil, jsonrpc2.ErrNotHandled } diff --git a/mcp/client.go b/mcp/client.go index 65a7a954..33530e05 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -329,16 +329,19 @@ func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo { } func (cs *ClientSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { + if req.IsCall() { + jsonrpc2.Async(ctx) + } return handleReceive(ctx, cs, req) } -func (cs *ClientSession) sendingMethodHandler() methodHandler { +func (cs *ClientSession) sendingMethodHandler() MethodHandler { cs.client.mu.Lock() defer cs.client.mu.Unlock() return cs.client.sendingMethodHandler_ } -func (cs *ClientSession) receivingMethodHandler() methodHandler { +func (cs *ClientSession) receivingMethodHandler() MethodHandler { cs.client.mu.Lock() defer cs.client.mu.Unlock() return cs.client.receivingMethodHandler_ diff --git a/mcp/conformance_test.go b/mcp/conformance_test.go index 8e6ea1be..9bd8b8f6 100644 --- a/mcp/conformance_test.go +++ b/mcp/conformance_test.go @@ -183,7 +183,7 @@ func runServerTest(t *testing.T, test *conformanceTest) { return nil, err, false } serverMessages = append(serverMessages, msg) - if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() { + if req, ok := msg.(*jsonrpc.Request); ok && req.IsCall() { // Pair up the next outgoing response with this request. // We assume requests arrive in the same order every time. if len(outResponses) == 0 { @@ -201,8 +201,8 @@ func runServerTest(t *testing.T, test *conformanceTest) { // Synthetic peer interacts with real peer. for _, req := range outRequests { writeMsg(req) - if req.ID.IsValid() { - // A request (as opposed to a notification). Wait for the response. + if req.IsCall() { + // A call (as opposed to a notification). Wait for the response. res, err, ok := nextResponse() if err != nil { t.Fatalf("reading server messages failed: %v", err) diff --git a/mcp/content.go b/mcp/content.go index f8777154..108b0271 100644 --- a/mcp/content.go +++ b/mcp/content.go @@ -253,7 +253,7 @@ func contentsFromWire(wires []*wireContent, allow map[string]bool) ([]Content, e func contentFromWire(wire *wireContent, allow map[string]bool) (Content, error) { if wire == nil { - return nil, fmt.Errorf("content wire is nil") + return nil, fmt.Errorf("nil content") } if allow != nil && !allow[wire.Type] { return nil, fmt.Errorf("invalid content type %q", wire.Type) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index d04235b8..91d350e8 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -17,6 +17,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "testing" "time" @@ -549,31 +550,47 @@ func errorCode(err error) int64 { // // The caller should cancel either the client connection or server connection // when the connections are no longer needed. -func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *ClientSession) { +func basicConnection(t *testing.T, config func(*Server)) (*ClientSession, *ServerSession) { + return basicClientServerConnection(t, nil, nil, config) +} + +// basicClientServerConnection creates a basic connection between client and +// server. If either client or server is nil, empty implementations are used. +// +// The provided function may be used to configure features on the resulting +// server, prior to connection. +// +// The caller should cancel either the client connection or server connection +// when the connections are no longer needed. +func basicClientServerConnection(t *testing.T, client *Client, server *Server, config func(*Server)) (*ClientSession, *ServerSession) { t.Helper() ctx := context.Background() ct, st := NewInMemoryTransports() - s := NewServer(testImpl, nil) + if server == nil { + server = NewServer(testImpl, nil) + } if config != nil { - config(s) + config(server) } - ss, err := s.Connect(ctx, st, nil) + ss, err := server.Connect(ctx, st, nil) if err != nil { t.Fatal(err) } - c := NewClient(testImpl, nil) - cs, err := c.Connect(ctx, ct, nil) + if client == nil { + client = NewClient(testImpl, nil) + } + cs, err := client.Connect(ctx, ct, nil) if err != nil { t.Fatal(err) } - return ss, cs + return cs, ss } func TestServerClosing(t *testing.T) { - cc, cs := basicConnection(t, func(s *Server) { + cs, ss := basicConnection(t, func(s *Server) { AddTool(s, greetTool(), sayHi) }) defer cs.Close() @@ -593,7 +610,7 @@ func TestServerClosing(t *testing.T) { }); err != nil { t.Fatalf("after connecting: %v", err) } - cc.Close() + ss.Close() wg.Wait() if _, err := cs.CallTool(ctx, &CallToolParams{ Name: "greet", @@ -656,7 +673,7 @@ func TestCancellation(t *testing.T) { } return nil, nil } - _, cs := basicConnection(t, func(s *Server) { + cs, _ := basicConnection(t, func(s *Server) { AddTool(s, &Tool{Name: "slow"}, slowRequest) }) defer cs.Close() @@ -940,7 +957,7 @@ func TestKeepAliveFailure(t *testing.T) { func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) { // Adding the same tool pointer twice should not panic and should not // produce duplicates in the server's tool list. - _, cs := basicConnection(t, func(s *Server) { + cs, _ := basicConnection(t, func(s *Server) { // Use two distinct Tool instances with the same name but different // descriptions to ensure the second replaces the first // This case was written specifically to reproduce a bug where duplicate tools where causing jsonschema errors @@ -972,4 +989,98 @@ func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) { } } +func TestSynchronousNotifications(t *testing.T) { + var toolsChanged atomic.Bool + clientOpts := &ClientOptions{ + ToolListChangedHandler: func(ctx context.Context, req *ClientRequest[*ToolListChangedParams]) { + toolsChanged.Store(true) + }, + CreateMessageHandler: func(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + if !toolsChanged.Load() { + return nil, fmt.Errorf("didn't get a tools changed notification") + } + // TODO(rfindley): investigate the error returned from this test if + // CreateMessageResult is new(CreateMessageResult): it's a mysterious + // unmarshalling error that we should improve. + return &CreateMessageResult{Content: &TextContent{}}, nil + }, + } + client := NewClient(testImpl, clientOpts) + + var rootsChanged atomic.Bool + serverOpts := &ServerOptions{ + RootsListChangedHandler: func(_ context.Context, req *ServerRequest[*RootsListChangedParams]) { + rootsChanged.Store(true) + }, + } + server := NewServer(testImpl, serverOpts) + cs, ss := basicClientServerConnection(t, client, server, func(s *Server) { + AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + if !rootsChanged.Load() { + return nil, fmt.Errorf("didn't get root change notification") + } + return new(CallToolResult), nil + }) + }) + + t.Run("from client", func(t *testing.T) { + client.AddRoots(&Root{Name: "myroot", URI: "file://foo"}) + res, err := cs.CallTool(context.Background(), &CallToolParams{Name: "tool"}) + if err != nil { + t.Fatalf("CallTool failed: %v", err) + } + if res.IsError { + t.Errorf("tool error: %v", res.Content[0].(*TextContent).Text) + } + }) + + t.Run("from server", func(t *testing.T) { + server.RemoveTools("tool") + if _, err := ss.CreateMessage(context.Background(), new(CreateMessageParams)); err != nil { + t.Errorf("CreateMessage failed: %v", err) + } + }) +} + +func TestNoDistributedDeadlock(t *testing.T) { + // This test verifies that calls are asynchronous, and so it's not possible + // to have a distributed deadlock. + // + // The setup creates potential deadlock for both the client and server: the + // client sends a call to tool1, which itself calls createMessage, which in + // turn calls tool2, which calls ping. + // + // If the server were not asynchronous, the call to tool2 would hang. If the + // client were not asynchronous, the call to ping would hang. + // + // Such a scenario is unlikely in practice, but is still theoretically + // possible, and in any case making tool calls asynchronous by default + // delegates synchronization to the user. + clientOpts := &ClientOptions{ + CreateMessageHandler: func(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + req.Session.CallTool(ctx, &CallToolParams{Name: "tool2"}) + return &CreateMessageResult{Content: &TextContent{}}, nil + }, + } + client := NewClient(testImpl, clientOpts) + cs, _ := basicClientServerConnection(t, client, nil, func(s *Server) { + AddTool(s, &Tool{Name: "tool1"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + req.Session.CreateMessage(ctx, new(CreateMessageParams)) + return new(CallToolResult), nil + }) + AddTool(s, &Tool{Name: "tool2"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + req.Session.Ping(ctx, nil) + return new(CallToolResult), nil + }) + }) + defer cs.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := cs.CallTool(ctx, &CallToolParams{Name: "tool1"}); err != nil { + // should not deadlock + t.Fatalf("CallTool failed: %v", err) + } +} + var testImpl = &Implementation{Name: "test", Version: "v1.0.0"} diff --git a/mcp/server.go b/mcp/server.go index 5bc626b3..85c2b5ed 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -779,14 +779,14 @@ func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return cli func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { return serverMethodInfos } -func (ss *ServerSession) sendingMethodHandler() methodHandler { +func (ss *ServerSession) sendingMethodHandler() MethodHandler { s := ss.server s.mu.Lock() defer s.mu.Unlock() return s.sendingMethodHandler_ } -func (ss *ServerSession) receivingMethodHandler() methodHandler { +func (ss *ServerSession) receivingMethodHandler() MethodHandler { s := ss.server s.mu.Lock() defer s.mu.Unlock() @@ -801,6 +801,7 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, ss.mu.Lock() initialized := ss.state.InitializedParams != nil ss.mu.Unlock() + // From the spec: // "The client SHOULD NOT send requests other than pings before the server // has responded to the initialize request." @@ -811,6 +812,14 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) } } + + // modelcontextprotocol/go-sdk#26: handle calls asynchronously, and + // notifications synchronously, except for 'initialize' which shouldn't be + // asynchronous to other + if req.IsCall() && req.Method != methodInitialize { + jsonrpc2.Async(ctx) + } + // For the streamable transport, we need the request ID to correlate // server->client calls and notifications to the incoming request from which // they originated. See [idContextKey] for details. diff --git a/mcp/shared.go b/mcp/shared.go index ca062214..608e2aaf 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -38,12 +38,6 @@ var supportedProtocolVersions = []string{ // For notifications, both must be nil. type MethodHandler func(ctx context.Context, method string, req Request) (result Result, err error) -// A methodHandler is a MethodHandler[Session] for some session. -// We need to give up type safety here, or we will end up with a type cycle somewhere -// else. For example, if Session.methodHandler returned a MethodHandler[Session], -// the compiler would complain. -type methodHandler any // MethodHandler[*ClientSession] | MethodHandler[*ServerSession] - // A Session is either a [ClientSession] or a [ServerSession]. type Session interface { // ID returns the session ID, or the empty string if there is none. @@ -51,8 +45,8 @@ type Session interface { sendingMethodInfos() map[string]methodInfo receivingMethodInfos() map[string]methodInfo - sendingMethodHandler() methodHandler - receivingMethodHandler() methodHandler + sendingMethodHandler() MethodHandler + receivingMethodHandler() MethodHandler getConn() *jsonrpc2.Connection } @@ -95,13 +89,13 @@ func orZero[T any, P *U, U any](p P) T { } func handleNotify(ctx context.Context, method string, req Request) error { - mh := req.GetSession().sendingMethodHandler().(MethodHandler) + mh := req.GetSession().sendingMethodHandler() _, err := mh(ctx, method, req) return err } func handleSend[R Result](ctx context.Context, method string, req Request) (R, error) { - mh := req.GetSession().sendingMethodHandler().(MethodHandler) + mh := req.GetSession().sendingMethodHandler() // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. res, err := mh(ctx, method, req) if err != nil { @@ -118,7 +112,7 @@ func defaultReceivingMethodHandler[S Session](ctx context.Context, method string // This can be called from user code, with an arbitrary value for method. return nil, jsonrpc2.ErrNotHandled } - return info.handleMethod.(MethodHandler)(ctx, method, req) + return info.handleMethod(ctx, method, req) } func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Request) (Result, error) { @@ -131,7 +125,7 @@ func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Requ return nil, fmt.Errorf("handling '%s': %w", jreq.Method, err) } - mh := session.receivingMethodHandler().(MethodHandler) + mh := session.receivingMethodHandler() req := info.newRequest(session, params) // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. res, err := mh(ctx, jreq.Method, req) @@ -154,10 +148,10 @@ func checkRequest(req *jsonrpc.Request, infos map[string]methodInfo) (methodInfo if !ok { return methodInfo{}, fmt.Errorf("%w: %q unsupported", jsonrpc2.ErrNotHandled, req.Method) } - if info.flags¬ification != 0 && req.ID.IsValid() { + if info.flags¬ification != 0 && req.IsCall() { return methodInfo{}, fmt.Errorf("%w: unexpected id for %q", jsonrpc2.ErrInvalidRequest, req.Method) } - if info.flags¬ification == 0 && !req.ID.IsValid() { + if info.flags¬ification == 0 && !req.IsCall() { return methodInfo{}, fmt.Errorf("%w: missing id for %q", jsonrpc2.ErrInvalidRequest, req.Method) } // missingParamsOK is checked here to catch the common case where "params" is @@ -182,7 +176,7 @@ type methodInfo struct { newRequest func(Session, Params) Request // Run the code when a call to the method is received. // Used on the receive side. - handleMethod methodHandler + handleMethod MethodHandler // Create a pointer to a Result struct. // Used on the send side. newResult func() Result diff --git a/mcp/streamable.go b/mcp/streamable.go index 9ae20c02..1ecf201f 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -509,7 +509,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques http.Error(w, err.Error(), http.StatusBadRequest) return } - if req.ID.IsValid() { + if req.IsCall() { requests[req.ID] = struct{}{} } } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 4181303f..e0b00cc6 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -689,7 +689,7 @@ func TestStreamableServerTransport(t *testing.T) { defer wg.Done() for m := range out { - if req, ok := m.(*jsonrpc.Request); ok && req.ID.IsValid() { + if req, ok := m.(*jsonrpc.Request); ok && req.IsCall() { // Encountered a server->client request. We should have a // response queued. Otherwise, we may deadlock. mu.Lock()