From 6b915a7c90dd689b036936abe8f8c26f371fbc8a Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 7 Aug 2025 19:22:43 +0000 Subject: [PATCH] mcp: lock down Params and Result Add unexported methods to the Params and Result interface, so that they're harder to implement outside the mcp package. It looks like these are the only two interfaces we need to lock down: others are either intentionally open (Transport, Connection), or already closed (Session). Fixes #263 --- mcp/protocol.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ mcp/shared.go | 7 +++++++ 2 files changed, 53 insertions(+) diff --git a/mcp/protocol.go b/mcp/protocol.go index 3ca6cb5e..944b73db 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -78,6 +78,8 @@ type CallToolResultFor[Out any] struct { IsError bool `json:"isError,omitempty"` } +func (*CallToolResultFor[Out]) mcpResult() {} + // UnmarshalJSON handles the unmarshalling of content into the Content // interface. func (x *CallToolResultFor[Out]) UnmarshalJSON(data []byte) error { @@ -97,6 +99,7 @@ func (x *CallToolResultFor[Out]) UnmarshalJSON(data []byte) error { return nil } +func (x *CallToolParamsFor[Out]) mcpParams() {} func (x *CallToolParamsFor[Out]) GetProgressToken() any { return getProgressToken(x) } func (x *CallToolParamsFor[Out]) SetProgressToken(t any) { setProgressToken(x, t) } @@ -114,6 +117,7 @@ type CancelledParams struct { RequestID any `json:"requestId"` } +func (x *CancelledParams) mcpParams() {} func (x *CancelledParams) GetProgressToken() any { return getProgressToken(x) } func (x *CancelledParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -207,6 +211,8 @@ type CompleteParams struct { Ref *CompleteReference `json:"ref"` } +func (*CompleteParams) mcpParams() {} + type CompletionResultDetails struct { HasMore bool `json:"hasMore,omitempty"` Total int `json:"total,omitempty"` @@ -221,6 +227,8 @@ type CompleteResult struct { Completion CompletionResultDetails `json:"completion"` } +func (*CompleteResult) mcpResult() {} + type CreateMessageParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. @@ -245,6 +253,7 @@ type CreateMessageParams struct { Temperature float64 `json:"temperature,omitempty"` } +func (x *CreateMessageParams) mcpParams() {} func (x *CreateMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *CreateMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -264,6 +273,7 @@ type CreateMessageResult struct { StopReason string `json:"stopReason,omitempty"` } +func (*CreateMessageResult) mcpResult() {} func (r *CreateMessageResult) UnmarshalJSON(data []byte) error { type result CreateMessageResult // avoid recursion var wire struct { @@ -291,6 +301,7 @@ type GetPromptParams struct { Name string `json:"name"` } +func (x *GetPromptParams) mcpParams() {} func (x *GetPromptParams) GetProgressToken() any { return getProgressToken(x) } func (x *GetPromptParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -304,6 +315,8 @@ type GetPromptResult struct { Messages []*PromptMessage `json:"messages"` } +func (*GetPromptResult) mcpResult() {} + type InitializeParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. @@ -315,6 +328,7 @@ type InitializeParams struct { ProtocolVersion string `json:"protocolVersion"` } +func (x *InitializeParams) mcpParams() {} func (x *InitializeParams) GetProgressToken() any { return getProgressToken(x) } func (x *InitializeParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -338,12 +352,15 @@ type InitializeResult struct { ServerInfo *Implementation `json:"serverInfo"` } +func (*InitializeResult) mcpResult() {} + type InitializedParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` } +func (x *InitializedParams) mcpParams() {} func (x *InitializedParams) GetProgressToken() any { return getProgressToken(x) } func (x *InitializedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -356,6 +373,7 @@ type ListPromptsParams struct { Cursor string `json:"cursor,omitempty"` } +func (x *ListPromptsParams) mcpParams() {} func (x *ListPromptsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListPromptsParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListPromptsParams) cursorPtr() *string { return &x.Cursor } @@ -371,6 +389,7 @@ type ListPromptsResult struct { Prompts []*Prompt `json:"prompts"` } +func (x *ListPromptsResult) mcpResult() {} func (x *ListPromptsResult) nextCursorPtr() *string { return &x.NextCursor } type ListResourceTemplatesParams struct { @@ -382,6 +401,7 @@ type ListResourceTemplatesParams struct { Cursor string `json:"cursor,omitempty"` } +func (x *ListResourceTemplatesParams) mcpParams() {} func (x *ListResourceTemplatesParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListResourceTemplatesParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListResourceTemplatesParams) cursorPtr() *string { return &x.Cursor } @@ -397,6 +417,7 @@ type ListResourceTemplatesResult struct { ResourceTemplates []*ResourceTemplate `json:"resourceTemplates"` } +func (x *ListResourceTemplatesResult) mcpResult() {} func (x *ListResourceTemplatesResult) nextCursorPtr() *string { return &x.NextCursor } type ListResourcesParams struct { @@ -408,6 +429,7 @@ type ListResourcesParams struct { Cursor string `json:"cursor,omitempty"` } +func (x *ListResourcesParams) mcpParams() {} func (x *ListResourcesParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListResourcesParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListResourcesParams) cursorPtr() *string { return &x.Cursor } @@ -423,6 +445,7 @@ type ListResourcesResult struct { Resources []*Resource `json:"resources"` } +func (x *ListResourcesResult) mcpResult() {} func (x *ListResourcesResult) nextCursorPtr() *string { return &x.NextCursor } type ListRootsParams struct { @@ -431,6 +454,7 @@ type ListRootsParams struct { Meta `json:"_meta,omitempty"` } +func (x *ListRootsParams) mcpParams() {} func (x *ListRootsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListRootsParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -444,6 +468,8 @@ type ListRootsResult struct { Roots []*Root `json:"roots"` } +func (*ListRootsResult) mcpResult() {} + type ListToolsParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. @@ -453,6 +479,7 @@ type ListToolsParams struct { Cursor string `json:"cursor,omitempty"` } +func (x *ListToolsParams) mcpParams() {} func (x *ListToolsParams) GetProgressToken() any { return getProgressToken(x) } func (x *ListToolsParams) SetProgressToken(t any) { setProgressToken(x, t) } func (x *ListToolsParams) cursorPtr() *string { return &x.Cursor } @@ -468,6 +495,7 @@ type ListToolsResult struct { Tools []*Tool `json:"tools"` } +func (x *ListToolsResult) mcpResult() {} func (x *ListToolsResult) nextCursorPtr() *string { return &x.NextCursor } // The severity of a log message. @@ -489,6 +517,7 @@ type LoggingMessageParams struct { Logger string `json:"logger,omitempty"` } +func (x *LoggingMessageParams) mcpParams() {} func (x *LoggingMessageParams) GetProgressToken() any { return getProgressToken(x) } func (x *LoggingMessageParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -550,6 +579,7 @@ type PingParams struct { Meta `json:"_meta,omitempty"` } +func (x *PingParams) mcpParams() {} func (x *PingParams) GetProgressToken() any { return getProgressToken(x) } func (x *PingParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -569,6 +599,8 @@ type ProgressNotificationParams struct { Total float64 `json:"total,omitempty"` } +func (*ProgressNotificationParams) mcpParams() {} + // A prompt or prompt template that the server offers. type Prompt struct { // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta @@ -606,6 +638,7 @@ type PromptListChangedParams struct { Meta `json:"_meta,omitempty"` } +func (x *PromptListChangedParams) mcpParams() {} func (x *PromptListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *PromptListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -646,6 +679,7 @@ type ReadResourceParams struct { URI string `json:"uri"` } +func (x *ReadResourceParams) mcpParams() {} func (x *ReadResourceParams) GetProgressToken() any { return getProgressToken(x) } func (x *ReadResourceParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -657,6 +691,8 @@ type ReadResourceResult struct { Contents []*ResourceContents `json:"contents"` } +func (*ReadResourceResult) mcpResult() {} + // A known resource that the server is capable of reading. type Resource struct { // See [specification/2025-06-18/basic/index#general-fields] for notes on _meta @@ -697,6 +733,7 @@ type ResourceListChangedParams struct { Meta `json:"_meta,omitempty"` } +func (x *ResourceListChangedParams) mcpParams() {} func (x *ResourceListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *ResourceListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -754,6 +791,7 @@ type RootsListChangedParams struct { Meta `json:"_meta,omitempty"` } +func (x *RootsListChangedParams) mcpParams() {} func (x *RootsListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *RootsListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -798,6 +836,7 @@ type SetLevelParams struct { Level LoggingLevel `json:"level"` } +func (x *SetLevelParams) mcpParams() {} func (x *SetLevelParams) GetProgressToken() any { return getProgressToken(x) } func (x *SetLevelParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -873,6 +912,7 @@ type ToolListChangedParams struct { Meta `json:"_meta,omitempty"` } +func (x *ToolListChangedParams) mcpParams() {} func (x *ToolListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *ToolListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } @@ -886,6 +926,8 @@ type SubscribeParams struct { URI string `json:"uri"` } +func (*SubscribeParams) mcpParams() {} + // Sent from the client to request cancellation of resources/updated // notifications from the server. This should follow a previous // resources/subscribe request. @@ -897,6 +939,8 @@ type UnsubscribeParams struct { URI string `json:"uri"` } +func (*UnsubscribeParams) mcpParams() {} + // A notification from the server to the client, informing it that a resource // has changed and may need to be read again. This should only be sent if the // client previously sent a resources/subscribe request. @@ -908,6 +952,8 @@ type ResourceUpdatedNotificationParams struct { URI string `json:"uri"` } +func (*ResourceUpdatedNotificationParams) mcpParams() {} + // TODO(jba): add CompleteRequest and related types. // TODO(jba): add ElicitRequest and related types. diff --git a/mcp/shared.go b/mcp/shared.go index 319071f2..8d0ceb1c 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -335,6 +335,9 @@ func setProgressToken(p Params, pt any) { // Params is a parameter (input) type for an MCP call or notification. type Params interface { + // mcpParams discourages implementation of Params outside of this package. + mcpParams() + // GetMeta returns metadata from a value. GetMeta() map[string]any // SetMeta sets the metadata on a value. @@ -356,6 +359,9 @@ type RequestParams interface { // Result is a result of an MCP call. type Result interface { + // mcpResult discourages implementation of Result outside of this package. + mcpResult() + // GetMeta returns metadata from a value. GetMeta() map[string]any // SetMeta sets the metadata on a value. @@ -366,6 +372,7 @@ type Result interface { // Those methods cannot return nil, because jsonrpc2 cannot handle nils. type emptyResult struct{} +func (*emptyResult) mcpResult() {} func (*emptyResult) GetMeta() map[string]any { panic("should never be called") } func (*emptyResult) SetMeta(map[string]any) { panic("should never be called") }