diff --git a/mcp/streamable.go b/mcp/streamable.go index 0921c22d..4c1c8167 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -735,6 +735,12 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er ctx: connCtx, cancel: cancel, } + // Start the persistent SSE listener right away. + // Section 2.2: The client MAY issue an HTTP GET to the MCP endpoint. + // This can be used to open an SSE stream, allowing the server to + // communicate to the client, without the client first sending data via HTTP POST. + go conn.handleSSE(nil, true) + return conn, nil } @@ -859,7 +865,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string switch ct := resp.Header.Get("Content-Type"); ct { case "text/event-stream": // Section 2.1: The SSE stream is initiated after a POST. - go s.handleSSE(resp) + go s.handleSSE(resp, false) case "application/json": body, err := io.ReadAll(resp.Body) resp.Body.Close() @@ -879,13 +885,11 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string return sessionID, nil } -// handleSSE manages the entire lifecycle of an SSE connection. It processes -// an incoming Server-Sent Events stream and automatically handles reconnection -// logic if the stream breaks. -func (s *streamableClientConn) handleSSE(initialResp *http.Response) { +// handleSSE manages the lifecycle of an SSE connection. It can be either +// persistent (for the main GET listener) or temporary (for a POST response). +func (s *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool) { resp := initialResp var lastEventID string - for { eventID, clientClosed := s.processStream(resp) lastEventID = eventID @@ -894,6 +898,11 @@ func (s *streamableClientConn) handleSSE(initialResp *http.Response) { if clientClosed { return } + // If the stream has ended, then do not reconnect if the stream is + // temporary (POST initiated SSE). + if lastEventID == "" && !persistent { + return + } // The stream was interrupted or ended by the server. Attempt to reconnect. newResp, err := s.reconnect(lastEventID) @@ -915,9 +924,13 @@ func (s *streamableClientConn) handleSSE(initialResp *http.Response) { // processStream reads from a single response body, sending events to the // incoming channel. It returns the ID of the last processed event, any error // that occurred, and a flag indicating if the connection was closed by the client. +// If resp is nil, it returns "", false. func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID string, clientClosed bool) { - defer resp.Body.Close() + if resp == nil { + return "", false + } + defer resp.Body.Close() for evt, err := range scanEvents(resp.Body) { if err != nil { return lastEventID, false @@ -931,13 +944,11 @@ func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID s case s.incoming <- evt.Data: case <-s.done: // The connection was closed by the client; exit gracefully. - return lastEventID, true + return "", true } } - // The loop finished without an error, indicating the server closed the stream. - // We'll attempt to reconnect, so this is not a client-side close. - return lastEventID, false + return "", false } // reconnect handles the logic of retrying a connection with an exponential diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 185bc638..69934314 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -168,7 +168,7 @@ func TestClientReplay(t *testing.T) { clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"}) // 4. Read and verify messages until the server signals it's ready for the proxy kill. - receivedNotifications := readProgressNotifications(t, ctx, notifications, 2) + receivedNotifications := readNotifications(t, ctx, notifications, 2) wantReceived := []string{"msg1", "msg2"} if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" { @@ -201,7 +201,7 @@ func TestClientReplay(t *testing.T) { // 7. Continue reading from the same connection object. // Its internal logic should successfully retry, reconnect to the new proxy, // and receive the replayed messages. - recoveredNotifications := readProgressNotifications(t, ctx, notifications, 2) + recoveredNotifications := readNotifications(t, ctx, notifications, 2) // 8. Verify the correct messages were received on the recovered connection. wantRecovered := []string{"msg3", "msg4"} @@ -211,8 +211,39 @@ func TestClientReplay(t *testing.T) { } } -// Helper to read a specific number of progress notifications. -func readProgressNotifications(t *testing.T, ctx context.Context, notifications chan string, count int) []string { +// TestServerInitiatedSSE verifies that the persistent SSE connection remains +// open and can receive server-initiated events. +func TestServerInitiatedSSE(t *testing.T) { + notifications := make(chan string) + server := NewServer(testImpl, nil) + + httpServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)) + defer httpServer.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + client := NewClient(testImpl, &ClientOptions{ToolListChangedHandler: func(ctx context.Context, cc *ClientSession, params *ToolListChangedParams) { + notifications <- "toolListChanged" + }, + }) + clientSession, err := client.Connect(ctx, NewStreamableClientTransport(httpServer.URL, nil)) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer clientSession.Close() + server.AddTool(&Tool{Name: "testTool", InputSchema: &jsonschema.Schema{}}, + func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { + return &CallToolResult{}, nil + }) + receivedNotifications := readNotifications(t, ctx, notifications, 1) + wantReceived := []string{"toolListChanged"} + if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" { + t.Errorf("Received notifications mismatch (-want +got):\n%s", diff) + } +} + +// Helper to read a specific number of notifications. +func readNotifications(t *testing.T, ctx context.Context, notifications chan string, count int) []string { t.Helper() var collectedNotifications []string for {