From 1dd9c324f0c43ffa4ead1eeb16c7d714ac310f9c Mon Sep 17 00:00:00 2001 From: Sam Thanawalla Date: Wed, 30 Jul 2025 17:18:36 +0000 Subject: [PATCH 1/3] mcp/streamable: add persistent SSE GET listener This CL adds the optional persistent SSE GET listener as specified in section 2.2. https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server This enables server initiated SSE streams. --- mcp/streamable.go | 33 ++++++++++++++++++++++----------- mcp/streamable_test.go | 40 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 58 insertions(+), 15 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 0921c22d..00a0e0c3 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -735,6 +735,12 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er ctx: connCtx, cancel: cancel, } + // Start the persistent SSE listener right away. + // Section 2.2: The client MAY issue an HTTP GET to the MCP endpoint. + // This can be used to open an SSE stream, allowing the server to + // communicate to the client, without the client first sending data via HTTP POST. + go conn.handleSSE(nil, false) + return conn, nil } @@ -859,7 +865,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string switch ct := resp.Header.Get("Content-Type"); ct { case "text/event-stream": // Section 2.1: The SSE stream is initiated after a POST. - go s.handleSSE(resp) + go s.handleSSE(resp, true) case "application/json": body, err := io.ReadAll(resp.Body) resp.Body.Close() @@ -879,13 +885,11 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string return sessionID, nil } -// handleSSE manages the entire lifecycle of an SSE connection. It processes -// an incoming Server-Sent Events stream and automatically handles reconnection -// logic if the stream breaks. -func (s *streamableClientConn) handleSSE(initialResp *http.Response) { +// handleSSE manages the lifecycle of an SSE connection. It can be either +// temporary (for a POST response) or persistent (for the main GET listener). +func (s *streamableClientConn) handleSSE(initialResp *http.Response, temporary bool) { resp := initialResp var lastEventID string - for { eventID, clientClosed := s.processStream(resp) lastEventID = eventID @@ -894,6 +898,11 @@ func (s *streamableClientConn) handleSSE(initialResp *http.Response) { if clientClosed { return } + // If the stream has ended, then do not reconnect if the stream is + // temporary (POST initiated SSE). + if lastEventID == "" && temporary { + return + } // The stream was interrupted or ended by the server. Attempt to reconnect. newResp, err := s.reconnect(lastEventID) @@ -915,9 +924,13 @@ func (s *streamableClientConn) handleSSE(initialResp *http.Response) { // processStream reads from a single response body, sending events to the // incoming channel. It returns the ID of the last processed event, any error // that occurred, and a flag indicating if the connection was closed by the client. +// If resp is nil, it returns "", false. func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID string, clientClosed bool) { - defer resp.Body.Close() + if resp == nil { + return "", false + } + defer resp.Body.Close() for evt, err := range scanEvents(resp.Body) { if err != nil { return lastEventID, false @@ -931,13 +944,11 @@ func (s *streamableClientConn) processStream(resp *http.Response) (lastEventID s case s.incoming <- evt.Data: case <-s.done: // The connection was closed by the client; exit gracefully. - return lastEventID, true + return "", true } } - // The loop finished without an error, indicating the server closed the stream. - // We'll attempt to reconnect, so this is not a client-side close. - return lastEventID, false + return "", false } // reconnect handles the logic of retrying a connection with an exponential diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 185bc638..16804929 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -168,7 +168,7 @@ func TestClientReplay(t *testing.T) { clientSession.CallTool(ctx, &CallToolParams{Name: "multiMessageTool"}) // 4. Read and verify messages until the server signals it's ready for the proxy kill. - receivedNotifications := readProgressNotifications(t, ctx, notifications, 2) + receivedNotifications := readNotifications(t, ctx, notifications, 2) wantReceived := []string{"msg1", "msg2"} if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" { @@ -201,7 +201,7 @@ func TestClientReplay(t *testing.T) { // 7. Continue reading from the same connection object. // Its internal logic should successfully retry, reconnect to the new proxy, // and receive the replayed messages. - recoveredNotifications := readProgressNotifications(t, ctx, notifications, 2) + recoveredNotifications := readNotifications(t, ctx, notifications, 2) // 8. Verify the correct messages were received on the recovered connection. wantRecovered := []string{"msg3", "msg4"} @@ -211,8 +211,40 @@ func TestClientReplay(t *testing.T) { } } -// Helper to read a specific number of progress notifications. -func readProgressNotifications(t *testing.T, ctx context.Context, notifications chan string, count int) []string { +// TestServerInitiatedSSE verifies that the persistent SSE connection remains +// open and can receive multiple, non-consecutive, server-initiated events. +func TestServerInitiatedSSE(t *testing.T) { + notifications := make(chan string) + server := NewServer(testImpl, &ServerOptions{}) + + httpServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)) + defer httpServer.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + client := NewClient(testImpl, &ClientOptions{ToolListChangedHandler: func(ctx context.Context, cc *ClientSession, params *ToolListChangedParams) { + notifications <- "toolListChanged" + }, + }) + clientSession, err := client.Connect(ctx, NewStreamableClientTransport(httpServer.URL, nil)) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer clientSession.Close() + time.Sleep(50 * time.Millisecond) + server.AddTool(&Tool{Name: "testTool", InputSchema: &jsonschema.Schema{}}, + func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { + return &CallToolResult{}, nil + }) + receivedNotifications := readNotifications(t, ctx, notifications, 1) + wantReceived := []string{"toolListChanged"} + if diff := cmp.Diff(wantReceived, receivedNotifications); diff != "" { + t.Errorf("Received notifications mismatch (-want +got):\n%s", diff) + } +} + +// Helper to read a specific number of notifications. +func readNotifications(t *testing.T, ctx context.Context, notifications chan string, count int) []string { t.Helper() var collectedNotifications []string for { From 1dd5e9788ed096b8d378e4bce9a57cda5328ab60 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla Date: Wed, 30 Jul 2025 18:57:48 +0000 Subject: [PATCH 2/3] mcp/streamable: update bool name to persistent --- mcp/streamable.go | 10 +++++----- mcp/streamable_test.go | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 00a0e0c3..4c1c8167 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -739,7 +739,7 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er // Section 2.2: The client MAY issue an HTTP GET to the MCP endpoint. // This can be used to open an SSE stream, allowing the server to // communicate to the client, without the client first sending data via HTTP POST. - go conn.handleSSE(nil, false) + go conn.handleSSE(nil, true) return conn, nil } @@ -865,7 +865,7 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string switch ct := resp.Header.Get("Content-Type"); ct { case "text/event-stream": // Section 2.1: The SSE stream is initiated after a POST. - go s.handleSSE(resp, true) + go s.handleSSE(resp, false) case "application/json": body, err := io.ReadAll(resp.Body) resp.Body.Close() @@ -886,8 +886,8 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string } // handleSSE manages the lifecycle of an SSE connection. It can be either -// temporary (for a POST response) or persistent (for the main GET listener). -func (s *streamableClientConn) handleSSE(initialResp *http.Response, temporary bool) { +// persistent (for the main GET listener) or temporary (for a POST response). +func (s *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool) { resp := initialResp var lastEventID string for { @@ -900,7 +900,7 @@ func (s *streamableClientConn) handleSSE(initialResp *http.Response, temporary b } // If the stream has ended, then do not reconnect if the stream is // temporary (POST initiated SSE). - if lastEventID == "" && temporary { + if lastEventID == "" && !persistent { return } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 16804929..56169e6f 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -212,10 +212,10 @@ func TestClientReplay(t *testing.T) { } // TestServerInitiatedSSE verifies that the persistent SSE connection remains -// open and can receive multiple, non-consecutive, server-initiated events. +// open and can receive server-initiated events. func TestServerInitiatedSSE(t *testing.T) { notifications := make(chan string) - server := NewServer(testImpl, &ServerOptions{}) + server := NewServer(testImpl, nil) httpServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)) defer httpServer.Close() From 24a8551fbaf2515bdf63e75076428fb051424f2a Mon Sep 17 00:00:00 2001 From: Sam Thanawalla Date: Wed, 30 Jul 2025 19:57:43 +0000 Subject: [PATCH 3/3] mcp/streamable_test: remove sleep --- mcp/streamable_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 56169e6f..69934314 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -231,7 +231,6 @@ func TestServerInitiatedSSE(t *testing.T) { t.Fatalf("client.Connect() failed: %v", err) } defer clientSession.Close() - time.Sleep(50 * time.Millisecond) server.AddTool(&Tool{Name: "testTool", InputSchema: &jsonschema.Schema{}}, func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { return &CallToolResult{}, nil