diff --git a/mcp/server.go b/mcp/server.go index 88021336..71fea7c6 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -850,18 +850,11 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam state.InitializeParams = params }) - // If we support the client's version, reply with it. Otherwise, reply with our - // latest version. - version := params.ProtocolVersion - if !slices.Contains(supportedProtocolVersions, params.ProtocolVersion) { - version = latestProtocolVersion - } - s := ss.server return &InitializeResult{ // TODO(rfindley): alter behavior when falling back to an older version: // reject unsupported features. - ProtocolVersion: version, + ProtocolVersion: negotiatedVersion(params.ProtocolVersion), Capabilities: s.capabilities(), Instructions: s.opts.Instructions, ServerInfo: s.impl, diff --git a/mcp/shared.go b/mcp/shared.go index 608e2aaf..e2caf100 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -23,14 +23,36 @@ import ( "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) -// latestProtocolVersion is the latest protocol version that this version of the SDK supports. -// It is the version that the client sends in the initialization request. -const latestProtocolVersion = "2025-06-18" +const ( + // latestProtocolVersion is the latest protocol version that this version of + // the SDK supports. + // + // It is the version that the client sends in the initialization request, and + // the default version used by the server. + latestProtocolVersion = protocolVersion20250618 + protocolVersion20250618 = "2025-06-18" + protocolVersion20250326 = "2025-03-26" + protocolVersion20251105 = "2024-11-05" +) var supportedProtocolVersions = []string{ - latestProtocolVersion, - "2025-03-26", - "2024-11-05", + protocolVersion20250618, + protocolVersion20250326, + protocolVersion20251105, +} + +// negotiatedVersion returns the effective protocol version to use, given a +// client version. +func negotiatedVersion(clientVersion string) string { + // In general, prefer to use the clientVersion, but if we don't support the + // client's version, use the latest version. + // + // This handles the case where a new spec version is released, and the SDK + // does not support it yet. + if !slices.Contains(supportedProtocolVersions, clientVersion) { + return latestProtocolVersion + } + return clientVersion } // A MethodHandler handles MCP messages. diff --git a/mcp/streamable.go b/mcp/streamable.go index 572fe5de..cf9dccaa 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -15,6 +15,7 @@ import ( "math" "math/rand/v2" "net/http" + "slices" "strconv" "strings" "sync" @@ -152,7 +153,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque if req.Method == http.MethodDelete { if sessionID == "" { - http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) + http.Error(w, "Bad Request: DELETE requires an Mcp-Session-Id header", http.StatusBadRequest) return } if transport != nil { // transport may be nil in stateless mode @@ -172,8 +173,45 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } default: - w.Header().Set("Allow", "GET, POST") - http.Error(w, "unsupported method", http.StatusMethodNotAllowed) + w.Header().Set("Allow", "GET, POST, DELETE") + http.Error(w, "Method Not Allowed: streamable MCP servers support GET, POST, and DELETE requests", http.StatusMethodNotAllowed) + return + } + + // Section 2.7 of the spec (2025-06-18) states: + // + // "If using HTTP, the client MUST include the MCP-Protocol-Version: + // HTTP header on all subsequent requests to the MCP + // server, allowing the MCP server to respond based on the MCP protocol + // version. + // + // For example: MCP-Protocol-Version: 2025-06-18 + // The protocol version sent by the client SHOULD be the one negotiated during + // initialization. + // + // For backwards compatibility, if the server does not receive an + // MCP-Protocol-Version header, and has no other way to identify the version - + // for example, by relying on the protocol version negotiated during + // initialization - the server SHOULD assume protocol version 2025-03-26. + // + // If the server receives a request with an invalid or unsupported + // MCP-Protocol-Version, it MUST respond with 400 Bad Request." + // + // Since this wasn't present in the 2025-03-26 version of the spec, this + // effectively means: + // 1. IF the client provides a version header, it must be a supported + // version. + // 2. In stateless mode, where we've lost the state of the initialize + // request, we assume that whatever the client tells us is the truth (or + // assume 2025-03-26 if the client doesn't say anything). + // + // This logic matches the typescript SDK. + protocolVersion := req.Header.Get(protocolVersionHeader) + if protocolVersion == "" { + protocolVersion = protocolVersion20250326 + } + if !slices.Contains(supportedProtocolVersions, protocolVersion) { + http.Error(w, fmt.Sprintf("Bad Request: Unsupported protocol version (supported versions: %s)", strings.Join(supportedProtocolVersions, ",")), http.StatusBadRequest) return } @@ -234,7 +272,9 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque // set the initial state to a default value. state := new(ServerSessionState) if !hasInitialize { - state.InitializeParams = new(InitializeParams) + state.InitializeParams = &InitializeParams{ + ProtocolVersion: protocolVersion, + } } if !hasInitialized { state.InitializedParams = new(InitializedParams) @@ -377,11 +417,12 @@ type streamableServerConn struct { eventStore EventStore incoming chan jsonrpc.Message // messages from the client to the server - done chan struct{} - mu sync.Mutex + mu sync.Mutex // guards all fields below + // Sessions are closed exactly once. isDone bool + done chan struct{} // Sessions can have multiple logical connections (which we call streams), // corresponding to HTTP requests. Additionally, streams may be resumed by diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 8334bc0d..9ac3b66c 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -10,6 +10,7 @@ import ( "encoding/json" "fmt" "io" + "maps" "net" "net/http" "net/http/cookiejar" @@ -459,21 +460,21 @@ func TestStreamableServerTransport(t *testing.T) { // Test various accept headers. { method: "POST", - accept: []string{"text/plain", "application/*"}, + headers: http.Header{"Accept": {"text/plain", "application/*"}}, messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusBadRequest, // missing text/event-stream wantSessionID: false, }, { method: "POST", - accept: []string{"text/event-stream"}, + headers: http.Header{"Accept": {"text/event-stream"}}, messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusBadRequest, // missing application/json wantSessionID: false, }, { method: "POST", - accept: []string{"text/plain", "*/*"}, + headers: http.Header{"Accept": {"text/plain", "*/*"}}, messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, @@ -481,7 +482,7 @@ func TestStreamableServerTransport(t *testing.T) { }, { method: "POST", - accept: []string{"text/*, application/*"}, + headers: http.Header{"Accept": {"text/*, application/*"}}, messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "tool"})}, wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(4, &CallToolResult{}, nil)}, @@ -489,6 +490,21 @@ func TestStreamableServerTransport(t *testing.T) { }, }, }, + { + name: "protocol version headers", + requests: []streamableRequest{ + initialize, + initialized, + { + method: "POST", + headers: http.Header{"mcp-protocol-version": {"2025-01-01"}}, // an invalid protocol version + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "2025-03-26", // a supported version + wantSessionID: false, // could be true, but shouldn't matter + }, + }, + }, { name: "tool notification", tool: func(t *testing.T, ctx context.Context, ss *ServerSession) { @@ -729,7 +745,7 @@ func testStreamableHandler(t *testing.T, handler http.Handler, requests []stream } }() - gotSessionID, gotStatusCode, err := request.do(ctx, httpServer.URL, sessionID.Load().(string), out) + gotSessionID, gotStatusCode, gotBody, err := request.do(ctx, httpServer.URL, sessionID.Load().(string), out) // Don't fail on cancelled requests: error (if any) is handled // elsewhere. @@ -745,7 +761,12 @@ func testStreamableHandler(t *testing.T, handler http.Handler, requests []stream } wg.Wait() - if !request.ignoreResponse { + if request.wantBodyContaining != "" { + body := string(gotBody) + if !strings.Contains(body, request.wantBodyContaining) { + t.Errorf("body does not contain %q:\n%s", request.wantBodyContaining, body) + } + } else { transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) if diff := cmp.Diff(request.wantMessages, got, transform); diff != "" { t.Errorf("request #%d: received unexpected messages (-want +got):\n%s", i, diff) @@ -793,14 +814,14 @@ type streamableRequest struct { // Request attributes method string // HTTP request method (required) - accept []string // if non-empty, the Accept header to use; otherwise the default header is used + headers http.Header // additional headers to set, overlaid on top of the default headers messages []jsonrpc.Message // messages to send - closeAfter int // if nonzero, close after receiving this many messages - wantStatusCode int // expected status code - ignoreResponse bool // if set, don't check the response messages - wantMessages []jsonrpc.Message // expected messages to receive - wantSessionID bool // whether or not a session ID is expected in the response + closeAfter int // if nonzero, close after receiving this many messages + wantStatusCode int // expected status code + wantBodyContaining string // if set, expect the response body to contain this text; overrides wantMessages + wantMessages []jsonrpc.Message // expected messages to receive; ignored if wantBodyContaining is set + wantSessionID bool // whether or not a session ID is expected in the response } // streamingRequest makes a request to the given streamable server with the @@ -816,14 +837,14 @@ type streamableRequest struct { // Returns the sessionID and http status code from the response. If an error is // returned, sessionID and status code may still be set if the error occurs // after the response headers have been received. -func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, out chan<- jsonrpc.Message) (string, int, error) { +func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, out chan<- jsonrpc.Message) (string, int, []byte, error) { defer close(out) var body []byte if len(s.messages) == 1 { data, err := jsonrpc2.EncodeMessage(s.messages[0]) if err != nil { - return "", 0, fmt.Errorf("encoding message: %w", err) + return "", 0, nil, fmt.Errorf("encoding message: %w", err) } body = data } else { @@ -831,68 +852,93 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, for _, msg := range s.messages { data, err := jsonrpc2.EncodeMessage(msg) if err != nil { - return "", 0, fmt.Errorf("encoding message: %w", err) + return "", 0, nil, fmt.Errorf("encoding message: %w", err) } rawMsgs = append(rawMsgs, data) } data, err := json.Marshal(rawMsgs) if err != nil { - return "", 0, fmt.Errorf("marshaling batch: %w", err) + return "", 0, nil, fmt.Errorf("marshaling batch: %w", err) } body = data } req, err := http.NewRequestWithContext(ctx, s.method, serverURL, bytes.NewReader(body)) if err != nil { - return "", 0, fmt.Errorf("creating request: %w", err) + return "", 0, nil, fmt.Errorf("creating request: %w", err) } if sessionID != "" { req.Header.Set("Mcp-Session-Id", sessionID) } req.Header.Set("Content-Type", "application/json") - if len(s.accept) > 0 { - for _, accept := range s.accept { - req.Header.Add("Accept", accept) - } - } else { - req.Header.Add("Accept", "application/json, text/event-stream") - } + req.Header.Set("Accept", "application/json, text/event-stream") + maps.Copy(req.Header, s.headers) resp, err := http.DefaultClient.Do(req) if err != nil { - return "", 0, fmt.Errorf("request failed: %v", err) + return "", 0, nil, fmt.Errorf("request failed: %v", err) } defer resp.Body.Close() newSessionID := resp.Header.Get("Mcp-Session-Id") contentType := resp.Header.Get("Content-Type") + var respBody []byte if strings.HasPrefix(contentType, "text/event-stream") { - for evt, err := range scanEvents(resp.Body) { + r := readerInto{resp.Body, new(bytes.Buffer)} + for evt, err := range scanEvents(r) { if err != nil { - return newSessionID, resp.StatusCode, fmt.Errorf("reading events: %v", err) + return newSessionID, resp.StatusCode, nil, fmt.Errorf("reading events: %v", err) } // TODO(rfindley): do we need to check evt.name? // Does the MCP spec say anything about this? msg, err := jsonrpc2.DecodeMessage(evt.Data) if err != nil { - return newSessionID, resp.StatusCode, fmt.Errorf("decoding message: %w", err) + return newSessionID, resp.StatusCode, nil, fmt.Errorf("decoding message: %w", err) } out <- msg } + respBody = r.w.Bytes() } else if strings.HasPrefix(contentType, "application/json") { data, err := io.ReadAll(resp.Body) if err != nil { - return newSessionID, resp.StatusCode, fmt.Errorf("reading json body: %w", err) + return newSessionID, resp.StatusCode, nil, fmt.Errorf("reading json body: %w", err) } + respBody = data msg, err := jsonrpc2.DecodeMessage(data) if err != nil { - return newSessionID, resp.StatusCode, fmt.Errorf("decoding message: %w", err) + return newSessionID, resp.StatusCode, nil, fmt.Errorf("decoding message: %w", err) } out <- msg + } else { + respBody, err = io.ReadAll(resp.Body) + if err != nil { + return newSessionID, resp.StatusCode, nil, fmt.Errorf("reading response: %v", err) + } } - return newSessionID, resp.StatusCode, nil + return newSessionID, resp.StatusCode, respBody, nil +} + +// readerInto is an io.Reader that writes any bytes read from r into w. +type readerInto struct { + r io.Reader + w *bytes.Buffer +} + +// Read implements io.Reader. +func (r readerInto) Read(p []byte) (n int, err error) { + n, err = r.r.Read(p) + if err == nil || err == io.EOF { + n2, err2 := r.w.Write(p[:n]) + if err2 != nil { + return n, fmt.Errorf("failed to write: %v", err) + } + if n2 != n { + return n, fmt.Errorf("short write: %d != %d", n2, n) + } + } + return n, err } func mustMarshal(v any) json.RawMessage { @@ -906,8 +952,13 @@ func mustMarshal(v any) json.RawMessage { return data } -func TestStreamableClientTransportApplicationJSON(t *testing.T) { - // Test handling of application/json responses. +func TestStreamableClientTransport(t *testing.T) { + // This test verifies various behavior of the streamable client transport: + // - check that it can handle application/json responses + // - check that it sends the negotiated protocol version + // + // TODO(rfindley): make this test more comprehensive, similar to + // [TestStreamableServerTransport]. ctx := context.Background() resp := func(id int64, result any, err error) *jsonrpc.Response { return &jsonrpc.Response{ @@ -927,14 +978,25 @@ func TestStreamableClientTransportApplicationJSON(t *testing.T) { } initResp := resp(1, initResult, nil) + var reqN atomic.Int32 // request count serverHandler := func(w http.ResponseWriter, r *http.Request) { - data, err := jsonrpc2.EncodeMessage(initResp) - if err != nil { - t.Fatal(err) - } + rN := reqN.Add(1) + + // TODO(rfindley): if the status code is NoContent or Accepted, we should + // probably be tolerant of when the content type is not application/json. w.Header().Set("Content-Type", "application/json") - w.Header().Set("Mcp-Session-Id", "123") - w.Write(data) + if rN == 1 { + data, err := jsonrpc2.EncodeMessage(initResp) + if err != nil { + t.Errorf("encoding failed: %v", err) + } + w.Header().Set("Mcp-Session-Id", "123") + w.Write(data) + } else { + if v := r.Header.Get(protocolVersionHeader); v != latestProtocolVersion { + t.Errorf("bad protocol version header: got %q, want %q", v, latestProtocolVersion) + } + } } httpServer := httptest.NewServer(http.HandlerFunc(serverHandler)) @@ -946,7 +1008,16 @@ func TestStreamableClientTransportApplicationJSON(t *testing.T) { if err != nil { t.Fatalf("client.Connect() failed: %v", err) } - defer session.Close() + if err := session.Close(); err != nil { + t.Errorf("closing session: %v", err) + } + + if got, want := reqN.Load(), int32(3); got < want { + // Expect at least 3 requests: initialize, initialized, and DELETE. + // We may or may not observe the GET, depending on timing. + t.Errorf("unexpected number of requests: got %d, want at least %d", got, want) + } + if diff := cmp.Diff(initResult, session.state.InitializeResult); diff != "" { t.Errorf("mismatch (-want, +got):\n%s", diff) } @@ -1010,11 +1081,11 @@ func TestStreamableStateless(t *testing.T) { requests := []streamableRequest{ { - method: "POST", - wantStatusCode: http.StatusOK, - messages: []jsonrpc.Message{req(1, "tools/list", struct{}{})}, - ignoreResponse: true, - wantSessionID: false, + method: "POST", + wantStatusCode: http.StatusOK, + messages: []jsonrpc.Message{req(1, "tools/list", struct{}{})}, + wantBodyContaining: "greet", + wantSessionID: false, }, { method: "POST",