diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7baf10c83..992f0ab3c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,6 +15,28 @@ jobs: with: go-version-file: 'go.mod' - run: go test ./... -race + + coverage: + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: write + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: 'go.mod' + - name: Run tests with coverage + run: | + go test -coverprofile=coverage.txt -covermode=atomic $(go list ./... | grep -v '/examples/' | grep -v '/testdata' | grep -v '/mcptest' | grep -v '/server/internal/gen') + - name: Upload coverage artifact + uses: actions/upload-artifact@v4 + with: + name: code-coverage + path: coverage.txt + - name: Generate coverage report + uses: fgrosse/go-coverage-report@v1.2.0 + if: github.event_name == 'pull_request' verify-codegen: runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index b575ab67e..1d4dcd5cb 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ .idea .opencode .claude +coverage.out +coverage.txt diff --git a/client/client.go b/client/client.go index 220786b68..192588b52 100644 --- a/client/client.go +++ b/client/client.go @@ -502,6 +502,19 @@ func (c *Client) handleSamplingRequestTransport(ctx context.Context, request tra } } + // Fix content parsing - HTTP transport unmarshals TextContent as map[string]any + // Use the helper function to properly handle content from different transports + for i := range params.Messages { + if contentMap, ok := params.Messages[i].Content.(map[string]any); ok { + // Parse the content map into a proper Content type + content, err := mcp.ParseContent(contentMap) + if err != nil { + return nil, fmt.Errorf("failed to parse content for message %d: %w", i, err) + } + params.Messages[i].Content = content + } + } + // Create the MCP request mcpRequest := mcp.CreateMessageRequest{ Request: mcp.Request{ diff --git a/client/client_edge_cases_test.go b/client/client_edge_cases_test.go new file mode 100644 index 000000000..0f1c86b90 --- /dev/null +++ b/client/client_edge_cases_test.go @@ -0,0 +1,188 @@ +package client + +import ( + "context" + "testing" + "time" + + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestClient_UnsupportedProtocolVersionResponse tests that client rejects unsupported protocol versions +func TestClient_UnsupportedProtocolVersionResponse(t *testing.T) { + // Create mock transport + mockTrans := newMockTransport() + + // Create client + client := &Client{ + transport: mockTrans, + } + + ctx := context.Background() + err := client.Start(ctx) + require.NoError(t, err) + + // Server responds with an unsupported/invalid protocol version + initResponse := transport.NewJSONRPCResultResponse( + mcp.NewRequestId(1), + []byte(`{"protocolVersion":"9999-99-99","capabilities":{},"serverInfo":{"name":"test-server","version":"1.0.0"}}`), + ) + + go func() { + mockTrans.responseChan <- initResponse + }() + + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, + }, + } + + _, err = client.Initialize(ctx, initRequest) + require.Error(t, err) + + // Should be an UnsupportedProtocolVersionError + var unsupportedErr mcp.UnsupportedProtocolVersionError + assert.ErrorAs(t, err, &unsupportedErr) + assert.Equal(t, "9999-99-99", unsupportedErr.Version) +} + +// TestClient_OperationsBeforeInitialize tests operations fail before initialization +func TestClient_OperationsBeforeInitialize(t *testing.T) { + mockTrans := newMockTransport() + client := &Client{ + transport: mockTrans, + } + + ctx := context.Background() + err := client.Start(ctx) + require.NoError(t, err) + + // Try to send request before initialization + err = client.Ping(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "not initialized") + + // List tools should also fail + _, err = client.ListTools(ctx, mcp.ListToolsRequest{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "not initialized") + + // List resources should also fail + _, err = client.ListResources(ctx, mcp.ListResourcesRequest{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "not initialized") +} + +// TestClient_NotificationHandlers tests notification handler behavior +func TestClient_NotificationHandlers(t *testing.T) { + t.Run("multiple handlers called in order", func(t *testing.T) { + mockTrans := newMockTransport() + client := &Client{ + transport: mockTrans, + } + + ctx := context.Background() + err := client.Start(ctx) + require.NoError(t, err) + + var callOrder []int + var handlerCalls int + + // Register multiple handlers + for i := 0; i < 3; i++ { + handlerID := i + client.OnNotification(func(notification mcp.JSONRPCNotification) { + callOrder = append(callOrder, handlerID) + handlerCalls++ + }) + } + + // Simulate notification via the handler + notif := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: "test-method", + }, + } + + // Manually trigger the handlers we registered on the client + // Access them through the read lock + client.notifyMu.RLock() + handlers := make([]func(mcp.JSONRPCNotification), len(client.notifications)) + copy(handlers, client.notifications) + client.notifyMu.RUnlock() + + for _, h := range handlers { + h(notif) + } + + // Wait a bit for handlers to execute + time.Sleep(50 * time.Millisecond) + + // All handlers should have been called in order + assert.Equal(t, []int{0, 1, 2}, callOrder) + assert.Equal(t, 3, handlerCalls) + }) +} + +// TestClient_GetSessionId tests session ID retrieval +func TestClient_GetSessionId(t *testing.T) { + mockTrans := newMockTransport() + client := &Client{ + transport: mockTrans, + } + + // Should return the transport's session ID + sessionID := client.GetSessionId() + assert.Equal(t, "mock-session-id", sessionID) +} + +// TestClient_IsInitialized tests initialization state tracking +func TestClient_IsInitialized(t *testing.T) { + mockTrans := newMockTransport() + client := &Client{ + transport: mockTrans, + } + + // Should not be initialized initially + assert.False(t, client.IsInitialized()) + + ctx := context.Background() + err := client.Start(ctx) + require.NoError(t, err) + + // Still not initialized after Start + assert.False(t, client.IsInitialized()) + + // Initialize + initResponse := transport.NewJSONRPCResultResponse( + mcp.NewRequestId(1), + []byte(`{"protocolVersion":"2025-03-26","capabilities":{},"serverInfo":{"name":"test-server","version":"1.0.0"}}`), + ) + go func() { + mockTrans.responseChan <- initResponse + mockTrans.responseChan <- transport.NewJSONRPCResultResponse(mcp.NewRequestId(2), []byte(`{}`)) + }() + + _, err = client.Initialize(ctx, mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, + }, + }) + require.NoError(t, err) + + // Should be initialized now + assert.True(t, client.IsInitialized()) +} diff --git a/client/transport/interface.go b/client/transport/interface.go index a877e49dc..b00210e52 100644 --- a/client/transport/interface.go +++ b/client/transport/interface.go @@ -70,4 +70,3 @@ type JSONRPCResponse struct { Result json.RawMessage `json:"result,omitempty"` Error *mcp.JSONRPCErrorDetails `json:"error,omitempty"` } - diff --git a/client/transport/stdio.go b/client/transport/stdio.go index 69557668c..26c1d73a0 100644 --- a/client/transport/stdio.go +++ b/client/transport/stdio.go @@ -9,7 +9,10 @@ import ( "io" "os" "os/exec" + "strings" "sync" + "syscall" + "time" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/util" @@ -24,21 +27,22 @@ type Stdio struct { args []string env []string - cmd *exec.Cmd - cmdFunc CommandFunc - stdin io.WriteCloser - stdout *bufio.Scanner - stderr io.ReadCloser - responses map[string]chan *JSONRPCResponse - mu sync.RWMutex - done chan struct{} - onNotification func(mcp.JSONRPCNotification) - notifyMu sync.RWMutex - onRequest RequestHandler - requestMu sync.RWMutex - ctx context.Context - ctxMu sync.RWMutex - logger util.Logger + cmd *exec.Cmd + cmdFunc CommandFunc + stdin io.WriteCloser + stdout *bufio.Reader + stderr io.ReadCloser + responses map[string]chan *JSONRPCResponse + mu sync.RWMutex + done chan struct{} + onNotification func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + onRequest RequestHandler + requestMu sync.RWMutex + ctx context.Context + ctxMu sync.RWMutex + logger util.Logger + terminateDuration time.Duration } // StdioOption defines a function that configures a Stdio transport instance. @@ -66,19 +70,27 @@ func WithCommandLogger(logger util.Logger) StdioOption { } } +// WithTerminateDuration sets the duration to wait for graceful shutdown before sending SIGTERM. +func WithTerminateDuration(duration time.Duration) StdioOption { + return func(s *Stdio) { + s.terminateDuration = duration + } +} + // NewIO returns a new stdio-based transport using existing input, output, and // logging streams instead of spawning a subprocess. // This is useful for testing and simulating client behavior. func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio { return &Stdio{ stdin: output, - stdout: bufio.NewScanner(input), + stdout: bufio.NewReader(input), stderr: logging, - responses: make(map[string]chan *JSONRPCResponse), - done: make(chan struct{}), - ctx: context.Background(), - logger: util.DefaultLogger(), + responses: make(map[string]chan *JSONRPCResponse), + done: make(chan struct{}), + ctx: context.Background(), + logger: util.DefaultLogger(), + terminateDuration: 5 * time.Second, // Default 5 second timeout } } @@ -109,10 +121,11 @@ func NewStdioWithOptions( args: args, env: env, - responses: make(map[string]chan *JSONRPCResponse), - done: make(chan struct{}), - ctx: context.Background(), - logger: util.DefaultLogger(), + responses: make(map[string]chan *JSONRPCResponse), + done: make(chan struct{}), + ctx: context.Background(), + logger: util.DefaultLogger(), + terminateDuration: 5 * time.Second, // Default 5 second timeout } for _, opt := range opts { @@ -180,7 +193,7 @@ func (c *Stdio) spawnCommand(ctx context.Context) error { c.cmd = cmd c.stdin = stdin c.stderr = stderr - c.stdout = bufio.NewScanner(stdout) + c.stdout = bufio.NewReader(stdout) if err := cmd.Start(); err != nil { return fmt.Errorf("failed to start command: %w", err) @@ -189,8 +202,10 @@ func (c *Stdio) spawnCommand(ctx context.Context) error { return nil } -// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit. -// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate. +// Close closes the input stream to the child process, and awaits normal +// termination of the command. If the command does not exit, it is signalled to +// terminate, and then eventually killed. This follows the MCP specification +// for stdio transport shutdown. func (c *Stdio) Close() error { select { case <-c.done: @@ -200,6 +215,8 @@ func (c *Stdio) Close() error { // cancel all in-flight request close(c.done) + // For the stdio transport, the client SHOULD initiate shutdown by: + // First, closing the input stream to the child process (the server) if c.stdin != nil { if err := c.stdin.Close(); err != nil { return fmt.Errorf("failed to close stdin: %w", err) @@ -212,7 +229,40 @@ func (c *Stdio) Close() error { } if c.cmd != nil { - return c.cmd.Wait() + resChan := make(chan error, 1) + go func() { + resChan <- c.cmd.Wait() + }() + + // Waiting for the server to exit, or sending SIGTERM if the server does not exit within a reasonable time + wait := func() (error, bool) { + select { + case err := <-resChan: + return err, true + case <-time.After(c.terminateDuration): + } + return nil, false + } + + if err, ok := wait(); ok { + return err + } + + // Note the condition here: if sending SIGTERM fails, don't wait and just + // move on to SIGKILL. + if err := c.cmd.Process.Signal(syscall.SIGTERM); err == nil { + if err, ok := wait(); ok { + return err + } + } + // Sending SIGKILL if the server does not exit within a reasonable time after SIGTERM + if err := c.cmd.Process.Kill(); err != nil { + return err + } + if err, ok := wait(); ok { + return err + } + return fmt.Errorf("unresponsive subprocess") } return nil @@ -251,15 +301,15 @@ func (c *Stdio) readResponses() { case <-c.done: return default: - if !c.stdout.Scan() { - err := c.stdout.Err() - if err != nil && !errors.Is(err, context.Canceled) { + line, err := c.stdout.ReadString('\n') + if err != nil { + if err != io.EOF && !errors.Is(err, context.Canceled) { c.logger.Errorf("Error reading from stdout: %v", err) } return } - line := c.stdout.Text() + line = strings.TrimRight(line, "\r\n") // First try to parse as a generic message to check for ID field var baseMessage struct { JSONRPC string `json:"jsonrpc"` diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index 56867acf4..04bf18a65 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -703,3 +703,120 @@ func TestStdio_NewStdioWithOptions_AppliesOptions(t *testing.T) { require.NotNil(t, stdio) require.True(t, configured, "option was not applied") } + +func TestStdio_LargeMessages(t *testing.T) { + tempFile, err := os.CreateTemp("", "mockstdio_server") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + tempFile.Close() + mockServerPath := tempFile.Name() + + if runtime.GOOS == "windows" { + os.Remove(mockServerPath) + mockServerPath += ".exe" + } + + if compileErr := compileTestServer(mockServerPath); compileErr != nil { + t.Fatalf("Failed to compile mock server: %v", compileErr) + } + defer os.Remove(mockServerPath) + + stdio := NewStdio(mockServerPath, nil) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + startErr := stdio.Start(ctx) + if startErr != nil { + t.Fatalf("Failed to start Stdio transport: %v", startErr) + } + defer stdio.Close() + + testCases := []struct { + name string + dataSize int + description string + }{ + {"SmallMessage_1KB", 1024, "Small message under scanner default limit"}, + {"MediumMessage_32KB", 32 * 1024, "Medium message under scanner default limit"}, + {"AtLimit_64KB", 64 * 1024, "Message at default scanner limit"}, + {"OverLimit_128KB", 128 * 1024, "Message over default scanner limit - would fail with Scanner"}, + {"Large_256KB", 256 * 1024, "Large message well over scanner limit"}, + {"VeryLarge_1MB", 1024 * 1024, "Very large message"}, + {"Huge_5MB", 5 * 1024 * 1024, "Huge message to stress test"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + largeString := generateRandomString(tc.dataSize) + + params := map[string]any{ + "data": largeString, + "size": len(largeString), + } + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(int64(1)), + Method: "debug/echo", + Params: params, + } + + response, err := stdio.SendRequest(ctx, request) + if err != nil { + t.Fatalf("SendRequest failed for %s: %v", tc.description, err) + } + + var result struct { + JSONRPC string `json:"jsonrpc"` + ID mcp.RequestId `json:"id"` + Method string `json:"method"` + Params map[string]any `json:"params"` + } + + if err := json.Unmarshal(response.Result, &result); err != nil { + t.Fatalf("Failed to unmarshal result for %s: %v", tc.description, err) + } + + if result.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) + } + + returnedData, ok := result.Params["data"].(string) + if !ok { + t.Fatalf("Expected data to be string, got %T", result.Params["data"]) + } + + if returnedData != largeString { + t.Errorf("Data mismatch for %s: expected length %d, got length %d", + tc.description, len(largeString), len(returnedData)) + } + + returnedSize, ok := result.Params["size"].(float64) + if !ok { + t.Fatalf("Expected size to be number, got %T", result.Params["size"]) + } + + if int(returnedSize) != tc.dataSize { + t.Errorf("Size mismatch for %s: expected %d, got %d", + tc.description, tc.dataSize, int(returnedSize)) + } + + t.Logf("Successfully handled %s message of size %d bytes", tc.name, tc.dataSize) + }) + } +} + +func generateRandomString(size int) string { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 " + + b := make([]byte, size) + for i := range b { + b[i] = charset[i%len(charset)] + } + return string(b) +} diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 9d5218139..6339b6110 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -594,10 +594,14 @@ func (c *StreamableHTTP) IsOAuthEnabled() bool { func (c *StreamableHTTP) listenForever(ctx context.Context) { c.logger.Infof("listening to server forever") for { - connectCtx, cancel := context.WithCancel(ctx) - err := c.createGETConnectionToServer(connectCtx) - cancel() - + // Use the original context for continuous listening - no per-iteration timeout + // The SSE connection itself will detect disconnections via the underlying HTTP transport, + // and the context cancellation will propagate from the parent to stop listening gracefully. + // We don't add an artificial timeout here because: + // 1. Persistent SSE connections are meant to stay open indefinitely + // 2. Network-level timeouts and keep-alives handle connection health + // 3. Context cancellation (user-initiated or system shutdown) provides clean shutdown + err := c.createGETConnectionToServer(ctx) if errors.Is(err, ErrGetMethodNotAllowed) { // server does not support listening c.logger.Errorf("server does not support listening") diff --git a/client/transport/utils.go b/client/transport/utils.go index d36d74722..2d0e847da 100644 --- a/client/transport/utils.go +++ b/client/transport/utils.go @@ -23,4 +23,4 @@ func NewJSONRPCResultResponse(id mcp.RequestId, result json.RawMessage) *JSONRPC ID: id, Result: result, } -} \ No newline at end of file +} diff --git a/e2e/sampling_http_test.go b/e2e/sampling_http_test.go new file mode 100644 index 000000000..fe1914ab3 --- /dev/null +++ b/e2e/sampling_http_test.go @@ -0,0 +1,549 @@ +package e2e + +import ( + "context" + "fmt" + "log" + "net" + "net/http" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// TestSamplingHandler implements client.SamplingHandler for e2e testing +type TestSamplingHandler struct { + responses map[string]string + mutex sync.RWMutex +} + +func NewTestSamplingHandler() *TestSamplingHandler { + return &TestSamplingHandler{ + responses: make(map[string]string), + } +} + +func (h *TestSamplingHandler) SetResponse(question, response string) { + h.mutex.Lock() + defer h.mutex.Unlock() + h.responses[question] = response +} + +func (h *TestSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + log.Printf("[TestSamplingHandler] *** CLIENT RECEIVED SAMPLING REQUEST *** with %d messages", len(request.Messages)) + + if len(request.Messages) == 0 { + log.Printf("[TestSamplingHandler] ERROR: no messages provided") + return nil, fmt.Errorf("no messages provided") + } + + // Get the last user message + lastMessage := request.Messages[len(request.Messages)-1] + userText := "" + if textContent, ok := lastMessage.Content.(mcp.TextContent); ok { + userText = textContent.Text + } + + log.Printf("[TestSamplingHandler] CLIENT processing user text: '%s'", userText) + + h.mutex.RLock() + response, exists := h.responses[userText] + h.mutex.RUnlock() + + if !exists { + response = fmt.Sprintf("Test response to: '%s'", userText) + } + + log.Printf("[TestSamplingHandler] CLIENT Question: %s -> Response: %s", userText, response) + + result := &mcp.CreateMessageResult{ + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: response, + }, + }, + Model: "test-model-v1", + StopReason: "endTurn", + } + + log.Printf("[TestSamplingHandler] *** CLIENT SENDING SAMPLING RESPONSE *** with model: %s", result.Model) + return result, nil +} + +// getAvailablePort finds an available port for testing +func getAvailablePort() (int, error) { + listener, err := net.Listen("tcp", ":0") + if err != nil { + return 0, err + } + defer listener.Close() + return listener.Addr().(*net.TCPAddr).Port, nil +} + +func TestSamplingHTTPE2E(t *testing.T) { + if testing.Short() { + t.Skip("Skipping e2e test in short mode") + } + + log.Printf("[E2E Test] Starting Sampling HTTP E2E Test") + + // Get available port for HTTP server + port, err := getAvailablePort() + if err != nil { + t.Fatalf("Failed to get available port: %v", err) + } + + serverURL := fmt.Sprintf("http://localhost:%d", port) + serverAddr := fmt.Sprintf(":%d", port) + + // Create test sampling handler with predefined responses + samplingHandler := NewTestSamplingHandler() + samplingHandler.SetResponse("What is the capital of France?", "Paris is the capital of France.") + samplingHandler.SetResponse("What is 2+2?", "2+2 equals 4.") + + // Create MCP server with sampling capability + mcpServer := server.NewMCPServer("e2e-sampling-server", "1.0.0") + mcpServer.EnableSampling() + + // Add tool that uses sampling - this is the "question" tool + mcpServer.AddTool(mcp.Tool{ + Name: "question", + Description: "Ask a question and get an answer using sampling", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "question": map[string]any{ + "type": "string", + "description": "The question to ask", + }, + }, + Required: []string{"question"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + question, err := request.RequireString("question") + if err != nil { + return nil, err + } + + log.Printf("[E2E Test] Tool handler processing question: %s", question) + + // Create sampling request to send back to client + samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: question, + }, + }, + }, + MaxTokens: 500, + Temperature: 0.7, + }, + } + + log.Printf("[E2E Test] *** SERVER SENDING SAMPLING REQUEST *** for question: %s", question) + + // Request sampling from client with timeout + samplingCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + serverFromCtx := server.ServerFromContext(ctx) + if serverFromCtx == nil { + log.Printf("[E2E Test] ERROR: No server in context") + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "Error: No server in context", + }, + }, + IsError: true, + }, nil + } + + log.Printf("[E2E Test] SERVER calling RequestSampling...") + + // Check what session we have + session := server.ClientSessionFromContext(ctx) + if session != nil { + log.Printf("[E2E Test] SERVER session ID: %s", session.SessionID()) + } else { + log.Printf("[E2E Test] SERVER ERROR: No session in context") + } + + // This creates the sampling request to the client + result, err := serverFromCtx.RequestSampling(samplingCtx, samplingRequest) + if err != nil { + log.Printf("[E2E Test] *** SERVER SAMPLING REQUEST FAILED ***: %v", err) + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Error requesting sampling: %v", err), + }, + }, + IsError: true, + }, nil + } + + log.Printf("[E2E Test] *** SERVER RECEIVED SAMPLING RESPONSE ***, model: %s", result.Model) + + // Extract response text + var responseText string + if textContent, ok := result.Content.(mcp.TextContent); ok { + responseText = textContent.Text + } else { + responseText = fmt.Sprintf("%v", result.Content) + } + + // Return sampling response as the question tool response + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Answer: %s (Model: %s)", responseText, result.Model), + }, + }, + }, nil + }) + + // Start HTTP server + httpServer := server.NewStreamableHTTPServer(mcpServer) + + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + log.Printf("[E2E Test] Starting HTTP server on %s", serverAddr) + if err := httpServer.Start(serverAddr); err != nil && err != http.ErrServerClosed { + log.Printf("[E2E Test] Server error: %v", err) + } + }() + + // Wait for server to start and be ready + time.Sleep(2 * time.Second) + + // Create HTTP transport for client connection to server - enable continuous listening for sampling + httpTransport, err := transport.NewStreamableHTTP(serverURL+"/mcp", transport.WithContinuousListening()) + if err != nil { + t.Fatalf("Failed to create HTTP transport: %v", err) + } + defer httpTransport.Close() + + log.Printf("[E2E Test] HTTP transport created, will connect to: %s", serverURL+"/mcp") + + // Create HTTP client with sampling handler - this is the actual client connecting over HTTP + httpClient := client.NewClient(httpTransport, client.WithSamplingHandler(samplingHandler)) + defer httpClient.Close() + + // Start the HTTP client + ctx := context.Background() + if err := httpClient.Start(ctx); err != nil { + t.Fatalf("Failed to start HTTP client: %v", err) + } + + // Initialize MCP session over HTTP + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "e2e-http-test-client", + Version: "1.0.0", + }, + Capabilities: mcp.ClientCapabilities{ + // Sampling capability will be automatically added by WithSamplingHandler + }, + }, + } + + initResponse, err := httpClient.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Failed to initialize HTTP session: %v", err) + } + + log.Printf("[E2E Test] HTTP session initialized. Server capabilities: %+v", initResponse.Capabilities) + log.Printf("[E2E Test] Client session ID: %s", httpTransport.GetSessionId()) + + // Verify sampling capability is supported + if initResponse.Capabilities.Sampling == nil { + t.Fatal("Server should support sampling capability") + } + + // Wait a bit more for continuous listening to establish + log.Printf("[E2E Test] Waiting for continuous listening connection to be established...") + time.Sleep(3 * time.Second) + log.Printf("[E2E Test] Continuous listening should now be established, proceeding with tests...") + + // Test Case 1: HTTP client calls "question" tool - complete e2e flow + t.Run("HTTPClientCallsQuestionTool", func(t *testing.T) { + log.Printf("[E2E Test] HTTP client calling 'question' tool") + + // Client calls "question" tool over HTTP + result, err := httpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "question", + Arguments: map[string]any{ + "question": "What is the capital of France?", + }, + }, + }) + if err != nil { + t.Fatalf("Failed to call question tool over HTTP: %v", err) + } + + if result.IsError { + t.Fatalf("Question tool returned error: %v", result.Content) + } + + if len(result.Content) == 0 { + t.Fatal("Question tool result should have content") + } + + // Verify response content + textContent, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatalf("Expected TextContent, got %T", result.Content[0]) + } + + responseText := textContent.Text + log.Printf("[E2E Test] Question tool response over HTTP: %s", responseText) + + // Verify the complete flow worked: client->server->sampling_request->client->sampling_response->server->tool_response->client + if !strings.Contains(responseText, "Paris is the capital of France") { + t.Errorf("Expected response to contain 'Paris is the capital of France', got: %s", responseText) + } + + if !strings.Contains(responseText, "test-model-v1") { + t.Errorf("Expected response to contain model name, got: %s", responseText) + } + }) + + // Test Case 2: Multiple HTTP sampling requests + t.Run("MultipleHTTPSamplingRequests", func(t *testing.T) { + questions := []string{ + "What is 2+2?", + "What is the capital of France?", + } + + expectedAnswers := []string{ + "2+2 equals 4", + "Paris is the capital of France", + } + + for i, question := range questions { + log.Printf("[E2E Test] HTTP client calling question tool with: %s", question) + result, err := httpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "question", + Arguments: map[string]any{ + "question": question, + }, + }, + }) + if err != nil { + t.Fatalf("Failed to call question tool for '%s': %v", question, err) + } + + if result.IsError { + t.Fatalf("Question tool returned error for '%s': %v", question, result.Content) + } + + if len(result.Content) == 0 { + t.Fatal("Question tool result should have content") + } + + textContent, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatalf("Expected TextContent, got %T", result.Content[0]) + } + + responseText := textContent.Text + log.Printf("[E2E Test] HTTP Response for '%s': %s", question, responseText) + + if !strings.Contains(responseText, expectedAnswers[i]) { + t.Errorf("Expected response to contain '%s', got: %s", expectedAnswers[i], responseText) + } + } + }) + + // Cleanup + httpClient.Close() + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := httpServer.Shutdown(shutdownCtx); err != nil { + t.Logf("Server shutdown error: %v", err) + } + + <-serverDone + log.Printf("[E2E Test] HTTP E2E test completed successfully") +} + +// TestSamplingHTTPBasic creates a basic HTTP sampling test +func TestSamplingHTTPBasic(t *testing.T) { + if testing.Short() { + t.Skip("Skipping HTTP test in short mode") + } + + log.Printf("[E2E HTTP Test] Starting basic HTTP sampling test") + + // Get available port + port, err := getAvailablePort() + if err != nil { + t.Fatalf("Failed to get available port: %v", err) + } + + serverURL := fmt.Sprintf("http://localhost:%d", port) + serverAddr := fmt.Sprintf(":%d", port) + + // Create MCP server with sampling capability + mcpServer := server.NewMCPServer("e2e-http-server", "1.0.0") + mcpServer.EnableSampling() + + // Add simple echo tool (no sampling needed) + mcpServer.AddTool(mcp.Tool{ + Name: "echo", + Description: "Echo back the input message", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "message": map[string]any{ + "type": "string", + "description": "The message to echo back", + }, + }, + Required: []string{"message"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + message := request.GetString("message", "") + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Echo: %s", message), + }, + }, + }, nil + }) + + // Start HTTP server + httpServer := server.NewStreamableHTTPServer(mcpServer) + + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + log.Printf("[E2E HTTP Test] Starting server on %s", serverAddr) + if err := httpServer.Start(serverAddr); err != nil && err != http.ErrServerClosed { + log.Printf("[E2E HTTP Test] Server error: %v", err) + } + }() + + // Wait for server to start + time.Sleep(500 * time.Millisecond) + + // Create HTTP transport (no continuous listening for simple test) + httpTransport, err := transport.NewStreamableHTTP(serverURL + "/mcp") + if err != nil { + t.Fatalf("Failed to create HTTP transport: %v", err) + } + defer httpTransport.Close() + + // Create simple client (no sampling handler) + mcpClient := client.NewClient(httpTransport) + + // Start client + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + err = mcpClient.Start(ctx) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize MCP session + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: "e2e-http-test-client", + Version: "1.0.0", + }, + }, + } + + initResponse, err := mcpClient.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Failed to initialize MCP session: %v", err) + } + + log.Printf("[E2E HTTP Test] Session initialized. Server capabilities: %+v", initResponse.Capabilities) + + // Test basic tool call over HTTP + result, err := mcpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "echo", + Arguments: map[string]any{ + "message": "Hello HTTP MCP!", + }, + }, + }) + if err != nil { + t.Fatalf("Failed to call echo tool: %v", err) + } + + if result.IsError { + t.Fatalf("Tool returned error: %v", result.Content) + } + + if len(result.Content) == 0 { + t.Fatal("Tool result should have content") + } + + textContent, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatalf("Expected TextContent, got %T", result.Content[0]) + } + + responseText := textContent.Text + log.Printf("[E2E HTTP Test] Tool response: %s", responseText) + + if !strings.Contains(responseText, "Hello HTTP MCP!") { + t.Errorf("Expected response to contain 'Hello HTTP MCP!', got: %s", responseText) + } + + // Cleanup + mcpClient.Close() + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := httpServer.Shutdown(shutdownCtx); err != nil { + t.Logf("Server shutdown error: %v", err) + } + + <-serverDone + log.Printf("[E2E HTTP Test] HTTP test completed successfully") +} + +// TestMain sets up test environment +func TestMain(m *testing.M) { + // Enable debug logging for better visibility during tests + log.SetOutput(os.Stdout) + log.SetFlags(log.LstdFlags | log.Lshortfile) + + code := m.Run() + os.Exit(code) +} diff --git a/examples/sampling_server/main.go b/examples/sampling_server/main.go index ea887c588..a2ca13baf 100644 --- a/examples/sampling_server/main.go +++ b/examples/sampling_server/main.go @@ -83,7 +83,7 @@ func main() { Content: []mcp.Content{ mcp.TextContent{ Type: "text", - Text: fmt.Sprintf("LLM Response (model: %s): %s", result.Model, getTextFromContent(result.Content)), + Text: fmt.Sprintf("LLM Response (model: %s): %s", result.Model, mcp.GetTextFromContent(result.Content)), }, }, }, nil @@ -125,21 +125,3 @@ func main() { log.Fatalf("Server error: %v", err) } } - -// Helper function to extract text from content -func getTextFromContent(content any) string { - switch c := content.(type) { - case mcp.TextContent: - return c.Text - case map[string]any: - // Handle JSON unmarshaled content - if text, ok := c["text"].(string); ok { - return text - } - return fmt.Sprintf("%v", content) - case string: - return c - default: - return fmt.Sprintf("%v", content) - } -} diff --git a/mcp/errors.go b/mcp/errors.go index c29024282..aead24744 100644 --- a/mcp/errors.go +++ b/mcp/errors.go @@ -83,4 +83,3 @@ func (e *JSONRPCErrorDetails) AsError() error { return err } - diff --git a/mcp/errors_additional_test.go b/mcp/errors_additional_test.go new file mode 100644 index 000000000..b8a43fa02 --- /dev/null +++ b/mcp/errors_additional_test.go @@ -0,0 +1,218 @@ +package mcp + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUnsupportedProtocolVersionError_Is(t *testing.T) { + err1 := UnsupportedProtocolVersionError{Version: "1.0"} + err2 := UnsupportedProtocolVersionError{Version: "2.0"} + + t.Run("matches same type", func(t *testing.T) { + assert.True(t, err1.Is(UnsupportedProtocolVersionError{})) + assert.True(t, err2.Is(UnsupportedProtocolVersionError{Version: "different"})) + }) + + t.Run("does not match different type", func(t *testing.T) { + assert.False(t, err1.Is(errors.New("different error"))) + assert.False(t, err1.Is(ErrMethodNotFound)) + }) +} + +func TestIsUnsupportedProtocolVersion(t *testing.T) { + t.Run("returns true for UnsupportedProtocolVersionError", func(t *testing.T) { + err := UnsupportedProtocolVersionError{Version: "1.0"} + assert.True(t, IsUnsupportedProtocolVersion(err)) + }) + + t.Run("returns false for other errors", func(t *testing.T) { + assert.False(t, IsUnsupportedProtocolVersion(errors.New("other error"))) + assert.False(t, IsUnsupportedProtocolVersion(ErrMethodNotFound)) + }) + + t.Run("returns false for wrapped errors", func(t *testing.T) { + // Create a wrapped error - IsUnsupportedProtocolVersion checks direct type, not wrapped + err := UnsupportedProtocolVersionError{Version: "1.0"} + wrapped := errors.New("wrapped: " + err.Error()) + assert.False(t, IsUnsupportedProtocolVersion(wrapped)) + }) +} + +func TestJSONRPCErrorDetails_AsError_EmptyMessage(t *testing.T) { + t.Run("with empty message", func(t *testing.T) { + details := JSONRPCErrorDetails{ + Code: METHOD_NOT_FOUND, + Message: "", + } + + err := details.AsError() + // Should return the sentinel error when message is empty + assert.Equal(t, ErrMethodNotFound, err) + }) + + t.Run("with message matching sentinel", func(t *testing.T) { + details := JSONRPCErrorDetails{ + Code: PARSE_ERROR, + Message: "parse error", + } + + err := details.AsError() + assert.Equal(t, ErrParseError, err) + }) +} + +func TestJSONRPCErrorDetails_AsError_AllCodes(t *testing.T) { + tests := []struct { + name string + code int + message string + sentinel error + shouldMatch bool + }{ + { + name: "PARSE_ERROR", + code: PARSE_ERROR, + message: "custom parse error", + sentinel: ErrParseError, + shouldMatch: true, + }, + { + name: "INVALID_REQUEST", + code: INVALID_REQUEST, + message: "custom invalid request", + sentinel: ErrInvalidRequest, + shouldMatch: true, + }, + { + name: "METHOD_NOT_FOUND", + code: METHOD_NOT_FOUND, + message: "custom method not found", + sentinel: ErrMethodNotFound, + shouldMatch: true, + }, + { + name: "INVALID_PARAMS", + code: INVALID_PARAMS, + message: "custom invalid params", + sentinel: ErrInvalidParams, + shouldMatch: true, + }, + { + name: "INTERNAL_ERROR", + code: INTERNAL_ERROR, + message: "custom internal error", + sentinel: ErrInternalError, + shouldMatch: true, + }, + { + name: "REQUEST_INTERRUPTED", + code: REQUEST_INTERRUPTED, + message: "custom interrupted", + sentinel: ErrRequestInterrupted, + shouldMatch: true, + }, + { + name: "RESOURCE_NOT_FOUND", + code: RESOURCE_NOT_FOUND, + message: "custom resource not found", + sentinel: ErrResourceNotFound, + shouldMatch: true, + }, + { + name: "unknown code", + code: -99999, + message: "unknown error", + sentinel: nil, + shouldMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + details := JSONRPCErrorDetails{ + Code: tt.code, + Message: tt.message, + } + + err := details.AsError() + require.NotNil(t, err) + + if tt.shouldMatch { + assert.True(t, errors.Is(err, tt.sentinel)) + // Custom message should be wrapped + assert.Contains(t, err.Error(), tt.message) + } else { + // Unknown codes just return the message + assert.Equal(t, tt.message, err.Error()) + } + }) + } +} + +func TestErrorChaining_WithAs(t *testing.T) { + t.Run("errors.As does not work with wrapped sentinel", func(t *testing.T) { + details := &JSONRPCErrorDetails{ + Code: METHOD_NOT_FOUND, + Message: "Method 'foo' not found", + } + + err := details.AsError() + + // Since we wrap with fmt.Errorf, errors.As won't find the exact type + // but errors.Is will work because of the %w verb + assert.True(t, errors.Is(err, ErrMethodNotFound)) + }) +} + +func TestSentinelErrors_Comparison(t *testing.T) { + // Ensure all sentinel errors are distinct + sentinels := []error{ + ErrParseError, + ErrInvalidRequest, + ErrMethodNotFound, + ErrInvalidParams, + ErrInternalError, + ErrRequestInterrupted, + ErrResourceNotFound, + } + + for i, err1 := range sentinels { + for j, err2 := range sentinels { + if i == j { + assert.True(t, errors.Is(err1, err2), "Same sentinel should match itself") + } else { + assert.False(t, errors.Is(err1, err2), "Different sentinels should not match") + } + } + } +} + +func TestUnsupportedProtocolVersionError_Error(t *testing.T) { + err := UnsupportedProtocolVersionError{Version: "3.0"} + assert.Equal(t, `unsupported protocol version: "3.0"`, err.Error()) +} + +func TestJSONRPCErrorDetails_WithData(t *testing.T) { + details := &JSONRPCErrorDetails{ + Code: INVALID_PARAMS, + Message: "Invalid parameter 'foo'", + Data: map[string]any{ + "param": "foo", + "expected": "string", + "got": "number", + }, + } + + err := details.AsError() + + // The error should still wrap properly + assert.True(t, errors.Is(err, ErrInvalidParams)) + assert.Contains(t, err.Error(), "Invalid parameter 'foo'") + + // Data is not included in the error string, but it's preserved in the details + assert.NotNil(t, details.Data) +} diff --git a/mcp/errors_test.go b/mcp/errors_test.go index 22556da12..00ce4dc53 100644 --- a/mcp/errors_test.go +++ b/mcp/errors_test.go @@ -112,7 +112,6 @@ func TestJSONRPCErrorDetails_AsError_WithPointer(t *testing.T) { require.True(t, errors.Is(result, ErrMethodNotFound)) } - func TestSentinelErrors(t *testing.T) { t.Parallel() diff --git a/mcp/prompts_test.go b/mcp/prompts_test.go new file mode 100644 index 000000000..0c35ea311 --- /dev/null +++ b/mcp/prompts_test.go @@ -0,0 +1,287 @@ +package mcp + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewPrompt(t *testing.T) { + tests := []struct { + name string + prompt Prompt + expected Prompt + }{ + { + name: "basic prompt", + prompt: NewPrompt("test-prompt"), + expected: Prompt{ + Name: "test-prompt", + }, + }, + { + name: "prompt with description", + prompt: NewPrompt("test-prompt", + WithPromptDescription("A test prompt")), + expected: Prompt{ + Name: "test-prompt", + Description: "A test prompt", + }, + }, + { + name: "prompt with single argument", + prompt: NewPrompt("test-prompt", + WithPromptDescription("Test prompt with arg"), + WithArgument("query", + ArgumentDescription("Search query"), + RequiredArgument())), + expected: Prompt{ + Name: "test-prompt", + Description: "Test prompt with arg", + Arguments: []PromptArgument{ + { + Name: "query", + Description: "Search query", + Required: true, + }, + }, + }, + }, + { + name: "prompt with multiple arguments", + prompt: NewPrompt("search-prompt", + WithPromptDescription("Search with filters"), + WithArgument("query", + ArgumentDescription("Search query"), + RequiredArgument()), + WithArgument("limit", + ArgumentDescription("Max results")), + WithArgument("offset", + ArgumentDescription("Starting position"))), + expected: Prompt{ + Name: "search-prompt", + Description: "Search with filters", + Arguments: []PromptArgument{ + { + Name: "query", + Description: "Search query", + Required: true, + }, + { + Name: "limit", + Description: "Max results", + Required: false, + }, + { + Name: "offset", + Description: "Starting position", + Required: false, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected.Name, tt.prompt.Name) + assert.Equal(t, tt.expected.Description, tt.prompt.Description) + assert.Equal(t, tt.expected.Arguments, tt.prompt.Arguments) + }) + } +} + +func TestPromptGetName(t *testing.T) { + prompt := NewPrompt("my-prompt", + WithPromptDescription("Test prompt")) + + assert.Equal(t, "my-prompt", prompt.GetName()) +} + +func TestPromptJSONMarshaling(t *testing.T) { + tests := []struct { + name string + prompt Prompt + }{ + { + name: "simple prompt", + prompt: NewPrompt("simple", + WithPromptDescription("A simple prompt")), + }, + { + name: "prompt with required argument", + prompt: NewPrompt("with-arg", + WithPromptDescription("Prompt with argument"), + WithArgument("name", + ArgumentDescription("User name"), + RequiredArgument())), + }, + { + name: "prompt with optional arguments", + prompt: NewPrompt("optional-args", + WithArgument("field1", ArgumentDescription("First field")), + WithArgument("field2", ArgumentDescription("Second field"))), + }, + { + name: "complex prompt", + prompt: NewPrompt("complex", + WithPromptDescription("Complex prompt template"), + WithArgument("query", ArgumentDescription("Search query"), RequiredArgument()), + WithArgument("limit", ArgumentDescription("Result limit")), + WithArgument("sort", ArgumentDescription("Sort order"), RequiredArgument())), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Marshal to JSON + data, err := json.Marshal(tt.prompt) + require.NoError(t, err) + + // Unmarshal back + var unmarshaled Prompt + err = json.Unmarshal(data, &unmarshaled) + require.NoError(t, err) + + // Compare + assert.Equal(t, tt.prompt.Name, unmarshaled.Name) + assert.Equal(t, tt.prompt.Description, unmarshaled.Description) + assert.Equal(t, tt.prompt.Arguments, unmarshaled.Arguments) + }) + } +} + +func TestWithArgument(t *testing.T) { + t.Run("single argument", func(t *testing.T) { + prompt := NewPrompt("test") + opt := WithArgument("arg1", ArgumentDescription("First arg")) + opt(&prompt) + + require.Len(t, prompt.Arguments, 1) + assert.Equal(t, "arg1", prompt.Arguments[0].Name) + assert.Equal(t, "First arg", prompt.Arguments[0].Description) + assert.False(t, prompt.Arguments[0].Required) + }) + + t.Run("multiple arguments", func(t *testing.T) { + prompt := NewPrompt("test") + opt1 := WithArgument("arg1", RequiredArgument()) + opt2 := WithArgument("arg2", ArgumentDescription("Second")) + + opt1(&prompt) + opt2(&prompt) + + require.Len(t, prompt.Arguments, 2) + assert.Equal(t, "arg1", prompt.Arguments[0].Name) + assert.True(t, prompt.Arguments[0].Required) + assert.Equal(t, "arg2", prompt.Arguments[1].Name) + assert.Equal(t, "Second", prompt.Arguments[1].Description) + }) + + t.Run("argument with no options", func(t *testing.T) { + prompt := NewPrompt("test") + opt := WithArgument("simple") + opt(&prompt) + + require.Len(t, prompt.Arguments, 1) + assert.Equal(t, "simple", prompt.Arguments[0].Name) + assert.Empty(t, prompt.Arguments[0].Description) + assert.False(t, prompt.Arguments[0].Required) + }) +} + +func TestArgumentDescription(t *testing.T) { + arg := PromptArgument{} + opt := ArgumentDescription("Test description") + opt(&arg) + + assert.Equal(t, "Test description", arg.Description) +} + +func TestRequiredArgument(t *testing.T) { + arg := PromptArgument{} + opt := RequiredArgument() + opt(&arg) + + assert.True(t, arg.Required) +} + +func TestWithPromptDescription(t *testing.T) { + prompt := Prompt{} + opt := WithPromptDescription("Test prompt description") + opt(&prompt) + + assert.Equal(t, "Test prompt description", prompt.Description) +} + +func TestPromptMessageCreation(t *testing.T) { + tests := []struct { + name string + role Role + content Content + }{ + { + name: "user text message", + role: RoleUser, + content: NewTextContent("Hello"), + }, + { + name: "assistant text message", + role: RoleAssistant, + content: NewTextContent("Hi there"), + }, + { + name: "user image message", + role: RoleUser, + content: NewImageContent("base64data", "image/png"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msg := NewPromptMessage(tt.role, tt.content) + + assert.Equal(t, tt.role, msg.Role) + assert.Equal(t, tt.content, msg.Content) + }) + } +} + +func TestPromptJSONStructure(t *testing.T) { + prompt := NewPrompt("test-prompt", + WithPromptDescription("Test description"), + WithArgument("arg1", + ArgumentDescription("First argument"), + RequiredArgument()), + WithArgument("arg2", + ArgumentDescription("Second argument"))) + + data, err := json.Marshal(prompt) + require.NoError(t, err) + + var result map[string]any + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + // Verify structure + assert.Equal(t, "test-prompt", result["name"]) + assert.Equal(t, "Test description", result["description"]) + + args, ok := result["arguments"].([]any) + require.True(t, ok) + require.Len(t, args, 2) + + arg1, ok := args[0].(map[string]any) + require.True(t, ok) + assert.Equal(t, "arg1", arg1["name"]) + assert.Equal(t, "First argument", arg1["description"]) + assert.Equal(t, true, arg1["required"]) + + arg2, ok := args[1].(map[string]any) + require.True(t, ok) + assert.Equal(t, "arg2", arg2["name"]) + assert.Equal(t, "Second argument", arg2["description"]) + // Optional arguments may not have "required" field or it's false +} diff --git a/mcp/resources_test.go b/mcp/resources_test.go new file mode 100644 index 000000000..b4439fb7d --- /dev/null +++ b/mcp/resources_test.go @@ -0,0 +1,315 @@ +package mcp + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewResource(t *testing.T) { + tests := []struct { + name string + resource Resource + expected Resource + }{ + { + name: "basic resource", + resource: NewResource("file:///test.txt", "test.txt"), + expected: Resource{ + URI: "file:///test.txt", + Name: "test.txt", + }, + }, + { + name: "resource with description", + resource: NewResource("file:///doc.md", "doc.md", + WithResourceDescription("A markdown document")), + expected: Resource{ + URI: "file:///doc.md", + Name: "doc.md", + Description: "A markdown document", + }, + }, + { + name: "resource with MIME type", + resource: NewResource("file:///image.png", "image.png", + WithMIMEType("image/png")), + expected: Resource{ + URI: "file:///image.png", + Name: "image.png", + MIMEType: "image/png", + }, + }, + { + name: "resource with annotations", + resource: NewResource("file:///data.json", "data.json", + WithAnnotations([]Role{RoleUser, RoleAssistant}, 1.5)), + expected: Resource{ + URI: "file:///data.json", + Name: "data.json", + Annotated: Annotated{ + Annotations: &Annotations{ + Audience: []Role{RoleUser, RoleAssistant}, + Priority: 1.5, + }, + }, + }, + }, + { + name: "resource with all options", + resource: NewResource("file:///complete.txt", "complete.txt", + WithResourceDescription("Complete resource"), + WithMIMEType("text/plain"), + WithAnnotations([]Role{RoleUser}, 2.0)), + expected: Resource{ + URI: "file:///complete.txt", + Name: "complete.txt", + Description: "Complete resource", + MIMEType: "text/plain", + Annotated: Annotated{ + Annotations: &Annotations{ + Audience: []Role{RoleUser}, + Priority: 2.0, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected.URI, tt.resource.URI) + assert.Equal(t, tt.expected.Name, tt.resource.Name) + assert.Equal(t, tt.expected.Description, tt.resource.Description) + assert.Equal(t, tt.expected.MIMEType, tt.resource.MIMEType) + assert.Equal(t, tt.expected.Annotations, tt.resource.Annotations) + }) + } +} + +func TestNewResourceTemplate(t *testing.T) { + tests := []struct { + name string + template ResourceTemplate + validate func(t *testing.T, template ResourceTemplate) + }{ + { + name: "basic template", + template: NewResourceTemplate("file:///{path}", "files"), + validate: func(t *testing.T, template ResourceTemplate) { + assert.NotNil(t, template.URITemplate) + assert.Equal(t, "files", template.Name) + }, + }, + { + name: "template with description", + template: NewResourceTemplate("file:///{dir}/{file}", "directory-files", + WithTemplateDescription("Files in directories")), + validate: func(t *testing.T, template ResourceTemplate) { + assert.Equal(t, "directory-files", template.Name) + assert.Equal(t, "Files in directories", template.Description) + }, + }, + { + name: "template with MIME type", + template: NewResourceTemplate("file:///{name}.txt", "text-files", + WithTemplateMIMEType("text/plain")), + validate: func(t *testing.T, template ResourceTemplate) { + assert.Equal(t, "text-files", template.Name) + assert.Equal(t, "text/plain", template.MIMEType) + }, + }, + { + name: "template with annotations", + template: NewResourceTemplate("file:///{id}", "resources", + WithTemplateAnnotations([]Role{RoleUser}, 1.0)), + validate: func(t *testing.T, template ResourceTemplate) { + assert.Equal(t, "resources", template.Name) + require.NotNil(t, template.Annotations) + assert.Equal(t, []Role{RoleUser}, template.Annotations.Audience) + assert.Equal(t, 1.0, template.Annotations.Priority) + }, + }, + { + name: "template with all options", + template: NewResourceTemplate("api:///{version}/{resource}", "api-resources", + WithTemplateDescription("API resources"), + WithTemplateMIMEType("application/json"), + WithTemplateAnnotations([]Role{RoleUser, RoleAssistant}, 2.5)), + validate: func(t *testing.T, template ResourceTemplate) { + assert.Equal(t, "api-resources", template.Name) + assert.Equal(t, "API resources", template.Description) + assert.Equal(t, "application/json", template.MIMEType) + require.NotNil(t, template.Annotations) + assert.Equal(t, []Role{RoleUser, RoleAssistant}, template.Annotations.Audience) + assert.Equal(t, 2.5, template.Annotations.Priority) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.validate(t, tt.template) + }) + } +} + +func TestWithResourceDescription(t *testing.T) { + resource := Resource{} + opt := WithResourceDescription("Test resource") + opt(&resource) + + assert.Equal(t, "Test resource", resource.Description) +} + +func TestWithMIMEType(t *testing.T) { + resource := Resource{} + opt := WithMIMEType("application/json") + opt(&resource) + + assert.Equal(t, "application/json", resource.MIMEType) +} + +func TestWithAnnotations(t *testing.T) { + tests := []struct { + name string + audience []Role + priority float64 + }{ + { + name: "user audience", + audience: []Role{RoleUser}, + priority: 1.0, + }, + { + name: "multiple audiences", + audience: []Role{RoleUser, RoleAssistant}, + priority: 2.5, + }, + { + name: "empty audience", + audience: []Role{}, + priority: 0.5, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resource := Resource{} + opt := WithAnnotations(tt.audience, tt.priority) + opt(&resource) + + require.NotNil(t, resource.Annotations) + assert.Equal(t, tt.audience, resource.Annotations.Audience) + assert.Equal(t, tt.priority, resource.Annotations.Priority) + }) + } +} + +func TestWithTemplateDescription(t *testing.T) { + template := ResourceTemplate{} + opt := WithTemplateDescription("Test template") + opt(&template) + + assert.Equal(t, "Test template", template.Description) +} + +func TestWithTemplateMIMEType(t *testing.T) { + template := ResourceTemplate{} + opt := WithTemplateMIMEType("text/html") + opt(&template) + + assert.Equal(t, "text/html", template.MIMEType) +} + +func TestWithTemplateAnnotations(t *testing.T) { + tests := []struct { + name string + audience []Role + priority float64 + }{ + { + name: "assistant audience", + audience: []Role{RoleAssistant}, + priority: 3.0, + }, + { + name: "both audiences", + audience: []Role{RoleUser, RoleAssistant}, + priority: 1.5, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + template := ResourceTemplate{} + opt := WithTemplateAnnotations(tt.audience, tt.priority) + opt(&template) + + require.NotNil(t, template.Annotations) + assert.Equal(t, tt.audience, template.Annotations.Audience) + assert.Equal(t, tt.priority, template.Annotations.Priority) + }) + } +} + +func TestResourceJSONMarshaling(t *testing.T) { + resource := NewResource("file:///test.txt", "test.txt", + WithResourceDescription("Test file"), + WithMIMEType("text/plain"), + WithAnnotations([]Role{RoleUser}, 1.0)) + + data, err := json.Marshal(resource) + require.NoError(t, err) + + var unmarshaled Resource + err = json.Unmarshal(data, &unmarshaled) + require.NoError(t, err) + + assert.Equal(t, resource.URI, unmarshaled.URI) + assert.Equal(t, resource.Name, unmarshaled.Name) + assert.Equal(t, resource.Description, unmarshaled.Description) + assert.Equal(t, resource.MIMEType, unmarshaled.MIMEType) +} + +func TestResourceTemplateJSONMarshaling(t *testing.T) { + template := NewResourceTemplate("file:///{path}", "files", + WithTemplateDescription("File resources"), + WithTemplateMIMEType("text/plain")) + + data, err := json.Marshal(template) + require.NoError(t, err) + + var unmarshaled ResourceTemplate + err = json.Unmarshal(data, &unmarshaled) + require.NoError(t, err) + + assert.Equal(t, template.Name, unmarshaled.Name) + assert.Equal(t, template.Description, unmarshaled.Description) + assert.Equal(t, template.MIMEType, unmarshaled.MIMEType) + assert.NotNil(t, unmarshaled.URITemplate) +} + +func TestAnnotationsCreationFromNil(t *testing.T) { + // Test that annotations are created when nil + resource := Resource{} + opt := WithAnnotations([]Role{RoleUser}, 1.0) + opt(&resource) + + require.NotNil(t, resource.Annotations) + assert.Equal(t, []Role{RoleUser}, resource.Annotations.Audience) + assert.Equal(t, 1.0, resource.Annotations.Priority) +} + +func TestTemplateAnnotationsCreationFromNil(t *testing.T) { + // Test that annotations are created when nil + template := ResourceTemplate{} + opt := WithTemplateAnnotations([]Role{RoleAssistant}, 2.0) + opt(&template) + + require.NotNil(t, template.Annotations) + assert.Equal(t, []Role{RoleAssistant}, template.Annotations.Audience) + assert.Equal(t, 2.0, template.Annotations.Priority) +} diff --git a/mcp/tools_additional_test.go b/mcp/tools_additional_test.go new file mode 100644 index 000000000..4ce723d9e --- /dev/null +++ b/mcp/tools_additional_test.go @@ -0,0 +1,463 @@ +package mcp + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test edge cases for CallToolRequest methods + +func TestCallToolRequest_GetArgumentsWithNilArguments(t *testing.T) { + req := CallToolRequest{} + req.Params.Name = "test-tool" + req.Params.Arguments = nil + + args := req.GetArguments() + assert.Nil(t, args) +} + +func TestCallToolRequest_BindArgumentsWithInvalidTarget(t *testing.T) { + req := CallToolRequest{} + req.Params.Arguments = map[string]any{"key": "value"} + + t.Run("nil target", func(t *testing.T) { + err := req.BindArguments(nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "non-nil pointer") + }) + + t.Run("non-pointer target", func(t *testing.T) { + var target struct{ Key string } + err := req.BindArguments(target) + assert.Error(t, err) + assert.Contains(t, err.Error(), "non-nil pointer") + }) +} + +func TestCallToolRequest_TypeConversionEdgeCases(t *testing.T) { + req := CallToolRequest{} + req.Params.Arguments = map[string]any{ + "string_as_int": "not a number", + "string_as_float": "not a float", + "string_as_bool": "maybe", + "int_as_bool": 5, + "float_as_bool": 3.14, + "object_val": map[string]any{"nested": "value"}, + } + + t.Run("GetInt with invalid string", func(t *testing.T) { + result := req.GetInt("string_as_int", 42) + assert.Equal(t, 42, result) + }) + + t.Run("GetFloat with invalid string", func(t *testing.T) { + result := req.GetFloat("string_as_float", 1.5) + assert.Equal(t, 1.5, result) + }) + + t.Run("GetBool with invalid string", func(t *testing.T) { + result := req.GetBool("string_as_bool", false) + assert.Equal(t, false, result) + }) + + t.Run("GetBool with non-zero int", func(t *testing.T) { + result := req.GetBool("int_as_bool", false) + assert.True(t, result) + }) + + t.Run("GetBool with non-zero float", func(t *testing.T) { + result := req.GetBool("float_as_bool", false) + assert.True(t, result) + }) + + t.Run("RequireInt with wrong type", func(t *testing.T) { + _, err := req.RequireInt("object_val") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not an int") + }) + + t.Run("RequireFloat with wrong type", func(t *testing.T) { + _, err := req.RequireFloat("object_val") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not a float64") + }) + + t.Run("RequireBool with wrong type", func(t *testing.T) { + _, err := req.RequireBool("object_val") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not a bool") + }) +} + +func TestCallToolRequest_SliceWithMixedTypes(t *testing.T) { + req := CallToolRequest{} + req.Params.Arguments = map[string]any{ + "mixed_string_slice": []any{"valid", 123, "another"}, + "mixed_int_slice": []any{1, "not a number", 3}, + "mixed_float_slice": []any{1.1, "not a float", 3.3}, + "mixed_bool_slice": []any{true, "not a bool", false}, + } + + t.Run("GetStringSlice with mixed types", func(t *testing.T) { + result := req.GetStringSlice("mixed_string_slice", nil) + // Should only include valid strings + assert.Equal(t, []string{"valid", "another"}, result) + }) + + t.Run("RequireStringSlice with non-string element", func(t *testing.T) { + _, err := req.RequireStringSlice("mixed_string_slice") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not a string") + }) + + t.Run("GetIntSlice with mixed types", func(t *testing.T) { + result := req.GetIntSlice("mixed_int_slice", nil) + // Should only include convertible values + assert.Equal(t, []int{1, 3}, result) + }) + + t.Run("RequireIntSlice with non-convertible element", func(t *testing.T) { + _, err := req.RequireIntSlice("mixed_int_slice") + assert.Error(t, err) + assert.Contains(t, err.Error(), "cannot be converted to int") + }) + + t.Run("GetFloatSlice with mixed types", func(t *testing.T) { + result := req.GetFloatSlice("mixed_float_slice", nil) + assert.Equal(t, []float64{1.1, 3.3}, result) + }) + + t.Run("RequireFloatSlice with non-convertible element", func(t *testing.T) { + _, err := req.RequireFloatSlice("mixed_float_slice") + assert.Error(t, err) + assert.Contains(t, err.Error(), "cannot be converted to float64") + }) + + t.Run("GetBoolSlice with mixed types", func(t *testing.T) { + result := req.GetBoolSlice("mixed_bool_slice", nil) + // Should skip invalid values + assert.Equal(t, []bool{true, false}, result) + }) + + t.Run("RequireBoolSlice with non-convertible element", func(t *testing.T) { + _, err := req.RequireBoolSlice("mixed_bool_slice") + assert.Error(t, err) + assert.Contains(t, err.Error(), "cannot be converted to bool") + }) +} + +// Test Property Options + +func TestPropertyOptions(t *testing.T) { + t.Run("MaxProperties", func(t *testing.T) { + schema := make(map[string]any) + opt := MaxProperties(10) + opt(schema) + assert.Equal(t, 10, schema["maxProperties"]) + }) + + t.Run("MinProperties", func(t *testing.T) { + schema := make(map[string]any) + opt := MinProperties(2) + opt(schema) + assert.Equal(t, 2, schema["minProperties"]) + }) + + t.Run("PropertyNames", func(t *testing.T) { + schema := make(map[string]any) + nameSchema := map[string]any{"pattern": "^[a-z]+$"} + opt := PropertyNames(nameSchema) + opt(schema) + assert.Equal(t, nameSchema, schema["propertyNames"]) + }) + + t.Run("AdditionalProperties with bool", func(t *testing.T) { + schema := make(map[string]any) + opt := AdditionalProperties(false) + opt(schema) + assert.Equal(t, false, schema["additionalProperties"]) + }) + + t.Run("AdditionalProperties with schema", func(t *testing.T) { + schema := make(map[string]any) + propSchema := map[string]any{"type": "string"} + opt := AdditionalProperties(propSchema) + opt(schema) + assert.Equal(t, propSchema, schema["additionalProperties"]) + }) +} + +// Test Array Options + +func TestArrayOptions(t *testing.T) { + t.Run("MinItems", func(t *testing.T) { + schema := make(map[string]any) + opt := MinItems(1) + opt(schema) + assert.Equal(t, 1, schema["minItems"]) + }) + + t.Run("MaxItems", func(t *testing.T) { + schema := make(map[string]any) + opt := MaxItems(100) + opt(schema) + assert.Equal(t, 100, schema["maxItems"]) + }) + + t.Run("UniqueItems", func(t *testing.T) { + schema := make(map[string]any) + opt := UniqueItems(true) + opt(schema) + assert.Equal(t, true, schema["uniqueItems"]) + }) + + t.Run("Items with custom schema", func(t *testing.T) { + schema := make(map[string]any) + itemSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + }, + } + opt := Items(itemSchema) + opt(schema) + assert.Equal(t, itemSchema, schema["items"]) + }) +} + +// Test Tool Annotations + +func TestToolAnnotations(t *testing.T) { + t.Run("WithTitleAnnotation", func(t *testing.T) { + tool := NewTool("test") + opt := WithTitleAnnotation("Test Tool") + opt(&tool) + assert.Equal(t, "Test Tool", tool.Annotations.Title) + }) + + t.Run("WithReadOnlyHintAnnotation", func(t *testing.T) { + tool := NewTool("test") + opt := WithReadOnlyHintAnnotation(true) + opt(&tool) + require.NotNil(t, tool.Annotations.ReadOnlyHint) + assert.True(t, *tool.Annotations.ReadOnlyHint) + }) + + t.Run("WithDestructiveHintAnnotation", func(t *testing.T) { + tool := NewTool("test") + opt := WithDestructiveHintAnnotation(false) + opt(&tool) + require.NotNil(t, tool.Annotations.DestructiveHint) + assert.False(t, *tool.Annotations.DestructiveHint) + }) + + t.Run("WithIdempotentHintAnnotation", func(t *testing.T) { + tool := NewTool("test") + opt := WithIdempotentHintAnnotation(true) + opt(&tool) + require.NotNil(t, tool.Annotations.IdempotentHint) + assert.True(t, *tool.Annotations.IdempotentHint) + }) + + t.Run("WithOpenWorldHintAnnotation", func(t *testing.T) { + tool := NewTool("test") + opt := WithOpenWorldHintAnnotation(false) + opt(&tool) + require.NotNil(t, tool.Annotations.OpenWorldHint) + assert.False(t, *tool.Annotations.OpenWorldHint) + }) + + t.Run("WithToolAnnotation full", func(t *testing.T) { + tool := NewTool("test") + annotation := ToolAnnotation{ + Title: "Custom Tool", + ReadOnlyHint: ToBoolPtr(true), + DestructiveHint: ToBoolPtr(false), + IdempotentHint: ToBoolPtr(true), + OpenWorldHint: ToBoolPtr(false), + } + opt := WithToolAnnotation(annotation) + opt(&tool) + + assert.Equal(t, "Custom Tool", tool.Annotations.Title) + require.NotNil(t, tool.Annotations.ReadOnlyHint) + assert.True(t, *tool.Annotations.ReadOnlyHint) + require.NotNil(t, tool.Annotations.DestructiveHint) + assert.False(t, *tool.Annotations.DestructiveHint) + require.NotNil(t, tool.Annotations.IdempotentHint) + assert.True(t, *tool.Annotations.IdempotentHint) + require.NotNil(t, tool.Annotations.OpenWorldHint) + assert.False(t, *tool.Annotations.OpenWorldHint) + }) +} + +// Test Tool with both InputSchema and OutputSchema + +func TestToolWithBothSchemas(t *testing.T) { + type Input struct { + Query string `json:"query"` + Limit int `json:"limit"` + } + + type Output struct { + Results []string `json:"results"` + Count int `json:"count"` + } + + tool := NewTool("search", + WithDescription("Search with typed input and output"), + WithInputSchema[Input](), + WithOutputSchema[Output]()) + + // Verify tool can be marshaled + data, err := json.Marshal(tool) + require.NoError(t, err) + + var result map[string]any + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + // Verify both schemas exist + assert.Contains(t, result, "inputSchema") + assert.Contains(t, result, "outputSchema") + + // Verify outputSchema structure + outputSchema, ok := result["outputSchema"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "object", outputSchema["type"]) +} + +// Test RawOutputSchema conflict + +func TestToolWithBothOutputSchemasError(t *testing.T) { + tool := NewTool("test", + WithString("input", Required())) + + // Set OutputSchema via DSL + tool.OutputSchema = ToolOutputSchema{ + Type: "object", + Properties: map[string]any{"result": map[string]any{"type": "string"}}, + } + + // Also set RawOutputSchema - should conflict + tool.RawOutputSchema = json.RawMessage(`{"type":"string"}`) + + // Attempt to marshal + _, err := json.Marshal(tool) + assert.ErrorIs(t, err, errToolSchemaConflict) +} + +// Test array property options in tool + +func TestToolWithArrayConstraints(t *testing.T) { + tool := NewTool("list-tool", + WithDescription("Tool with constrained arrays"), + WithArray("tags", + Description("List of tags"), + Required(), + WithStringItems(MinLength(1), MaxLength(50)), + MinItems(1), + MaxItems(10), + UniqueItems(true)), + WithArray("scores", + Description("List of scores"), + WithNumberItems(Min(0), Max(100)), + MinItems(1))) + + // Marshal and verify + data, err := json.Marshal(tool) + require.NoError(t, err) + + var result map[string]any + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + schema := result["inputSchema"].(map[string]any) + properties := schema["properties"].(map[string]any) + + // Verify tags array + tags := properties["tags"].(map[string]any) + assert.Equal(t, "array", tags["type"]) + assert.Equal(t, float64(1), tags["minItems"]) + assert.Equal(t, float64(10), tags["maxItems"]) + assert.Equal(t, true, tags["uniqueItems"]) + + // Verify items schema for tags + tagsItems := tags["items"].(map[string]any) + assert.Equal(t, "string", tagsItems["type"]) + assert.Equal(t, float64(1), tagsItems["minLength"]) + assert.Equal(t, float64(50), tagsItems["maxLength"]) + + // Verify scores array + scores := properties["scores"].(map[string]any) + assert.Equal(t, "array", scores["type"]) + assert.Equal(t, float64(1), scores["minItems"]) + + // Verify items schema for scores + scoresItems := scores["items"].(map[string]any) + assert.Equal(t, "number", scoresItems["type"]) + assert.Equal(t, float64(0), scoresItems["minimum"]) + assert.Equal(t, float64(100), scoresItems["maximum"]) +} + +// Test object property options + +func TestToolWithObjectConstraints(t *testing.T) { + tool := NewTool("object-tool", + WithDescription("Tool with object constraints"), + WithObject("metadata", + Description("Metadata object"), + Properties(map[string]any{ + "created": map[string]any{"type": "string"}, + "updated": map[string]any{"type": "string"}, + }), + MinProperties(1), + MaxProperties(5), + AdditionalProperties(false))) + + // Marshal and verify + data, err := json.Marshal(tool) + require.NoError(t, err) + + var result map[string]any + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + schema := result["inputSchema"].(map[string]any) + properties := schema["properties"].(map[string]any) + + metadata := properties["metadata"].(map[string]any) + assert.Equal(t, "object", metadata["type"]) + assert.Equal(t, float64(1), metadata["minProperties"]) + assert.Equal(t, float64(5), metadata["maxProperties"]) + assert.Equal(t, false, metadata["additionalProperties"]) + + // Verify nested properties + metaProps := metadata["properties"].(map[string]any) + assert.Contains(t, metaProps, "created") + assert.Contains(t, metaProps, "updated") +} + +// Test BindArguments with json.RawMessage + +func TestCallToolRequest_BindArgumentsWithRawJSON(t *testing.T) { + type Args struct { + Name string `json:"name"` + Value int `json:"value"` + } + + rawJSON := json.RawMessage(`{"name": "test", "value": 42}`) + + req := CallToolRequest{} + req.Params.Arguments = rawJSON + + var args Args + err := req.BindArguments(&args) + require.NoError(t, err) + + assert.Equal(t, "test", args.Name) + assert.Equal(t, 42, args.Value) +} diff --git a/mcp/tools_properties_test.go b/mcp/tools_properties_test.go new file mode 100644 index 000000000..9edb799b4 --- /dev/null +++ b/mcp/tools_properties_test.go @@ -0,0 +1,312 @@ +package mcp + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test property option functions with 0% coverage + +func TestDefaultString(t *testing.T) { + schema := make(map[string]any) + opt := DefaultString("default value") + opt(schema) + assert.Equal(t, "default value", schema["default"]) +} + +func TestEnum(t *testing.T) { + schema := make(map[string]any) + opt := Enum("red", "green", "blue") + opt(schema) + assert.Equal(t, []string{"red", "green", "blue"}, schema["enum"]) +} + +func TestPattern(t *testing.T) { + schema := make(map[string]any) + opt := Pattern("^[a-z]+$") + opt(schema) + assert.Equal(t, "^[a-z]+$", schema["pattern"]) +} + +func TestDefaultNumber(t *testing.T) { + schema := make(map[string]any) + opt := DefaultNumber(42.5) + opt(schema) + assert.Equal(t, 42.5, schema["default"]) +} + +func TestMultipleOf(t *testing.T) { + schema := make(map[string]any) + opt := MultipleOf(5.0) + opt(schema) + assert.Equal(t, 5.0, schema["multipleOf"]) +} + +func TestDefaultBool(t *testing.T) { + schema := make(map[string]any) + opt := DefaultBool(true) + opt(schema) + assert.Equal(t, true, schema["default"]) +} + +func TestDefaultArray(t *testing.T) { + schema := make(map[string]any) + opt := DefaultArray([]string{"a", "b", "c"}) + opt(schema) + assert.Equal(t, []string{"a", "b", "c"}, schema["default"]) +} + +func TestTitle(t *testing.T) { + schema := make(map[string]any) + opt := Title("Field Title") + opt(schema) + assert.Equal(t, "Field Title", schema["title"]) +} + +func TestWithBoolean(t *testing.T) { + tool := NewTool("test", + WithBoolean("enabled", + Description("Enable feature"), + Required(), + DefaultBool(false))) + + // Marshal and verify + data, err := json.Marshal(tool) + require.NoError(t, err) + + var result map[string]any + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + schema := result["inputSchema"].(map[string]any) + properties := schema["properties"].(map[string]any) + enabled := properties["enabled"].(map[string]any) + + assert.Equal(t, "boolean", enabled["type"]) + assert.Equal(t, "Enable feature", enabled["description"]) + assert.Equal(t, false, enabled["default"]) + + required := schema["required"].([]any) + assert.Contains(t, required, "enabled") +} + +func TestWithNumber(t *testing.T) { + tool := NewTool("test", + WithNumber("score", + Description("Score value"), + Required(), + DefaultNumber(0.0), + Min(0.0), + Max(100.0), + MultipleOf(0.5))) + + // Marshal and verify + data, err := json.Marshal(tool) + require.NoError(t, err) + + var result map[string]any + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + schema := result["inputSchema"].(map[string]any) + properties := schema["properties"].(map[string]any) + score := properties["score"].(map[string]any) + + assert.Equal(t, "number", score["type"]) + assert.Equal(t, "Score value", score["description"]) + assert.Equal(t, 0.0, score["default"]) + assert.Equal(t, 0.0, score["minimum"]) + assert.Equal(t, 100.0, score["maximum"]) + assert.Equal(t, 0.5, score["multipleOf"]) + + required := schema["required"].([]any) + assert.Contains(t, required, "score") +} + +func TestWithRawInputSchema(t *testing.T) { + rawSchema := json.RawMessage(`{ + "type": "object", + "properties": { + "custom": {"type": "string"} + } + }`) + + // Use NewToolWithRawSchema instead of NewTool to avoid conflict + tool := NewToolWithRawSchema("test", "Tool with raw input schema", rawSchema) + + // Verify RawInputSchema is set + assert.NotNil(t, tool.RawInputSchema) + assert.Equal(t, rawSchema, tool.RawInputSchema) + + // Marshal and verify + data, err := json.Marshal(tool) + require.NoError(t, err) + + var result map[string]any + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + inputSchema := result["inputSchema"].(map[string]any) + assert.Equal(t, "object", inputSchema["type"]) + + properties := inputSchema["properties"].(map[string]any) + assert.Contains(t, properties, "custom") +} + +func TestWithRawInputSchemaOption(t *testing.T) { + // Test the WithRawInputSchema option function directly + rawSchema := json.RawMessage(`{"type": "string"}`) + + tool := Tool{} + opt := WithRawInputSchema(rawSchema) + opt(&tool) + + assert.Equal(t, rawSchema, tool.RawInputSchema) +} + +func TestWithRawOutputSchema(t *testing.T) { + rawSchema := json.RawMessage(`{ + "type": "object", + "properties": { + "result": {"type": "string"} + } + }`) + + tool := NewTool("test", + WithString("input", Required()), + WithRawOutputSchema(rawSchema)) + + // Verify RawOutputSchema is set + assert.NotNil(t, tool.RawOutputSchema) + assert.Equal(t, rawSchema, tool.RawOutputSchema) + + // Marshal and verify + data, err := json.Marshal(tool) + require.NoError(t, err) + + var result map[string]any + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + outputSchema := result["outputSchema"].(map[string]any) + assert.Equal(t, "object", outputSchema["type"]) + + properties := outputSchema["properties"].(map[string]any) + assert.Contains(t, properties, "result") +} + +func TestToolGetName(t *testing.T) { + tool := NewTool("my-tool") + assert.Equal(t, "my-tool", tool.GetName()) +} + +func TestToolInputSchemaMarshalJSON(t *testing.T) { + schema := ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "name": map[string]any{"type": "string"}, + }, + Required: []string{"name"}, + Defs: map[string]any{ + "CustomType": map[string]any{"type": "string"}, + }, + } + + data, err := json.Marshal(schema) + require.NoError(t, err) + + var result map[string]any + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + assert.Equal(t, "object", result["type"]) + assert.Contains(t, result, "properties") + assert.Contains(t, result, "required") + assert.Contains(t, result, "$defs") +} + +func TestWithStringWithAllOptions(t *testing.T) { + tool := NewTool("test", + WithString("name", + Description("User name"), + Title("Name"), + Required(), + DefaultString("John"), + MinLength(1), + MaxLength(50), + Pattern("^[A-Za-z]+$"), + Enum("John", "Jane", "Bob"))) + + // Marshal and verify + data, err := json.Marshal(tool) + require.NoError(t, err) + + var result map[string]any + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + schema := result["inputSchema"].(map[string]any) + properties := schema["properties"].(map[string]any) + name := properties["name"].(map[string]any) + + assert.Equal(t, "string", name["type"]) + assert.Equal(t, "User name", name["description"]) + assert.Equal(t, "Name", name["title"]) + assert.Equal(t, "John", name["default"]) + assert.Equal(t, float64(1), name["minLength"]) + assert.Equal(t, float64(50), name["maxLength"]) + assert.Equal(t, "^[A-Za-z]+$", name["pattern"]) + + enum := name["enum"].([]any) + assert.Len(t, enum, 3) + assert.Contains(t, enum, "John") + assert.Contains(t, enum, "Jane") + assert.Contains(t, enum, "Bob") +} + +func TestWithObjectWithAllOptions(t *testing.T) { + tool := NewTool("test", + WithObject("config", + Description("Configuration object"), + Title("Config"), + Required(), + Properties(map[string]any{ + "host": map[string]any{"type": "string"}, + "port": map[string]any{"type": "number"}, + }), + MinProperties(1), + MaxProperties(10), + AdditionalProperties(map[string]any{"type": "string"}), + PropertyNames(map[string]any{"pattern": "^[a-z]+$"}))) + + // Marshal and verify + data, err := json.Marshal(tool) + require.NoError(t, err) + + var result map[string]any + err = json.Unmarshal(data, &result) + require.NoError(t, err) + + schema := result["inputSchema"].(map[string]any) + properties := schema["properties"].(map[string]any) + config := properties["config"].(map[string]any) + + assert.Equal(t, "object", config["type"]) + assert.Equal(t, "Configuration object", config["description"]) + assert.Equal(t, "Config", config["title"]) + assert.Equal(t, float64(1), config["minProperties"]) + assert.Equal(t, float64(10), config["maxProperties"]) + + configProps := config["properties"].(map[string]any) + assert.Contains(t, configProps, "host") + assert.Contains(t, configProps, "port") + + additionalProps := config["additionalProperties"].(map[string]any) + assert.Equal(t, "string", additionalProps["type"]) + + propertyNames := config["propertyNames"].(map[string]any) + assert.Equal(t, "^[a-z]+$", propertyNames["pattern"]) +} diff --git a/mcp/typed_tools_additional_test.go b/mcp/typed_tools_additional_test.go new file mode 100644 index 000000000..9dfa1cc21 --- /dev/null +++ b/mcp/typed_tools_additional_test.go @@ -0,0 +1,329 @@ +package mcp + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewTypedToolHandler(t *testing.T) { + type Args struct { + Name string `json:"name"` + Count int `json:"count"` + } + + t.Run("successful execution", func(t *testing.T) { + handler := func(ctx context.Context, request CallToolRequest, args Args) (*CallToolResult, error) { + return NewToolResultText("Name: " + args.Name), nil + } + + typedHandler := NewTypedToolHandler(handler) + + req := CallToolRequest{} + req.Params.Arguments = map[string]any{ + "name": "test", + "count": 5, + } + + result, err := typedHandler(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, result) + + require.Len(t, result.Content, 1) + textContent, ok := result.Content[0].(TextContent) + require.True(t, ok) + assert.Equal(t, "Name: test", textContent.Text) + }) + + t.Run("bind arguments error", func(t *testing.T) { + handler := func(ctx context.Context, request CallToolRequest, args Args) (*CallToolResult, error) { + return NewToolResultText("Should not reach here"), nil + } + + typedHandler := NewTypedToolHandler(handler) + + req := CallToolRequest{} + req.Params.Arguments = "invalid arguments" // Not a map + + result, err := typedHandler(context.Background(), req) + require.NoError(t, err) // Handler returns result, not error + require.NotNil(t, result) + + // Should return error result + require.Len(t, result.Content, 1) + textContent, ok := result.Content[0].(TextContent) + require.True(t, ok) + assert.Contains(t, textContent.Text, "failed to bind arguments") + }) + + t.Run("handler returns error", func(t *testing.T) { + handler := func(ctx context.Context, request CallToolRequest, args Args) (*CallToolResult, error) { + return nil, errors.New("handler error") + } + + typedHandler := NewTypedToolHandler(handler) + + req := CallToolRequest{} + req.Params.Arguments = map[string]any{ + "name": "test", + "count": 5, + } + + result, err := typedHandler(context.Background(), req) + assert.Error(t, err) + assert.Nil(t, result) + }) + + t.Run("with complex arguments", func(t *testing.T) { + type ComplexArgs struct { + Items []string `json:"items"` + Options map[string]string `json:"options"` + Nested struct { + Value int `json:"value"` + } `json:"nested"` + } + + handler := func(ctx context.Context, request CallToolRequest, args ComplexArgs) (*CallToolResult, error) { + return NewToolResultText("OK"), nil + } + + typedHandler := NewTypedToolHandler(handler) + + req := CallToolRequest{} + req.Params.Arguments = map[string]any{ + "items": []any{"a", "b", "c"}, + "options": map[string]any{"key": "value"}, + "nested": map[string]any{"value": 42}, + } + + result, err := typedHandler(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, result) + }) +} + +func TestNewStructuredToolHandler(t *testing.T) { + type Args struct { + Query string `json:"query"` + Limit int `json:"limit"` + } + + type Result struct { + Results []string `json:"results"` + Count int `json:"count"` + } + + t.Run("successful execution", func(t *testing.T) { + handler := func(ctx context.Context, request CallToolRequest, args Args) (Result, error) { + return Result{ + Results: []string{"result1", "result2"}, + Count: 2, + }, nil + } + + structuredHandler := NewStructuredToolHandler(handler) + + req := CallToolRequest{} + req.Params.Arguments = map[string]any{ + "query": "test", + "limit": 10, + } + + result, err := structuredHandler(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, result) + + // Should have text content (JSON fallback) + require.Len(t, result.Content, 1) + textContent, ok := result.Content[0].(TextContent) + require.True(t, ok) + + // Text should be JSON representation + var jsonResult map[string]any + err = json.Unmarshal([]byte(textContent.Text), &jsonResult) + require.NoError(t, err) + + // Should have structured content + require.NotNil(t, result.StructuredContent) + structuredMap, ok := result.StructuredContent.(Result) + require.True(t, ok) + assert.Equal(t, 2, structuredMap.Count) + assert.Len(t, structuredMap.Results, 2) + }) + + t.Run("bind arguments error", func(t *testing.T) { + handler := func(ctx context.Context, request CallToolRequest, args Args) (Result, error) { + return Result{}, errors.New("should not reach here") + } + + structuredHandler := NewStructuredToolHandler(handler) + + req := CallToolRequest{} + req.Params.Arguments = "invalid" // Not a map + + result, err := structuredHandler(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, result) + + // Should return error result + require.Len(t, result.Content, 1) + textContent, ok := result.Content[0].(TextContent) + require.True(t, ok) + assert.Contains(t, textContent.Text, "failed to bind arguments") + }) + + t.Run("handler execution error", func(t *testing.T) { + handler := func(ctx context.Context, request CallToolRequest, args Args) (Result, error) { + return Result{}, errors.New("execution failed") + } + + structuredHandler := NewStructuredToolHandler(handler) + + req := CallToolRequest{} + req.Params.Arguments = map[string]any{ + "query": "test", + "limit": 10, + } + + result, err := structuredHandler(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, result) + + // Should return error result + require.Len(t, result.Content, 1) + textContent, ok := result.Content[0].(TextContent) + require.True(t, ok) + assert.Contains(t, textContent.Text, "tool execution failed") + assert.Contains(t, textContent.Text, "execution failed") + }) + + t.Run("with primitive result", func(t *testing.T) { + handler := func(ctx context.Context, request CallToolRequest, args Args) (string, error) { + return "simple result", nil + } + + structuredHandler := NewStructuredToolHandler(handler) + + req := CallToolRequest{} + req.Params.Arguments = map[string]any{ + "query": "test", + "limit": 10, + } + + result, err := structuredHandler(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, result) + + require.NotNil(t, result.StructuredContent) + strResult, ok := result.StructuredContent.(string) + require.True(t, ok) + assert.Equal(t, "simple result", strResult) + }) + + t.Run("with map result", func(t *testing.T) { + handler := func(ctx context.Context, request CallToolRequest, args Args) (map[string]any, error) { + return map[string]any{ + "status": "success", + "data": []int{1, 2, 3}, + }, nil + } + + structuredHandler := NewStructuredToolHandler(handler) + + req := CallToolRequest{} + req.Params.Arguments = map[string]any{ + "query": "test", + "limit": 10, + } + + result, err := structuredHandler(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, result) + + require.NotNil(t, result.StructuredContent) + mapResult, ok := result.StructuredContent.(map[string]any) + require.True(t, ok) + assert.Equal(t, "success", mapResult["status"]) + }) + + t.Run("with empty struct args", func(t *testing.T) { + type EmptyArgs struct{} + + handler := func(ctx context.Context, request CallToolRequest, args EmptyArgs) (string, error) { + return "no args needed", nil + } + + structuredHandler := NewStructuredToolHandler(handler) + + req := CallToolRequest{} + req.Params.Arguments = map[string]any{} + + result, err := structuredHandler(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, result) + }) +} + +func TestTypedToolHandler_ContextPropagation(t *testing.T) { + type Args struct { + Value string `json:"value"` + } + + type contextKey string + + t.Run("context is passed to handler", func(t *testing.T) { + ctxKey := contextKey("test-key") + ctxValue := "test-value" + + handler := func(ctx context.Context, request CallToolRequest, args Args) (*CallToolResult, error) { + // Verify context value is available + val := ctx.Value(ctxKey) + if val == nil { + return NewToolResultError("context value missing"), nil + } + return NewToolResultText("context value: " + val.(string)), nil + } + + typedHandler := NewTypedToolHandler(handler) + + req := CallToolRequest{} + req.Params.Arguments = map[string]any{"value": "test"} + + ctx := context.WithValue(context.Background(), ctxKey, ctxValue) + result, err := typedHandler(ctx, req) + require.NoError(t, err) + require.NotNil(t, result) + + require.Len(t, result.Content, 1) + textContent, ok := result.Content[0].(TextContent) + require.True(t, ok) + assert.Equal(t, "context value: test-value", textContent.Text) + }) + + t.Run("cancelled context", func(t *testing.T) { + handler := func(ctx context.Context, request CallToolRequest, args Args) (*CallToolResult, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + return NewToolResultText("completed"), nil + } + } + + typedHandler := NewTypedToolHandler(handler) + + req := CallToolRequest{} + req.Params.Arguments = map[string]any{"value": "test"} + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + result, err := typedHandler(ctx, req) + assert.Error(t, err) + assert.Nil(t, result) + }) +} diff --git a/mcp/types.go b/mcp/types.go index 69ea73ff5..0f97821b4 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -739,8 +739,9 @@ type ResourceContents interface { } type TextResourceContents struct { - // Meta is a metadata object that is reserved by MCP for storing additional information. - Meta *Meta `json:"_meta,omitempty"` + // Raw per‑resource metadata; pass‑through as defined by MCP. Not the same as mcp.Meta. + // Allows _meta to be used for MCP-UI features for example. Does not assume any specific format. + Meta map[string]any `json:"_meta,omitempty"` // The URI of this resource. URI string `json:"uri"` // The MIME type of this resource, if known. @@ -753,8 +754,9 @@ type TextResourceContents struct { func (TextResourceContents) isResourceContents() {} type BlobResourceContents struct { - // Meta is a metadata object that is reserved by MCP for storing additional information. - Meta *Meta `json:"_meta,omitempty"` + // Raw per‑resource metadata; pass‑through as defined by MCP. Not the same as mcp.Meta. + // Allows _meta to be used for MCP-UI features for example. Does not assume any specific format. + Meta map[string]any `json:"_meta,omitempty"` // The URI of this resource. URI string `json:"uri"` // The MIME type of this resource, if known. diff --git a/mcp/types_test.go b/mcp/types_test.go index c1453de60..0d052f2d8 100644 --- a/mcp/types_test.go +++ b/mcp/types_test.go @@ -138,3 +138,238 @@ func TestCallToolResultWithResourceLink(t *testing.T) { assert.Equal(t, "A test document", resourceLink.Description) assert.Equal(t, "application/pdf", resourceLink.MIMEType) } + +func TestResourceContentsMetaField(t *testing.T) { + tests := []struct { + name string + inputJSON string + expectedType string + expectedMeta map[string]any + }{ + { + name: "TextResourceContents with empty _meta", + inputJSON: `{ + "uri":"file://empty-meta.txt", + "mimeType":"text/plain", + "text":"x", + "_meta": {} + }`, + expectedType: "text", + expectedMeta: map[string]any{}, + }, + { + name: "TextResourceContents with _meta field", + inputJSON: `{ + "uri": "file://test.txt", + "mimeType": "text/plain", + "text": "Hello World", + "_meta": { + "mcpui.dev/ui-preferred-frame-size": ["800px", "600px"], + "mcpui.dev/ui-initial-render-data": { + "test": "value" + } + } + }`, + expectedType: "text", + expectedMeta: map[string]any{ + "mcpui.dev/ui-preferred-frame-size": []interface{}{"800px", "600px"}, + "mcpui.dev/ui-initial-render-data": map[string]any{ + "test": "value", + }, + }, + }, + { + name: "BlobResourceContents with _meta field", + inputJSON: `{ + "uri": "file://image.png", + "mimeType": "image/png", + "blob": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==", + "_meta": { + "width": 100, + "height": 100, + "format": "PNG" + } + }`, + expectedType: "blob", + expectedMeta: map[string]any{ + "width": float64(100), // JSON numbers are always float64 + "height": float64(100), + "format": "PNG", + }, + }, + { + name: "TextResourceContents without _meta field", + inputJSON: `{ + "uri": "file://simple.txt", + "mimeType": "text/plain", + "text": "Simple content" + }`, + expectedType: "text", + expectedMeta: nil, + }, + { + name: "BlobResourceContents without _meta field", + inputJSON: `{ + "uri": "file://simple.png", + "mimeType": "image/png", + "blob": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" + }`, + expectedType: "blob", + expectedMeta: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Parse the JSON as a generic map first + var contentMap map[string]any + err := json.Unmarshal([]byte(tc.inputJSON), &contentMap) + require.NoError(t, err) + + // Use ParseResourceContents to convert to ResourceContents + resourceContent, err := ParseResourceContents(contentMap) + require.NoError(t, err) + require.NotNil(t, resourceContent) + + switch tc.expectedType { + case "text": + textContent, ok := resourceContent.(TextResourceContents) + require.True(t, ok, "Expected TextResourceContents") + + assert.Equal(t, contentMap["uri"], textContent.URI) + assert.Equal(t, contentMap["mimeType"], textContent.MIMEType) + assert.Equal(t, contentMap["text"], textContent.Text) + + assert.Equal(t, tc.expectedMeta, textContent.Meta) + + case "blob": + blobContent, ok := resourceContent.(BlobResourceContents) + require.True(t, ok, "Expected BlobResourceContents") + + assert.Equal(t, contentMap["uri"], blobContent.URI) + assert.Equal(t, contentMap["mimeType"], blobContent.MIMEType) + assert.Equal(t, contentMap["blob"], blobContent.Blob) + + assert.Equal(t, tc.expectedMeta, blobContent.Meta) + } + + // Test round-trip marshaling to ensure _meta is preserved + marshaledJSON, err := json.Marshal(resourceContent) + require.NoError(t, err) + + var marshaledMap map[string]any + err = json.Unmarshal(marshaledJSON, &marshaledMap) + require.NoError(t, err) + + // Verify _meta field is preserved in marshaled output + v, ok := marshaledMap["_meta"] + if tc.expectedMeta != nil { + // Special case: empty maps are omitted due to omitempty tag + if len(tc.expectedMeta) == 0 { + assert.False(t, ok, "_meta should be omitted when empty due to omitempty") + } else { + require.True(t, ok, "_meta should be present") + assert.Equal(t, tc.expectedMeta, v) + } + } else { + assert.False(t, ok, "_meta should be omitted when nil") + } + }) + } +} + +func TestParseResourceContentsInvalidMeta(t *testing.T) { + tests := []struct { + name string + inputJSON string + expectedErr string + }{ + { + name: "TextResourceContents with invalid _meta (string)", + inputJSON: `{ + "uri": "file://test.txt", + "mimeType": "text/plain", + "text": "Hello World", + "_meta": "invalid_meta_string" + }`, + expectedErr: "_meta must be an object", + }, + { + name: "TextResourceContents with invalid _meta (number)", + inputJSON: `{ + "uri": "file://test.txt", + "mimeType": "text/plain", + "text": "Hello World", + "_meta": 123 + }`, + expectedErr: "_meta must be an object", + }, + { + name: "TextResourceContents with invalid _meta (array)", + inputJSON: `{ + "uri": "file://test.txt", + "mimeType": "text/plain", + "text": "Hello World", + "_meta": ["invalid", "array"] + }`, + expectedErr: "_meta must be an object", + }, + { + name: "TextResourceContents with invalid _meta (boolean)", + inputJSON: `{ + "uri": "file://test.txt", + "mimeType": "text/plain", + "text": "Hello World", + "_meta": true + }`, + expectedErr: "_meta must be an object", + }, + { + name: "TextResourceContents with invalid _meta (null)", + inputJSON: `{ + "uri": "file://test.txt", + "mimeType": "text/plain", + "text": "Hello World", + "_meta": null + }`, + expectedErr: "_meta must be an object", + }, + { + name: "BlobResourceContents with invalid _meta (string)", + inputJSON: `{ + "uri": "file://image.png", + "mimeType": "image/png", + "blob": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==", + "_meta": "invalid_meta_string" + }`, + expectedErr: "_meta must be an object", + }, + { + name: "BlobResourceContents with invalid _meta (number)", + inputJSON: `{ + "uri": "file://image.png", + "mimeType": "image/png", + "blob": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==", + "_meta": 456 + }`, + expectedErr: "_meta must be an object", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Parse the JSON as a generic map first + var contentMap map[string]any + err := json.Unmarshal([]byte(tc.inputJSON), &contentMap) + require.NoError(t, err) + + // Use ParseResourceContents to convert to ResourceContents + resourceContent, err := ParseResourceContents(contentMap) + + // Expect an error + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErr) + assert.Nil(t, resourceContent) + }) + } +} diff --git a/mcp/utils.go b/mcp/utils.go index 0a3cde236..904a3dd6b 100644 --- a/mcp/utils.go +++ b/mcp/utils.go @@ -767,8 +767,15 @@ func ParseResourceContents(contentMap map[string]any) (ResourceContents, error) mimeType := ExtractString(contentMap, "mimeType") + meta := ExtractMap(contentMap, "_meta") + + if _, present := contentMap["_meta"]; present && meta == nil { + return nil, fmt.Errorf("_meta must be an object") + } + if text := ExtractString(contentMap, "text"); text != "" { return TextResourceContents{ + Meta: meta, URI: uri, MIMEType: mimeType, Text: text, @@ -777,6 +784,7 @@ func ParseResourceContents(contentMap map[string]any) (ResourceContents, error) if blob := ExtractString(contentMap, "blob"); blob != "" { return BlobResourceContents{ + Meta: meta, URI: uri, MIMEType: mimeType, Blob: blob, @@ -941,3 +949,31 @@ func ParseStringMap(request CallToolRequest, key string, defaultValue map[string func ToBoolPtr(b bool) *bool { return &b } + +// GetTextFromContent extracts text from a Content interface that might be a TextContent struct +// or a map[string]any that was unmarshaled from JSON. This is useful when dealing with content +// that comes from different transport layers that may handle JSON differently. +// +// This function uses fallback behavior for non-text content - it returns a string representation +// via fmt.Sprintf for any content that cannot be extracted as text. This is a lossy operation +// intended for convenience in logging and display scenarios. +// +// For strict type validation, use ParseContent() instead, which returns an error for invalid content. +func GetTextFromContent(content any) string { + switch c := content.(type) { + case TextContent: + return c.Text + case map[string]any: + // Handle JSON unmarshaled content + if contentType, exists := c["type"]; exists && contentType == "text" { + if text, exists := c["text"].(string); exists { + return text + } + } + return fmt.Sprintf("%v", content) + case string: + return c + default: + return fmt.Sprintf("%v", content) + } +} diff --git a/mcp/utils_additional_test.go b/mcp/utils_additional_test.go new file mode 100644 index 000000000..37309e05e --- /dev/null +++ b/mcp/utils_additional_test.go @@ -0,0 +1,604 @@ +package mcp + +import ( + "encoding/json" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test AsXXXContent type assertion helpers + +func TestAsTextContent(t *testing.T) { + t.Run("valid TextContent", func(t *testing.T) { + content := TextContent{Type: ContentTypeText, Text: "hello"} + result, ok := AsTextContent(content) + assert.True(t, ok) + require.NotNil(t, result) + assert.Equal(t, "hello", result.Text) + }) + + t.Run("invalid type", func(t *testing.T) { + content := ImageContent{Type: ContentTypeImage, Data: "data"} + result, ok := AsTextContent(content) + assert.False(t, ok) + assert.Nil(t, result) + }) + + t.Run("wrong type string", func(t *testing.T) { + result, ok := AsTextContent("not a text content") + assert.False(t, ok) + assert.Nil(t, result) + }) +} + +func TestAsImageContent(t *testing.T) { + t.Run("valid ImageContent", func(t *testing.T) { + content := ImageContent{Type: ContentTypeImage, Data: "base64", MIMEType: "image/png"} + result, ok := AsImageContent(content) + assert.True(t, ok) + require.NotNil(t, result) + assert.Equal(t, "base64", result.Data) + assert.Equal(t, "image/png", result.MIMEType) + }) + + t.Run("invalid type", func(t *testing.T) { + content := TextContent{Type: ContentTypeText, Text: "text"} + result, ok := AsImageContent(content) + assert.False(t, ok) + assert.Nil(t, result) + }) +} + +func TestAsAudioContent(t *testing.T) { + t.Run("valid AudioContent", func(t *testing.T) { + content := AudioContent{Type: ContentTypeAudio, Data: "base64", MIMEType: "audio/mp3"} + result, ok := AsAudioContent(content) + assert.True(t, ok) + require.NotNil(t, result) + assert.Equal(t, "base64", result.Data) + assert.Equal(t, "audio/mp3", result.MIMEType) + }) + + t.Run("invalid type", func(t *testing.T) { + result, ok := AsAudioContent(123) + assert.False(t, ok) + assert.Nil(t, result) + }) +} + +func TestAsEmbeddedResource(t *testing.T) { + t.Run("valid EmbeddedResource", func(t *testing.T) { + resource := TextResourceContents{URI: "file:///test.txt", Text: "content"} + content := EmbeddedResource{Type: ContentTypeResource, Resource: resource} + result, ok := AsEmbeddedResource(content) + assert.True(t, ok) + require.NotNil(t, result) + assert.Equal(t, resource, result.Resource) + }) + + t.Run("invalid type", func(t *testing.T) { + result, ok := AsEmbeddedResource(nil) + assert.False(t, ok) + assert.Nil(t, result) + }) +} + +func TestAsTextResourceContents(t *testing.T) { + t.Run("valid TextResourceContents", func(t *testing.T) { + content := TextResourceContents{URI: "file:///test.txt", Text: "hello"} + result, ok := AsTextResourceContents(content) + assert.True(t, ok) + require.NotNil(t, result) + assert.Equal(t, "hello", result.Text) + }) + + t.Run("invalid type", func(t *testing.T) { + content := BlobResourceContents{URI: "file:///test.bin", Blob: "data"} + result, ok := AsTextResourceContents(content) + assert.False(t, ok) + assert.Nil(t, result) + }) +} + +func TestAsBlobResourceContents(t *testing.T) { + t.Run("valid BlobResourceContents", func(t *testing.T) { + content := BlobResourceContents{URI: "file:///test.bin", Blob: "base64"} + result, ok := AsBlobResourceContents(content) + assert.True(t, ok) + require.NotNil(t, result) + assert.Equal(t, "base64", result.Blob) + }) + + t.Run("invalid type", func(t *testing.T) { + result, ok := AsBlobResourceContents([]byte{1, 2, 3}) + assert.False(t, ok) + assert.Nil(t, result) + }) +} + +// Test NewJSONRPCError and NewJSONRPCErrorDetails + +func TestNewJSONRPCError(t *testing.T) { + id := NewRequestId(123) + code := METHOD_NOT_FOUND + message := "Method not found" + data := map[string]any{"method": "unknown"} + + result := NewJSONRPCError(id, code, message, data) + + assert.Equal(t, JSONRPC_VERSION, result.JSONRPC) + assert.Equal(t, id, result.ID) + assert.Equal(t, code, result.Error.Code) + assert.Equal(t, message, result.Error.Message) + assert.Equal(t, data, result.Error.Data) +} + +func TestNewJSONRPCErrorDetails(t *testing.T) { + code := INVALID_PARAMS + message := "Invalid parameters" + data := "Additional error info" + + result := NewJSONRPCErrorDetails(code, message, data) + + assert.Equal(t, code, result.Code) + assert.Equal(t, message, result.Message) + assert.Equal(t, data, result.Data) +} + +// Test helper content creation functions + +func TestNewAudioContent(t *testing.T) { + result := NewAudioContent("audiodata", "audio/mp3") + + assert.Equal(t, ContentTypeAudio, result.Type) + assert.Equal(t, "audiodata", result.Data) + assert.Equal(t, "audio/mp3", result.MIMEType) +} + +func TestNewResourceLink(t *testing.T) { + result := NewResourceLink("file:///test.txt", "test.txt", "A test file", "text/plain") + + assert.Equal(t, ContentTypeLink, result.Type) + assert.Equal(t, "file:///test.txt", result.URI) + assert.Equal(t, "test.txt", result.Name) + assert.Equal(t, "A test file", result.Description) + assert.Equal(t, "text/plain", result.MIMEType) +} + +func TestNewEmbeddedResource(t *testing.T) { + resource := TextResourceContents{URI: "file:///test.txt", Text: "content"} + result := NewEmbeddedResource(resource) + + assert.Equal(t, ContentTypeResource, result.Type) + assert.Equal(t, resource, result.Resource) +} + +// Test ParseResourceContents + +func TestParseResourceContents(t *testing.T) { + t.Run("text resource", func(t *testing.T) { + contentMap := map[string]any{ + "uri": "file:///test.txt", + "mimeType": "text/plain", + "text": "hello world", + } + + result, err := ParseResourceContents(contentMap) + require.NoError(t, err) + + textRes, ok := result.(TextResourceContents) + require.True(t, ok) + assert.Equal(t, "file:///test.txt", textRes.URI) + assert.Equal(t, "text/plain", textRes.MIMEType) + assert.Equal(t, "hello world", textRes.Text) + }) + + t.Run("blob resource", func(t *testing.T) { + contentMap := map[string]any{ + "uri": "file:///test.bin", + "mimeType": "application/octet-stream", + "blob": "base64data", + } + + result, err := ParseResourceContents(contentMap) + require.NoError(t, err) + + blobRes, ok := result.(BlobResourceContents) + require.True(t, ok) + assert.Equal(t, "file:///test.bin", blobRes.URI) + assert.Equal(t, "application/octet-stream", blobRes.MIMEType) + assert.Equal(t, "base64data", blobRes.Blob) + }) + + t.Run("missing uri", func(t *testing.T) { + contentMap := map[string]any{ + "text": "hello", + } + + _, err := ParseResourceContents(contentMap) + assert.Error(t, err) + assert.Contains(t, err.Error(), "uri is missing") + }) + + t.Run("no text or blob", func(t *testing.T) { + contentMap := map[string]any{ + "uri": "file:///test", + } + + _, err := ParseResourceContents(contentMap) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported resource type") + }) +} + +// Test ParseGetPromptResult with malformed JSON + +func TestParseGetPromptResult_Errors(t *testing.T) { + t.Run("nil raw message", func(t *testing.T) { + _, err := ParseGetPromptResult(nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "nil") + }) + + t.Run("invalid JSON", func(t *testing.T) { + raw := json.RawMessage(`{invalid json}`) + _, err := ParseGetPromptResult(&raw) + assert.Error(t, err) + }) + + t.Run("messages not array", func(t *testing.T) { + raw := json.RawMessage(`{"messages": "not an array"}`) + _, err := ParseGetPromptResult(&raw) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not an array") + }) + + t.Run("message not object", func(t *testing.T) { + raw := json.RawMessage(`{"messages": ["not an object"]}`) + _, err := ParseGetPromptResult(&raw) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not an object") + }) + + t.Run("unsupported role", func(t *testing.T) { + raw := json.RawMessage(`{ + "messages": [{ + "role": "system", + "content": {"type": "text", "text": "hello"} + }] + }`) + _, err := ParseGetPromptResult(&raw) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported role") + }) + + t.Run("content not object", func(t *testing.T) { + raw := json.RawMessage(`{ + "messages": [{ + "role": "user", + "content": "not an object" + }] + }`) + _, err := ParseGetPromptResult(&raw) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not an object") + }) +} + +// Test ParseCallToolResult with malformed JSON + +func TestParseCallToolResult_Errors(t *testing.T) { + t.Run("nil raw message", func(t *testing.T) { + _, err := ParseCallToolResult(nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "nil") + }) + + t.Run("invalid JSON", func(t *testing.T) { + raw := json.RawMessage(`{invalid}`) + _, err := ParseCallToolResult(&raw) + assert.Error(t, err) + }) + + t.Run("missing content", func(t *testing.T) { + raw := json.RawMessage(`{"isError": false}`) + _, err := ParseCallToolResult(&raw) + assert.Error(t, err) + assert.Contains(t, err.Error(), "content is missing") + }) + + t.Run("content not array", func(t *testing.T) { + raw := json.RawMessage(`{"content": "not an array"}`) + _, err := ParseCallToolResult(&raw) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not an array") + }) + + t.Run("content item not object", func(t *testing.T) { + raw := json.RawMessage(`{"content": ["not an object"]}`) + _, err := ParseCallToolResult(&raw) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not an object") + }) +} + +// Test ParseReadResourceResult with malformed JSON + +func TestParseReadResourceResult_Errors(t *testing.T) { + t.Run("nil raw message", func(t *testing.T) { + _, err := ParseReadResourceResult(nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "nil") + }) + + t.Run("invalid JSON", func(t *testing.T) { + raw := json.RawMessage(`{bad json}`) + _, err := ParseReadResourceResult(&raw) + assert.Error(t, err) + }) + + t.Run("missing contents", func(t *testing.T) { + raw := json.RawMessage(`{}`) + _, err := ParseReadResourceResult(&raw) + assert.Error(t, err) + assert.Contains(t, err.Error(), "contents is missing") + }) + + t.Run("contents not array", func(t *testing.T) { + raw := json.RawMessage(`{"contents": "not an array"}`) + _, err := ParseReadResourceResult(&raw) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not an array") + }) + + t.Run("content item not object", func(t *testing.T) { + raw := json.RawMessage(`{"contents": [123]}`) + _, err := ParseReadResourceResult(&raw) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not an object") + }) +} + +// Test ParseStringMap + +func TestParseStringMap(t *testing.T) { + req := CallToolRequest{} + req.Params.Arguments = map[string]any{ + "valid_map": map[string]any{ + "key1": "value1", + "key2": 123, + }, + "not_a_map": "string value", + } + + t.Run("valid map", func(t *testing.T) { + result := ParseStringMap(req, "valid_map", nil) + require.NotNil(t, result) + assert.Equal(t, "value1", result["key1"]) + assert.Equal(t, 123, result["key2"]) + }) + + t.Run("invalid type returns empty map", func(t *testing.T) { + defaultMap := map[string]any{"default": "value"} + result := ParseStringMap(req, "not_a_map", defaultMap) + // cast.ToStringMap returns empty map when it can't convert + assert.Equal(t, map[string]any{}, result) + }) + + t.Run("missing key returns converted default", func(t *testing.T) { + defaultMap := map[string]any{"default": "value"} + result := ParseStringMap(req, "missing", defaultMap) + // ParseArgument returns the default, which is then converted by cast.ToStringMap + assert.Equal(t, defaultMap, result) + }) +} + +// Test ExtractMap and ExtractString + +func TestExtractMap(t *testing.T) { + data := map[string]any{ + "nested": map[string]any{ + "key": "value", + }, + "not_map": "string", + } + + t.Run("valid map", func(t *testing.T) { + result := ExtractMap(data, "nested") + require.NotNil(t, result) + assert.Equal(t, "value", result["key"]) + }) + + t.Run("not a map", func(t *testing.T) { + result := ExtractMap(data, "not_map") + assert.Nil(t, result) + }) + + t.Run("missing key", func(t *testing.T) { + result := ExtractMap(data, "missing") + assert.Nil(t, result) + }) +} + +func TestExtractString(t *testing.T) { + data := map[string]any{ + "string_val": "hello", + "int_val": 123, + } + + t.Run("valid string", func(t *testing.T) { + result := ExtractString(data, "string_val") + assert.Equal(t, "hello", result) + }) + + t.Run("not a string", func(t *testing.T) { + result := ExtractString(data, "int_val") + assert.Equal(t, "", result) + }) + + t.Run("missing key", func(t *testing.T) { + result := ExtractString(data, "missing") + assert.Equal(t, "", result) + }) +} + +// Test all ParseXXX functions + +func TestParseIntVariants(t *testing.T) { + req := CallToolRequest{} + req.Params.Arguments = map[string]any{ + "valid": "42", + } + + t.Run("ParseInt32", func(t *testing.T) { + result := ParseInt32(req, "valid", 0) + assert.Equal(t, int32(42), result) + + result = ParseInt32(req, "missing", 10) + assert.Equal(t, int32(10), result) + }) + + t.Run("ParseInt16", func(t *testing.T) { + result := ParseInt16(req, "valid", 0) + assert.Equal(t, int16(42), result) + }) + + t.Run("ParseInt8", func(t *testing.T) { + result := ParseInt8(req, "valid", 0) + assert.Equal(t, int8(42), result) + }) + + t.Run("ParseUInt", func(t *testing.T) { + result := ParseUInt(req, "valid", 0) + assert.Equal(t, uint(42), result) + }) + + t.Run("ParseUInt64", func(t *testing.T) { + result := ParseUInt64(req, "valid", 0) + assert.Equal(t, uint64(42), result) + }) + + t.Run("ParseUInt32", func(t *testing.T) { + result := ParseUInt32(req, "valid", 0) + assert.Equal(t, uint32(42), result) + }) + + t.Run("ParseUInt16", func(t *testing.T) { + result := ParseUInt16(req, "valid", 0) + assert.Equal(t, uint16(42), result) + }) + + t.Run("ParseUInt8", func(t *testing.T) { + result := ParseUInt8(req, "valid", 0) + assert.Equal(t, uint8(42), result) + }) +} + +func TestParseFloatVariants(t *testing.T) { + req := CallToolRequest{} + req.Params.Arguments = map[string]any{ + "valid": "3.14", + } + + t.Run("ParseFloat32", func(t *testing.T) { + result := ParseFloat32(req, "valid", 0.0) + assert.InDelta(t, float32(3.14), result, 0.001) + + result = ParseFloat32(req, "missing", 1.5) + assert.Equal(t, float32(1.5), result) + }) + + t.Run("ParseFloat64", func(t *testing.T) { + result := ParseFloat64(req, "valid", 0.0) + assert.InDelta(t, 3.14, result, 0.001) + }) +} + +func TestParseString(t *testing.T) { + req := CallToolRequest{} + req.Params.Arguments = map[string]any{ + "valid": "hello", + "int": 123, + } + + t.Run("valid string", func(t *testing.T) { + result := ParseString(req, "valid", "") + assert.Equal(t, "hello", result) + }) + + t.Run("converts int to string", func(t *testing.T) { + result := ParseString(req, "int", "") + assert.Equal(t, "123", result) + }) + + t.Run("missing returns default", func(t *testing.T) { + result := ParseString(req, "missing", "default") + assert.Equal(t, "default", result) + }) +} + +// Test ToBoolPtr + +func TestToBoolPtr(t *testing.T) { + t.Run("true", func(t *testing.T) { + result := ToBoolPtr(true) + require.NotNil(t, result) + assert.True(t, *result) + }) + + t.Run("false", func(t *testing.T) { + result := ToBoolPtr(false) + require.NotNil(t, result) + assert.False(t, *result) + }) +} + +// Test NewToolResultJSON with error + +func TestNewToolResultJSON_Error(t *testing.T) { + // Create a type that can't be marshaled + type BadType struct { + Func func() // functions can't be marshaled + } + + _, err := NewToolResultJSON(BadType{Func: func() {}}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unable to marshal JSON") +} + +// Test FormatNumberResult + +func TestFormatNumberResult(t *testing.T) { + result := FormatNumberResult(42.5678) + + require.Len(t, result.Content, 1) + textContent, ok := result.Content[0].(TextContent) + require.True(t, ok) + assert.Equal(t, "42.57", textContent.Text) +} + +// Test NewToolResultErrorFromErr with nil error + +func TestNewToolResultErrorFromErr_NilError(t *testing.T) { + result := NewToolResultErrorFromErr("test error", nil) + + assert.True(t, result.IsError) + require.Len(t, result.Content, 1) + textContent, ok := result.Content[0].(TextContent) + require.True(t, ok) + assert.Equal(t, "test error", textContent.Text) +} + +func TestNewToolResultErrorFromErr_WithError(t *testing.T) { + result := NewToolResultErrorFromErr("test error", errors.New("underlying error")) + + assert.True(t, result.IsError) + require.Len(t, result.Content, 1) + textContent, ok := result.Content[0].(TextContent) + require.True(t, ok) + assert.Contains(t, textContent.Text, "test error") + assert.Contains(t, textContent.Text, "underlying error") +} diff --git a/mcp/utils_helpers_test.go b/mcp/utils_helpers_test.go new file mode 100644 index 000000000..5f7354562 --- /dev/null +++ b/mcp/utils_helpers_test.go @@ -0,0 +1,210 @@ +package mcp + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test helper functions with 0% coverage + +func TestNewProgressNotification(t *testing.T) { + token := ProgressToken("test-token") + progress := 50.0 + total := 100.0 + message := "Processing..." + + result := NewProgressNotification(token, progress, &total, &message) + + assert.Equal(t, "notifications/progress", result.Method) + assert.Equal(t, token, result.Params.ProgressToken) + assert.Equal(t, progress, result.Params.Progress) + assert.Equal(t, total, result.Params.Total) + assert.Equal(t, message, result.Params.Message) +} + +func TestNewProgressNotification_WithNils(t *testing.T) { + token := ProgressToken("test-token") + progress := 50.0 + + result := NewProgressNotification(token, progress, nil, nil) + + assert.Equal(t, "notifications/progress", result.Method) + assert.Equal(t, token, result.Params.ProgressToken) + assert.Equal(t, progress, result.Params.Progress) + assert.Equal(t, 0.0, result.Params.Total) + assert.Equal(t, "", result.Params.Message) +} + +func TestNewLoggingMessageNotification(t *testing.T) { + level := LoggingLevelInfo + logger := "test-logger" + data := map[string]any{"key": "value"} + + result := NewLoggingMessageNotification(level, logger, data) + + assert.Equal(t, "notifications/message", result.Method) + assert.Equal(t, level, result.Params.Level) + assert.Equal(t, logger, result.Params.Logger) + assert.Equal(t, data, result.Params.Data) +} + +func TestNewToolResultImage(t *testing.T) { + text := "Image result" + imageData := "base64imagedata" + mimeType := "image/png" + + result := NewToolResultImage(text, imageData, mimeType) + + require.Len(t, result.Content, 2) + + textContent, ok := result.Content[0].(TextContent) + require.True(t, ok) + assert.Equal(t, text, textContent.Text) + + imageContent, ok := result.Content[1].(ImageContent) + require.True(t, ok) + assert.Equal(t, imageData, imageContent.Data) + assert.Equal(t, mimeType, imageContent.MIMEType) +} + +func TestNewToolResultAudio(t *testing.T) { + text := "Audio result" + audioData := "base64audiodata" + mimeType := "audio/mp3" + + result := NewToolResultAudio(text, audioData, mimeType) + + require.Len(t, result.Content, 2) + + textContent, ok := result.Content[0].(TextContent) + require.True(t, ok) + assert.Equal(t, text, textContent.Text) + + audioContent, ok := result.Content[1].(AudioContent) + require.True(t, ok) + assert.Equal(t, audioData, audioContent.Data) + assert.Equal(t, mimeType, audioContent.MIMEType) +} + +func TestNewToolResultResource(t *testing.T) { + text := "Resource result" + resource := TextResourceContents{ + URI: "file:///test.txt", + MIMEType: "text/plain", + Text: "content", + } + + result := NewToolResultResource(text, resource) + + require.Len(t, result.Content, 2) + + textContent, ok := result.Content[0].(TextContent) + require.True(t, ok) + assert.Equal(t, text, textContent.Text) + + embeddedResource, ok := result.Content[1].(EmbeddedResource) + require.True(t, ok) + assert.Equal(t, resource, embeddedResource.Resource) +} + +func TestNewToolResultErrorf(t *testing.T) { + result := NewToolResultErrorf("error code: %d, message: %s", 404, "not found") + + assert.True(t, result.IsError) + require.Len(t, result.Content, 1) + + textContent, ok := result.Content[0].(TextContent) + require.True(t, ok) + assert.Equal(t, "error code: 404, message: not found", textContent.Text) +} + +func TestNewListResourcesResult(t *testing.T) { + resources := []Resource{ + {URI: "file:///test1.txt", Name: "test1.txt"}, + {URI: "file:///test2.txt", Name: "test2.txt"}, + } + cursor := Cursor("next-page") + + result := NewListResourcesResult(resources, cursor) + + assert.Equal(t, resources, result.Resources) + assert.Equal(t, cursor, result.NextCursor) +} + +func TestNewListResourceTemplatesResult(t *testing.T) { + templates := []ResourceTemplate{ + {Name: "template1"}, + {Name: "template2"}, + } + cursor := Cursor("next-page") + + result := NewListResourceTemplatesResult(templates, cursor) + + assert.Equal(t, templates, result.ResourceTemplates) + assert.Equal(t, cursor, result.NextCursor) +} + +func TestNewReadResourceResult(t *testing.T) { + text := "file content" + + result := NewReadResourceResult(text) + + require.Len(t, result.Contents, 1) + textContents, ok := result.Contents[0].(TextResourceContents) + require.True(t, ok) + assert.Equal(t, text, textContents.Text) +} + +func TestNewListPromptsResult(t *testing.T) { + prompts := []Prompt{ + {Name: "prompt1"}, + {Name: "prompt2"}, + } + cursor := Cursor("next-page") + + result := NewListPromptsResult(prompts, cursor) + + assert.Equal(t, prompts, result.Prompts) + assert.Equal(t, cursor, result.NextCursor) +} + +func TestNewGetPromptResult(t *testing.T) { + description := "Test prompt" + messages := []PromptMessage{ + {Role: RoleUser, Content: TextContent{Text: "Hello"}}, + } + + result := NewGetPromptResult(description, messages) + + assert.Equal(t, description, result.Description) + assert.Equal(t, messages, result.Messages) +} + +func TestNewListToolsResult(t *testing.T) { + tools := []Tool{ + {Name: "tool1"}, + {Name: "tool2"}, + } + cursor := Cursor("next-page") + + result := NewListToolsResult(tools, cursor) + + assert.Equal(t, tools, result.Tools) + assert.Equal(t, cursor, result.NextCursor) +} + +func TestNewInitializeResult(t *testing.T) { + version := "1.0" + capabilities := ServerCapabilities{} + serverInfo := Implementation{Name: "test-server", Version: "1.0"} + instructions := "Use this server carefully" + + result := NewInitializeResult(version, capabilities, serverInfo, instructions) + + assert.Equal(t, version, result.ProtocolVersion) + assert.Equal(t, capabilities, result.Capabilities) + assert.Equal(t, serverInfo, result.ServerInfo) + assert.Equal(t, instructions, result.Instructions) +} diff --git a/mcp/utils_test.go b/mcp/utils_test.go index aad89b64e..fb6ca2aaf 100644 --- a/mcp/utils_test.go +++ b/mcp/utils_test.go @@ -1,9 +1,9 @@ package mcp import ( - "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/stretchr/testify/require" + "testing" ) func TestParseAnnotations(t *testing.T) { diff --git a/server/server.go b/server/server.go index e9b48e6dd..5234bd6c3 100644 --- a/server/server.go +++ b/server/server.go @@ -124,7 +124,7 @@ func (e *requestError) ToJSONRPCError() mcp.JSONRPCError { return mcp.JSONRPCError{ JSONRPC: mcp.JSONRPC_VERSION, ID: mcp.NewRequestId(e.id), - Error: mcp.NewJSONRPCErrorDetails(e.code, e.err.Error(), nil), + Error: mcp.NewJSONRPCErrorDetails(e.code, e.err.Error(), nil), } } @@ -1215,6 +1215,6 @@ func createErrorResponse( return mcp.JSONRPCError{ JSONRPC: mcp.JSONRPC_VERSION, ID: mcp.NewRequestId(id), - Error: mcp.NewJSONRPCErrorDetails(code, message, nil), + Error: mcp.NewJSONRPCErrorDetails(code, message, nil), } } diff --git a/server/server_additional_test.go b/server/server_additional_test.go new file mode 100644 index 000000000..4da5708a0 --- /dev/null +++ b/server/server_additional_test.go @@ -0,0 +1,889 @@ +package server + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestMCPServer_MiddlewarePanicRecovery tests that panics in middleware are properly recovered +func TestMCPServer_MiddlewarePanicRecovery(t *testing.T) { + t.Run("tool handler panic with recovery middleware", func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithRecovery()) + + server.AddTool( + mcp.NewTool("panic-tool"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + panic("intentional panic in tool handler") + }, + ) + + response := server.HandleMessage(context.Background(), []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "panic-tool" + } + }`)) + + errorResponse, ok := response.(mcp.JSONRPCError) + require.True(t, ok) + assert.Equal(t, mcp.INTERNAL_ERROR, errorResponse.Error.Code) + assert.Contains(t, errorResponse.Error.Message, "panic recovered") + assert.Contains(t, errorResponse.Error.Message, "intentional panic in tool handler") + }) + + t.Run("resource handler panic with recovery middleware", func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", + WithResourceCapabilities(false, false), + WithResourceRecovery(), + ) + + server.AddResource( + mcp.Resource{URI: "test://panic-resource", Name: "Panic Resource"}, + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + panic("intentional panic in resource handler") + }, + ) + + response := server.HandleMessage(context.Background(), []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "resources/read", + "params": { + "uri": "test://panic-resource" + } + }`)) + + errorResponse, ok := response.(mcp.JSONRPCError) + require.True(t, ok) + assert.Equal(t, mcp.INTERNAL_ERROR, errorResponse.Error.Code) + assert.Contains(t, errorResponse.Error.Message, "panic recovered") + assert.Contains(t, errorResponse.Error.Message, "intentional panic in resource handler") + }) +} + +// TestMCPServer_ConcurrentOperations tests thread safety of Add/Delete operations +func TestMCPServer_ConcurrentOperations(t *testing.T) { + t.Run("concurrent tool add/delete", func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(false)) + + var wg sync.WaitGroup + numGoroutines := 10 + operationsPerGoroutine := 50 + + // Concurrent add operations + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < operationsPerGoroutine; j++ { + toolName := fmt.Sprintf("tool-%d-%d", id, j) + server.AddTool(mcp.NewTool(toolName), nil) + } + }(i) + } + + wg.Wait() + + // Verify all tools were added + tools := server.ListTools() + assert.Len(t, tools, numGoroutines*operationsPerGoroutine) + + // Concurrent delete operations + wg = sync.WaitGroup{} + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < operationsPerGoroutine; j++ { + toolName := fmt.Sprintf("tool-%d-%d", id, j) + server.DeleteTools(toolName) + } + }(i) + } + + wg.Wait() + + // Verify all tools were deleted + tools = server.ListTools() + assert.Nil(t, tools) + }) + + t.Run("concurrent resource add/delete", func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithResourceCapabilities(false, false)) + + var wg sync.WaitGroup + numGoroutines := 10 + operationsPerGoroutine := 50 + + // Concurrent add operations + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < operationsPerGoroutine; j++ { + uri := fmt.Sprintf("test://resource-%d-%d", id, j) + server.AddResource( + mcp.Resource{URI: uri, Name: uri}, + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return nil, nil + }, + ) + } + }(i) + } + + wg.Wait() + + // Concurrent delete operations + wg = sync.WaitGroup{} + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < operationsPerGoroutine; j++ { + uri := fmt.Sprintf("test://resource-%d-%d", id, j) + server.DeleteResources(uri) + } + }(i) + } + + wg.Wait() + }) + + t.Run("concurrent prompt add/delete", func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithPromptCapabilities(false)) + + var wg sync.WaitGroup + numGoroutines := 10 + operationsPerGoroutine := 50 + + // Concurrent add operations + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < operationsPerGoroutine; j++ { + promptName := fmt.Sprintf("prompt-%d-%d", id, j) + server.AddPrompt( + mcp.Prompt{Name: promptName, Description: promptName}, + nil, + ) + } + }(i) + } + + wg.Wait() + + // Concurrent delete operations + wg = sync.WaitGroup{} + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < operationsPerGoroutine; j++ { + promptName := fmt.Sprintf("prompt-%d-%d", id, j) + server.DeletePrompts(promptName) + } + }(i) + } + + wg.Wait() + }) +} + +// TestMCPServer_PaginationEdgeCases tests pagination boundary conditions +func TestMCPServer_PaginationEdgeCases(t *testing.T) { + t.Run("malformed cursor - invalid base64", func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", + WithResourceCapabilities(false, false), + WithPaginationLimit(5), + ) + + // Add some resources + for i := 0; i < 10; i++ { + uri := fmt.Sprintf("test://resource-%d", i) + server.AddResource( + mcp.Resource{URI: uri, Name: fmt.Sprintf("Resource %d", i)}, + nil, + ) + } + + response := server.HandleMessage(context.Background(), []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "resources/list", + "params": { + "cursor": "not-valid-base64!!!" + } + }`)) + + errorResponse, ok := response.(mcp.JSONRPCError) + require.True(t, ok) + assert.Equal(t, mcp.INVALID_PARAMS, errorResponse.Error.Code) + }) + + t.Run("cursor pointing beyond list", func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", + WithToolCapabilities(false), + WithPaginationLimit(5), + ) + + // Add 3 tools + for i := 0; i < 3; i++ { + server.AddTool(mcp.NewTool(fmt.Sprintf("tool-%d", i)), nil) + } + + // Create cursor that points beyond the list + beyondCursor := base64.StdEncoding.EncodeToString([]byte("tool-99")) + + response := server.HandleMessage(context.Background(), []byte(fmt.Sprintf(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list", + "params": { + "cursor": "%s" + } + }`, beyondCursor))) + + resp, ok := response.(mcp.JSONRPCResponse) + require.True(t, ok) + + result, ok := resp.Result.(mcp.ListToolsResult) + require.True(t, ok) + + // Should return empty list with no cursor + assert.Empty(t, result.Tools) + assert.Empty(t, result.NextCursor) + }) + + t.Run("pagination with exactly paginationLimit items", func(t *testing.T) { + limit := 5 + server := NewMCPServer("test-server", "1.0.0", + WithPromptCapabilities(false), + WithPaginationLimit(limit), + ) + + // Add exactly paginationLimit prompts + for i := 0; i < limit; i++ { + server.AddPrompt( + mcp.Prompt{Name: fmt.Sprintf("prompt-%d", i), Description: "Test"}, + nil, + ) + } + + response := server.HandleMessage(context.Background(), []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "prompts/list" + }`)) + + resp, ok := response.(mcp.JSONRPCResponse) + require.True(t, ok) + + result, ok := resp.Result.(mcp.ListPromptsResult) + require.True(t, ok) + + // Should return all items with cursor pointing to last item + assert.Len(t, result.Prompts, limit) + assert.NotEmpty(t, result.NextCursor, "Cursor should be set when exactly at limit") + + // Request next page - should be empty + response = server.HandleMessage(context.Background(), []byte(fmt.Sprintf(`{ + "jsonrpc": "2.0", + "id": 2, + "method": "prompts/list", + "params": { + "cursor": "%s" + } + }`, result.NextCursor))) + + resp, ok = response.(mcp.JSONRPCResponse) + require.True(t, ok) + + result, ok = resp.Result.(mcp.ListPromptsResult) + require.True(t, ok) + + assert.Empty(t, result.Prompts) + assert.Empty(t, result.NextCursor) + }) + + t.Run("empty list pagination", func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", + WithResourceCapabilities(false, false), + WithPaginationLimit(10), + ) + + response := server.HandleMessage(context.Background(), []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "resources/list" + }`)) + + resp, ok := response.(mcp.JSONRPCResponse) + require.True(t, ok) + + result, ok := resp.Result.(mcp.ListResourcesResult) + require.True(t, ok) + + assert.Empty(t, result.Resources) + assert.Empty(t, result.NextCursor) + }) +} + +// TestMCPServer_SessionUnregistrationDuringNotification tests race conditions +func TestMCPServer_SessionUnregistrationDuringNotification(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) + + // Create multiple sessions + numSessions := 10 + sessions := make([]*sessionTestClient, numSessions) + for i := 0; i < numSessions; i++ { + sessions[i] = &sessionTestClient{ + sessionID: fmt.Sprintf("session-%d", i), + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + } + sessions[i].Initialize() + err := server.RegisterSession(context.Background(), sessions[i]) + require.NoError(t, err) + } + + var wg sync.WaitGroup + stopCh := make(chan struct{}) + + // Goroutine that continuously sends notifications + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-stopCh: + return + default: + server.SendNotificationToAllClients("test-notification", map[string]any{ + "data": "test", + }) + time.Sleep(1 * time.Millisecond) + } + } + }() + + // Goroutine that unregisters sessions + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < numSessions; i++ { + time.Sleep(5 * time.Millisecond) + server.UnregisterSession(context.Background(), sessions[i].SessionID()) + } + }() + + // Let it run for a bit + time.Sleep(100 * time.Millisecond) + close(stopCh) + wg.Wait() + + // Should complete without panic or deadlock +} + +// TestMCPServer_DuplicateSessionRegistration tests that duplicate session IDs are rejected +func TestMCPServer_DuplicateSessionRegistration(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + + session1 := &sessionTestClient{ + sessionID: "duplicate-id", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + } + + session2 := &sessionTestClient{ + sessionID: "duplicate-id", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + } + + // First registration should succeed + err := server.RegisterSession(context.Background(), session1) + require.NoError(t, err) + + // Second registration with same ID should fail + err = server.RegisterSession(context.Background(), session2) + require.Error(t, err) + assert.ErrorIs(t, err, ErrSessionExists) +} + +// TestMCPServer_SessionToolOperationsAfterUnregister tests operations on removed sessions +func TestMCPServer_SessionToolOperationsAfterUnregister(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) + + session := &sessionTestClientWithTools{ + sessionID: "test-session", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: true, + } + + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + // Add a tool to the session + err = server.AddSessionTool(session.SessionID(), mcp.NewTool("test-tool"), nil) + require.NoError(t, err) + + // Unregister the session + server.UnregisterSession(context.Background(), session.SessionID()) + + // Try to add tool to unregistered session + err = server.AddSessionTool(session.SessionID(), mcp.NewTool("another-tool"), nil) + require.Error(t, err) + assert.ErrorIs(t, err, ErrSessionNotFound) + + // Try to delete tool from unregistered session + err = server.DeleteSessionTools(session.SessionID(), "test-tool") + require.Error(t, err) + assert.ErrorIs(t, err, ErrSessionNotFound) +} + +// TestMCPServer_ResourceTemplateURIMatching tests URI template edge cases +func TestMCPServer_ResourceTemplateURIMatching(t *testing.T) { + tests := []struct { + name string + templateURI string + requestURI string + shouldMatch bool + expectedArgs map[string]any + setupTemplate func(*MCPServer, string) + validateRequest func(*testing.T, mcp.ReadResourceRequest) + }{ + { + name: "exact match no variables", + templateURI: "test://fixed/path", + requestURI: "test://fixed/path", + shouldMatch: true, + setupTemplate: func(s *MCPServer, uri string) { + s.AddResourceTemplate( + mcp.NewResourceTemplate(uri, "Test"), + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: request.Params.URI, + Text: "matched", + }, + }, nil + }, + ) + }, + }, + { + name: "single variable match", + templateURI: "test://users/{id}", + requestURI: "test://users/123", + shouldMatch: true, + expectedArgs: map[string]any{ + "id": []string{"123"}, + }, + setupTemplate: func(s *MCPServer, uri string) { + s.AddResourceTemplate( + mcp.NewResourceTemplate(uri, "User"), + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: request.Params.URI, + Text: fmt.Sprintf("user-id: %v", request.Params.Arguments["id"]), + }, + }, nil + }, + ) + }, + validateRequest: func(t *testing.T, request mcp.ReadResourceRequest) { + assert.NotNil(t, request.Params.Arguments) + assert.Equal(t, []string{"123"}, request.Params.Arguments["id"]) + }, + }, + { + name: "path explosion match", + templateURI: "test://files{/path*}", + requestURI: "test://files/a/b/c", + shouldMatch: true, + expectedArgs: map[string]any{ + "path": []string{"a", "b", "c"}, + }, + setupTemplate: func(s *MCPServer, uri string) { + s.AddResourceTemplate( + mcp.NewResourceTemplate(uri, "Files"), + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: request.Params.URI, + Text: fmt.Sprintf("path: %v", request.Params.Arguments["path"]), + }, + }, nil + }, + ) + }, + validateRequest: func(t *testing.T, request mcp.ReadResourceRequest) { + assert.NotNil(t, request.Params.Arguments) + pathParts, ok := request.Params.Arguments["path"].([]string) + require.True(t, ok) + assert.Equal(t, []string{"a", "b", "c"}, pathParts) + }, + }, + { + name: "no match - different scheme", + templateURI: "test://resource", + requestURI: "other://resource", + shouldMatch: false, + setupTemplate: func(s *MCPServer, uri string) { + s.AddResourceTemplate( + mcp.NewResourceTemplate(uri, "Test"), + func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return nil, nil + }, + ) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithResourceCapabilities(false, false)) + tt.setupTemplate(server, tt.templateURI) + + requestBytes, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "resources/read", + "params": map[string]any{ + "uri": tt.requestURI, + }, + }) + require.NoError(t, err) + + response := server.HandleMessage(context.Background(), requestBytes) + + if tt.shouldMatch { + resp, ok := response.(mcp.JSONRPCResponse) + require.True(t, ok, "Expected successful response for matching URI") + + result, ok := resp.Result.(mcp.ReadResourceResult) + require.True(t, ok) + require.NotEmpty(t, result.Contents) + + // Validate request if validator provided + if tt.validateRequest != nil { + // We need to capture the request in the handler to validate it + // This is a bit tricky, so we'll just check the expected args + if tt.expectedArgs != nil { + content := result.Contents[0].(mcp.TextResourceContents) + // The text should contain our arguments + for key, expectedVal := range tt.expectedArgs { + assert.Contains(t, content.Text, fmt.Sprintf("%v", expectedVal), + "Response should contain argument %s=%v", key, expectedVal) + } + } + } + } else { + errorResp, ok := response.(mcp.JSONRPCError) + require.True(t, ok, "Expected error response for non-matching URI") + assert.Equal(t, mcp.RESOURCE_NOT_FOUND, errorResp.Error.Code) + } + }) + } +} + +// TestMCPServer_UnsupportedProtocolVersions tests client/server version negotiation +func TestMCPServer_UnsupportedProtocolVersions(t *testing.T) { + tests := []struct { + name string + clientVersion string + expectedVersion string + description string + }{ + { + name: "ancient unsupported version", + clientVersion: "2020-01-01", + expectedVersion: mcp.LATEST_PROTOCOL_VERSION, + description: "Server should respond with its latest version", + }, + { + name: "future unsupported version", + clientVersion: "2030-12-31", + expectedVersion: mcp.LATEST_PROTOCOL_VERSION, + description: "Server should respond with its latest version", + }, + { + name: "supported version", + clientVersion: "2024-11-05", + expectedVersion: "2024-11-05", + description: "Server should respond with client's version if supported", + }, + { + name: "latest supported version", + clientVersion: mcp.LATEST_PROTOCOL_VERSION, + expectedVersion: mcp.LATEST_PROTOCOL_VERSION, + description: "Server should respond with matching version", + }, + { + name: "empty version defaults to 2025-03-26", + clientVersion: "", + expectedVersion: "2025-03-26", + description: "Backward compatibility: empty version defaults to 2025-03-26", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = tt.clientVersion + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + requestBytes, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": initRequest.Params, + }) + require.NoError(t, err) + + response := server.HandleMessage(context.Background(), requestBytes) + + resp, ok := response.(mcp.JSONRPCResponse) + require.True(t, ok) + + result, ok := resp.Result.(mcp.InitializeResult) + require.True(t, ok) + + assert.Equal(t, tt.expectedVersion, result.ProtocolVersion, tt.description) + }) + } +} + +// TestMCPServer_HooksWithNilSession tests that hooks handle nil session contexts gracefully +func TestMCPServer_HooksWithNilSession(t *testing.T) { + hookCalled := false + var receivedSession ClientSession + + hooks := &Hooks{} + hooks.AddBeforeAny(func(ctx context.Context, id any, method mcp.MCPMethod, message any) { + hookCalled = true + receivedSession = ClientSessionFromContext(ctx) + }) + + server := NewMCPServer("test-server", "1.0.0", WithHooks(hooks)) + + // Make request without session context + response := server.HandleMessage(context.Background(), []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "ping" + }`)) + + require.NotNil(t, response) + assert.True(t, hookCalled, "Hook should be called even without session") + assert.Nil(t, receivedSession, "Session should be nil when not in context") +} + +// TestMCPServer_CapabilityImplicitRegistration tests edge cases in capability registration +func TestMCPServer_CapabilityImplicitRegistration(t *testing.T) { + t.Run("implicit registration after explicit false", func(t *testing.T) { + // If user explicitly sets listChanged=false, adding tools shouldn't override it + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(false)) + + server.capabilitiesMu.RLock() + initialListChanged := server.capabilities.tools.listChanged + server.capabilitiesMu.RUnlock() + assert.False(t, initialListChanged) + + // Add a tool - should not change listChanged to true + server.AddTool(mcp.NewTool("test-tool"), nil) + + server.capabilitiesMu.RLock() + finalListChanged := server.capabilities.tools.listChanged + server.capabilitiesMu.RUnlock() + assert.False(t, finalListChanged, "Explicit false should not be overridden by implicit registration") + }) + + t.Run("implicit registration when no capability set", func(t *testing.T) { + // If no capability was set, adding tools should enable it with listChanged=true + server := NewMCPServer("test-server", "1.0.0") + + server.capabilitiesMu.RLock() + initialTools := server.capabilities.tools + server.capabilitiesMu.RUnlock() + assert.Nil(t, initialTools) + + // Add a tool - should implicitly register with listChanged=true + server.AddTool(mcp.NewTool("test-tool"), nil) + + server.capabilitiesMu.RLock() + finalTools := server.capabilities.tools + server.capabilitiesMu.RUnlock() + require.NotNil(t, finalTools) + assert.True(t, finalTools.listChanged) + }) + + t.Run("resources implicit registration", func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + + // Initially nil + server.capabilitiesMu.RLock() + initialResources := server.capabilities.resources + server.capabilitiesMu.RUnlock() + assert.Nil(t, initialResources) + + // Add resource - should implicitly register + server.AddResource( + mcp.Resource{URI: "test://resource", Name: "Test"}, + nil, + ) + + server.capabilitiesMu.RLock() + finalResources := server.capabilities.resources + server.capabilitiesMu.RUnlock() + require.NotNil(t, finalResources) + // For resources, implicit registration doesn't set listChanged + assert.False(t, finalResources.listChanged) + }) + + t.Run("prompts implicit registration", func(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + + // Initially nil + server.capabilitiesMu.RLock() + initialPrompts := server.capabilities.prompts + server.capabilitiesMu.RUnlock() + assert.Nil(t, initialPrompts) + + // Add prompt - should implicitly register + server.AddPrompt( + mcp.Prompt{Name: "test-prompt", Description: "Test"}, + nil, + ) + + server.capabilitiesMu.RLock() + finalPrompts := server.capabilities.prompts + server.capabilitiesMu.RUnlock() + require.NotNil(t, finalPrompts) + // For prompts, implicit registration doesn't set listChanged + assert.False(t, finalPrompts.listChanged) + }) +} + +// TestMCPServer_ConcurrentCapabilityChecks tests thread safety of capability access +func TestMCPServer_ConcurrentCapabilityChecks(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + + var wg sync.WaitGroup + stopCh := make(chan struct{}) + errorCount := atomic.Int32{} + + // Goroutine that adds tools (triggers capability registration) + wg.Add(1) + go func() { + defer wg.Done() + i := 0 + for { + select { + case <-stopCh: + return + default: + server.AddTool(mcp.NewTool(fmt.Sprintf("tool-%d", i)), nil) + i++ + time.Sleep(1 * time.Millisecond) + } + } + }() + + // Goroutines that check capabilities + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-stopCh: + return + default: + server.capabilitiesMu.RLock() + _ = server.capabilities.tools + server.capabilitiesMu.RUnlock() + time.Sleep(1 * time.Millisecond) + } + } + }() + } + + // Let it run for a bit + time.Sleep(50 * time.Millisecond) + close(stopCh) + wg.Wait() + + assert.Equal(t, int32(0), errorCount.Load(), "Should complete without errors") +} + +// TestMCPServer_PaginationCursorStability tests pagination behavior when items change +func TestMCPServer_PaginationCursorStability(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", + WithToolCapabilities(false), + WithPaginationLimit(5), + ) + + // Add initial tools + for i := 0; i < 10; i++ { + server.AddTool(mcp.NewTool(fmt.Sprintf("tool-%02d", i)), nil) + } + + // Get first page + response := server.HandleMessage(context.Background(), []byte(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list" + }`)) + + resp, ok := response.(mcp.JSONRPCResponse) + require.True(t, ok) + + result, ok := resp.Result.(mcp.ListToolsResult) + require.True(t, ok) + + assert.Len(t, result.Tools, 5) + cursor := result.NextCursor + require.NotEmpty(t, cursor) + + // Modify list (add and remove tools) + server.AddTool(mcp.NewTool("tool-new-1"), nil) + server.DeleteTools("tool-05") + + // Get second page with original cursor + response = server.HandleMessage(context.Background(), []byte(fmt.Sprintf(`{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + "params": { + "cursor": "%s" + } + }`, cursor))) + + resp, ok = response.(mcp.JSONRPCResponse) + require.True(t, ok) + + result, ok = resp.Result.(mcp.ListToolsResult) + require.True(t, ok) + + // Should handle gracefully (may have different results due to modifications) + // The key is that it shouldn't crash or return errors + assert.NotNil(t, result.Tools) +} diff --git a/server/stdio.go b/server/stdio.go index d80941c3d..80131f06c 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -606,7 +606,18 @@ func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool { if err := json.Unmarshal(response.Result, &result); err != nil { samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err) } else { - samplingResp.result = &result + // Parse content from map[string]any to proper Content type (TextContent, ImageContent, AudioContent) + if contentMap, ok := result.Content.(map[string]any); ok { + content, err := mcp.ParseContent(contentMap) + if err != nil { + samplingResp.err = fmt.Errorf("failed to parse sampling response content: %w", err) + } else { + result.Content = content + samplingResp.result = &result + } + } else { + samplingResp.result = &result + } } } diff --git a/server/streamable_http.go b/server/streamable_http.go index 056dc876c..9ad37fea1 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -309,7 +309,19 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request } } - session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels) + // Check if a persistent session exists (for sampling support), otherwise create ephemeral session + // Persistent sessions are created by GET (continuous listening) connections + var session *streamableHttpSession + if sessionInterface, exists := s.activeSessions.Load(sessionID); exists { + if persistentSession, ok := sessionInterface.(*streamableHttpSession); ok { + session = persistentSession + } + } + + // Create ephemeral session if no persistent session exists + if session == nil { + session = newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels) + } // Set the client context before handling the message ctx := s.server.WithContext(r.Context(), session) @@ -420,16 +432,23 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) sessionID = uuid.New().String() } - session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels) - if err := s.server.RegisterSession(r.Context(), session); err != nil { - http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest) - return + // Get or create session atomically to prevent TOCTOU races + // where concurrent GETs could both create and register duplicate sessions + var session *streamableHttpSession + newSession := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels) + actual, loaded := s.activeSessions.LoadOrStore(sessionID, newSession) + session = actual.(*streamableHttpSession) + + if !loaded { + // We created a new session, need to register it + if err := s.server.RegisterSession(r.Context(), session); err != nil { + s.activeSessions.Delete(sessionID) + http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest) + return + } + defer s.server.UnregisterSession(r.Context(), sessionID) + defer s.activeSessions.Delete(sessionID) } - defer s.server.UnregisterSession(r.Context(), sessionID) - - // Register session for sampling response delivery - s.activeSessions.Store(sessionID, session) - defer s.activeSessions.Delete(sessionID) // Set the client context before handling the message w.Header().Set("Content-Type", "text/event-stream") @@ -626,7 +645,8 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r * response.err = fmt.Errorf("sampling error %d: %s", jsonrpcError.Code, jsonrpcError.Message) } } else if responseMessage.Result != nil { - // Parse result + // Store the result to be unmarshaled later + response.result = responseMessage.Result } else { response.err = fmt.Errorf("sampling response has neither result nor error") } @@ -900,6 +920,17 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp if err := json.Unmarshal(response.result, &result); err != nil { return nil, fmt.Errorf("failed to unmarshal sampling response: %v", err) } + + // Parse content from map[string]any to proper Content type (TextContent, ImageContent, AudioContent) + // HTTP transport unmarshals Content as map[string]any, we need to convert it to the proper type + if contentMap, ok := result.Content.(map[string]any); ok { + content, err := mcp.ParseContent(contentMap) + if err != nil { + return nil, fmt.Errorf("failed to parse sampling response content: %w", err) + } + result.Content = content + } + return &result, nil case <-ctx.Done(): return nil, ctx.Err() diff --git a/testdata/mockstdio_server.go b/testdata/mockstdio_server.go index b7bebc248..6a83f1841 100644 --- a/testdata/mockstdio_server.go +++ b/testdata/mockstdio_server.go @@ -6,6 +6,7 @@ import ( "fmt" "log/slog" "os" + "strings" "github.com/mark3labs/mcp-go/mcp" ) @@ -18,19 +19,26 @@ type JSONRPCRequest struct { } type JSONRPCResponse struct { - JSONRPC string `json:"jsonrpc"` - ID *mcp.RequestId `json:"id,omitempty"` - Result any `json:"result,omitempty"` + JSONRPC string `json:"jsonrpc"` + ID *mcp.RequestId `json:"id,omitempty"` + Result any `json:"result,omitempty"` Error *mcp.JSONRPCErrorDetails `json:"error,omitempty"` } func main() { logger := slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{})) logger.Info("launch successful") - scanner := bufio.NewScanner(os.Stdin) - for scanner.Scan() { + reader := bufio.NewReader(os.Stdin) + for { + line, err := reader.ReadString('\n') + if err != nil { + break + } + + line = strings.TrimRight(line, "\r\n") + var request JSONRPCRequest - if err := json.Unmarshal(scanner.Bytes(), &request); err != nil { + if err := json.Unmarshal([]byte(line), &request); err != nil { continue } diff --git a/www/docs/pages/clients/advanced-sampling.mdx b/www/docs/pages/clients/advanced-sampling.mdx index 81a4cc9aa..02fc959a2 100644 --- a/www/docs/pages/clients/advanced-sampling.mdx +++ b/www/docs/pages/clients/advanced-sampling.mdx @@ -6,6 +6,24 @@ Learn how to implement MCP clients that can handle sampling requests from server Sampling allows MCP clients to respond to LLM completion requests from servers. When a server needs to generate content, answer questions, or perform reasoning tasks, it can send a sampling request to the client, which then processes it using an LLM and returns the result. +:::danger[Critical Security Requirement] +Per the [MCP specification](https://modelcontextprotocol.io/specification/2025-06-18/client/sampling#user-interaction-model), sampling implementations **SHOULD** always include a human in the loop with the ability to deny sampling requests. + +**You MUST implement approval flows that:** +- Present each sampling request to the user for review before execution +- Allow users to view and edit prompts before sending to the LLM +- Display generated responses for user approval before returning to the server +- Provide clear UI to accept or reject requests at each stage + +**Without human approval, your implementation:** +- Allows servers to make unauthorized LLM requests without user consent +- May expose sensitive information through unreviewed prompts +- Creates uncontrolled API costs from automated sampling +- Violates user trust and security best practices + +The examples below show basic handler implementation. **You must add approval logic** before using in production. +::: + ## Implementing a Sampling Handler Create a sampling handler by implementing the `SamplingHandler` interface: diff --git a/www/docs/pages/servers/advanced-sampling.mdx b/www/docs/pages/servers/advanced-sampling.mdx index 1bc05eb6e..4ae98fcd5 100644 --- a/www/docs/pages/servers/advanced-sampling.mdx +++ b/www/docs/pages/servers/advanced-sampling.mdx @@ -6,6 +6,25 @@ Learn how to implement MCP servers that can request LLM completions from clients Sampling allows MCP servers to request LLM completions from clients, enabling bidirectional communication where servers can leverage client-side LLM capabilities. This is particularly useful for tools that need to generate content, answer questions, or perform reasoning tasks. +:::info[User Consent Required] +Per the [MCP specification](https://modelcontextprotocol.io/specification/2025-06-18/client/sampling#user-interaction-model), clients **SHOULD** implement human-in-the-loop approval for sampling requests. + +When you request sampling from a client: +- The user will typically be prompted to review and approve your request +- The user may modify your prompts before sending to their LLM +- The user may reject your request entirely +- Response times may be longer due to user interaction + +**Design your tools accordingly:** +- Provide clear descriptions of why sampling is needed +- Use descriptive system prompts explaining the purpose +- Handle rejection errors gracefully +- Consider timeouts for user approval delays +- Don't assume immediate or automatic approval + +Well-designed sampling requests improve user trust and approval rates. +::: + ## Enabling Sampling To enable sampling in your server, call `EnableSampling()` during server setup: diff --git a/www/docs/pages/transports/http.mdx b/www/docs/pages/transports/http.mdx index 39089ad6b..24ed0a392 100644 --- a/www/docs/pages/transports/http.mdx +++ b/www/docs/pages/transports/http.mdx @@ -689,6 +689,160 @@ This works for all MCP request types including: The headers are automatically populated by the transport layer and are available in your handlers without any additional configuration. +## Sampling Support + +StreamableHTTP transport now supports bidirectional sampling, allowing servers to request LLM completions from clients. This enables advanced scenarios where servers can leverage client-side LLM capabilities. + +:::warning[Security: Human-in-the-Loop Required] +Per the [MCP specification](https://modelcontextprotocol.io/specification/2025-06-18/client/sampling), implementations **SHOULD** always include a human in the loop with the ability to deny sampling requests. + +**Your sampling handler implementation MUST:** +- Present sampling requests to users for review before execution +- Allow users to view and edit prompts before sending to the LLM +- Display generated responses for approval before returning to the server +- Provide clear UI to accept or reject sampling requests + +Failing to implement approval flows creates serious security and trust risks, including: +- Servers making unauthorized LLM requests on behalf of users +- Exposure of sensitive data through unreviewed prompts +- Uncontrolled API costs from automated sampling +- Lack of user consent for AI interactions + +See the [example implementation](#example-with-approval-flow) below for a reference approval pattern. +::: + +### Requirements for Sampling + +To enable sampling with StreamableHTTP transport, the client **must** use the `WithContinuousListening()` option: + +```go +// Client setup with sampling support +httpTransport, err := transport.NewStreamableHTTP( + serverURL, + transport.WithContinuousListening(), // Required for sampling +) + +// Create client with sampling handler +mcpClient := client.NewClient(httpTransport, + client.WithSamplingHandler(samplingHandler)) +``` + +Without `WithContinuousListening()`, the client won't maintain a persistent connection to receive sampling requests from the server. + +### Server-Side Implementation + +Enable sampling in your StreamableHTTP server: + +```go +mcpServer := server.NewMCPServer("HTTP Sampling Server", "1.0.0") +mcpServer.EnableSampling() + +// Add a tool that uses sampling +mcpServer.AddTool(mcp.Tool{ + Name: "ask-llm", + Description: "Ask the LLM a question", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "question": map[string]any{ + "type": "string", + "description": "Question to ask", + }, + }, + Required: []string{"question"}, + }, +}, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + question := mcp.ParseString(req, "question", "") + + // Request sampling from client + samplingRequest := mcp.CreateMessageRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodSamplingCreateMessage), + }, + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: question, + }, + }, + }, + MaxTokens: 1000, + }, + } + + result, err := mcpServer.RequestSampling(ctx, samplingRequest) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("Sampling failed: %v", err)), nil + } + + return mcp.NewToolResultText(mcp.GetTextFromContent(result.Content)), nil +}) +``` + +### How It Works + +1. **Persistent Connection**: When `WithContinuousListening()` is enabled, the client maintains a persistent SSE connection to the server +2. **Bidirectional Communication**: The server can send sampling requests through the SSE stream +3. **Response Channel**: The client responds to sampling requests via HTTP POST to the same endpoint +4. **Session Correlation**: Responses are correlated using session IDs to ensure they reach the correct handler + +### Limitations + +- Sampling requires `WithContinuousListening()` to maintain the SSE connection +- Without continuous listening, the transport operates in stateless request/response mode only +- Network interruptions may require reconnection and re-establishment of the sampling channel + +### Example with Approval Flow + +Here's a reference implementation showing proper human-in-the-loop approval: + +```go +type ApprovalSamplingHandler struct { + llmClient LLMClient // Your actual LLM client + ui UserInterface // Your UI for presenting requests to users +} + +func (h *ApprovalSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Step 1: Present the sampling request to the user for review + approved, modifiedRequest, err := h.ui.PresentSamplingRequest(ctx, request) + if err != nil { + return nil, fmt.Errorf("failed to get user approval: %w", err) + } + + if !approved { + return nil, fmt.Errorf("user rejected sampling request") + } + + // Step 2: Send the approved/modified request to the LLM + response, err := h.llmClient.CreateCompletion(ctx, modifiedRequest) + if err != nil { + return nil, fmt.Errorf("LLM request failed: %w", err) + } + + // Step 3: Present the response to the user for final approval + approved, modifiedResponse, err := h.ui.PresentSamplingResponse(ctx, response) + if err != nil { + return nil, fmt.Errorf("failed to get response approval: %w", err) + } + + if !approved { + return nil, fmt.Errorf("user rejected sampling response") + } + + // Step 4: Return the approved response to the server + return modifiedResponse, nil +} +``` + +**Key Points:** +- Users must explicitly approve both the request (before sending to LLM) and the response (before returning to server) +- Users can modify prompts or responses before approval +- Rejection at any stage returns an error to the server +- The UI should clearly display what the server is requesting and why + ## Next Steps - **[In-Process Transport](/transports/inprocess)** - Learn about embedded scenarios