Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 4 additions & 22 deletions auth/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,19 @@ package auth

import (
"bytes"
"context"
"errors"
"io"
"net/http"
"sync"

"github.com/modelcontextprotocol/go-sdk/oauthex"
"golang.org/x/oauth2"
)

// An OAuthHandler conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization
// is approved, or an error if not.
type OAuthHandler func(context.Context, OAuthHandlerArgs) (oauth2.TokenSource, error)

// OAuthHandlerArgs are arguments to an [OAuthHandler].
type OAuthHandlerArgs struct {
// The URL to fetch protected resource metadata, extracted from the WWW-Authenticate header.
// Empty if not present or there was an error obtaining it.
ResourceMetadataURL string
}
// The handler receives the HTTP request and response that triggered the authentication flow.
// To obtain the protected resource metadata, call [oauthex.GetProtectedResourceMetadataFromHeader].
type OAuthHandler func(req *http.Request, res *http.Response) (oauth2.TokenSource, error)

// HTTPTransport is an [http.RoundTripper] that follows the MCP
// OAuth protocol when it encounters a 401 Unauthorized response.
Expand Down Expand Up @@ -112,10 +105,7 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// TODO: We hold the lock for the entire OAuth flow. This could be a long
// time. Is there a better way?
if _, ok := t.opts.Base.(*oauth2.Transport); !ok {
authHeaders := resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]
ts, err := t.handler(req.Context(), OAuthHandlerArgs{
ResourceMetadataURL: extractResourceMetadataURL(authHeaders),
})
ts, err := t.handler(req, resp)
if err != nil {
return nil, err
}
Expand All @@ -131,11 +121,3 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {

return t.opts.Base.RoundTrip(req)
}

func extractResourceMetadataURL(authHeaders []string) string {
cs, err := oauthex.ParseWWWAuthenticate(authHeaders)
if err != nil {
return ""
}
return oauthex.ResourceMetadataURL(cs)
}
17 changes: 9 additions & 8 deletions auth/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
package auth

import (
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -65,10 +64,11 @@ func TestHTTPTransport(t *testing.T) {

t.Run("successful auth flow", func(t *testing.T) {
var handlerCalls int
handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) {
handler := func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) {
handlerCalls++
if args.ResourceMetadataURL != "http://metadata.example.com" {
t.Errorf("handler got metadata URL %q, want %q", args.ResourceMetadataURL, "http://metadata.example.com")
// Verify that the response has the expected WWW-Authenticate header
if res.Header.Get("WWW-Authenticate") != `Bearer resource_metadata="http://metadata.example.com"` {
t.Errorf("handler got WWW-Authenticate header %q", res.Header.Get("WWW-Authenticate"))
}
return fakeTokenSource, nil
}
Expand Down Expand Up @@ -108,9 +108,10 @@ func TestHTTPTransport(t *testing.T) {
})

t.Run("request body is cloned", func(t *testing.T) {
handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) {
if args.ResourceMetadataURL != "http://metadata.example.com" {
t.Errorf("handler got metadata URL %q, want %q", args.ResourceMetadataURL, "http://metadata.example.com")
handler := func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) {
// Verify that the response has the expected WWW-Authenticate header
if res.Header.Get("WWW-Authenticate") != `Bearer resource_metadata="http://metadata.example.com"` {
t.Errorf("handler got WWW-Authenticate header %q", res.Header.Get("WWW-Authenticate"))
}
return fakeTokenSource, nil
}
Expand All @@ -134,7 +135,7 @@ func TestHTTPTransport(t *testing.T) {

t.Run("handler returns error", func(t *testing.T) {
handlerErr := errors.New("user rejected auth")
handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) {
handler := func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) {
return nil, handlerErr
}

Expand Down