Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 4 additions & 17 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package client
import (
"context"
"encoding/json"
"errors"
"fmt"
"slices"
"sync"
Expand Down Expand Up @@ -166,7 +165,7 @@ func (c *Client) sendRequest(
}

if response.Error != nil {
return nil, errors.New(response.Error.Message)
return nil, response.Error.AsError()
}

return &response.Result, nil
Expand Down Expand Up @@ -524,11 +523,7 @@ func (c *Client) handleSamplingRequestTransport(ctx context.Context, request tra
}

// Create the transport response
response := &transport.JSONRPCResponse{
JSONRPC: mcp.JSONRPC_VERSION,
ID: request.ID,
Result: json.RawMessage(resultBytes),
}
response := transport.NewJSONRPCResultResponse(request.ID, json.RawMessage(resultBytes))

return response, nil
}
Expand Down Expand Up @@ -572,22 +567,14 @@ func (c *Client) handleElicitationRequestTransport(ctx context.Context, request
}

// Create the transport response
response := &transport.JSONRPCResponse{
JSONRPC: mcp.JSONRPC_VERSION,
ID: request.ID,
Result: json.RawMessage(resultBytes),
}
response := transport.NewJSONRPCResultResponse(request.ID, resultBytes)

return response, nil
}

func (c *Client) handlePingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
b, _ := json.Marshal(&mcp.EmptyResult{})
return &transport.JSONRPCResponse{
JSONRPC: mcp.JSONRPC_VERSION,
ID: request.ID,
Result: b,
}, nil
return transport.NewJSONRPCResultResponse(request.ID, b), nil
}

func listByPage[T any](
Expand Down
6 changes: 1 addition & 5 deletions client/elicitation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,7 @@ func TestClient_Initialize_WithElicitationHandler(t *testing.T) {
}

resultBytes, _ := json.Marshal(result)
return &transport.JSONRPCResponse{
ID: request.ID,
Result: json.RawMessage(resultBytes),
}, nil
return transport.NewJSONRPCResultResponse(request.ID, resultBytes), nil
},
sendNotificationFunc: func(ctx context.Context, notification mcp.JSONRPCNotification) error {
return nil
Expand All @@ -183,7 +180,6 @@ func TestClient_Initialize_WithElicitationHandler(t *testing.T) {
Capabilities: mcp.ClientCapabilities{},
},
})

if err != nil {
t.Fatalf("failed to initialize: %v", err)
}
Expand Down
6 changes: 1 addition & 5 deletions client/protocol_negotiation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@ func (m *mockProtocolTransport) SendRequest(ctx context.Context, request transpo
return nil, fmt.Errorf("no mock response for method %s", request.Method)
}

return &transport.JSONRPCResponse{
JSONRPC: "2.0",
ID: request.ID,
Result: json.RawMessage(responseStr),
}, nil
return transport.NewJSONRPCResultResponse(request.ID, json.RawMessage(responseStr)), nil
}

func (m *mockProtocolTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
Expand Down
9 changes: 4 additions & 5 deletions client/sampling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,10 @@ func TestClient_Initialize_WithSampling(t *testing.T) {
}

// Prepare mock response for initialization
initResponse := &transport.JSONRPCResponse{
JSONRPC: mcp.JSONRPC_VERSION,
ID: mcp.NewRequestId(1),
Result: []byte(`{"protocolVersion":"2024-11-05","capabilities":{"logging":{},"prompts":{},"resources":{},"tools":{}},"serverInfo":{"name":"test-server","version":"1.0.0"}}`),
}
initResponse := transport.NewJSONRPCResultResponse(
mcp.NewRequestId(1),
[]byte(`{"protocolVersion":"2024-11-05","capabilities":{"logging":{},"prompts":{},"resources":{},"tools":{}},"serverInfo":{"name":"test-server","version":"1.0.0"}}`),
)

// Send the response in a goroutine
go func() {
Expand Down
2 changes: 1 addition & 1 deletion client/transport/inprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (c *InProcessTransport) SendRequest(ctx context.Context, request JSONRPCReq
if err != nil {
return nil, fmt.Errorf("failed to marshal response message: %w", err)
}
rpcResp := JSONRPCResponse{}
var rpcResp JSONRPCResponse
err = json.Unmarshal(respByte, &rpcResp)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal response message: %w", err)
Expand Down
16 changes: 8 additions & 8 deletions client/transport/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ type JSONRPCRequest struct {
Params any `json:"params,omitempty"`
}

// JSONRPCResponse represents a JSON-RPC 2.0 response message.
// Use NewJSONRPCResultResponse to create a JSONRPCResponse with a result.
// Use NewJSONRPCErrorResponse to create a JSONRPCResponse with an error.
type JSONRPCResponse struct {
JSONRPC string `json:"jsonrpc"`
ID mcp.RequestId `json:"id"`
Result json.RawMessage `json:"result,omitempty"`
Error *struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
} `json:"error,omitempty"`
JSONRPC string `json:"jsonrpc"`
ID mcp.RequestId `json:"id"`
Result json.RawMessage `json:"result,omitempty"`
Error *mcp.JSONRPCErrorDetails `json:"error,omitempty"`
}

44 changes: 8 additions & 36 deletions client/transport/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,18 +402,12 @@ func (c *Stdio) handleIncomingRequest(request JSONRPCRequest) {

if handler == nil {
// Send error response if no handler is configured
errorResponse := JSONRPCResponse{
JSONRPC: mcp.JSONRPC_VERSION,
ID: request.ID,
Error: &struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
}{
Code: mcp.METHOD_NOT_FOUND,
Message: "No request handler configured",
},
}
errorResponse := *NewJSONRPCErrorResponse(
request.ID,
mcp.METHOD_NOT_FOUND,
"No request handler configured",
nil,
)
c.sendResponse(errorResponse)
return
}
Expand All @@ -427,37 +421,15 @@ func (c *Stdio) handleIncomingRequest(request JSONRPCRequest) {
// Check if context is already cancelled before processing
select {
case <-ctx.Done():
errorResponse := JSONRPCResponse{
JSONRPC: mcp.JSONRPC_VERSION,
ID: request.ID,
Error: &struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
}{
Code: mcp.INTERNAL_ERROR,
Message: ctx.Err().Error(),
},
}
errorResponse := *NewJSONRPCErrorResponse(request.ID, mcp.INTERNAL_ERROR, ctx.Err().Error(), nil)
c.sendResponse(errorResponse)
return
default:
}

response, err := handler(ctx, request)
if err != nil {
errorResponse := JSONRPCResponse{
JSONRPC: mcp.JSONRPC_VERSION,
ID: request.ID,
Error: &struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
}{
Code: mcp.INTERNAL_ERROR,
Message: err.Error(),
},
}
errorResponse := *NewJSONRPCErrorResponse(request.ID, mcp.INTERNAL_ERROR, err.Error(), nil)
c.sendResponse(errorResponse)
return
}
Expand Down
6 changes: 1 addition & 5 deletions client/transport/stdio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -536,11 +536,7 @@ func TestStdioErrors(t *testing.T) {
t.Cleanup(func() { _ = stdio.Close() })

stdio.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) {
return &JSONRPCResponse{
JSONRPC: "2.0",
ID: request.ID,
Result: json.RawMessage(`"test response"`),
}, nil
return NewJSONRPCResultResponse(request.ID, json.RawMessage(`"test response"`)), nil
})

