-
Notifications
You must be signed in to change notification settings - Fork 233
mcp: add client-side OAuth flow #544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bffb313
ddb40ce
c83604c
5e8bfda
a60c60b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
// Copyright 2025 The Go MCP SDK Authors. All rights reserved. | ||
// Use of this source code is governed by an MIT-style | ||
// license that can be found in the LICENSE file. | ||
|
||
//go:build mcp_go_client_oauth | ||
|
||
package auth | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"net/http" | ||
"sync" | ||
|
||
"github.com/modelcontextprotocol/go-sdk/internal/oauthex" | ||
"golang.org/x/oauth2" | ||
) | ||
|
||
// An OAuthHandler conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization | ||
// is approved, or an error if not. | ||
type OAuthHandler func(context.Context, OAuthHandlerArgs) (oauth2.TokenSource, error) | ||
|
||
// OAuthHandlerArgs are arguments to an [OAuthHandler]. | ||
type OAuthHandlerArgs struct { | ||
// The URL to fetch protected resource metadata, extracted from the WWW-Authenticate header. | ||
// Empty if not present or there was an error obtaining it. | ||
ResourceMetadataURL string | ||
} | ||
|
||
// HTTPTransport is an [http.RoundTripper] that follows the MCP | ||
// OAuth protocol when it encounters a 401 Unauthorized response. | ||
type HTTPTransport struct { | ||
handler OAuthHandler | ||
mu sync.Mutex // protects opts.Base | ||
opts HTTPTransportOptions | ||
} | ||
|
||
// NewHTTPTransport returns a new [*HTTPTransport]. | ||
// The handler is invoked when an HTTP request results in a 401 Unauthorized status. | ||
// It is called only once per transport. Once a TokenSource is obtained, it is used | ||
// for the lifetime of the transport; subsequent 401s are not processed. | ||
func NewHTTPTransport(handler OAuthHandler, opts *HTTPTransportOptions) (*HTTPTransport, error) { | ||
if handler == nil { | ||
return nil, errors.New("handler cannot be nil") | ||
} | ||
t := &HTTPTransport{ | ||
handler: handler, | ||
} | ||
if opts != nil { | ||
t.opts = *opts | ||
} | ||
if t.opts.Base == nil { | ||
t.opts.Base = http.DefaultTransport | ||
} | ||
return t, nil | ||
} | ||
|
||
// HTTPTransportOptions are options to [NewHTTPTransport]. | ||
type HTTPTransportOptions struct { | ||
// Base is the [http.RoundTripper] to use. | ||
// If nil, [http.DefaultTransport] is used. | ||
Base http.RoundTripper | ||
} | ||
|
||
func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { | ||
t.mu.Lock() | ||
base := t.opts.Base | ||
t.mu.Unlock() | ||
|
||
resp, err := base.RoundTrip(req) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if resp.StatusCode != http.StatusUnauthorized { | ||
return resp, nil | ||
} | ||
if _, ok := base.(*oauth2.Transport); ok { | ||
// We failed to authorize even with a token source; give up. | ||
return resp, nil | ||
} | ||
|
||
resp.Body.Close() | ||
// Try to authorize. | ||
t.mu.Lock() | ||
defer t.mu.Unlock() | ||
// If we don't have a token source, get one by following the OAuth flow. | ||
// (We may have obtained one while t.mu was not held above.) | ||
// TODO: We hold the lock for the entire OAuth flow. This could be a long | ||
// time. Is there a better way? | ||
if _, ok := t.opts.Base.(*oauth2.Transport); !ok { | ||
authHeaders := resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")] | ||
ts, err := t.handler(req.Context(), OAuthHandlerArgs{ | ||
ResourceMetadataURL: extractResourceMetadataURL(authHeaders), | ||
}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts} | ||
} | ||
return t.opts.Base.RoundTrip(req.Clone(req.Context())) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In cagent we use the same trick for OAuth. I just spent the morning debugging and fixing our code for streamable mcp servers and thought I'd tell you what was wrong, might be useful for you. This second request that comes after the oauth flow has a half-broken request, the body of the request was already read by the first call on line 70, so you would get a 400 Bad Request back from the server. You need to hold on the original body and put it back here before sending a second RoundTrip. |
||
} | ||
|
||
func extractResourceMetadataURL(authHeaders []string) string { | ||
cs, err := oauthex.ParseWWWAuthenticate(authHeaders) | ||
if err != nil { | ||
return "" | ||
} | ||
return oauthex.ResourceMetadataURL(cs) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
// Copyright 2025 The Go MCP SDK Authors. All rights reserved. | ||
// Use of this source code is governed by an MIT-style | ||
// license that can be found in the LICENSE file. | ||
|
||
//go:build mcp_go_client_oauth | ||
|
||
package auth | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"golang.org/x/oauth2" | ||
) | ||
|
||
// TestHTTPTransport validates the OAuth HTTPTransport. | ||
func TestHTTPTransport(t *testing.T) { | ||
const testToken = "test-token-123" | ||
fakeTokenSource := oauth2.StaticTokenSource(&oauth2.Token{ | ||
AccessToken: testToken, | ||
TokenType: "Bearer", | ||
}) | ||
|
||
// authServer simulates a resource that requires OAuth. | ||
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
authHeader := r.Header.Get("Authorization") | ||
if authHeader == fmt.Sprintf("Bearer %s", testToken) { | ||
w.WriteHeader(http.StatusOK) | ||
return | ||
} | ||
|
||
w.Header().Set("WWW-Authenticate", `Bearer resource_metadata="http://metadata.example.com"`) | ||
w.WriteHeader(http.StatusUnauthorized) | ||
})) | ||
defer authServer.Close() | ||
|
||
t.Run("successful auth flow", func(t *testing.T) { | ||
var handlerCalls int | ||
handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) { | ||
handlerCalls++ | ||
if args.ResourceMetadataURL != "http://metadata.example.com" { | ||
t.Errorf("handler got metadata URL %q, want %q", args.ResourceMetadataURL, "http://metadata.example.com") | ||
} | ||
return fakeTokenSource, nil | ||
} | ||
|
||
transport, err := NewHTTPTransport(handler, nil) | ||
if err != nil { | ||
t.Fatalf("NewHTTPTransport() failed: %v", err) | ||
} | ||
client := &http.Client{Transport: transport} | ||
|
||
resp, err := client.Get(authServer.URL) | ||
if err != nil { | ||
t.Fatalf("client.Get() failed: %v", err) | ||
} | ||
defer resp.Body.Close() | ||
|
||
if resp.StatusCode != http.StatusOK { | ||
t.Errorf("got status %d, want %d", resp.StatusCode, http.StatusOK) | ||
} | ||
if handlerCalls != 1 { | ||
t.Errorf("handler was called %d times, want 1", handlerCalls) | ||
} | ||
|
||
// Second request should reuse the token and not call the handler again. | ||
resp2, err := client.Get(authServer.URL) | ||
if err != nil { | ||
t.Fatalf("second client.Get() failed: %v", err) | ||
} | ||
defer resp2.Body.Close() | ||
|
||
if resp2.StatusCode != http.StatusOK { | ||
t.Errorf("second request got status %d, want %d", resp2.StatusCode, http.StatusOK) | ||
} | ||
if handlerCalls != 1 { | ||
t.Errorf("handler should still be called only once, but was %d", handlerCalls) | ||
} | ||
}) | ||
|
||
t.Run("handler returns error", func(t *testing.T) { | ||
handlerErr := errors.New("user rejected auth") | ||
handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) { | ||
return nil, handlerErr | ||
} | ||
|
||
transport, err := NewHTTPTransport(handler, nil) | ||
if err != nil { | ||
t.Fatalf("NewHTTPTransport() failed: %v", err) | ||
} | ||
client := &http.Client{Transport: transport} | ||
|
||
_, err = client.Get(authServer.URL) | ||
if !errors.Is(err, handlerErr) { | ||
t.Errorf("client.Get() returned error %v, want %v", err, handlerErr) | ||
} | ||
}) | ||
} |
Uh oh!
There was an error while loading. Please reload this page.