diff --git a/auth/client.go b/auth/client.go new file mode 100644 index 00000000..2dda1351 --- /dev/null +++ b/auth/client.go @@ -0,0 +1,109 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "errors" + "net/http" + "sync" + + "github.com/modelcontextprotocol/go-sdk/internal/oauthex" + "golang.org/x/oauth2" +) + +// An OAuthHandler conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization +// is approved, or an error if not. +type OAuthHandler func(context.Context, OAuthHandlerArgs) (oauth2.TokenSource, error) + +// OAuthHandlerArgs are arguments to an [OAuthHandler]. +type OAuthHandlerArgs struct { + // The URL to fetch protected resource metadata, extracted from the WWW-Authenticate header. + // Empty if not present or there was an error obtaining it. + ResourceMetadataURL string +} + +// HTTPTransport is an [http.RoundTripper] that follows the MCP +// OAuth protocol when it encounters a 401 Unauthorized response. +type HTTPTransport struct { + handler OAuthHandler + mu sync.Mutex // protects opts.Base + opts HTTPTransportOptions +} + +// NewHTTPTransport returns a new [*HTTPTransport]. +// The handler is invoked when an HTTP request results in a 401 Unauthorized status. +// It is called only once per transport. Once a TokenSource is obtained, it is used +// for the lifetime of the transport; subsequent 401s are not processed. +func NewHTTPTransport(handler OAuthHandler, opts *HTTPTransportOptions) (*HTTPTransport, error) { + if handler == nil { + return nil, errors.New("handler cannot be nil") + } + t := &HTTPTransport{ + handler: handler, + } + if opts != nil { + t.opts = *opts + } + if t.opts.Base == nil { + t.opts.Base = http.DefaultTransport + } + return t, nil +} + +// HTTPTransportOptions are options to [NewHTTPTransport]. +type HTTPTransportOptions struct { + // Base is the [http.RoundTripper] to use. + // If nil, [http.DefaultTransport] is used. + Base http.RoundTripper +} + +func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { + t.mu.Lock() + base := t.opts.Base + t.mu.Unlock() + + resp, err := base.RoundTrip(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusUnauthorized { + return resp, nil + } + if _, ok := base.(*oauth2.Transport); ok { + // We failed to authorize even with a token source; give up. + return resp, nil + } + + resp.Body.Close() + // Try to authorize. + t.mu.Lock() + defer t.mu.Unlock() + // If we don't have a token source, get one by following the OAuth flow. + // (We may have obtained one while t.mu was not held above.) + // TODO: We hold the lock for the entire OAuth flow. This could be a long + // time. Is there a better way? + if _, ok := t.opts.Base.(*oauth2.Transport); !ok { + authHeaders := resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")] + ts, err := t.handler(req.Context(), OAuthHandlerArgs{ + ResourceMetadataURL: extractResourceMetadataURL(authHeaders), + }) + if err != nil { + return nil, err + } + t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts} + } + return t.opts.Base.RoundTrip(req.Clone(req.Context())) +} + +func extractResourceMetadataURL(authHeaders []string) string { + cs, err := oauthex.ParseWWWAuthenticate(authHeaders) + if err != nil { + return "" + } + return oauthex.ResourceMetadataURL(cs) +} diff --git a/auth/client_test.go b/auth/client_test.go new file mode 100644 index 00000000..310fc56e --- /dev/null +++ b/auth/client_test.go @@ -0,0 +1,102 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "golang.org/x/oauth2" +) + +// TestHTTPTransport validates the OAuth HTTPTransport. +func TestHTTPTransport(t *testing.T) { + const testToken = "test-token-123" + fakeTokenSource := oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: testToken, + TokenType: "Bearer", + }) + + // authServer simulates a resource that requires OAuth. + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if authHeader == fmt.Sprintf("Bearer %s", testToken) { + w.WriteHeader(http.StatusOK) + return + } + + w.Header().Set("WWW-Authenticate", `Bearer resource_metadata="http://metadata.example.com"`) + w.WriteHeader(http.StatusUnauthorized) + })) + defer authServer.Close() + + t.Run("successful auth flow", func(t *testing.T) { + var handlerCalls int + handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) { + handlerCalls++ + if args.ResourceMetadataURL != "http://metadata.example.com" { + t.Errorf("handler got metadata URL %q, want %q", args.ResourceMetadataURL, "http://metadata.example.com") + } + return fakeTokenSource, nil + } + + transport, err := NewHTTPTransport(handler, nil) + if err != nil { + t.Fatalf("NewHTTPTransport() failed: %v", err) + } + client := &http.Client{Transport: transport} + + resp, err := client.Get(authServer.URL) + if err != nil { + t.Fatalf("client.Get() failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("got status %d, want %d", resp.StatusCode, http.StatusOK) + } + if handlerCalls != 1 { + t.Errorf("handler was called %d times, want 1", handlerCalls) + } + + // Second request should reuse the token and not call the handler again. + resp2, err := client.Get(authServer.URL) + if err != nil { + t.Fatalf("second client.Get() failed: %v", err) + } + defer resp2.Body.Close() + + if resp2.StatusCode != http.StatusOK { + t.Errorf("second request got status %d, want %d", resp2.StatusCode, http.StatusOK) + } + if handlerCalls != 1 { + t.Errorf("handler should still be called only once, but was %d", handlerCalls) + } + }) + + t.Run("handler returns error", func(t *testing.T) { + handlerErr := errors.New("user rejected auth") + handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) { + return nil, handlerErr + } + + transport, err := NewHTTPTransport(handler, nil) + if err != nil { + t.Fatalf("NewHTTPTransport() failed: %v", err) + } + client := &http.Client{Transport: transport} + + _, err = client.Get(authServer.URL) + if !errors.Is(err, handlerErr) { + t.Errorf("client.Get() returned error %v, want %v", err, handlerErr) + } + }) +} diff --git a/go.mod b/go.mod index b78c25e7..d5917a16 100644 --- a/go.mod +++ b/go.mod @@ -7,5 +7,6 @@ require ( github.com/google/go-cmp v0.7.0 github.com/google/jsonschema-go v0.3.0 github.com/yosida95/uritemplate/v3 v3.0.2 + golang.org/x/oauth2 v0.30.0 golang.org/x/tools v0.34.0 ) diff --git a/go.sum b/go.sum index 2006a674..32ceedfe 100644 --- a/go.sum +++ b/go.sum @@ -2,13 +2,11 @@ github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeD github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/jsonschema-go v0.2.3 h1:dkP3B96OtZKKFvdrUSaDkL+YDx8Uw9uC4Y+eukpCnmM= -github.com/google/jsonschema-go v0.2.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= -github.com/google/jsonschema-go v0.2.4-0.20250922144851-e08864c65371 h1:e1VCqWtKpTYBOBhPcgGV5whTlMFpTbH5Ghm56wpxBsk= -github.com/google/jsonschema-go v0.2.4-0.20250922144851-e08864c65371/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= diff --git a/internal/oauthex/resource_meta.go b/internal/oauthex/resource_meta.go index 71d52cde..2387b0b5 100644 --- a/internal/oauthex/resource_meta.go +++ b/internal/oauthex/resource_meta.go @@ -146,11 +146,11 @@ func GetProtectedResourceMetadataFromHeader(ctx context.Context, header http.Hea if len(headers) == 0 { return nil, nil } - cs, err := parseWWWAuthenticate(headers) + cs, err := ParseWWWAuthenticate(headers) if err != nil { return nil, err } - url := resourceMetadataURL(cs) + url := ResourceMetadataURL(cs) if url == "" { return nil, nil } @@ -187,9 +187,9 @@ type challenge struct { Params map[string]string } -// resourceMetadataURL returns a resource metadata URL from the given challenges, +// ResourceMetadataURL returns a resource metadata URL from the given challenges, // or the empty string if there is none. -func resourceMetadataURL(cs []challenge) string { +func ResourceMetadataURL(cs []challenge) string { for _, c := range cs { if u := c.Params["resource_metadata"]; u != "" { return u @@ -198,11 +198,11 @@ func resourceMetadataURL(cs []challenge) string { return "" } -// parseWWWAuthenticate parses a WWW-Authenticate header string. +// ParseWWWAuthenticate parses a WWW-Authenticate header string. // The header format is defined in RFC 9110, Section 11.6.1, and can contain // one or more challenges, separated by commas. // It returns a slice of challenges or an error if one of the headers is malformed. -func parseWWWAuthenticate(headers []string) ([]challenge, error) { +func ParseWWWAuthenticate(headers []string) ([]challenge, error) { // GENERATED BY GEMINI 2.5 (human-tweaked) var challenges []challenge for _, h := range headers {