doneChan := make(chan struct{})
Expand Down
39 changes: 11 additions & 28 deletions client/transport/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -683,18 +683,12 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON
if handler == nil {
c.logger.Errorf("received request from server but no handler set: %s", request.Method)
// Send method not found error
errorResponse := &JSONRPCResponse{
JSONRPC: "2.0",
ID: request.ID,
Error: &struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
}{
Code: -32601, // Method not found
Message: fmt.Sprintf("no handler configured for method: %s", request.Method),
},
}
errorResponse := NewJSONRPCErrorResponse(
request.ID,
mcp.METHOD_NOT_FOUND,
fmt.Sprintf("no handler configured for method: %s", request.Method),
nil,
)
c.sendResponseToServer(ctx, errorResponse)
return
}
Expand All @@ -715,36 +709,25 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON

// Check for specific sampling-related errors
if errors.Is(err, context.Canceled) {
errorCode = -32800 // Request cancelled
errorCode = mcp.REQUEST_INTERRUPTED
errorMessage = "request was cancelled"
} else if errors.Is(err, context.DeadlineExceeded) {
errorCode = -32800 // Request timeout
errorCode = mcp.REQUEST_INTERRUPTED
errorMessage = "request timed out"
} else {
// Generic error cases
switch request.Method {
case string(mcp.MethodSamplingCreateMessage):
errorCode = -32603 // Internal error
errorCode = mcp.INTERNAL_ERROR
errorMessage = fmt.Sprintf("sampling request failed: %v", err)
default:
errorCode = -32603 // Internal error
errorCode = mcp.INTERNAL_ERROR
errorMessage = err.Error()
}
}

// Send error response
errorResponse := &JSONRPCResponse{
JSONRPC: "2.0",
ID: request.ID,
Error: &struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
}{
Code: errorCode,
Message: errorMessage,
},
}
errorResponse := NewJSONRPCErrorResponse(request.ID, errorCode, errorMessage, nil)
c.sendResponseToServer(requestCtx, errorResponse)
return
}
Expand Down
12 changes: 2 additions & 10 deletions client/transport/streamable_http_sampling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,7 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) {

resultBytes, _ := json.Marshal(result)

return &JSONRPCResponse{
JSONRPC: "2.0",
ID: request.ID,
Result: resultBytes,
}, nil
return NewJSONRPCResultResponse(request.ID, resultBytes), nil
})

// Start the client
Expand Down Expand Up @@ -360,11 +356,7 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) {

resultBytes, _ := json.Marshal(result)

return &JSONRPCResponse{
JSONRPC: "2.0",
ID: request.ID,
Result: resultBytes,
}, nil
return NewJSONRPCResultResponse(request.ID, resultBytes), nil
})

// Start the client
Expand Down
26 changes: 26 additions & 0 deletions client/transport/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package transport

import (
"encoding/json"

"github.com/mark3labs/mcp-go/mcp"
)

// NewJSONRPCErrorResponse creates a new JSONRPCResponse with an error.
func NewJSONRPCErrorResponse(id mcp.RequestId, code int, message string, data any) *JSONRPCResponse {
details := mcp.NewJSONRPCErrorDetails(code, message, data)
return &JSONRPCResponse{
JSONRPC: mcp.JSONRPC_VERSION,
ID: id,
Error: &details,
}
}

// NewJSONRPCResultResponse creates a new JSONRPCResponse with a result.
func NewJSONRPCResultResponse(id mcp.RequestId, result json.RawMessage) *JSONRPCResponse {
return &JSONRPCResponse{
JSONRPC: mcp.JSONRPC_VERSION,
ID: id,
Result: result,
}
}
Loading