diff --git a/auth/client.go b/auth/client.go index 3ed6b695..acadc51b 100644 --- a/auth/client.go +++ b/auth/client.go @@ -8,26 +8,19 @@ package auth import ( "bytes" - "context" "errors" "io" "net/http" "sync" - "github.com/modelcontextprotocol/go-sdk/oauthex" "golang.org/x/oauth2" ) // An OAuthHandler conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization // is approved, or an error if not. -type OAuthHandler func(context.Context, OAuthHandlerArgs) (oauth2.TokenSource, error) - -// OAuthHandlerArgs are arguments to an [OAuthHandler]. -type OAuthHandlerArgs struct { - // The URL to fetch protected resource metadata, extracted from the WWW-Authenticate header. - // Empty if not present or there was an error obtaining it. - ResourceMetadataURL string -} +// The handler receives the HTTP request and response that triggered the authentication flow. +// To obtain the protected resource metadata, call [oauthex.GetProtectedResourceMetadataFromHeader]. +type OAuthHandler func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) // HTTPTransport is an [http.RoundTripper] that follows the MCP // OAuth protocol when it encounters a 401 Unauthorized response. @@ -112,10 +105,7 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { // TODO: We hold the lock for the entire OAuth flow. This could be a long // time. Is there a better way? if _, ok := t.opts.Base.(*oauth2.Transport); !ok { - authHeaders := resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")] - ts, err := t.handler(req.Context(), OAuthHandlerArgs{ - ResourceMetadataURL: extractResourceMetadataURL(authHeaders), - }) + ts, err := t.handler(req, resp) if err != nil { return nil, err } @@ -131,11 +121,3 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { return t.opts.Base.RoundTrip(req) } - -func extractResourceMetadataURL(authHeaders []string) string { - cs, err := oauthex.ParseWWWAuthenticate(authHeaders) - if err != nil { - return "" - } - return oauthex.ResourceMetadataURL(cs) -} diff --git a/auth/client_test.go b/auth/client_test.go index e1b7dc70..4f3172d4 100644 --- a/auth/client_test.go +++ b/auth/client_test.go @@ -7,7 +7,6 @@ package auth import ( - "context" "errors" "fmt" "io" @@ -65,10 +64,11 @@ func TestHTTPTransport(t *testing.T) { t.Run("successful auth flow", func(t *testing.T) { var handlerCalls int - handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) { + handler := func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) { handlerCalls++ - if args.ResourceMetadataURL != "http://metadata.example.com" { - t.Errorf("handler got metadata URL %q, want %q", args.ResourceMetadataURL, "http://metadata.example.com") + // Verify that the response has the expected WWW-Authenticate header + if res.Header.Get("WWW-Authenticate") != `Bearer resource_metadata="http://metadata.example.com"` { + t.Errorf("handler got WWW-Authenticate header %q", res.Header.Get("WWW-Authenticate")) } return fakeTokenSource, nil } @@ -108,9 +108,10 @@ func TestHTTPTransport(t *testing.T) { }) t.Run("request body is cloned", func(t *testing.T) { - handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) { - if args.ResourceMetadataURL != "http://metadata.example.com" { - t.Errorf("handler got metadata URL %q, want %q", args.ResourceMetadataURL, "http://metadata.example.com") + handler := func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) { + // Verify that the response has the expected WWW-Authenticate header + if res.Header.Get("WWW-Authenticate") != `Bearer resource_metadata="http://metadata.example.com"` { + t.Errorf("handler got WWW-Authenticate header %q", res.Header.Get("WWW-Authenticate")) } return fakeTokenSource, nil } @@ -134,7 +135,7 @@ func TestHTTPTransport(t *testing.T) { t.Run("handler returns error", func(t *testing.T) { handlerErr := errors.New("user rejected auth") - handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) { + handler := func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) { return nil, handlerErr }