diff --git a/client/transport/stdio.go b/client/transport/stdio.go index 69557668..e4f26857 100644 --- a/client/transport/stdio.go +++ b/client/transport/stdio.go @@ -9,6 +9,7 @@ import ( "io" "os" "os/exec" + "strings" "sync" "github.com/mark3labs/mcp-go/mcp" @@ -27,7 +28,7 @@ type Stdio struct { cmd *exec.Cmd cmdFunc CommandFunc stdin io.WriteCloser - stdout *bufio.Scanner + stdout *bufio.Reader stderr io.ReadCloser responses map[string]chan *JSONRPCResponse mu sync.RWMutex @@ -72,7 +73,7 @@ func WithCommandLogger(logger util.Logger) StdioOption { func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio { return &Stdio{ stdin: output, - stdout: bufio.NewScanner(input), + stdout: bufio.NewReader(input), stderr: logging, responses: make(map[string]chan *JSONRPCResponse), @@ -180,7 +181,7 @@ func (c *Stdio) spawnCommand(ctx context.Context) error { c.cmd = cmd c.stdin = stdin c.stderr = stderr - c.stdout = bufio.NewScanner(stdout) + c.stdout = bufio.NewReader(stdout) if err := cmd.Start(); err != nil { return fmt.Errorf("failed to start command: %w", err) @@ -251,15 +252,15 @@ func (c *Stdio) readResponses() { case <-c.done: return default: - if !c.stdout.Scan() { - err := c.stdout.Err() - if err != nil && !errors.Is(err, context.Canceled) { + line, err := c.stdout.ReadString('\n') + if err != nil { + if err != io.EOF && !errors.Is(err, context.Canceled) { c.logger.Errorf("Error reading from stdout: %v", err) } return } - line := c.stdout.Text() + line = strings.TrimRight(line, "\r\n") // First try to parse as a generic message to check for ID field var baseMessage struct { JSONRPC string `json:"jsonrpc"` diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index 56867acf..04bf18a6 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -703,3 +703,120 @@ func TestStdio_NewStdioWithOptions_AppliesOptions(t *testing.T) { require.NotNil(t, stdio) require.True(t, configured, "option was not applied") } + +func TestStdio_LargeMessages(t *testing.T) { + tempFile, err := os.CreateTemp("", "mockstdio_server") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + tempFile.Close() + mockServerPath := tempFile.Name() + + if runtime.GOOS == "windows" { + os.Remove(mockServerPath) + mockServerPath += ".exe" + } + + if compileErr := compileTestServer(mockServerPath); compileErr != nil { + t.Fatalf("Failed to compile mock server: %v", compileErr) + } + defer os.Remove(mockServerPath) + + stdio := NewStdio(mockServerPath, nil) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + startErr := stdio.Start(ctx) + if startErr != nil { + t.Fatalf("Failed to start Stdio transport: %v", startErr) + } + defer stdio.Close() + + testCases := []struct { + name string + dataSize int + description string + }{ + {"SmallMessage_1KB", 1024, "Small message under scanner default limit"}, + {"MediumMessage_32KB", 32 * 1024, "Medium message under scanner default limit"}, + {"AtLimit_64KB", 64 * 1024, "Message at default scanner limit"}, + {"OverLimit_128KB", 128 * 1024, "Message over default scanner limit - would fail with Scanner"}, + {"Large_256KB", 256 * 1024, "Large message well over scanner limit"}, + {"VeryLarge_1MB", 1024 * 1024, "Very large message"}, + {"Huge_5MB", 5 * 1024 * 1024, "Huge message to stress test"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + largeString := generateRandomString(tc.dataSize) + + params := map[string]any{ + "data": largeString, + "size": len(largeString), + } + + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(int64(1)), + Method: "debug/echo", + Params: params, + } + + response, err := stdio.SendRequest(ctx, request) + if err != nil { + t.Fatalf("SendRequest failed for %s: %v", tc.description, err) + } + + var result struct { + JSONRPC string `json:"jsonrpc"` + ID mcp.RequestId `json:"id"` + Method string `json:"method"` + Params map[string]any `json:"params"` + } + + if err := json.Unmarshal(response.Result, &result); err != nil { + t.Fatalf("Failed to unmarshal result for %s: %v", tc.description, err) + } + + if result.JSONRPC != "2.0" { + t.Errorf("Expected JSONRPC value '2.0', got '%s'", result.JSONRPC) + } + + returnedData, ok := result.Params["data"].(string) + if !ok { + t.Fatalf("Expected data to be string, got %T", result.Params["data"]) + } + + if returnedData != largeString { + t.Errorf("Data mismatch for %s: expected length %d, got length %d", + tc.description, len(largeString), len(returnedData)) + } + + returnedSize, ok := result.Params["size"].(float64) + if !ok { + t.Fatalf("Expected size to be number, got %T", result.Params["size"]) + } + + if int(returnedSize) != tc.dataSize { + t.Errorf("Size mismatch for %s: expected %d, got %d", + tc.description, tc.dataSize, int(returnedSize)) + } + + t.Logf("Successfully handled %s message of size %d bytes", tc.name, tc.dataSize) + }) + } +} + +func generateRandomString(size int) string { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 " + + b := make([]byte, size) + for i := range b { + b[i] = charset[i%len(charset)] + } + return string(b) +} diff --git a/testdata/mockstdio_server.go b/testdata/mockstdio_server.go index b7bebc24..6a83f184 100644 --- a/testdata/mockstdio_server.go +++ b/testdata/mockstdio_server.go @@ -6,6 +6,7 @@ import ( "fmt" "log/slog" "os" + "strings" "github.com/mark3labs/mcp-go/mcp" ) @@ -18,19 +19,26 @@ type JSONRPCRequest struct { } type JSONRPCResponse struct { - JSONRPC string `json:"jsonrpc"` - ID *mcp.RequestId `json:"id,omitempty"` - Result any `json:"result,omitempty"` + JSONRPC string `json:"jsonrpc"` + ID *mcp.RequestId `json:"id,omitempty"` + Result any `json:"result,omitempty"` Error *mcp.JSONRPCErrorDetails `json:"error,omitempty"` } func main() { logger := slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{})) logger.Info("launch successful") - scanner := bufio.NewScanner(os.Stdin) - for scanner.Scan() { + reader := bufio.NewReader(os.Stdin) + for { + line, err := reader.ReadString('\n') + if err != nil { + break + } + + line = strings.TrimRight(line, "\r\n") + var request JSONRPCRequest - if err := json.Unmarshal(scanner.Bytes(), &request); err != nil { + if err := json.Unmarshal([]byte(line), &request); err != nil { continue }