From 067f241ac8fb6f940bc49f51a50b150f6e65f58a Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 3 Sep 2025 12:47:49 -0400 Subject: [PATCH 1/2] add HTTP Request to TokenVerifier The request might be needed to verify the token. Fixes #403. --- auth/auth.go | 12 ++++++++---- auth/auth_test.go | 8 +++++--- examples/server/auth-middleware/main.go | 5 ++--- mcp/streamable_test.go | 2 +- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 14ad28c7..34c58278 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -26,7 +26,8 @@ var ErrInvalidToken = errors.New("invalid token") // A TokenVerifier checks the validity of a bearer token, and extracts information // from it. If verification fails, it should return an error that unwraps to ErrInvalidToken. -type TokenVerifier func(ctx context.Context, token string) (*TokenInfo, error) +// The HTTP request is provided in case verifying the token involves checking it. +type TokenVerifier func(ctx context.Context, token string, req *http.Request) (*TokenInfo, error) // RequireBearerTokenOptions are options for [RequireBearerToken]. type RequireBearerTokenOptions struct { @@ -52,6 +53,8 @@ func TokenInfoFromContext(ctx context.Context) *TokenInfo { // If verification succeeds, the [TokenInfo] is added to the request's context and the request proceeds. // If verification fails, the request fails with a 401 Unauthenticated, and the WWW-Authenticate header // is populated to enable [protected resource metadata]. +// + // // [protected resource metadata]: https://datatracker.ietf.org/doc/rfc9728 func RequireBearerToken(verifier TokenVerifier, opts *RequireBearerTokenOptions) func(http.Handler) http.Handler { @@ -59,7 +62,7 @@ func RequireBearerToken(verifier TokenVerifier, opts *RequireBearerTokenOptions) return func(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tokenInfo, errmsg, code := verify(r.Context(), verifier, opts, r.Header.Get("Authorization")) + tokenInfo, errmsg, code := verify(r, verifier, opts) if code != 0 { if code == http.StatusUnauthorized || code == http.StatusForbidden { if opts != nil && opts.ResourceMetadataURL != "" { @@ -75,15 +78,16 @@ func RequireBearerToken(verifier TokenVerifier, opts *RequireBearerTokenOptions) } } -func verify(ctx context.Context, verifier TokenVerifier, opts *RequireBearerTokenOptions, authHeader string) (_ *TokenInfo, errmsg string, code int) { +func verify(req *http.Request, verifier TokenVerifier, opts *RequireBearerTokenOptions) (_ *TokenInfo, errmsg string, code int) { // Extract bearer token. + authHeader := req.Header.Get("Authorization") fields := strings.Fields(authHeader) if len(fields) != 2 || strings.ToLower(fields[0]) != "bearer" { return nil, "no bearer token", http.StatusUnauthorized } // Verify the token and get information from it. - tokenInfo, err := verifier(ctx, fields[1]) + tokenInfo, err := verifier(req.Context(), fields[1], req) if err != nil { if errors.Is(err, ErrInvalidToken) { return nil, err.Error(), http.StatusUnauthorized diff --git a/auth/auth_test.go b/auth/auth_test.go index 715b9bba..310e3ab0 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -7,13 +7,13 @@ package auth import ( "context" "errors" + "net/http" "testing" "time" ) func TestVerify(t *testing.T) { - ctx := context.Background() - verifier := func(_ context.Context, token string) (*TokenInfo, error) { + verifier := func(_ context.Context, token string, _ *http.Request) (*TokenInfo, error) { switch token { case "valid": return &TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil @@ -61,7 +61,9 @@ func TestVerify(t *testing.T) { }, } { t.Run(tt.name, func(t *testing.T) { - _, gotMsg, gotCode := verify(ctx, verifier, tt.opts, tt.header) + _, gotMsg, gotCode := verify(&http.Request{ + Header: http.Header{"Authorization": {tt.header}}, + }, verifier, tt.opts) if gotMsg != tt.wantMsg || gotCode != tt.wantCode { t.Errorf("got (%q, %d), want (%q, %d)", gotMsg, gotCode, tt.wantMsg, tt.wantCode) } diff --git a/examples/server/auth-middleware/main.go b/examples/server/auth-middleware/main.go index f472b760..dd1271eb 100644 --- a/examples/server/auth-middleware/main.go +++ b/examples/server/auth-middleware/main.go @@ -83,7 +83,7 @@ func generateToken(userID string, scopes []string, expiresIn time.Duration) (str // verifyJWT verifies JWT tokens and returns TokenInfo for the auth middleware. // This function implements the TokenVerifier interface required by auth.RequireBearerToken. -func verifyJWT(ctx context.Context, tokenString string) (*auth.TokenInfo, error) { +func verifyJWT(ctx context.Context, tokenString string, _ *http.Request) (*auth.TokenInfo, error) { // Parse and validate the JWT token. token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) { // Verify the signing method is HMAC. @@ -92,7 +92,6 @@ func verifyJWT(ctx context.Context, tokenString string) (*auth.TokenInfo, error) } return jwtSecret, nil }) - if err != nil { // Return standard error for invalid tokens. return nil, fmt.Errorf("%w: %v", auth.ErrInvalidToken, err) @@ -111,7 +110,7 @@ func verifyJWT(ctx context.Context, tokenString string) (*auth.TokenInfo, error) // verifyAPIKey verifies API keys and returns TokenInfo for the auth middleware. // This function implements the TokenVerifier interface required by auth.RequireBearerToken. -func verifyAPIKey(ctx context.Context, apiKey string) (*auth.TokenInfo, error) { +func verifyAPIKey(ctx context.Context, apiKey string, _ *http.Request) (*auth.TokenInfo, error) { // Look up the API key in our storage. key, exists := apiKeys[apiKey] if !exists { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 49c2e87f..2c90425d 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1270,7 +1270,7 @@ func TestTokenInfo(t *testing.T) { AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo) streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) - verifier := func(context.Context, string) (*auth.TokenInfo, error) { + verifier := func(context.Context, string, *http.Request) (*auth.TokenInfo, error) { return &auth.TokenInfo{ Scopes: []string{"scope"}, // Expiration is far, far in the future. From ed56ef04aa590afc08fc393ee9131cf2ce102bff Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Wed, 3 Sep 2025 13:06:53 -0400 Subject: [PATCH 2/2] remove comment --- auth/auth.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 214ee789..7cc0074a 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -98,8 +98,6 @@ func verify(req *http.Request, verifier TokenVerifier, opts *RequireBearerTokenO if errors.Is(err, ErrOAuth) { return nil, err.Error(), http.StatusBadRequest } - // Investigate how that works. - // See typescript-sdk/src/server/auth/middleware/bearerAuth.ts. return nil, err.Error(), http.StatusInternalServerError }