Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:

coverage:
runs-on: ubuntu-latest
if: github.event_name == 'pull_request' || github.event_name == 'pull_request_target'
if: github.event_name == 'pull_request' || github.event_name == 'pull_request_target' || (github.event_name == 'push' && github.ref == 'refs/heads/main')
permissions:
contents: read
pull-requests: write
Expand All @@ -38,6 +38,7 @@ jobs:
with:
name: code-coverage
path: coverage.txt
retention-days: 30
- name: Generate coverage report
uses: fgrosse/go-coverage-report@v1.2.0

Expand Down
44 changes: 39 additions & 5 deletions server/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ func WithStateLess(stateLess bool) StreamableHTTPOption {
}

// WithSessionIdManager sets a custom session id generator for the server.
// By default, the server uses InsecureStatefulSessionIdManager (UUID-based; insecure).
// By default, the server uses StatelessGeneratingSessionIdManager (generates IDs but no local validation).
// Note: Options are applied in order; the last one wins. If combined with
// WithStateLess or WithSessionIdManagerResolver, whichever is applied last takes effect.
func WithSessionIdManager(manager SessionIdManager) StreamableHTTPOption {
return func(s *StreamableHTTPServer) {
if manager == nil {
s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&InsecureStatefulSessionIdManager{})
s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&StatelessSessionIdManager{})
return
}
s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(manager)
Expand All @@ -72,13 +72,23 @@ func WithSessionIdManager(manager SessionIdManager) StreamableHTTPOption {
func WithSessionIdManagerResolver(resolver SessionIdManagerResolver) StreamableHTTPOption {
return func(s *StreamableHTTPServer) {
if resolver == nil {
s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&InsecureStatefulSessionIdManager{})
s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&StatelessSessionIdManager{})
return
}
s.sessionIdManagerResolver = resolver
}
}

// WithStateful enables stateful session management using InsecureStatefulSessionIdManager.
// This requires sticky sessions in multi-instance deployments.
func WithStateful(stateful bool) StreamableHTTPOption {
return func(s *StreamableHTTPServer) {
if stateful {
s.sessionIdManagerResolver = NewDefaultSessionIdManagerResolver(&InsecureStatefulSessionIdManager{})
}
}
}

