From df72d5e60aba28939c51eb6e9259c4418009b8ac Mon Sep 17 00:00:00 2001 From: Ruturaj Date: Mon, 27 Oct 2025 11:31:55 +0530 Subject: [PATCH] auth: change OAuthHandler to take http Request and Response Change OAuthHandler signature from func(context.Context, OAuthHandlerArgs) to func(req *http.Request, res *http.Response). - Remove OAuthHandlerArgs struct - Update HTTPTransport to pass req and resp to handler - Update tests to use new signature - Handler can now call oauthex.GetProtectedResourceMetadataFromHeader with proper validation against request URL This change fixes an impedance mismatch between OAuthHandler and the protected resource metadata functions of the oauthex package. The new signature allows handlers to properly validate resource metadata against the request URL, as required by RFC 9728. Fixes #600 --- auth/client.go | 26 ++++---------------------- auth/client_test.go | 17 +++++++++-------- 2 files changed, 13 insertions(+), 30 deletions(-) diff --git a/auth/client.go b/auth/client.go index 3ed6b695..acadc51b 100644 --- a/auth/client.go +++ b/auth/client.go @@ -8,26 +8,19 @@ package auth import ( "bytes" - "context" "errors" "io" "net/http" "sync" - "github.com/modelcontextprotocol/go-sdk/oauthex" "golang.org/x/oauth2" ) // An OAuthHandler conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization // is approved, or an error if not. -type OAuthHandler func(context.Context, OAuthHandlerArgs) (oauth2.TokenSource, error) - -// OAuthHandlerArgs are arguments to an [OAuthHandler]. -type OAuthHandlerArgs struct { - // The URL to fetch protected resource metadata, extracted from the WWW-Authenticate header. - // Empty if not present or there was an error obtaining it. - ResourceMetadataURL string -} +// The handler receives the HTTP request and response that triggered the authentication flow. +// To obtain the protected resource metadata, call [oauthex.GetProtectedResourceMetadataFromHeader]. +type OAuthHandler func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) // HTTPTransport is an [http.RoundTripper] that follows the MCP // OAuth protocol when it encounters a 401 Unauthorized response. @@ -112,10 +105,7 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { // TODO: We hold the lock for the entire OAuth flow. This could be a long // time. Is there a better way? if _, ok := t.opts.Base.(*oauth2.Transport); !ok { - authHeaders := resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")] - ts, err := t.handler(req.Context(), OAuthHandlerArgs{ - ResourceMetadataURL: extractResourceMetadataURL(authHeaders), - }) + ts, err := t.handler(req, resp) if err != nil { return nil, err } @@ -131,11 +121,3 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { return t.opts.Base.RoundTrip(req) } - -func extractResourceMetadataURL(authHeaders []string) string { - cs, err := oauthex.ParseWWWAuthenticate(authHeaders) - if err != nil { - return "" - } - return oauthex.ResourceMetadataURL(cs) -} diff --git a/auth/client_test.go b/auth/client_test.go index e1b7dc70..4f3172d4 100644 --- a/auth/client_test.go +++ b/auth/client_test.go @@ -7,7 +7,6 @@ package auth import ( - "context" "errors" "fmt" "io" @@ -65,10 +64,11 @@ func TestHTTPTransport(t *testing.T) { t.Run("successful auth flow", func(t *testing.T) { var handlerCalls int - handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) { + handler := func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) { handlerCalls++ - if args.ResourceMetadataURL != "http://metadata.example.com" { - t.Errorf("handler got metadata URL %q, want %q", args.ResourceMetadataURL, "http://metadata.example.com") + // Verify that the response has the expected WWW-Authenticate header + if res.Header.Get("WWW-Authenticate") != `Bearer resource_metadata="http://metadata.example.com"` { + t.Errorf("handler got WWW-Authenticate header %q", res.Header.Get("WWW-Authenticate")) } return fakeTokenSource, nil } @@ -108,9 +108,10 @@ func TestHTTPTransport(t *testing.T) { }) t.Run("request body is cloned", func(t *testing.T) { - handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) { - if args.ResourceMetadataURL != "http://metadata.example.com" { - t.Errorf("handler got metadata URL %q, want %q", args.ResourceMetadataURL, "http://metadata.example.com") + handler := func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) { + // Verify that the response has the expected WWW-Authenticate header + if res.Header.Get("WWW-Authenticate") != `Bearer resource_metadata="http://metadata.example.com"` { + t.Errorf("handler got WWW-Authenticate header %q", res.Header.Get("WWW-Authenticate")) } return fakeTokenSource, nil } @@ -134,7 +135,7 @@ func TestHTTPTransport(t *testing.T) { t.Run("handler returns error", func(t *testing.T) { handlerErr := errors.New("user rejected auth") - handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) { + handler := func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) { return nil, handlerErr }