diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a4cd0ac0..1fe659cf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: coverage: runs-on: ubuntu-latest - if: github.event_name == 'pull_request' || github.event_name == 'pull_request_target' + if: github.event_name == 'pull_request' || github.event_name == 'pull_request_target' || (github.event_name == 'push' && github.ref == 'refs/heads/main') permissions: contents: read pull-requests: write @@ -38,6 +38,7 @@ jobs: with: name: code-coverage path: coverage.txt + retention-days: 30 - name: Generate coverage report uses: fgrosse/go-coverage-report@v1.2.0 diff --git a/server/streamable_http.go b/server/streamable_http.go index 5a596467..b7d46408 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -52,13 +52,13 @@ func WithStateLess(stateLess bool) StreamableHTTPOption { } // WithSessionIdManager sets a custom session id generator for the server. -// By default, the server uses InsecureStatefulSessionIdManager (UUID-based; insecure). +// By default, the server uses StatelessGeneratingSessionIdManager (generates IDs but no local validation). // Note: Options are applied in order; the last one wins. If combined with // WithStateLess or WithSessionIdManagerResolver, whichever is applied last takes effect. func WithSessionIdManager(manager SessionIdManager) StreamableHTTPOption { return func(s *StreamableHTTPServer) { if manager == nil { - s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&InsecureStatefulSessionIdManager{}) + s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&StatelessSessionIdManager{}) return } s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(manager) @@ -72,13 +72,23 @@ func WithSessionIdManager(manager SessionIdManager) StreamableHTTPOption { func WithSessionIdManagerResolver(resolver SessionIdManagerResolver) StreamableHTTPOption { return func(s *StreamableHTTPServer) { if resolver == nil { - s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&InsecureStatefulSessionIdManager{}) + s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&StatelessSessionIdManager{}) return } s.sessionIdManagerResolver = resolver } } +// WithStateful enables stateful session management using InsecureStatefulSessionIdManager. +// This requires sticky sessions in multi-instance deployments. +func WithStateful(stateful bool) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + if stateful { + s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&InsecureStatefulSessionIdManager{}) + } + } +} + // WithHeartbeatInterval sets the heartbeat interval. Positive interval means the // server will send a heartbeat to the client through the GET connection, to keep // the connection alive from being closed by the network infrastructure (e.g. @@ -187,7 +197,7 @@ func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *S sessionTools: newSessionToolsStore(), sessionLogLevels: newSessionLogLevelsStore(), endpointPath: "/mcp", - sessionIdManagerResolver: NewDefaultSessionIdManagerResolver(&InsecureStatefulSessionIdManager{}), + sessionIdManagerResolver: NewDefaultSessionIdManagerResolver(&StatelessGeneratingSessionIdManager{}), logger: util.DefaultLogger(), sessionResources: newSessionResourcesStore(), sessionResourceTemplates: newSessionResourceTemplatesStore(), @@ -1244,7 +1254,7 @@ type DefaultSessionIdManagerResolver struct { // NewDefaultSessionIdManagerResolver creates a new DefaultSessionIdManagerResolver with the given SessionIdManager func NewDefaultSessionIdManagerResolver(manager SessionIdManager) *DefaultSessionIdManagerResolver { if manager == nil { - manager = &InsecureStatefulSessionIdManager{} + manager = &StatelessSessionIdManager{} } return &DefaultSessionIdManagerResolver{manager: manager} } @@ -1270,6 +1280,30 @@ func (s *StatelessSessionIdManager) Terminate(sessionID string) (isNotAllowed bo return false, nil } +// StatelessGeneratingSessionIdManager generates session IDs but doesn't validate them locally. +// This allows session IDs to be generated for clients while working across multiple instances. +type StatelessGeneratingSessionIdManager struct{} + +func (s *StatelessGeneratingSessionIdManager) Generate() string { + return idPrefix + uuid.New().String() +} + +func (s *StatelessGeneratingSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) { + // Only validate format, not existence - allows cross-instance operation + if !strings.HasPrefix(sessionID, idPrefix) { + return false, fmt.Errorf("invalid session id: %s", sessionID) + } + if _, err := uuid.Parse(sessionID[len(idPrefix):]); err != nil { + return false, fmt.Errorf("invalid session id: %s", sessionID) + } + return false, nil +} + +func (s *StatelessGeneratingSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) { + // No-op termination since we don't track sessions + return false, nil +} + // InsecureStatefulSessionIdManager generate id with uuid and tracks active sessions. // It validates both format and existence of session IDs. // For more secure session id, use a more complex generator, like a JWT. diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index f83a95eb..9e444c53 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -125,7 +125,7 @@ func TestStreamableHTTP_POST_InvalidContent(t *testing.T) { func TestStreamableHTTP_POST_SendAndReceive(t *testing.T) { mcpServer := NewMCPServer("test-mcp-server", "1.0") addSSETool(mcpServer) - server := NewTestStreamableHTTPServer(mcpServer) + server := NewTestStreamableHTTPServer(mcpServer, WithStateful(true)) var sessionID string t.Run("initialize", func(t *testing.T) { @@ -595,6 +595,7 @@ func TestStreamableHttpResourceGet(t *testing.T) { testServer := NewTestStreamableHTTPServer( s, + WithStateful(true), WithHTTPContextFunc(func(ctx context.Context, r *http.Request) context.Context { session := ClientSessionFromContext(ctx) @@ -1014,7 +1015,7 @@ func TestStreamableHTTP_SessionWithLogging(t *testing.T) { }) mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks), WithLogging()) - testServer := NewTestStreamableHTTPServer(mcpServer) + testServer := NewTestStreamableHTTPServer(mcpServer, WithStateful(true)) defer testServer.Close() // obtain a valid session ID first @@ -1404,7 +1405,7 @@ func TestStreamableHTTP_SessionValidation(t *testing.T) { server := NewTestStreamableHTTPServer(mcpServer) defer server.Close() - t.Run("Reject tool call with fake session ID", func(t *testing.T) { + t.Run("Accept tool call with properly formatted session ID", func(t *testing.T) { toolCallRequest := map[string]any{ "jsonrpc": "2.0", "id": 1, @@ -1425,13 +1426,29 @@ func TestStreamableHTTP_SessionValidation(t *testing.T) { } defer resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("Expected status 400, got %d", resp.StatusCode) + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) } body, _ := io.ReadAll(resp.Body) - if !strings.Contains(string(body), "Invalid session ID") { - t.Errorf("Expected 'Invalid session ID' error, got: %s", string(body)) + var response map[string]any + if err := json.Unmarshal(body, &response); err != nil { + t.Fatalf("Failed to unmarshal response: %v", err) + } + + if result, ok := response["result"].(map[string]any); ok { + if content, ok := result["content"].([]any); ok && len(content) > 0 { + if textContent, ok := content[0].(map[string]any); ok { + if text, ok := textContent["text"].(string); ok { + // Should be a valid timestamp response + if text == "" { + t.Error("Expected non-empty timestamp response") + } + } + } + } + } else { + t.Errorf("Expected result in response, got: %s", string(body)) } }) @@ -1508,22 +1525,45 @@ func TestStreamableHTTP_SessionValidation(t *testing.T) { } }) - t.Run("Reject tool call with terminated session ID", func(t *testing.T) { + t.Run("Reject tool call with terminated session ID (stateful mode)", func(t *testing.T) { + mcpServer := NewMCPServer("test-server", "1.0.0") + // Use explicit stateful mode for this test since termination requires local tracking + server := NewTestStreamableHTTPServer(mcpServer, WithStateful(true)) + defer server.Close() + + // First, initialize a session + initRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]any{ + "protocolVersion": "2025-03-26", + "capabilities": map[string]any{ + "tools": map[string]any{}, + }, + "clientInfo": map[string]any{ + "name": "test-client", + "version": "1.0.0", + }, + }, + } + jsonBody, _ := json.Marshal(initRequest) req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody)) req.Header.Set("Content-Type", "application/json") resp, err := server.Client().Do(req) if err != nil { - t.Fatalf("Failed to initialize: %v", err) + t.Fatalf("Failed to initialize session: %v", err) } - resp.Body.Close() sessionID := resp.Header.Get(HeaderKeySessionID) if sessionID == "" { t.Fatal("Expected session ID in response header") } + resp.Body.Close() + // Now terminate the session req, _ = http.NewRequest(http.MethodDelete, server.URL, nil) req.Header.Set(HeaderKeySessionID, sessionID) @@ -1780,13 +1820,19 @@ func TestDefaultSessionIdManagerResolver(t *testing.T) { t.Error("Expected resolver to return a non-nil manager") } - // Test that the resolved manager works (generates valid session IDs) + // Test that the resolved manager works (stateless behavior) sessionID := resolved.Generate() - if sessionID == "" { - t.Error("Expected default manager to generate non-empty session ID") + if sessionID != "" { + t.Error("Expected stateless manager to generate empty session ID") } - if !strings.HasPrefix(sessionID, idPrefix) { - t.Error("Expected default manager to generate session ID with correct prefix") + + // Test that validation accepts any session ID (stateless behavior) + isTerminated, err := resolved.Validate("any-session-id") + if err != nil { + t.Errorf("Expected stateless manager to accept any session ID, got error: %v", err) + } + if isTerminated { + t.Error("Expected stateless manager to not terminate sessions") } }) } @@ -1865,17 +1911,17 @@ func TestSessionIdManagerResolver_Integration(t *testing.T) { server := NewStreamableHTTPServer(mcpServer, WithStateLess(false)) - // Test that the default manager is still used (InsecureStatefulSessionIdManager) + // Test that the default manager is still used (StatelessGeneratingSessionIdManager) req, _ := http.NewRequest("POST", "/test", nil) resolved := server.sessionIdManagerResolver.ResolveSessionIdManager(req) - // Verify it's NOT a stateless manager + // Verify it's a generating manager (default behavior) sessionID := resolved.Generate() if sessionID == "" { - t.Error("Expected stateful manager when WithStateLess(false)") + t.Error("Expected generating manager to generate session ID by default") } if !strings.HasPrefix(sessionID, idPrefix) { - t.Error("Expected stateful session ID format") + t.Error("Expected generating manager to generate session ID with correct prefix") } }) @@ -1929,7 +1975,7 @@ func TestSessionIdManagerResolver_Integration(t *testing.T) { t.Run("WithSessionIdManagerResolver handles nil resolver defensively", func(t *testing.T) { mcpServer := NewMCPServer("test-server", "1.0.0") - // This should not panic and should fall back to default behavior + // This should not panic and should fall back to StatelessSessionIdManager (safe default) server := NewStreamableHTTPServer(mcpServer, WithSessionIdManagerResolver(nil)) req, _ := http.NewRequest("POST", "/test", nil) @@ -1938,20 +1984,17 @@ func TestSessionIdManagerResolver_Integration(t *testing.T) { t.Error("Expected nil resolver to be replaced with default") } - // Test that the resolved manager works (should be default stateful manager) + // Test that the resolved manager works (should be default stateless manager) sessionID := resolved.Generate() - if sessionID == "" { - t.Error("Expected default manager to generate non-empty session ID") - } - if !strings.HasPrefix(sessionID, idPrefix) { - t.Error("Expected default manager to generate session ID with correct prefix") + if sessionID != "" { + t.Error("Expected default stateless manager to generate empty session ID") } }) t.Run("WithSessionIdManager handles nil manager defensively", func(t *testing.T) { mcpServer := NewMCPServer("test-server", "1.0.0") - // This should not panic and should fall back to default behavior + // This should not panic and should fall back to StatelessSessionIdManager (safe default) server := NewStreamableHTTPServer(mcpServer, WithSessionIdManager(nil)) req, _ := http.NewRequest("POST", "/test", nil) @@ -1960,20 +2003,17 @@ func TestSessionIdManagerResolver_Integration(t *testing.T) { t.Error("Expected nil manager to be replaced with default") } - // Test that the resolved manager works (should be default stateful manager) + // Test that the resolved manager works (should be default stateless manager) sessionID := resolved.Generate() - if sessionID == "" { - t.Error("Expected default manager to generate non-empty session ID") - } - if !strings.HasPrefix(sessionID, idPrefix) { - t.Error("Expected default manager to generate session ID with correct prefix") + if sessionID != "" { + t.Error("Expected default stateless manager to generate empty session ID") } }) t.Run("Multiple nil options fall back safely", func(t *testing.T) { mcpServer := NewMCPServer("test-server", "1.0.0") - // Chain multiple nil options - last one should win with safe fallback + // Chain multiple nil options - last one should win with StatelessSessionIdManager fallback server := NewStreamableHTTPServer(mcpServer, WithSessionIdManager(nil), WithSessionIdManagerResolver(nil), @@ -1985,10 +2025,10 @@ func TestSessionIdManagerResolver_Integration(t *testing.T) { t.Error("Expected chained nil options to fall back safely") } - // Verify it generates valid session IDs + // Verify it uses stateless behavior (default) sessionID := resolved.Generate() - if sessionID == "" { - t.Error("Expected fallback manager to generate non-empty session ID") + if sessionID != "" { + t.Error("Expected fallback stateless manager to generate empty session ID") } }) @@ -2021,6 +2061,28 @@ func TestSessionIdManagerResolver_Integration(t *testing.T) { } _ = resp.Body.Close() }) + + t.Run("WithStateful enables stateful manager", func(t *testing.T) { + mcpServer := NewMCPServer("test-server", "1.0.0") + server := NewStreamableHTTPServer(mcpServer, WithStateful(true)) + + req, _ := http.NewRequest("POST", "/test", nil) + resolved := server.sessionIdManagerResolver.ResolveSessionIdManager(req) + + sessionID := resolved.Generate() + if sessionID == "" { + t.Error("Expected stateful manager to generate session ID") + } + if !strings.HasPrefix(sessionID, idPrefix) { + t.Error("Expected stateful session ID format") + } + + // Test that stateful manager validates session existence (unlike default) + _, err := resolved.Validate("unknown-session-id") + if err == nil { + t.Error("Expected stateful manager to reject unknown session ID") + } + }) } func TestStreamableHTTP_SendNotificationToSpecificClient(t *testing.T) { @@ -2039,7 +2101,7 @@ func TestStreamableHTTP_SendNotificationToSpecificClient(t *testing.T) { }) mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks)) - testServer := NewTestStreamableHTTPServer(mcpServer) + testServer := NewTestStreamableHTTPServer(mcpServer, WithStateful(true)) defer testServer.Close() // Send initialize request to register session @@ -2110,7 +2172,7 @@ func TestStreamableHTTP_SendNotificationToSpecificClient(t *testing.T) { return mcp.NewToolResultText("notification sent"), nil }) - testServer := NewTestStreamableHTTPServer(mcpServer) + testServer := NewTestStreamableHTTPServer(mcpServer, WithStateful(true)) defer testServer.Close() // Initialize session