// WithHeartbeatInterval sets the heartbeat interval. Positive interval means the
// server will send a heartbeat to the client through the GET connection, to keep
// the connection alive from being closed by the network infrastructure (e.g.
Expand Down Expand Up @@ -187,7 +197,7 @@ func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *S
sessionTools: newSessionToolsStore(),
sessionLogLevels: newSessionLogLevelsStore(),
endpointPath: "/mcp",
sessionIdManagerResolver: NewDefaultSessionIdManagerResolver(&InsecureStatefulSessionIdManager{}),
sessionIdManagerResolver: NewDefaultSessionIdManagerResolver(&StatelessGeneratingSessionIdManager{}),
logger: util.DefaultLogger(),
sessionResources: newSessionResourcesStore(),
sessionResourceTemplates: newSessionResourceTemplatesStore(),
Expand Down Expand Up @@ -1244,7 +1254,7 @@ type DefaultSessionIdManagerResolver struct {
// NewDefaultSessionIdManagerResolver creates a new DefaultSessionIdManagerResolver with the given SessionIdManager
func NewDefaultSessionIdManagerResolver(manager SessionIdManager) *DefaultSessionIdManagerResolver {
if manager == nil {
manager = &InsecureStatefulSessionIdManager{}
manager = &StatelessSessionIdManager{}
}
return &DefaultSessionIdManagerResolver{manager: manager}
}
Expand All @@ -1270,6 +1280,30 @@ func (s *StatelessSessionIdManager) Terminate(sessionID string) (isNotAllowed bo
return false, nil
}

// StatelessGeneratingSessionIdManager generates session IDs but doesn't validate them locally.
// This allows session IDs to be generated for clients while working across multiple instances.
type StatelessGeneratingSessionIdManager struct{}

func (s *StatelessGeneratingSessionIdManager) Generate() string {
return idPrefix + uuid.New().String()
}

func (s *StatelessGeneratingSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) {
// Only validate format, not existence - allows cross-instance operation
if !strings.HasPrefix(sessionID, idPrefix) {
return false, fmt.Errorf("invalid session id: %s", sessionID)
}
if _, err := uuid.Parse(sessionID[len(idPrefix):]); err != nil {
return false, fmt.Errorf("invalid session id: %s", sessionID)
}
return false, nil
}

func (s *StatelessGeneratingSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) {
// No-op termination since we don't track sessions
return false, nil
}

// InsecureStatefulSessionIdManager generate id with uuid and tracks active sessions.
// It validates both format and existence of session IDs.
// For more secure session id, use a more complex generator, like a JWT.
Expand Down
140 changes: 101 additions & 39 deletions server/streamable_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func TestStreamableHTTP_POST_InvalidContent(t *testing.T) {
func TestStreamableHTTP_POST_SendAndReceive(t *testing.T) {
mcpServer := NewMCPServer("test-mcp-server", "1.0")
addSSETool(mcpServer)
server := NewTestStreamableHTTPServer(mcpServer)
server := NewTestStreamableHTTPServer(mcpServer, WithStateful(true))
var sessionID string

t.Run("initialize", func(t *testing.T) {
Expand Down Expand Up @@ -595,6 +595,7 @@ func TestStreamableHttpResourceGet(t *testing.T) {

testServer := NewTestStreamableHTTPServer(
s,
WithStateful(true),
WithHTTPContextFunc(func(ctx context.Context, r *http.Request) context.Context {
session := ClientSessionFromContext(ctx)

Expand Down Expand Up @@ -1014,7 +1015,7 @@ func TestStreamableHTTP_SessionWithLogging(t *testing.T) {
})

mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks), WithLogging())
testServer := NewTestStreamableHTTPServer(mcpServer)
testServer := NewTestStreamableHTTPServer(mcpServer, WithStateful(true))
defer testServer.Close()

// obtain a valid session ID first
Expand Down Expand Up @@ -1404,7 +1405,7 @@ func TestStreamableHTTP_SessionValidation(t *testing.T) {
server := NewTestStreamableHTTPServer(mcpServer)
defer server.Close()

t.Run("Reject tool call with fake session ID", func(t *testing.T) {
t.Run("Accept tool call with properly formatted session ID", func(t *testing.T) {
toolCallRequest := map[string]any{
"jsonrpc": "2.0",
"id": 1,
Expand All @@ -1425,13 +1426,29 @@ func TestStreamableHTTP_SessionValidation(t *testing.T) {
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected status 400, got %d", resp.StatusCode)
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}

body, _ := io.ReadAll(resp.Body)
if !strings.Contains(string(body), "Invalid session ID") {
t.Errorf("Expected 'Invalid session ID' error, got: %s", string(body))
var response map[string]any
if err := json.Unmarshal(body, &response); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}

if result, ok := response["result"].(map[string]any); ok {
if content, ok := result["content"].([]any); ok && len(content) > 0 {
if textContent, ok := content[0].(map[string]any); ok {
if text, ok := textContent["text"].(string); ok {
// Should be a valid timestamp response
if text == "" {
t.Error("Expected non-empty timestamp response")
}
}
}
}
} else {
t.Errorf("Expected result in response, got: %s", string(body))
}
})

Expand Down Expand Up @@ -1508,22 +1525,45 @@ func TestStreamableHTTP_SessionValidation(t *testing.T) {
}
})

t.Run("Reject tool call with terminated session ID", func(t *testing.T) {
t.Run("Reject tool call with terminated session ID (stateful mode)", func(t *testing.T) {
mcpServer := NewMCPServer("test-server", "1.0.0")
// Use explicit stateful mode for this test since termination requires local tracking
server := NewTestStreamableHTTPServer(mcpServer, WithStateful(true))
defer server.Close()

// First, initialize a session
initRequest := map[string]any{
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": map[string]any{
"protocolVersion": "2025-03-26",
"capabilities": map[string]any{
"tools": map[string]any{},
},
"clientInfo": map[string]any{
"name": "test-client",
"version": "1.0.0",
},
},
}

jsonBody, _ := json.Marshal(initRequest)
req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")

resp, err := server.Client().Do(req)
if err != nil {
t.Fatalf("Failed to initialize: %v", err)
t.Fatalf("Failed to initialize session: %v", err)
}
resp.Body.Close()

sessionID := resp.Header.Get(HeaderKeySessionID)
if sessionID == "" {
t.Fatal("Expected session ID in response header")
}
resp.Body.Close()

// Now terminate the session
req, _ = http.NewRequest(http.MethodDelete, server.URL, nil)
req.Header.Set(HeaderKeySessionID, sessionID)

Expand Down Expand Up @@ -1780,13 +1820,19 @@ func TestDefaultSessionIdManagerResolver(t *testing.T) {
t.Error("Expected resolver to return a non-nil manager")
}

// Test that the resolved manager works (generates valid session IDs)
// Test that the resolved manager works (stateless behavior)
sessionID := resolved.Generate()
if sessionID == "" {
t.Error("Expected default manager to generate non-empty session ID")
if sessionID != "" {
t.Error("Expected stateless manager to generate empty session ID")
}
if !strings.HasPrefix(sessionID, idPrefix) {
t.Error("Expected default manager to generate session ID with correct prefix")

// Test that validation accepts any session ID (stateless behavior)
isTerminated, err := resolved.Validate("any-session-id")
if err != nil {
t.Errorf("Expected stateless manager to accept any session ID, got error: %v", err)
}
if isTerminated {
t.Error("Expected stateless manager to not terminate sessions")
}
})
}
Expand Down Expand Up @@ -1865,17 +1911,17 @@ func TestSessionIdManagerResolver_Integration(t *testing.T) {

server := NewStreamableHTTPServer(mcpServer, WithStateLess(false))

// Test that the default manager is still used (InsecureStatefulSessionIdManager)
// Test that the default manager is still used (StatelessGeneratingSessionIdManager)
req, _ := http.NewRequest("POST", "/test", nil)
resolved := server.sessionIdManagerResolver.ResolveSessionIdManager(req)

// Verify it's NOT a stateless manager
// Verify it's a generating manager (default behavior)
sessionID := resolved.Generate()
if sessionID == "" {
t.Error("Expected stateful manager when WithStateLess(false)")
t.Error("Expected generating manager to generate session ID by default")
}
if !strings.HasPrefix(sessionID, idPrefix) {
t.Error("Expected stateful session ID format")
t.Error("Expected generating manager to generate session ID with correct prefix")
}
})

Expand Down Expand Up @@ -1929,7 +1975,7 @@ func TestSessionIdManagerResolver_Integration(t *testing.T) {
t.Run("WithSessionIdManagerResolver handles nil resolver defensively", func(t *testing.T) {
mcpServer := NewMCPServer("test-server", "1.0.0")

// This should not panic and should fall back to default behavior
// This should not panic and should fall back to StatelessSessionIdManager (safe default)
server := NewStreamableHTTPServer(mcpServer, WithSessionIdManagerResolver(nil))

req, _ := http.NewRequest("POST", "/test", nil)
Expand All @@ -1938,20 +1984,17 @@ func TestSessionIdManagerResolver_Integration(t *testing.T) {
t.Error("Expected nil resolver to be replaced with default")
}

// Test that the resolved manager works (should be default stateful manager)
// Test that the resolved manager works (should be default stateless manager)
sessionID := resolved.Generate()
if sessionID == "" {
t.Error("Expected default manager to generate non-empty session ID")
}
if !strings.HasPrefix(sessionID, idPrefix) {
t.Error("Expected default manager to generate session ID with correct prefix")
if sessionID != "" {
t.Error("Expected default stateless manager to generate empty session ID")
}
})

t.Run("WithSessionIdManager handles nil manager defensively", func(t *testing.T) {
mcpServer := NewMCPServer("test-server", "1.0.0")

// This should not panic and should fall back to default behavior
// This should not panic and should fall back to StatelessSessionIdManager (safe default)
server := NewStreamableHTTPServer(mcpServer, WithSessionIdManager(nil))

req, _ := http.NewRequest("POST", "/test", nil)
Expand All @@ -1960,20 +2003,17 @@ func TestSessionIdManagerResolver_Integration(t *testing.T) {
t.Error("Expected nil manager to be replaced with default")
}

// Test that the resolved manager works (should be default stateful manager)
// Test that the resolved manager works (should be default stateless manager)
sessionID := resolved.Generate()
if sessionID == "" {
t.Error("Expected default manager to generate non-empty session ID")
}
if !strings.HasPrefix(sessionID, idPrefix) {
t.Error("Expected default manager to generate session ID with correct prefix")
if sessionID != "" {
t.Error("Expected default stateless manager to generate empty session ID")
}
})

t.Run("Multiple nil options fall back safely", func(t *testing.T) {
mcpServer := NewMCPServer("test-server", "1.0.0")

// Chain multiple nil options - last one should win with safe fallback
// Chain multiple nil options - last one should win with StatelessSessionIdManager fallback
server := NewStreamableHTTPServer(mcpServer,
WithSessionIdManager(nil),
WithSessionIdManagerResolver(nil),
Expand All @@ -1985,10 +2025,10 @@ func TestSessionIdManagerResolver_Integration(t *testing.T) {
t.Error("Expected chained nil options to fall back safely")
}

// Verify it generates valid session IDs
// Verify it uses stateless behavior (default)
sessionID := resolved.Generate()
if sessionID == "" {
t.Error("Expected fallback manager to generate non-empty session ID")
if sessionID != "" {
t.Error("Expected fallback stateless manager to generate empty session ID")
}
})

Expand Down Expand Up @@ -2021,6 +2061,28 @@ func TestSessionIdManagerResolver_Integration(t *testing.T) {
}
_ = resp.Body.Close()
})

t.Run("WithStateful enables stateful manager", func(t *testing.T) {
mcpServer := NewMCPServer("test-server", "1.0.0")
server := NewStreamableHTTPServer(mcpServer, WithStateful(true))

req, _ := http.NewRequest("POST", "/test", nil)
resolved := server.sessionIdManagerResolver.ResolveSessionIdManager(req)

sessionID := resolved.Generate()
if sessionID == "" {
t.Error("Expected stateful manager to generate session ID")
}
if !strings.HasPrefix(sessionID, idPrefix) {
t.Error("Expected stateful session ID format")
}

// Test that stateful manager validates session existence (unlike default)
_, err := resolved.Validate("unknown-session-id")
if err == nil {
t.Error("Expected stateful manager to reject unknown session ID")
}
})
}

func TestStreamableHTTP_SendNotificationToSpecificClient(t *testing.T) {
Expand All @@ -2039,7 +2101,7 @@ func TestStreamableHTTP_SendNotificationToSpecificClient(t *testing.T) {
})

mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks))
testServer := NewTestStreamableHTTPServer(mcpServer)
testServer := NewTestStreamableHTTPServer(mcpServer, WithStateful(true))
defer testServer.Close()

// Send initialize request to register session
Expand Down Expand Up @@ -2110,7 +2172,7 @@ func TestStreamableHTTP_SendNotificationToSpecificClient(t *testing.T) {
return mcp.NewToolResultText("notification sent"), nil
})

testServer := NewTestStreamableHTTPServer(mcpServer)
testServer := NewTestStreamableHTTPServer(mcpServer, WithStateful(true))
defer testServer.Close()

// Initialize session
Expand Down
Loading