From 508b13289b12e8a9e3f55f7ac134d57e1d44e3b1 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 12 Aug 2025 13:50:38 -0400 Subject: [PATCH 1/4] mcp: pass TokenInfo to server handler If there is a TokenInfo in the request context of a StreamableServerTransport, then propagate it through to the ServerRequest that is passed to server methods like callTool. --- auth/auth.go | 27 ++++++++++++++++--- internal/jsonrpc2/messages.go | 6 +++++ mcp/shared.go | 25 +++++++++++------- mcp/streamable.go | 12 ++++++++- mcp/streamable_test.go | 49 +++++++++++++++++++++++++++++++++++ mcp/tool.go | 5 ++-- 6 files changed, 107 insertions(+), 17 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 68873b48..14ad28c7 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -13,22 +13,41 @@ import ( "time" ) +// TokenInfo holds information from a bearer token. type TokenInfo struct { Scopes []string Expiration time.Time + // TODO: add standard JWT fields + Extra map[string]any } +// The error that a TokenVerifier should return if the token cannot be verified. +var ErrInvalidToken = errors.New("invalid token") + +// A TokenVerifier checks the validity of a bearer token, and extracts information +// from it. If verification fails, it should return an error that unwraps to ErrInvalidToken. type TokenVerifier func(ctx context.Context, token string) (*TokenInfo, error) +// RequireBearerTokenOptions are options for [RequireBearerToken]. type RequireBearerTokenOptions struct { - Scopes []string + // The URL for the resource server metadata OAuth flow, to be returned as part + // of the WWW-Authenticate header. ResourceMetadataURL string + // The required scopes. + Scopes []string } -var ErrInvalidToken = errors.New("invalid token") - type tokenInfoKey struct{} +// TokenInfoFromContext returns the [TokenInfo] stored in ctx, or nil if none. +func TokenInfoFromContext(ctx context.Context) *TokenInfo { + ti := ctx.Value(tokenInfoKey{}) + if ti == nil { + return nil + } + return ti.(*TokenInfo) +} + // RequireBearerToken returns a piece of middleware that verifies a bearer token using the verifier. // If verification succeeds, the [TokenInfo] is added to the request's context and the request proceeds. // If verification fails, the request fails with a 401 Unauthenticated, and the WWW-Authenticate header @@ -75,7 +94,7 @@ func verify(ctx context.Context, verifier TokenVerifier, opts *RequireBearerToke return nil, err.Error(), http.StatusInternalServerError } - // Check scopes. + // Check scopes. All must be present. if opts != nil { // Note: quadratic, but N is small. for _, s := range opts.Scopes { diff --git a/internal/jsonrpc2/messages.go b/internal/jsonrpc2/messages.go index 2de3d4f0..9c0d5d69 100644 --- a/internal/jsonrpc2/messages.go +++ b/internal/jsonrpc2/messages.go @@ -56,6 +56,9 @@ type Request struct { Method string // Params is either a struct or an array with the parameters of the method. Params json.RawMessage + // Extra is additional information that does not appear on the wire. It can be + // used to pass information from the application to the underlying transport. + Extra any } // Response is a Message used as a reply to a call Request. @@ -67,6 +70,9 @@ type Response struct { Error error // id of the request this is a response to. ID ID + // Extra is additional information that does not appear on the wire. It can be + // used to pass information from the underlying transport to the application. + Extra any } // StringID creates a new string request identifier. diff --git a/mcp/shared.go b/mcp/shared.go index ca062214..43b3026f 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -19,6 +19,7 @@ import ( "strings" "time" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -132,7 +133,8 @@ func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Requ } mh := session.receivingMethodHandler().(MethodHandler) - req := info.newRequest(session, params) + ti, _ := jreq.Extra.(*RequestExtra) + req := info.newRequest(session, params, ti) // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. res, err := mh(ctx, jreq.Method, req) if err != nil { @@ -179,7 +181,7 @@ type methodInfo struct { // Unmarshal params from the wire into a Params struct. // Used on the receive side. unmarshalParams func(json.RawMessage) (Params, error) - newRequest func(Session, Params) Request + newRequest func(Session, Params, *auth.TokenInfo) Request // Run the code when a call to the method is received. // Used on the receive side. handleMethod methodHandler @@ -214,7 +216,7 @@ const ( func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHandler[P, R], flags methodFlags) methodInfo { mi := newMethodInfo[P, R](flags) - mi.newRequest = func(s Session, p Params) Request { + mi.newRequest = func(s Session, p Params, _ *auth.TokenInfo) Request { r := &ClientRequest[P]{Session: s.(*ClientSession)} if p != nil { r.Params = p.(P) @@ -229,19 +231,15 @@ func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHan func newServerMethodInfo[P paramsPtr[T], R Result, T any](d typedServerMethodHandler[P, R], flags methodFlags) methodInfo { mi := newMethodInfo[P, R](flags) - mi.newRequest = func(s Session, p Params) Request { - r := &ServerRequest[P]{Session: s.(*ServerSession)} + mi.newRequest = func(s Session, p Params, ti *auth.TokenInfo) Request { + r := &ServerRequest[P]{Session: s.(*ServerSession), Extra: RequestExtra{TokenInfo: ti}} if p != nil { r.Params = p.(P) } return r } mi.handleMethod = MethodHandler(func(ctx context.Context, _ string, req Request) (Result, error) { - rf := &ServerRequest[P]{Session: req.GetSession().(*ServerSession)} - if req.GetParams() != nil { - rf.Params = req.GetParams().(P) - } - return d(ctx, rf) + return d(ctx, req.(*ServerRequest[P])) }) return mi } @@ -397,6 +395,13 @@ type ClientRequest[P Params] struct { type ServerRequest[P Params] struct { Session *ServerSession Params P + Extra RequestExtra +} + +// RequestExtra is extra information included in requests, typically from +// the transport layer. +type RequestExtra struct { + TokenInfo *auth.TokenInfo // bearer token info (e.g. from OAuth) if any } func (*ClientRequest[P]) isRequest() {} diff --git a/mcp/streamable.go b/mcp/streamable.go index e3d80bc3..5757be55 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -21,6 +21,7 @@ import ( "sync/atomic" "time" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -494,12 +495,13 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // This also requires access to the negotiated version, which would either be // set by the MCP-Protocol-Version header, or would require peeking into the // session. - incoming, _, err := readBatch(body) if err != nil { http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) return } + incoming, _, err := readBatch(body) requests := make(map[jsonrpc.ID]struct{}) + tokenInfo := auth.TokenInfoFromContext(req.Context()) for _, msg := range incoming { if req, ok := msg.(*jsonrpc.Request); ok { // Preemptively check that this is a valid request, so that we can fail @@ -509,6 +511,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques http.Error(w, err.Error(), http.StatusBadRequest) return } + req.Extra = tokenInfo if req.ID.IsValid() { requests[req.ID] = struct{}{} } @@ -1038,6 +1041,10 @@ func (c *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error } } +// testAuth controls whether a fake Authorization header is added to outgoing requests. +// TODO: replace with a better mechanism when client-side auth is in place. +var testAuth = false + // Write implements the [Connection] interface. func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) error { if err := c.failure(); err != nil { @@ -1055,6 +1062,9 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") + if testAuth { + req.Header.Set("Authorization", "Bearer foo") + } c.setMCPHeaders(req) resp, err := c.client.Do(req) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 11600fbc..d0f32106 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -26,6 +26,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -1038,3 +1039,51 @@ func TestStreamableStateless(t *testing.T) { // Verify we can make another request without session ID checkRequest(`{"jsonrpc":"2.0","method":"tools/call","id":2,"params":{"name":"greet","arguments":{"name":"World"}}}`) } + +func TestTokenInfo(t *testing.T) { + defer func(b bool) { testAuth = b }(testAuth) + testAuth = true + ctx := context.Background() + + // Create a server with a tool that returns TokenInfo. + tokenInfo := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[struct{}]]) (*CallToolResultFor[any], error) { + return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: fmt.Sprintf("%v", req.TokenInfo)}}}, nil + } + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo) + + streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + verifier := func(context.Context, string) (*auth.TokenInfo, error) { + return &auth.TokenInfo{ + Scopes: []string{"scope"}, + // Expiration is far, far in the future. + Expiration: time.Date(5000, 1, 2, 3, 4, 5, 0, time.UTC), + }, nil + } + handler := auth.RequireBearerToken(verifier, nil)(streamHandler) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + transport := NewStreamableClientTransport(httpServer.URL, nil) + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer session.Close() + + res, err := session.CallTool(ctx, &CallToolParams{Name: "tokenInfo"}) + if err != nil { + t.Fatal(err) + } + if len(res.Content) == 0 { + t.Fatal("missing content") + } + tc, ok := res.Content[0].(*TextContent) + if !ok { + t.Fatal("not TextContent") + } + if g, w := tc.Text, "&{[scope] 5000-01-02 03:04:05 +0000 UTC map[]}"; g != w { + t.Errorf("got %q, want %q", g, w) + } +} diff --git a/mcp/tool.go b/mcp/tool.go index 15f17e11..052e987f 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -66,8 +66,9 @@ func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool } // TODO(jba): improve copy res, err := h(ctx, &ServerRequest[*CallToolParamsFor[In]]{ - Session: req.Session, - Params: params, + Session: req.Session, + Params: params, + TokenInfo: req.TokenInfo, }) // TODO(rfindley): investigate why server errors are embedded in this strange way, // rather than returned as jsonrpc2 server errors. From 58d9bbe8c30cbfbb0b0eb0fcd80dbbe6b1dc6c6d Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Mon, 18 Aug 2025 11:50:45 -0400 Subject: [PATCH 2/4] RequestExtra --- mcp/shared.go | 14 +++++++------- mcp/streamable.go | 2 +- mcp/streamable_test.go | 2 +- mcp/tool.go | 6 +++--- mcp/transport.go | 3 +-- 5 files changed, 13 insertions(+), 14 deletions(-) diff --git a/mcp/shared.go b/mcp/shared.go index 43b3026f..e685c2f0 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -133,8 +133,8 @@ func handleReceive[S Session](ctx context.Context, session S, jreq *jsonrpc.Requ } mh := session.receivingMethodHandler().(MethodHandler) - ti, _ := jreq.Extra.(*RequestExtra) - req := info.newRequest(session, params, ti) + re, _ := jreq.Extra.(*RequestExtra) + req := info.newRequest(session, params, re) // mh might be user code, so ensure that it returns the right values for the jsonrpc2 protocol. res, err := mh(ctx, jreq.Method, req) if err != nil { @@ -181,7 +181,7 @@ type methodInfo struct { // Unmarshal params from the wire into a Params struct. // Used on the receive side. unmarshalParams func(json.RawMessage) (Params, error) - newRequest func(Session, Params, *auth.TokenInfo) Request + newRequest func(Session, Params, *RequestExtra) Request // Run the code when a call to the method is received. // Used on the receive side. handleMethod methodHandler @@ -216,7 +216,7 @@ const ( func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHandler[P, R], flags methodFlags) methodInfo { mi := newMethodInfo[P, R](flags) - mi.newRequest = func(s Session, p Params, _ *auth.TokenInfo) Request { + mi.newRequest = func(s Session, p Params, _ *RequestExtra) Request { r := &ClientRequest[P]{Session: s.(*ClientSession)} if p != nil { r.Params = p.(P) @@ -231,8 +231,8 @@ func newClientMethodInfo[P paramsPtr[T], R Result, T any](d typedClientMethodHan func newServerMethodInfo[P paramsPtr[T], R Result, T any](d typedServerMethodHandler[P, R], flags methodFlags) methodInfo { mi := newMethodInfo[P, R](flags) - mi.newRequest = func(s Session, p Params, ti *auth.TokenInfo) Request { - r := &ServerRequest[P]{Session: s.(*ServerSession), Extra: RequestExtra{TokenInfo: ti}} + mi.newRequest = func(s Session, p Params, re *RequestExtra) Request { + r := &ServerRequest[P]{Session: s.(*ServerSession), Extra: re} if p != nil { r.Params = p.(P) } @@ -395,7 +395,7 @@ type ClientRequest[P Params] struct { type ServerRequest[P Params] struct { Session *ServerSession Params P - Extra RequestExtra + Extra *RequestExtra } // RequestExtra is extra information included in requests, typically from diff --git a/mcp/streamable.go b/mcp/streamable.go index 5757be55..365d1b84 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -511,7 +511,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques http.Error(w, err.Error(), http.StatusBadRequest) return } - req.Extra = tokenInfo + req.Extra = &RequestExtra{TokenInfo: tokenInfo} if req.ID.IsValid() { requests[req.ID] = struct{}{} } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index d0f32106..dff7cc4a 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1047,7 +1047,7 @@ func TestTokenInfo(t *testing.T) { // Create a server with a tool that returns TokenInfo. tokenInfo := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[struct{}]]) (*CallToolResultFor[any], error) { - return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: fmt.Sprintf("%v", req.TokenInfo)}}}, nil + return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: fmt.Sprintf("%v", req.Extra.TokenInfo)}}}, nil } server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo) diff --git a/mcp/tool.go b/mcp/tool.go index 052e987f..7173b8a8 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -66,9 +66,9 @@ func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool } // TODO(jba): improve copy res, err := h(ctx, &ServerRequest[*CallToolParamsFor[In]]{ - Session: req.Session, - Params: params, - TokenInfo: req.TokenInfo, + Session: req.Session, + Params: params, + Extra: req.Extra, }) // TODO(rfindley): investigate why server errors are embedded in this strange way, // rather than returned as jsonrpc2 server errors. diff --git a/mcp/transport.go b/mcp/transport.go index 6d25de33..1d2da5d2 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -86,8 +86,7 @@ type serverConnection interface { // A StdioTransport is a [Transport] that communicates over stdin/stdout using // newline-delimited JSON. -type StdioTransport struct { -} +type StdioTransport struct{} // Connect implements the [Transport] interface. func (*StdioTransport) Connect(context.Context) (Connection, error) { From c43a4b0127ea379f38fcd6fc5c532c4f2453c8d1 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 20 Aug 2025 10:45:04 -0400 Subject: [PATCH 3/4] fix auth test --- mcp/streamable.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 6911a131..2e0a4cba 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1123,10 +1123,6 @@ func (c *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error } } -// testAuth controls whether a fake Authorization header is added to outgoing requests. -// TODO: replace with a better mechanism when client-side auth is in place. -var testAuth = false - // Write implements the [Connection] interface. func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) error { if err := c.failure(); err != nil { @@ -1144,9 +1140,6 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") - if testAuth { - req.Header.Set("Authorization", "Bearer foo") - } c.setMCPHeaders(req) resp, err := c.client.Do(req) @@ -1192,6 +1185,10 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return nil } +// testAuth controls whether a fake Authorization header is added to outgoing requests. +// TODO: replace with a better mechanism when client-side auth is in place. +var testAuth = false + func (c *streamableClientConn) setMCPHeaders(req *http.Request) { c.mu.Lock() defer c.mu.Unlock() @@ -1202,6 +1199,9 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) { if c.sessionID != "" { req.Header.Set(sessionIDHeader, c.sessionID) } + if testAuth { + req.Header.Set("Authorization", "Bearer foo") + } } func (c *streamableClientConn) handleJSON(resp *http.Response) { From 7e982de9494c059b3fc1f9f72b758c19c1fb3681 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 20 Aug 2025 10:47:18 -0400 Subject: [PATCH 4/4] fix staticcheck issue --- mcp/streamable.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mcp/streamable.go b/mcp/streamable.go index 2e0a4cba..c51f3cc4 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -585,6 +585,10 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques return } incoming, _, err := readBatch(body) + if err != nil { + http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest) + return + } requests := make(map[jsonrpc.ID]struct{}) tokenInfo := auth.TokenInfoFromContext(req.Context()) for _, msg := range incoming {