Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions server/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,12 +309,23 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
}
}

// For non-initialize requests, try to reuse existing registered session
var session *streamableHttpSession
if !isInitializeRequest {
if sessionValue, ok := s.server.sessions.Load(sessionID); ok {
if existingSession, ok := sessionValue.(*streamableHttpSession); ok {
session = existingSession
}
}
}

// Check if a persistent session exists (for sampling support), otherwise create ephemeral session
// Persistent sessions are created by GET (continuous listening) connections
var session *streamableHttpSession
if sessionInterface, exists := s.activeSessions.Load(sessionID); exists {
if persistentSession, ok := sessionInterface.(*streamableHttpSession); ok {
session = persistentSession
if session == nil {
if sessionInterface, exists := s.activeSessions.Load(sessionID); exists {
if persistentSession, ok := sessionInterface.(*streamableHttpSession); ok {
session = persistentSession
}
}
}

Expand Down Expand Up @@ -417,6 +428,21 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
s.logger.Errorf("Failed to write response: %v", err)
}
}

// Register session after successful initialization
// Only register if not already registered (e.g., by a GET connection)
if isInitializeRequest && sessionID != "" {
if _, exists := s.server.sessions.Load(sessionID); !exists {
// Store in activeSessions to prevent duplicate registration from GET
s.activeSessions.Store(sessionID, session)
// Register the session with the MCPServer for notification support
if err := s.server.RegisterSession(ctx, session); err != nil {
s.logger.Errorf("Failed to register POST session: %v", err)
s.activeSessions.Delete(sessionID)
// Don't fail the request, just log the error
}
}
}
}

func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) {
Expand Down
167 changes: 167 additions & 0 deletions server/streamable_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1314,3 +1314,170 @@ func TestInsecureStatefulSessionIdManager(t *testing.T) {
}
})
}

func TestStreamableHTTP_SendNotificationToSpecificClient(t *testing.T) {
t.Run("POST session registration enables SendNotificationToSpecificClient", func(t *testing.T) {
hooks := &Hooks{}
var registeredSessionID string
var mu sync.Mutex
var sessionRegistered sync.WaitGroup
sessionRegistered.Add(1)

hooks.AddOnRegisterSession(func(ctx context.Context, session ClientSession) {
mu.Lock()
registeredSessionID = session.SessionID()
mu.Unlock()
sessionRegistered.Done()
})

mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks))
testServer := NewTestStreamableHTTPServer(mcpServer)
defer testServer.Close()

// Send initialize request to register session
resp, err := postJSON(testServer.URL, initRequest)
if err != nil {
t.Fatalf("Failed to send initialize request: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected status 200, got %d", resp.StatusCode)
}

// Get session ID from response header
sessionID := resp.Header.Get(HeaderKeySessionID)
if sessionID == "" {
t.Fatal("Expected session ID in response header")
}

// Wait for session registration
done := make(chan struct{})
go func() {
sessionRegistered.Wait()
close(done)
}()

select {
case <-done:
// Session registered successfully
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for session registration")
}

mu.Lock()
if registeredSessionID != sessionID {
t.Errorf("Expected registered session ID %s, got %s", sessionID, registeredSessionID)
}
mu.Unlock()

// Now test SendNotificationToSpecificClient
err = mcpServer.SendNotificationToSpecificClient(sessionID, "test/notification", map[string]any{
"message": "test notification",
})
if err != nil {
t.Errorf("SendNotificationToSpecificClient failed: %v", err)
}
})

t.Run("Session reuse for non-initialize requests", func(t *testing.T) {
mcpServer := NewMCPServer("test", "1.0.0")

// Add a tool that sends a notification
mcpServer.AddTool(mcp.NewTool("notify_tool"), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
session := ClientSessionFromContext(ctx)
if session == nil {
return mcp.NewToolResultError("no session in context"), nil
}

// Try to send notification to specific client
server := ServerFromContext(ctx)
err := server.SendNotificationToSpecificClient(session.SessionID(), "tool/notification", map[string]any{
"from": "tool",
})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("notification failed: %v", err)), nil
}

return mcp.NewToolResultText("notification sent"), nil
})

testServer := NewTestStreamableHTTPServer(mcpServer)
defer testServer.Close()

// Initialize session
resp, err := postJSON(testServer.URL, initRequest)
if err != nil {
t.Fatalf("Failed to send initialize request: %v", err)
}
sessionID := resp.Header.Get(HeaderKeySessionID)
resp.Body.Close()

if sessionID == "" {
t.Fatal("Expected session ID in response header")
}

// Give time for registration to complete
time.Sleep(100 * time.Millisecond)

// Call tool with the session ID
toolCallRequest := map[string]any{
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": map[string]any{
"name": "notify_tool",
},
}

jsonBody, _ := json.Marshal(toolCallRequest)
req, _ := http.NewRequest(http.MethodPost, testServer.URL, bytes.NewBuffer(jsonBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set(HeaderKeySessionID, sessionID)

resp, err = http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Failed to call tool: %v", err)
}
defer resp.Body.Close()

bodyBytes, _ := io.ReadAll(resp.Body)
bodyStr := string(bodyBytes)

// Response might be SSE format if notification was sent
var toolResponse jsonRPCResponse
if strings.HasPrefix(bodyStr, "event: message") {
// Parse SSE format
lines := strings.Split(bodyStr, "\n")
for _, line := range lines {
if strings.HasPrefix(line, "data: ") {
jsonData := strings.TrimPrefix(line, "data: ")
if err := json.Unmarshal([]byte(jsonData), &toolResponse); err == nil {
break
}
}
}
} else {
if err := json.Unmarshal(bodyBytes, &toolResponse); err != nil {
t.Fatalf("Failed to unmarshal response: %v. Body: %s", err, bodyStr)
}
}

if toolResponse.Error != nil {
t.Errorf("Tool call failed: %v", toolResponse.Error)
}

// Verify the tool result indicates success
if result, ok := toolResponse.Result["content"].([]any); ok {
if len(result) > 0 {
if content, ok := result[0].(map[string]any); ok {
if text, ok := content["text"].(string); ok {
if text != "notification sent" {
t.Errorf("Expected 'notification sent', got %s", text)
}
}
}
}
}
})
}