Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 134 additions & 40 deletions pkg/transport/stdio.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package transport

import (
"bufio"
"bytes"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -31,12 +32,27 @@ func NewStdio(command []string) *Stdio {
// Execute implements the Transport interface by spawning a subprocess
// and communicating with it via JSON-RPC over stdin/stdout.
func (t *Stdio) Execute(method string, params any) (map[string]any, error) {
if len(t.command) == 0 {
return nil, fmt.Errorf("no command specified for stdio transport")
stdin, stdout, cmd, stderrBuf, err := t.setupCommand()
if err != nil {
return nil, err
}

if t.debug {
fmt.Fprintf(os.Stderr, "DEBUG: Executing command: %v\n", t.command)
fmt.Fprintf(os.Stderr, "DEBUG: Starting initialization\n")
}

if initErr := t.initialize(stdin, stdout); initErr != nil {
if t.debug {
fmt.Fprintf(os.Stderr, "DEBUG: Initialization failed: %v\n", initErr)
if stderrBuf.Len() > 0 {
fmt.Fprintf(os.Stderr, "DEBUG: stderr during init: %s\n", stderrBuf.String())
}
}
return nil, initErr
}

if t.debug {
fmt.Fprintf(os.Stderr, "DEBUG: Initialization successful, sending method request\n")
}

request := Request{
Expand All @@ -47,77 +63,155 @@ func (t *Stdio) Execute(method string, params any) (map[string]any, error) {
}
t.nextID++

requestJSON, err := json.Marshal(request)
if sendErr := t.sendRequest(stdin, request); sendErr != nil {
return nil, sendErr
}
_ = stdin.Close()

response, err := t.readResponse(stdout)
if err != nil {
return nil, fmt.Errorf("error marshaling request: %w", err)
return nil, err
}

requestJSON = append(requestJSON, '\n')
waitErr := cmd.Wait()

if t.debug {
fmt.Fprintf(os.Stderr, "DEBUG: Sending request: %s\n", string(requestJSON))
fmt.Fprintf(os.Stderr, "DEBUG: Command completed with err: %v\n", waitErr)
if stderrBuf.Len() > 0 {
fmt.Fprintf(os.Stderr, "DEBUG: stderr output:\n%s\n", stderrBuf.String())
}
}

cmd := exec.Command(t.command[0], t.command[1:]...) // #nosec G204
if waitErr != nil && stderrBuf.Len() > 0 {
return nil, fmt.Errorf("command error: %w, stderr: %s", waitErr, stderrBuf.String())
}

return response.Result, nil
}

stdin, stdinErr := cmd.StdinPipe()
if stdinErr != nil {
return nil, fmt.Errorf("error getting stdin pipe: %w", stdinErr)
// setupCommand prepares and starts the command, returning the stdin/stdout pipes and any error.
func (t *Stdio) setupCommand() (stdin io.WriteCloser, stdout io.ReadCloser, cmd *exec.Cmd, stderrBuf *bytes.Buffer, err error) {
if len(t.command) == 0 {
return nil, nil, nil, nil, fmt.Errorf("no command specified for stdio transport")
}

stdout, stdoutErr := cmd.StdoutPipe()
if stdoutErr != nil {
return nil, fmt.Errorf("error getting stdout pipe: %w", stdoutErr)
if t.debug {
fmt.Fprintf(os.Stderr, "DEBUG: Executing command: %v\n", t.command)
}

var stderrBuf bytes.Buffer
cmd.Stderr = &stderrBuf
cmd = exec.Command(t.command[0], t.command[1:]...) // #nosec G204

stdin, err = cmd.StdinPipe()
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("error getting stdin pipe: %w", err)
}

if startErr := cmd.Start(); startErr != nil {
return nil, fmt.Errorf("error starting command: %w", startErr)
stdout, err = cmd.StdoutPipe()
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("error getting stdout pipe: %w", err)
}

if _, writeErr := stdin.Write(requestJSON); writeErr != nil {
return nil, fmt.Errorf("error writing to stdin: %w", writeErr)
stderrBuf = &bytes.Buffer{}
cmd.Stderr = stderrBuf

if err = cmd.Start(); err != nil {
return nil, nil, nil, nil, fmt.Errorf("error starting command: %w", err)
}

return stdin, stdout, cmd, stderrBuf, nil
}

// initialize sends the initialization request and waits for response and then sends the initialized
// notification.
func (t *Stdio) initialize(stdin io.WriteCloser, stdout io.ReadCloser) error {
initRequest := Request{
JSONRPC: "2.0",
Method: "initialize",
ID: t.nextID,
Params: map[string]any{
"clientInfo": map[string]any{
"name": "f/mcptools",
"version": "beta",
},
"protocolVersion": protocolVersion,
"capabilities": map[string]any{},
},
}
t.nextID++

if err := t.sendRequest(stdin, initRequest); err != nil {
return fmt.Errorf("init request failed: %w", err)
}
_ = stdin.Close()

_, err := t.readResponse(stdout)
if err != nil {
return fmt.Errorf("init response failed: %w", err)
}

initNotification := Request{
JSONRPC: "2.0",
Method: "notifications/initialized",
}

if sendErr := t.sendRequest(stdin, initNotification); sendErr != nil {
return fmt.Errorf("init notification failed: %w", sendErr)
}

return nil
}

// sendRequest sends a JSON-RPC request and returns the marshaled request.
func (t *Stdio) sendRequest(stdin io.WriteCloser, request Request) error {
requestJSON, err := json.Marshal(request)
if err != nil {
return fmt.Errorf("error marshaling request: %w", err)
}
requestJSON = append(requestJSON, '\n')

if t.debug {
fmt.Fprintf(os.Stderr, "DEBUG: Wrote request to stdin\n")
fmt.Fprintf(os.Stderr, "DEBUG: Preparing to send request: %s\n", string(requestJSON))
}

var respBytes bytes.Buffer
if _, copyErr := io.Copy(&respBytes, stdout); copyErr != nil {
return nil, fmt.Errorf("error reading from stdout: %w", copyErr)
writer := bufio.NewWriter(stdin)
n, err := writer.Write(requestJSON)
if err != nil {
return fmt.Errorf("error writing bytes to stdin: %w", err)
}

if t.debug {
fmt.Fprintf(os.Stderr, "DEBUG: Read from stdout: %s\n", respBytes.String())
fmt.Fprintf(os.Stderr, "DEBUG: Wrote %d bytes\n", n)
}

waitErr := cmd.Wait()
if flushErr := writer.Flush(); flushErr != nil {
return fmt.Errorf("error flushing bytes to stdin: %w", flushErr)
}

if t.debug {
fmt.Fprintf(os.Stderr, "DEBUG: Command completed with err: %v\n", waitErr)
if stderrBuf.Len() > 0 {
fmt.Fprintf(os.Stderr, "DEBUG: stderr output: %s\n", stderrBuf.String())
}
fmt.Fprintf(os.Stderr, "DEBUG: Successfully flushed bytes\n")
}

if waitErr != nil && stderrBuf.Len() > 0 {
return nil, fmt.Errorf("command error: %w, stderr: %s", waitErr, stderrBuf.String())
return nil
}

// readResponse reads and parses a JSON-RPC response.
func (t *Stdio) readResponse(stdout io.ReadCloser) (*Response, error) {
reader := bufio.NewReader(stdout)
line, err := reader.ReadBytes('\n')
if err != nil {
return nil, fmt.Errorf("error reading from stdout: %w", err)
}

if respBytes.Len() == 0 {
if stderrBuf.Len() > 0 {
return nil, fmt.Errorf("no response from command, stderr: %s", stderrBuf.String())
}
if t.debug {
fmt.Fprintf(os.Stderr, "DEBUG: Read from stdout: %s", string(line))
}

if len(line) == 0 {
return nil, fmt.Errorf("no response from command")
}

var response Response
if unmarshalErr := json.Unmarshal(respBytes.Bytes(), &response); unmarshalErr != nil {
return nil, fmt.Errorf("error unmarshaling response: %w, response: %s", unmarshalErr, respBytes.String())
if unmarshalErr := json.Unmarshal(line, &response); unmarshalErr != nil {
return nil, fmt.Errorf("error unmarshaling response: %w, response: %s", unmarshalErr, string(line))
}

if response.Error != nil {
Expand All @@ -128,5 +222,5 @@ func (t *Stdio) Execute(method string, params any) (map[string]any, error) {
fmt.Fprintf(os.Stderr, "DEBUG: Successfully parsed response\n")
}

return response.Result, nil
return &response, nil
}
6 changes: 5 additions & 1 deletion pkg/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ import (
"io"
)

const (
protocolVersion = "2024-11-05"
)

// Transport defines the interface for communicating with MCP servers.
// Implementations should handle the specifics of communication protocols.
type Transport interface {
Expand All @@ -17,7 +21,7 @@ type Request struct {
Params any `json:"params,omitempty"`
JSONRPC string `json:"jsonrpc"`
Method string `json:"method"`
ID int `json:"id"`
ID int `json:"id,omitempty"`
}

// Response represents a JSON-RPC 2.0 response.
Expand Down