diff --git a/client/client.go b/client/client.go index 8c7a9adc..220786b6 100644 --- a/client/client.go +++ b/client/client.go @@ -3,7 +3,6 @@ package client import ( "context" "encoding/json" - "errors" "fmt" "slices" "sync" @@ -166,7 +165,7 @@ func (c *Client) sendRequest( } if response.Error != nil { - return nil, errors.New(response.Error.Message) + return nil, response.Error.AsError() } return &response.Result, nil @@ -524,11 +523,7 @@ func (c *Client) handleSamplingRequestTransport(ctx context.Context, request tra } // Create the transport response - response := &transport.JSONRPCResponse{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: request.ID, - Result: json.RawMessage(resultBytes), - } + response := transport.NewJSONRPCResultResponse(request.ID, json.RawMessage(resultBytes)) return response, nil } @@ -572,22 +567,14 @@ func (c *Client) handleElicitationRequestTransport(ctx context.Context, request } // Create the transport response - response := &transport.JSONRPCResponse{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: request.ID, - Result: json.RawMessage(resultBytes), - } + response := transport.NewJSONRPCResultResponse(request.ID, resultBytes) return response, nil } func (c *Client) handlePingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { b, _ := json.Marshal(&mcp.EmptyResult{}) - return &transport.JSONRPCResponse{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: request.ID, - Result: b, - }, nil + return transport.NewJSONRPCResultResponse(request.ID, b), nil } func listByPage[T any]( diff --git a/client/elicitation_test.go b/client/elicitation_test.go index d21b1f07..858c6bcd 100644 --- a/client/elicitation_test.go +++ b/client/elicitation_test.go @@ -155,10 +155,7 @@ func TestClient_Initialize_WithElicitationHandler(t *testing.T) { } resultBytes, _ := json.Marshal(result) - return &transport.JSONRPCResponse{ - ID: request.ID, - Result: json.RawMessage(resultBytes), - }, nil + return transport.NewJSONRPCResultResponse(request.ID, resultBytes), nil }, sendNotificationFunc: func(ctx context.Context, notification mcp.JSONRPCNotification) error { return nil @@ -183,7 +180,6 @@ func TestClient_Initialize_WithElicitationHandler(t *testing.T) { Capabilities: mcp.ClientCapabilities{}, }, }) - if err != nil { t.Fatalf("failed to initialize: %v", err) } diff --git a/client/protocol_negotiation_test.go b/client/protocol_negotiation_test.go index 022b7fc6..15629bc5 100644 --- a/client/protocol_negotiation_test.go +++ b/client/protocol_negotiation_test.go @@ -30,11 +30,7 @@ func (m *mockProtocolTransport) SendRequest(ctx context.Context, request transpo return nil, fmt.Errorf("no mock response for method %s", request.Method) } - return &transport.JSONRPCResponse{ - JSONRPC: "2.0", - ID: request.ID, - Result: json.RawMessage(responseStr), - }, nil + return transport.NewJSONRPCResultResponse(request.ID, json.RawMessage(responseStr)), nil } func (m *mockProtocolTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { diff --git a/client/sampling_test.go b/client/sampling_test.go index 60f53322..d423cfef 100644 --- a/client/sampling_test.go +++ b/client/sampling_test.go @@ -175,11 +175,10 @@ func TestClient_Initialize_WithSampling(t *testing.T) { } // Prepare mock response for initialization - initResponse := &transport.JSONRPCResponse{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: mcp.NewRequestId(1), - Result: []byte(`{"protocolVersion":"2024-11-05","capabilities":{"logging":{},"prompts":{},"resources":{},"tools":{}},"serverInfo":{"name":"test-server","version":"1.0.0"}}`), - } + initResponse := transport.NewJSONRPCResultResponse( + mcp.NewRequestId(1), + []byte(`{"protocolVersion":"2024-11-05","capabilities":{"logging":{},"prompts":{},"resources":{},"tools":{}},"serverInfo":{"name":"test-server","version":"1.0.0"}}`), + ) // Send the response in a goroutine go func() { diff --git a/client/transport/inprocess.go b/client/transport/inprocess.go index 3757664a..f69fa7b5 100644 --- a/client/transport/inprocess.go +++ b/client/transport/inprocess.go @@ -82,7 +82,7 @@ func (c *InProcessTransport) SendRequest(ctx context.Context, request JSONRPCReq if err != nil { return nil, fmt.Errorf("failed to marshal response message: %w", err) } - rpcResp := JSONRPCResponse{} + var rpcResp JSONRPCResponse err = json.Unmarshal(respByte, &rpcResp) if err != nil { return nil, fmt.Errorf("failed to unmarshal response message: %w", err) diff --git a/client/transport/interface.go b/client/transport/interface.go index e6feeb74..a877e49d 100644 --- a/client/transport/interface.go +++ b/client/transport/interface.go @@ -61,13 +61,13 @@ type JSONRPCRequest struct { Params any `json:"params,omitempty"` } +// JSONRPCResponse represents a JSON-RPC 2.0 response message. +// Use NewJSONRPCResultResponse to create a JSONRPCResponse with a result. +// Use NewJSONRPCErrorResponse to create a JSONRPCResponse with an error. type JSONRPCResponse struct { - JSONRPC string `json:"jsonrpc"` - ID mcp.RequestId `json:"id"` - Result json.RawMessage `json:"result,omitempty"` - Error *struct { - Code int `json:"code"` - Message string `json:"message"` - Data json.RawMessage `json:"data"` - } `json:"error,omitempty"` + JSONRPC string `json:"jsonrpc"` + ID mcp.RequestId `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *mcp.JSONRPCErrorDetails `json:"error,omitempty"` } + diff --git a/client/transport/stdio.go b/client/transport/stdio.go index ddb11c0d..69557668 100644 --- a/client/transport/stdio.go +++ b/client/transport/stdio.go @@ -402,18 +402,12 @@ func (c *Stdio) handleIncomingRequest(request JSONRPCRequest) { if handler == nil { // Send error response if no handler is configured - errorResponse := JSONRPCResponse{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: request.ID, - Error: &struct { - Code int `json:"code"` - Message string `json:"message"` - Data json.RawMessage `json:"data"` - }{ - Code: mcp.METHOD_NOT_FOUND, - Message: "No request handler configured", - }, - } + errorResponse := *NewJSONRPCErrorResponse( + request.ID, + mcp.METHOD_NOT_FOUND, + "No request handler configured", + nil, + ) c.sendResponse(errorResponse) return } @@ -427,18 +421,7 @@ func (c *Stdio) handleIncomingRequest(request JSONRPCRequest) { // Check if context is already cancelled before processing select { case <-ctx.Done(): - errorResponse := JSONRPCResponse{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: request.ID, - Error: &struct { - Code int `json:"code"` - Message string `json:"message"` - Data json.RawMessage `json:"data"` - }{ - Code: mcp.INTERNAL_ERROR, - Message: ctx.Err().Error(), - }, - } + errorResponse := *NewJSONRPCErrorResponse(request.ID, mcp.INTERNAL_ERROR, ctx.Err().Error(), nil) c.sendResponse(errorResponse) return default: @@ -446,18 +429,7 @@ func (c *Stdio) handleIncomingRequest(request JSONRPCRequest) { response, err := handler(ctx, request) if err != nil { - errorResponse := JSONRPCResponse{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: request.ID, - Error: &struct { - Code int `json:"code"` - Message string `json:"message"` - Data json.RawMessage `json:"data"` - }{ - Code: mcp.INTERNAL_ERROR, - Message: err.Error(), - }, - } + errorResponse := *NewJSONRPCErrorResponse(request.ID, mcp.INTERNAL_ERROR, err.Error(), nil) c.sendResponse(errorResponse) return } diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index 18aa932e..56867acf 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -536,11 +536,7 @@ func TestStdioErrors(t *testing.T) { t.Cleanup(func() { _ = stdio.Close() }) stdio.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { - return &JSONRPCResponse{ - JSONRPC: "2.0", - ID: request.ID, - Result: json.RawMessage(`"test response"`), - }, nil + return NewJSONRPCResultResponse(request.ID, json.RawMessage(`"test response"`)), nil }) doneChan := make(chan struct{}) diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 9894224a..d8bb11a5 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -683,18 +683,12 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON if handler == nil { c.logger.Errorf("received request from server but no handler set: %s", request.Method) // Send method not found error - errorResponse := &JSONRPCResponse{ - JSONRPC: "2.0", - ID: request.ID, - Error: &struct { - Code int `json:"code"` - Message string `json:"message"` - Data json.RawMessage `json:"data"` - }{ - Code: -32601, // Method not found - Message: fmt.Sprintf("no handler configured for method: %s", request.Method), - }, - } + errorResponse := NewJSONRPCErrorResponse( + request.ID, + mcp.METHOD_NOT_FOUND, + fmt.Sprintf("no handler configured for method: %s", request.Method), + nil, + ) c.sendResponseToServer(ctx, errorResponse) return } @@ -715,36 +709,25 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON // Check for specific sampling-related errors if errors.Is(err, context.Canceled) { - errorCode = -32800 // Request cancelled + errorCode = mcp.REQUEST_INTERRUPTED errorMessage = "request was cancelled" } else if errors.Is(err, context.DeadlineExceeded) { - errorCode = -32800 // Request timeout + errorCode = mcp.REQUEST_INTERRUPTED errorMessage = "request timed out" } else { // Generic error cases switch request.Method { case string(mcp.MethodSamplingCreateMessage): - errorCode = -32603 // Internal error + errorCode = mcp.INTERNAL_ERROR errorMessage = fmt.Sprintf("sampling request failed: %v", err) default: - errorCode = -32603 // Internal error + errorCode = mcp.INTERNAL_ERROR errorMessage = err.Error() } } // Send error response - errorResponse := &JSONRPCResponse{ - JSONRPC: "2.0", - ID: request.ID, - Error: &struct { - Code int `json:"code"` - Message string `json:"message"` - Data json.RawMessage `json:"data"` - }{ - Code: errorCode, - Message: errorMessage, - }, - } + errorResponse := NewJSONRPCErrorResponse(request.ID, errorCode, errorMessage, nil) c.sendResponseToServer(requestCtx, errorResponse) return } diff --git a/client/transport/streamable_http_sampling_test.go b/client/transport/streamable_http_sampling_test.go index 4a38f280..0f30335e 100644 --- a/client/transport/streamable_http_sampling_test.go +++ b/client/transport/streamable_http_sampling_test.go @@ -50,11 +50,7 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) { resultBytes, _ := json.Marshal(result) - return &JSONRPCResponse{ - JSONRPC: "2.0", - ID: request.ID, - Result: resultBytes, - }, nil + return NewJSONRPCResultResponse(request.ID, resultBytes), nil }) // Start the client @@ -360,11 +356,7 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { resultBytes, _ := json.Marshal(result) - return &JSONRPCResponse{ - JSONRPC: "2.0", - ID: request.ID, - Result: resultBytes, - }, nil + return NewJSONRPCResultResponse(request.ID, resultBytes), nil }) // Start the client diff --git a/client/transport/utils.go b/client/transport/utils.go new file mode 100644 index 00000000..d36d7472 --- /dev/null +++ b/client/transport/utils.go @@ -0,0 +1,26 @@ +package transport + +import ( + "encoding/json" + + "github.com/mark3labs/mcp-go/mcp" +) + +// NewJSONRPCErrorResponse creates a new JSONRPCResponse with an error. +func NewJSONRPCErrorResponse(id mcp.RequestId, code int, message string, data any) *JSONRPCResponse { + details := mcp.NewJSONRPCErrorDetails(code, message, data) + return &JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Error: &details, + } +} + +// NewJSONRPCResultResponse creates a new JSONRPCResponse with a result. +func NewJSONRPCResultResponse(id mcp.RequestId, result json.RawMessage) *JSONRPCResponse { + return &JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Result: result, + } +} \ No newline at end of file diff --git a/client/transport/utils_test.go b/client/transport/utils_test.go new file mode 100644 index 00000000..d24d50c6 --- /dev/null +++ b/client/transport/utils_test.go @@ -0,0 +1,140 @@ +package transport + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/mark3labs/mcp-go/mcp" +) + +func TestNewJSONRPCErrorResponse(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + id mcp.RequestId + code int + message string + data any + want *JSONRPCResponse + }{ + "basic error response": { + id: mcp.NewRequestId(1), + code: mcp.METHOD_NOT_FOUND, + message: "Method not found", + data: nil, + want: &JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: mcp.NewRequestId(1), + Result: nil, + Error: &mcp.JSONRPCErrorDetails{ + Code: mcp.METHOD_NOT_FOUND, + Message: "Method not found", + Data: nil, + }, + }, + }, + "error response with data": { + id: mcp.NewRequestId("test"), + code: mcp.INVALID_PARAMS, + message: "Invalid parameters", + data: map[string]any{"field": "value"}, + want: &JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: mcp.NewRequestId("test"), + Result: nil, + Error: &mcp.JSONRPCErrorDetails{ + Code: mcp.INVALID_PARAMS, + Message: "Invalid parameters", + Data: map[string]any{"field": "value"}, + }, + }, + }, + "error response with empty message": { + id: mcp.NewRequestId(42), + code: mcp.INTERNAL_ERROR, + message: "", + data: nil, + want: &JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: mcp.NewRequestId(42), + Result: nil, + Error: &mcp.JSONRPCErrorDetails{ + Code: mcp.INTERNAL_ERROR, + Message: "", + Data: nil, + }, + }, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := NewJSONRPCErrorResponse(tc.id, tc.code, tc.message, tc.data) + require.Equal(t, tc.want, got) + }) + } +} + +func TestNewJSONRPCResultResponse(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + id mcp.RequestId + result json.RawMessage + want *JSONRPCResponse + }{ + "basic result response": { + id: mcp.NewRequestId(1), + result: json.RawMessage(`{"success": true}`), + want: &JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: mcp.NewRequestId(1), + Result: json.RawMessage(`{"success": true}`), + Error: nil, + }, + }, + "result response with string ID": { + id: mcp.NewRequestId("test-id"), + result: json.RawMessage(`"simple string result"`), + want: &JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: mcp.NewRequestId("test-id"), + Result: json.RawMessage(`"simple string result"`), + Error: nil, + }, + }, + "result response with empty result": { + id: mcp.NewRequestId(0), + result: json.RawMessage(`{}`), + want: &JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: mcp.NewRequestId(0), + Result: json.RawMessage(`{}`), + Error: nil, + }, + }, + "result response with null result": { + id: mcp.NewRequestId(999), + result: json.RawMessage(`null`), + want: &JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: mcp.NewRequestId(999), + Result: json.RawMessage(`null`), + Error: nil, + }, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := NewJSONRPCResultResponse(tc.id, tc.result) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/mcp/errors.go b/mcp/errors.go index 01888bf5..c2902428 100644 --- a/mcp/errors.go +++ b/mcp/errors.go @@ -1,6 +1,33 @@ package mcp -import "fmt" +import ( + "errors" + "fmt" +) + +// Sentinel errors for common JSON-RPC error codes. +var ( + // ErrParseError indicates a JSON parsing error (code: PARSE_ERROR). + ErrParseError = errors.New("parse error") + + // ErrInvalidRequest indicates an invalid JSON-RPC request (code: INVALID_REQUEST). + ErrInvalidRequest = errors.New("invalid request") + + // ErrMethodNotFound indicates the requested method does not exist (code: METHOD_NOT_FOUND). + ErrMethodNotFound = errors.New("method not found") + + // ErrInvalidParams indicates invalid method parameters (code: INVALID_PARAMS). + ErrInvalidParams = errors.New("invalid params") + + // ErrInternalError indicates an internal JSON-RPC error (code: INTERNAL_ERROR). + ErrInternalError = errors.New("internal error") + + // ErrRequestInterrupted indicates a request was cancelled or timed out (code: REQUEST_INTERRUPTED). + ErrRequestInterrupted = errors.New("request interrupted") + + // ErrResourceNotFound indicates a requested resource was not found (code: RESOURCE_NOT_FOUND). + ErrResourceNotFound = errors.New("resource not found") +) // UnsupportedProtocolVersionError is returned when the server responds with // a protocol version that the client doesn't support. @@ -23,3 +50,37 @@ func IsUnsupportedProtocolVersion(err error) bool { _, ok := err.(UnsupportedProtocolVersionError) return ok } + +// AsError maps JSONRPCErrorDetails to a Go error. +// Returns sentinel errors wrapped with custom messages for known codes. +// Defaults to a generic error with the original message when the code is not mapped. +func (e *JSONRPCErrorDetails) AsError() error { + var err error + + switch e.Code { + case PARSE_ERROR: + err = ErrParseError + case INVALID_REQUEST: + err = ErrInvalidRequest + case METHOD_NOT_FOUND: + err = ErrMethodNotFound + case INVALID_PARAMS: + err = ErrInvalidParams + case INTERNAL_ERROR: + err = ErrInternalError + case REQUEST_INTERRUPTED: + err = ErrRequestInterrupted + case RESOURCE_NOT_FOUND: + err = ErrResourceNotFound + default: + return errors.New(e.Message) + } + + // Wrap the sentinel error with the custom message if it differs from the sentinel. + if e.Message != "" && e.Message != err.Error() { + return fmt.Errorf("%w: %s", err, e.Message) + } + + return err +} + diff --git a/mcp/errors_test.go b/mcp/errors_test.go new file mode 100644 index 00000000..22556da1 --- /dev/null +++ b/mcp/errors_test.go @@ -0,0 +1,171 @@ +package mcp + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestJSONRPCErrorDetails_AsError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + details JSONRPCErrorDetails + expectedType error + expectedMessage string + }{ + { + name: "parse error with custom message", + details: JSONRPCErrorDetails{ + Code: PARSE_ERROR, + Message: "Custom parse error message", + }, + expectedType: ErrParseError, + expectedMessage: "parse error: Custom parse error message", + }, + { + name: "parse error with standard message", + details: JSONRPCErrorDetails{ + Code: PARSE_ERROR, + Message: "parse error", + }, + expectedType: ErrParseError, + expectedMessage: "parse error", + }, + { + name: "method not found with custom message", + details: JSONRPCErrorDetails{ + Code: METHOD_NOT_FOUND, + Message: "Custom method not found message", + }, + expectedType: ErrMethodNotFound, + expectedMessage: "method not found: Custom method not found message", + }, + { + name: "method not found with standard message", + details: JSONRPCErrorDetails{ + Code: METHOD_NOT_FOUND, + Message: "method not found", + }, + expectedType: ErrMethodNotFound, + expectedMessage: "method not found", + }, + { + name: "request interrupted with custom message", + details: JSONRPCErrorDetails{ + Code: REQUEST_INTERRUPTED, + Message: "request was cancelled", + }, + expectedType: ErrRequestInterrupted, + expectedMessage: "request interrupted: request was cancelled", + }, + { + name: "resource not found with custom message", + details: JSONRPCErrorDetails{ + Code: RESOURCE_NOT_FOUND, + Message: "resource 'foo' not found", + }, + expectedType: ErrResourceNotFound, + expectedMessage: "resource not found: resource 'foo' not found", + }, + { + name: "unknown error code", + details: JSONRPCErrorDetails{ + Code: -99999, + Message: "Unknown error occurred", + }, + expectedType: nil, // No sentinel error for unknown codes + expectedMessage: "Unknown error occurred", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + result := tc.details.AsError() + + // Check the full error message + require.EqualError(t, result, tc.expectedMessage) + + // Check the error type (if expected) + if tc.expectedType != nil { + require.True(t, errors.Is(result, tc.expectedType), + "Expected error to be of type %v", tc.expectedType) + } + }) + } +} + +func TestJSONRPCErrorDetails_AsError_WithPointer(t *testing.T) { + t.Parallel() + + details := &JSONRPCErrorDetails{ + Code: METHOD_NOT_FOUND, + Message: "Method not found", + Data: map[string]string{"extra": "info"}, + } + + result := details.AsError() + require.True(t, errors.Is(result, ErrMethodNotFound)) +} + + +func TestSentinelErrors(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + err error + msg string + }{ + "ErrParseError": { + err: ErrParseError, + msg: "parse error", + }, + "ErrInvalidRequest": { + err: ErrInvalidRequest, + msg: "invalid request", + }, + "ErrMethodNotFound": { + err: ErrMethodNotFound, + msg: "method not found", + }, + "ErrInvalidParams": { + err: ErrInvalidParams, + msg: "invalid params", + }, + "ErrInternalError": { + err: ErrInternalError, + msg: "internal error", + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + require.Equal(t, tc.msg, tc.err.Error()) + }) + } +} + +func TestErrorChaining(t *testing.T) { + t.Parallel() + + // Test that errors.Is works correctly with wrapped errors + details := &JSONRPCErrorDetails{ + Code: METHOD_NOT_FOUND, + Message: "Method 'foo' not found", + } + + err := details.AsError() + wrappedErr := errors.New("failed to call method: " + err.Error()) + + // The wrapped error should not match our sentinel error + require.False(t, errors.Is(wrappedErr, ErrMethodNotFound)) + + // But the original error should + require.True(t, errors.Is(err, ErrMethodNotFound)) +} diff --git a/mcp/types.go b/mcp/types.go index e28e72e9..69ea73ff 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -6,9 +6,8 @@ import ( "encoding/json" "fmt" "maps" - "strconv" - "net/http" + "strconv" "github.com/yosida95/uritemplate/v3" ) @@ -298,7 +297,6 @@ func (r RequestId) MarshalJSON() ([]byte, error) { } func (r *RequestId) UnmarshalJSON(data []byte) error { - if string(data) == "null" { r.value = nil return nil @@ -348,31 +346,48 @@ type JSONRPCResponse struct { // JSONRPCError represents a non-successful (error) response to a request. type JSONRPCError struct { - JSONRPC string `json:"jsonrpc"` - ID RequestId `json:"id"` - Error struct { - // The error type that occurred. - Code int `json:"code"` - // A short description of the error. The message SHOULD be limited - // to a concise single sentence. - Message string `json:"message"` - // Additional information about the error. The value of this member - // is defined by the sender (e.g. detailed error information, nested errors etc.). - Data any `json:"data,omitempty"` - } `json:"error"` + JSONRPC string `json:"jsonrpc"` + ID RequestId `json:"id"` + Error JSONRPCErrorDetails `json:"error"` +} + +// JSONRPCErrorDetails represents a JSON-RPC error for Go error handling. +// This is separate from the JSONRPCError type which represents the full JSON-RPC error response structure. +type JSONRPCErrorDetails struct { + // The error type that occurred. + Code int `json:"code"` + // A short description of the error. The message SHOULD be limited + // to a concise single sentence. + Message string `json:"message"` + // Additional information about the error. The value of this member + // is defined by the sender (e.g. detailed error information, nested errors etc.). + Data any `json:"data,omitempty"` } // Standard JSON-RPC error codes const ( - PARSE_ERROR = -32700 - INVALID_REQUEST = -32600 + // PARSE_ERROR indicates invalid JSON was received by the server. + PARSE_ERROR = -32700 + + // INVALID_REQUEST indicates the JSON sent is not a valid Request object. + INVALID_REQUEST = -32600 + + // METHOD_NOT_FOUND indicates the method does not exist/is not available. METHOD_NOT_FOUND = -32601 - INVALID_PARAMS = -32602 - INTERNAL_ERROR = -32603 + + // INVALID_PARAMS indicates invalid method parameter(s). + INVALID_PARAMS = -32602 + + // INTERNAL_ERROR indicates internal JSON-RPC error. + INTERNAL_ERROR = -32603 + + // REQUEST_INTERRUPTED indicates a request was cancelled or timed out. + REQUEST_INTERRUPTED = -32800 ) // MCP error codes const ( + // RESOURCE_NOT_FOUND indicates a requested resource was not found. RESOURCE_NOT_FOUND = -32002 ) diff --git a/mcp/utils.go b/mcp/utils.go index 19cfa890..feeabe1e 100644 --- a/mcp/utils.go +++ b/mcp/utils.go @@ -8,54 +8,66 @@ import ( ) // ClientRequest types -var _ ClientRequest = (*PingRequest)(nil) -var _ ClientRequest = (*InitializeRequest)(nil) -var _ ClientRequest = (*CompleteRequest)(nil) -var _ ClientRequest = (*SetLevelRequest)(nil) -var _ ClientRequest = (*GetPromptRequest)(nil) -var _ ClientRequest = (*ListPromptsRequest)(nil) -var _ ClientRequest = (*ListResourcesRequest)(nil) -var _ ClientRequest = (*ReadResourceRequest)(nil) -var _ ClientRequest = (*SubscribeRequest)(nil) -var _ ClientRequest = (*UnsubscribeRequest)(nil) -var _ ClientRequest = (*CallToolRequest)(nil) -var _ ClientRequest = (*ListToolsRequest)(nil) +var ( + _ ClientRequest = (*PingRequest)(nil) + _ ClientRequest = (*InitializeRequest)(nil) + _ ClientRequest = (*CompleteRequest)(nil) + _ ClientRequest = (*SetLevelRequest)(nil) + _ ClientRequest = (*GetPromptRequest)(nil) + _ ClientRequest = (*ListPromptsRequest)(nil) + _ ClientRequest = (*ListResourcesRequest)(nil) + _ ClientRequest = (*ReadResourceRequest)(nil) + _ ClientRequest = (*SubscribeRequest)(nil) + _ ClientRequest = (*UnsubscribeRequest)(nil) + _ ClientRequest = (*CallToolRequest)(nil) + _ ClientRequest = (*ListToolsRequest)(nil) +) // ClientNotification types -var _ ClientNotification = (*CancelledNotification)(nil) -var _ ClientNotification = (*ProgressNotification)(nil) -var _ ClientNotification = (*InitializedNotification)(nil) -var _ ClientNotification = (*RootsListChangedNotification)(nil) +var ( + _ ClientNotification = (*CancelledNotification)(nil) + _ ClientNotification = (*ProgressNotification)(nil) + _ ClientNotification = (*InitializedNotification)(nil) + _ ClientNotification = (*RootsListChangedNotification)(nil) +) // ClientResult types -var _ ClientResult = (*EmptyResult)(nil) -var _ ClientResult = (*CreateMessageResult)(nil) -var _ ClientResult = (*ListRootsResult)(nil) +var ( + _ ClientResult = (*EmptyResult)(nil) + _ ClientResult = (*CreateMessageResult)(nil) + _ ClientResult = (*ListRootsResult)(nil) +) // ServerRequest types -var _ ServerRequest = (*PingRequest)(nil) -var _ ServerRequest = (*CreateMessageRequest)(nil) -var _ ServerRequest = (*ListRootsRequest)(nil) +var ( + _ ServerRequest = (*PingRequest)(nil) + _ ServerRequest = (*CreateMessageRequest)(nil) + _ ServerRequest = (*ListRootsRequest)(nil) +) // ServerNotification types -var _ ServerNotification = (*CancelledNotification)(nil) -var _ ServerNotification = (*ProgressNotification)(nil) -var _ ServerNotification = (*LoggingMessageNotification)(nil) -var _ ServerNotification = (*ResourceUpdatedNotification)(nil) -var _ ServerNotification = (*ResourceListChangedNotification)(nil) -var _ ServerNotification = (*ToolListChangedNotification)(nil) -var _ ServerNotification = (*PromptListChangedNotification)(nil) +var ( + _ ServerNotification = (*CancelledNotification)(nil) + _ ServerNotification = (*ProgressNotification)(nil) + _ ServerNotification = (*LoggingMessageNotification)(nil) + _ ServerNotification = (*ResourceUpdatedNotification)(nil) + _ ServerNotification = (*ResourceListChangedNotification)(nil) + _ ServerNotification = (*ToolListChangedNotification)(nil) + _ ServerNotification = (*PromptListChangedNotification)(nil) +) // ServerResult types -var _ ServerResult = (*EmptyResult)(nil) -var _ ServerResult = (*InitializeResult)(nil) -var _ ServerResult = (*CompleteResult)(nil) -var _ ServerResult = (*GetPromptResult)(nil) -var _ ServerResult = (*ListPromptsResult)(nil) -var _ ServerResult = (*ListResourcesResult)(nil) -var _ ServerResult = (*ReadResourceResult)(nil) -var _ ServerResult = (*CallToolResult)(nil) -var _ ServerResult = (*ListToolsResult)(nil) +var ( + _ ServerResult = (*EmptyResult)(nil) + _ ServerResult = (*InitializeResult)(nil) + _ ServerResult = (*CompleteResult)(nil) + _ ServerResult = (*GetPromptResult)(nil) + _ ServerResult = (*ListPromptsResult)(nil) + _ ServerResult = (*ListResourcesResult)(nil) + _ ServerResult = (*ReadResourceResult)(nil) + _ ServerResult = (*CallToolResult)(nil) + _ ServerResult = (*ListToolsResult)(nil) +) // Helper functions for type assertions @@ -100,7 +112,10 @@ func AsBlobResourceContents(content any) (*BlobResourceContents, bool) { // Helper function for JSON-RPC -// NewJSONRPCResponse creates a new JSONRPCResponse with the given id and result +// NewJSONRPCResponse creates a new JSONRPCResponse with the given id and result. +// NOTE: This function expects a Result struct, but JSONRPCResponse.Result is typed as `any`. +// The Result struct wraps the actual result data with optional metadata. +// For direct result assignment, use NewJSONRPCResultResponse instead. func NewJSONRPCResponse(id RequestId, result Result) JSONRPCResponse { return JSONRPCResponse{ JSONRPC: JSONRPC_VERSION, @@ -109,6 +124,25 @@ func NewJSONRPCResponse(id RequestId, result Result) JSONRPCResponse { } } +// NewJSONRPCResultResponse creates a new JSONRPCResponse with the given id and result. +// This function accepts any type for the result, matching the JSONRPCResponse.Result field type. +func NewJSONRPCResultResponse(id RequestId, result any) JSONRPCResponse { + return JSONRPCResponse{ + JSONRPC: JSONRPC_VERSION, + ID: id, + Result: result, + } +} + +// NewJSONRPCErrorDetails creates a new JSONRPCErrorDetails with the given code, message, and data. +func NewJSONRPCErrorDetails(code int, message string, data any) JSONRPCErrorDetails { + return JSONRPCErrorDetails{ + Code: code, + Message: message, + Data: data, + } +} + // NewJSONRPCError creates a new JSONRPCResponse with the given id, code, and message func NewJSONRPCError( id RequestId, @@ -119,15 +153,7 @@ func NewJSONRPCError( return JSONRPCError{ JSONRPC: JSONRPC_VERSION, ID: id, - Error: struct { - Code int `json:"code"` - Message string `json:"message"` - Data any `json:"data,omitempty"` - }{ - Code: code, - Message: message, - Data: data, - }, + Error: NewJSONRPCErrorDetails(code, message, data), } } diff --git a/mcp/utils_test.go b/mcp/utils_test.go new file mode 100644 index 00000000..e171a4c2 --- /dev/null +++ b/mcp/utils_test.go @@ -0,0 +1,80 @@ +package mcp + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewJSONRPCResultResponse(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + id RequestId + result any + want JSONRPCResponse + }{ + "string result": { + id: NewRequestId(1), + result: "test result", + want: JSONRPCResponse{ + JSONRPC: JSONRPC_VERSION, + ID: NewRequestId(1), + Result: "test result", + }, + }, + "map result": { + id: NewRequestId("test-id"), + result: map[string]any{"key": "value"}, + want: JSONRPCResponse{ + JSONRPC: JSONRPC_VERSION, + ID: NewRequestId("test-id"), + Result: map[string]any{"key": "value"}, + }, + }, + "nil result": { + id: NewRequestId(42), + result: nil, + want: JSONRPCResponse{ + JSONRPC: JSONRPC_VERSION, + ID: NewRequestId(42), + Result: nil, + }, + }, + "struct result": { + id: NewRequestId(0), + result: struct{ Name string }{Name: "test"}, + want: JSONRPCResponse{ + JSONRPC: JSONRPC_VERSION, + ID: NewRequestId(0), + Result: struct{ Name string }{Name: "test"}, + }, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := NewJSONRPCResultResponse(tc.id, tc.result) + require.Equal(t, tc.want, got) + }) + } +} + +func TestNewJSONRPCResponse(t *testing.T) { + t.Parallel() + + // Test the existing constructor that takes Result struct + id := NewRequestId(1) + result := Result{Meta: &Meta{}} + + got := NewJSONRPCResponse(id, result) + want := JSONRPCResponse{ + JSONRPC: JSONRPC_VERSION, + ID: id, + Result: result, + } + + require.Equal(t, want, got) +} \ No newline at end of file diff --git a/server/server.go b/server/server.go index 95e95fe0..e9b48e6d 100644 --- a/server/server.go +++ b/server/server.go @@ -124,14 +124,7 @@ func (e *requestError) ToJSONRPCError() mcp.JSONRPCError { return mcp.JSONRPCError{ JSONRPC: mcp.JSONRPC_VERSION, ID: mcp.NewRequestId(e.id), - Error: struct { - Code int `json:"code"` - Message string `json:"message"` - Data any `json:"data,omitempty"` - }{ - Code: e.code, - Message: e.err.Error(), - }, + Error: mcp.NewJSONRPCErrorDetails(e.code, e.err.Error(), nil), } } @@ -1211,11 +1204,7 @@ func (s *MCPServer) handleNotification( } func createResponse(id any, result any) mcp.JSONRPCMessage { - return mcp.JSONRPCResponse{ - JSONRPC: mcp.JSONRPC_VERSION, - ID: mcp.NewRequestId(id), - Result: result, - } + return mcp.NewJSONRPCResultResponse(mcp.NewRequestId(id), result) } func createErrorResponse( @@ -1226,13 +1215,6 @@ func createErrorResponse( return mcp.JSONRPCError{ JSONRPC: mcp.JSONRPC_VERSION, ID: mcp.NewRequestId(id), - Error: struct { - Code int `json:"code"` - Message string `json:"message"` - Data any `json:"data,omitempty"` - }{ - Code: code, - Message: message, - }, + Error: mcp.NewJSONRPCErrorDetails(code, message, nil), } } diff --git a/testdata/mockstdio_server.go b/testdata/mockstdio_server.go index 30bf0c00..b7bebc24 100644 --- a/testdata/mockstdio_server.go +++ b/testdata/mockstdio_server.go @@ -21,10 +21,7 @@ type JSONRPCResponse struct { JSONRPC string `json:"jsonrpc"` ID *mcp.RequestId `json:"id,omitempty"` Result any `json:"result,omitempty"` - Error *struct { - Code int `json:"code"` - Message string `json:"message"` - } `json:"error,omitempty"` + Error *mcp.JSONRPCErrorDetails `json:"error,omitempty"` } func main() { @@ -157,21 +154,11 @@ func handleRequest(request JSONRPCRequest) JSONRPCResponse { case "debug/echo_error_string": all, _ := json.Marshal(request) - response.Error = &struct { - Code int `json:"code"` - Message string `json:"message"` - }{ - Code: -32601, - Message: string(all), - } + details := mcp.NewJSONRPCErrorDetails(mcp.METHOD_NOT_FOUND, string(all), nil) + response.Error = &details default: - response.Error = &struct { - Code int `json:"code"` - Message string `json:"message"` - }{ - Code: -32601, - Message: "Method not found", - } + details := mcp.NewJSONRPCErrorDetails(mcp.METHOD_NOT_FOUND, "Method not found", nil) + response.Error = &details } return response