Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions client/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type StdioMCPClient struct {
cmd *exec.Cmd
stdin io.WriteCloser
stdout *bufio.Reader
stderr io.ReadCloser
requestID atomic.Int64
responses map[int64]chan RPCResponse
mu sync.RWMutex
Expand Down Expand Up @@ -58,9 +59,15 @@ func NewStdioMCPClient(
return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
}

stderr, err := cmd.StderrPipe()
if err != nil {
return nil, fmt.Errorf("failed to create stderr pipe: %w", err)
}

client := &StdioMCPClient{
cmd: cmd,
stdin: stdin,
stderr: stderr,
stdout: bufio.NewReader(stdout),
responses: make(map[int64]chan RPCResponse),
done: make(chan struct{}),
Expand Down Expand Up @@ -88,9 +95,18 @@ func (c *StdioMCPClient) Close() error {
if err := c.stdin.Close(); err != nil {
return fmt.Errorf("failed to close stdin: %w", err)
}
if err := c.stderr.Close(); err != nil {
return fmt.Errorf("failed to close stderr: %w", err)
}
return c.cmd.Wait()
}

// Stderr returns a reader for the stderr output of the subprocess.
// This can be used to capture error messages or logs from the subprocess.
func (c *StdioMCPClient) Stderr() io.Reader {
return c.stderr
}

// OnNotification registers a handler function to be called when notifications are received.
// Multiple handlers can be registered and will be called in the order they were added.
func (c *StdioMCPClient) OnNotification(
Expand Down
42 changes: 41 additions & 1 deletion client/stdio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package client

import (
"context"
"encoding/json"
"fmt"
"log/slog"
"os"
"os/exec"
"path/filepath"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -38,7 +41,23 @@ func TestStdioMCPClient(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
defer client.Close()
var logRecords []map[string]any
var logRecordsMu sync.RWMutex
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
dec := json.NewDecoder(client.Stderr())
for {
var record map[string]any
if err := dec.Decode(&record); err != nil {
return
}
logRecordsMu.Lock()
logRecords = append(logRecords, record)
logRecordsMu.Unlock()
}
}()

t.Run("Initialize", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
Expand Down Expand Up @@ -238,4 +257,25 @@ func TestStdioMCPClient(t *testing.T) {
)
}
})

client.Close()
wg.Wait()

t.Run("CheckLogs", func(t *testing.T) {
logRecordsMu.RLock()
defer logRecordsMu.RUnlock()

if len(logRecords) != 1 {
t.Errorf("Expected 1 log record, got %d", len(logRecords))
return
}

msg, ok := logRecords[0][slog.MessageKey].(string)
if !ok {
t.Errorf("Expected log record to have message key")
}
if msg != "launch successful" {
t.Errorf("Expected log message 'launch successful', got '%s'", msg)
}
})
}
3 changes: 3 additions & 0 deletions testdata/mockstdio_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"encoding/json"
"fmt"
"log/slog"
"os"
)

Expand All @@ -25,6 +26,8 @@ type JSONRPCResponse struct {
}

func main() {
logger := slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{}))
logger.Info("launch successful")
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
var request JSONRPCRequest
Expand Down