From b41b0d0411cc2ca71d2eebb731ce988617c3873f Mon Sep 17 00:00:00 2001 From: Saurabh Davala Date: Wed, 8 Oct 2025 14:05:58 -0700 Subject: [PATCH 1/5] Add robustness improvements for OAuth discovery MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Make WWW-Authenticate header optional with fallback to well-known endpoints. This enables discovery to work with partially-compliant MCP servers that: - Don't provide WWW-Authenticate headers - Provide unparseable WWW-Authenticate headers - Return non-401 responses that still require OAuth Changes: - Make WWW-Authenticate optional: log warning if missing/unparseable instead of failing - Add nil check before calling FindResourceMetadataURL(challenges) - Continue discovery even if initial response isn't 401 (log warning) - Fallback to /.well-known/oauth-protected-resource when resource_metadata URL not available This aligns with RFC 9728 requirement that servers MUST provide the well-known endpoint, making it a valid fallback when WWW-Authenticate is unavailable. Fixes issue with Neon server and other servers that don't fully implement MCP Authorization Specification Section 4.1. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- discovery.go | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/discovery.go b/discovery.go index f4332e5..6ea26c2 100644 --- a/discovery.go +++ b/discovery.go @@ -61,21 +61,22 @@ func DiscoverOAuthRequirements(ctx context.Context, serverURL string) (*Discover // If not 401, OAuth is not required (Authorization is OPTIONAL per MCP spec Section 2.1) if resp.StatusCode != http.StatusUnauthorized { - return &Discovery{ - RequiresOAuth: false, - }, nil + fmt.Printf("warning: status code is not 401: %d\n", resp.StatusCode) } // STEP 2: Parse WWW-Authenticate header (if present) // MCP Spec Section 4.1: "MCP servers MUST use the HTTP header WWW-Authenticate when returning a 401 Unauthorized" wwwAuth := resp.Header.Get("WWW-Authenticate") - if wwwAuth == "" { - return nil, fmt.Errorf("server returned 401 but no WWW-Authenticate header") - } - challenges, err := ParseWWWAuthenticate(wwwAuth) - if err != nil { - return nil, fmt.Errorf("parsing WWW-Authenticate header: %w", err) + var challenges []WWWAuthenticateChallenge + if wwwAuth != "" { + var err error + challenges, err = ParseWWWAuthenticate(wwwAuth) + if err != nil { + // WWW-Authenticate header exists but isn't parseable - log but continue + fmt.Printf("warning: could not parse WWW-Authenticate header: %v\n", err) + challenges = nil + } } // STEP 3: Initialize with intelligent defaults (Inspector pattern) @@ -89,7 +90,11 @@ func DiscoverOAuthRequirements(ctx context.Context, serverURL string) (*Discover // STEP 4: Try to get resource metadata (OPTIONAL - don't fail if missing) // RFC 9728 Section 5.1: resource_metadata parameter in WWW-Authenticate - resourceMetadataURL := FindResourceMetadataURL(challenges) + resourceMetadataURL := "" + if challenges != nil { + resourceMetadataURL = FindResourceMetadataURL(challenges) + } + if resourceMetadataURL != "" { // Resource metadata URL found - try to fetch it resourceMetadata, resourceMetadataError = fetchOAuthProtectedResourceMetadata(ctx, client, resourceMetadataURL) From 8b4dd4dc8e6b151efe244d1eb3ef4eb75065c8d6 Mon Sep 17 00:00:00 2001 From: Saurabh Davala Date: Wed, 8 Oct 2025 15:45:37 -0700 Subject: [PATCH 2/5] Add structured logging via context pattern MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace fmt.Printf with proper structured logging that integrates with caller's logging infrastructure. Changes: - Add Logger interface with Infof/Warnf/Errorf methods - Support context-based logger injection via WithLogger/LoggerFromContext - Provide WrapLogger helper to adapt any compatible logger - Update DiscoverOAuthRequirements to use logger from context - Fallback to default stderr logger if no logger provided Benefits: - Logs from library now appear with proper component tags ([com.docker.backend.dcr]) - No messy grep patterns needed - logs integrate naturally - Pluggable logging - works with any logger (logrus, zap, slog, etc.) - Backward compatible - uses default logger if none provided This is the Go-idiomatic approach for library logging, following the context pattern used throughout the Go ecosystem. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- discovery.go | 7 +++- log.go | 107 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 2 deletions(-) create mode 100644 log.go diff --git a/discovery.go b/discovery.go index 6ea26c2..e85ca46 100644 --- a/discovery.go +++ b/discovery.go @@ -28,6 +28,9 @@ import ( // 5. Always fetch Authorization Server Metadata (required) // 6. Build discovery result with whatever information is available func DiscoverOAuthRequirements(ctx context.Context, serverURL string) (*Discovery, error) { + // Extract logger from context (or use default) + logger := LoggerFromContext(ctx) + // Create HTTP client with reasonable timeout client := &http.Client{ Timeout: 30 * time.Second, @@ -61,7 +64,7 @@ func DiscoverOAuthRequirements(ctx context.Context, serverURL string) (*Discover // If not 401, OAuth is not required (Authorization is OPTIONAL per MCP spec Section 2.1) if resp.StatusCode != http.StatusUnauthorized { - fmt.Printf("warning: status code is not 401: %d\n", resp.StatusCode) + logger.Warnf("status code is not 401: %d", resp.StatusCode) } // STEP 2: Parse WWW-Authenticate header (if present) @@ -74,7 +77,7 @@ func DiscoverOAuthRequirements(ctx context.Context, serverURL string) (*Discover challenges, err = ParseWWWAuthenticate(wwwAuth) if err != nil { // WWW-Authenticate header exists but isn't parseable - log but continue - fmt.Printf("warning: could not parse WWW-Authenticate header: %v\n", err) + logger.Warnf("could not parse WWW-Authenticate header: %v", err) challenges = nil } } diff --git a/log.go b/log.go new file mode 100644 index 0000000..542c9b8 --- /dev/null +++ b/log.go @@ -0,0 +1,107 @@ +package oauth + +import ( + "context" + "fmt" + "log" + "os" +) + +// Logger is a minimal interface for structured logging +// Implementations can provide their own logger (e.g., logrus, zap, slog) +type Logger interface { + Infof(format string, args ...interface{}) + Warnf(format string, args ...interface{}) + Errorf(format string, args ...interface{}) +} + +type contextKey struct{} + +var loggerKey = contextKey{} + +// WithLogger returns a new context with the logger attached +func WithLogger(ctx context.Context, logger Logger) context.Context { + return context.WithValue(ctx, loggerKey, logger) +} + +// LoggerFromContext extracts the logger from context +// Returns a default logger if none is set +func LoggerFromContext(ctx context.Context) Logger { + if logger, ok := ctx.Value(loggerKey).(Logger); ok { + return logger + } + return defaultLogger +} + +// defaultLogger is a simple implementation that logs to stderr with a prefix +var defaultLogger Logger = &stdLogger{ + logger: log.New(os.Stderr, "[oauth-helpers] ", log.LstdFlags), +} + +// stdLogger implements Logger using the standard log package +type stdLogger struct { + logger *log.Logger +} + +func (l *stdLogger) Infof(format string, args ...interface{}) { + l.logger.Printf("INFO: "+format, args...) +} + +func (l *stdLogger) Warnf(format string, args ...interface{}) { + l.logger.Printf("WARN: "+format, args...) +} + +func (l *stdLogger) Errorf(format string, args ...interface{}) { + l.logger.Printf("ERROR: "+format, args...) +} + +// noopLogger is a logger that does nothing (for testing or when logging is disabled) +type noopLogger struct{} + +func (noopLogger) Infof(format string, args ...interface{}) {} +func (noopLogger) Warnf(format string, args ...interface{}) {} +func (noopLogger) Errorf(format string, args ...interface{}) {} + +// NoopLogger returns a logger that does nothing +func NoopLogger() Logger { + return noopLogger{} +} + +// Adapter for Pinata's ComponentLogger to implement our Logger interface +// This allows Pinata to pass its logger directly without wrapping + +// For convenience, add a function to wrap any interface with these methods +func WrapLogger(l interface { + Infof(format string, args ...interface{}) + Warnf(format string, args ...interface{}) + Errorf(format string, args ...interface{}) +}) Logger { + return &wrappedLogger{l: l} +} + +type wrappedLogger struct { + l interface { + Infof(format string, args ...interface{}) + Warnf(format string, args ...interface{}) + Errorf(format string, args ...interface{}) + } +} + +func (w *wrappedLogger) Infof(format string, args ...interface{}) { + w.l.Infof(format, args...) +} + +func (w *wrappedLogger) Warnf(format string, args ...interface{}) { + w.l.Warnf(format, args...) +} + +func (w *wrappedLogger) Errorf(format string, args ...interface{}) { + w.l.Errorf(format, args...) +} + +// Helper to create a prefixed logger from fmt package (for quick debugging) +func NewPrefixLogger(prefix string) Logger { + return &stdLogger{ + logger: log.New(os.Stderr, fmt.Sprintf("[%s] ", prefix), log.LstdFlags), + } +} From 2e2cebfe3dd547049233b2ee8747d5a23389efe5 Mon Sep 17 00:00:00 2001 From: Saurabh Davala Date: Thu, 9 Oct 2025 16:40:01 -0700 Subject: [PATCH 3/5] Add ci + tests --- .github/workflows/ci.yml | 35 ++++++++ dcr.go | 1 - dcr_test.go | 95 ++++++++++++++++++++ discovery.go | 47 +++++++--- discovery_test.go | 189 +++++++++++++++++++++++++++++++++++++++ log.go | 101 ++++----------------- testutil.go | 34 +++++++ www_authenticate.go | 17 ---- www_authenticate_test.go | 174 +++++++++++++++++++++++++++++++++++ 9 files changed, 579 insertions(+), 114 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100644 dcr_test.go create mode 100644 discovery_test.go create mode 100644 testutil.go create mode 100644 www_authenticate_test.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..35b0d38 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,35 @@ +name: CI + +permissions: + contents: read + +on: + pull_request: + push: + branches: + - main + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Run tests + run: make test + + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Run linter for Linux + run: make lint-linux diff --git a/dcr.go b/dcr.go index 4932281..6985d48 100644 --- a/dcr.go +++ b/dcr.go @@ -41,7 +41,6 @@ func PerformDCR(ctx context.Context, discovery *Discovery, serverName string) (* // Add requested scopes if provided if len(discovery.Scopes) > 0 { registration.Scope = joinScopes(discovery.Scopes) - } else { } // Marshal the registration request diff --git a/dcr_test.go b/dcr_test.go new file mode 100644 index 0000000..b03f016 --- /dev/null +++ b/dcr_test.go @@ -0,0 +1,95 @@ +package oauth + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +// TestPerformDCR_PublicClient verifies Dynamic Client Registration +// for public clients (no client secret) +func TestPerformDCR_PublicClient(t *testing.T) { + var capturedRequest *DCRRequest + + // Mock registration endpoint + regServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Capture and verify the request + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &capturedRequest) + + // Return successful registration response + _ = json.NewEncoder(w).Encode(DCRResponse{ + ClientID: "test-client-id-123", + TokenEndpointAuthMethod: "none", + GrantTypes: []string{"authorization_code", "refresh_token"}, + RedirectURIs: []string{"https://mcp.docker.com/oauth/callback"}, + }) + })) + defer regServer.Close() + + // Create discovery with registration endpoint + discovery := &Discovery{ + RegistrationEndpoint: regServer.URL, + AuthorizationEndpoint: "https://auth.example.com/authorize", + TokenEndpoint: "https://auth.example.com/token", + ResourceURL: "https://api.example.com", + Scopes: []string{"read", "write"}, + } + + // Perform DCR + creds, err := PerformDCR(context.Background(), discovery, "test-server") + // Verify no error + if err != nil { + t.Fatalf("DCR failed: %v", err) + } + + // Verify credentials + if creds.ClientID != "test-client-id-123" { + t.Errorf("Expected ClientID=test-client-id-123, got %s", creds.ClientID) + } + if !creds.IsPublic { + t.Error("Expected IsPublic=true for public client") + } + if creds.ServerURL != "https://api.example.com" { + t.Errorf("Expected ServerURL=https://api.example.com, got %s", creds.ServerURL) + } + + // Verify DCR request was correct + if capturedRequest == nil { + t.Fatal("DCR request not captured") + } + if capturedRequest.TokenEndpointAuthMethod != "none" { + t.Errorf("Expected token_endpoint_auth_method=none for public client, got %s", capturedRequest.TokenEndpointAuthMethod) + } + if len(capturedRequest.RedirectURIs) == 0 { + t.Error("Expected redirect_uris to be set") + } + if len(capturedRequest.GrantTypes) == 0 { + t.Error("Expected grant_types to be set") + } +} + +// TestPerformDCR_NoRegistrationEndpoint verifies error handling +// when registration endpoint is not available +func TestPerformDCR_NoRegistrationEndpoint(t *testing.T) { + // Create discovery WITHOUT registration endpoint + discovery := &Discovery{ + AuthorizationEndpoint: "https://auth.example.com/authorize", + TokenEndpoint: "https://auth.example.com/token", + RegistrationEndpoint: "", // Empty - DCR not supported + } + + // Attempt DCR + creds, err := PerformDCR(context.Background(), discovery, "test-server") + + // Verify error occurred + if err == nil { + t.Fatal("Expected error when registration endpoint missing") + } + if creds != nil { + t.Error("Expected nil credentials on error") + } +} diff --git a/discovery.go b/discovery.go index e85ca46..1771c40 100644 --- a/discovery.go +++ b/discovery.go @@ -21,15 +21,20 @@ import ( // - Gracefully handles servers with partial MCP compliance // // ROBUST DISCOVERY FLOW (Inspector-inspired): -// 1. Make request to MCP server to trigger 401 response -// 2. Default authorization server to MCP server domain -// 3. Try to parse WWW-Authenticate header for resource_metadata URL -// 4. If resource metadata available, try to fetch it (optional) -// 5. Always fetch Authorization Server Metadata (required) -// 6. Build discovery result with whatever information is available +// 1. Make initial MCP request (expect 401 if OAuth required) +// 2. Parse WWW-Authenticate header (if present) +// 3. Initialize with intelligent defaults (fallback auth server = MCP domain) +// 4. Fetch resource metadata (from header URL or well-known endpoint fallback) +// 5. Fetch Authorization Server Metadata (REQUIRED) +// 6. Build discovery result with all gathered information +// +// FALLBACK BEHAVIOR: If WWW-Authenticate missing/unparseable, falls back to +// RFC 9728-required /.well-known/oauth-protected-resource endpoint func DiscoverOAuthRequirements(ctx context.Context, serverURL string) (*Discovery, error) { - // Extract logger from context (or use default) - logger := LoggerFromContext(ctx) + // Extract logger from context (or use noop if not provided) + logger := loggerFromContext(ctx) + + logger.Infof("starting OAuth discovery for server: %s", serverURL) // Create HTTP client with reasonable timeout client := &http.Client{ @@ -62,9 +67,12 @@ func DiscoverOAuthRequirements(ctx context.Context, serverURL string) (*Discover } defer resp.Body.Close() - // If not 401, OAuth is not required (Authorization is OPTIONAL per MCP spec Section 2.1) + logger.Infof("MCP server response: status=%d", resp.StatusCode) + + // If not 401, OAuth might not be required (Authorization is OPTIONAL per MCP spec Section 2.1) + // We log a warning but continue discovery attempt in case server is misconfigured if resp.StatusCode != http.StatusUnauthorized { - logger.Warnf("status code is not 401: %d", resp.StatusCode) + logger.Warnf("expected 401 Unauthorized, got %d - OAuth may not be required", resp.StatusCode) } // STEP 2: Parse WWW-Authenticate header (if present) @@ -73,18 +81,24 @@ func DiscoverOAuthRequirements(ctx context.Context, serverURL string) (*Discover var challenges []WWWAuthenticateChallenge if wwwAuth != "" { + logger.Infof("WWW-Authenticate header present: %s", wwwAuth) var err error challenges, err = ParseWWWAuthenticate(wwwAuth) if err != nil { // WWW-Authenticate header exists but isn't parseable - log but continue logger.Warnf("could not parse WWW-Authenticate header: %v", err) challenges = nil + } else { + logger.Infof("parsed %d WWW-Authenticate challenge(s)", len(challenges)) } + } else { + logger.Infof("no WWW-Authenticate header present - will try well-known endpoint") } // STEP 3: Initialize with intelligent defaults (Inspector pattern) // Default authorization server to MCP server's domain defaultAuthServerURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host) + logger.Debugf("default authorization server: %s", defaultAuthServerURL) // Initialize discovery with defaults var resourceMetadata *ProtectedResourceMetadata @@ -100,26 +114,36 @@ func DiscoverOAuthRequirements(ctx context.Context, serverURL string) (*Discover if resourceMetadataURL != "" { // Resource metadata URL found - try to fetch it + logger.Infof("fetching protected resource metadata from: %s", resourceMetadataURL) resourceMetadata, resourceMetadataError = fetchOAuthProtectedResourceMetadata(ctx, client, resourceMetadataURL) if resourceMetadataError == nil && resourceMetadata != nil && resourceMetadata.AuthorizationServer != "" { // Use authorization server from resource metadata if available authServerURL = resourceMetadata.AuthorizationServer + logger.Infof("resource metadata retrieved, auth server: %s", authServerURL) + } else if resourceMetadataError != nil { + logger.Warnf("failed to fetch resource metadata: %v", resourceMetadataError) } } else { // No resource_metadata in WWW-Authenticate - try well-known endpoint wellKnownURL := fmt.Sprintf("%s/.well-known/oauth-protected-resource", defaultAuthServerURL) + logger.Infof("FALLBACK: trying well-known resource metadata endpoint: %s", wellKnownURL) resourceMetadata, resourceMetadataError = fetchOAuthProtectedResourceMetadata(ctx, client, wellKnownURL) if resourceMetadataError == nil && resourceMetadata != nil && resourceMetadata.AuthorizationServer != "" { authServerURL = resourceMetadata.AuthorizationServer + logger.Infof("resource metadata from well-known endpoint, auth server: %s", authServerURL) } } // STEP 5: Fetch Authorization Server Metadata (REQUIRED) // MCP Spec Section 3.1: "Authorization servers MUST provide OAuth 2.0 Authorization Server Metadata (RFC8414)" + logger.Infof("fetching authorization server metadata from: %s", authServerURL) authServerMetadata, err := fetchAuthorizationServerMetadata(ctx, client, authServerURL) if err != nil { + logger.Warnf("failed to fetch authorization server metadata: %v", err) return nil, fmt.Errorf("fetching authorization server metadata from %s: %w", authServerURL, err) } + logger.Infof("auth server metadata retrieved: token_endpoint=%s, registration_endpoint=%s", + authServerMetadata.TokenEndpoint, authServerMetadata.RegistrationEndpoint) // STEP 6: Build discovery result with all available information discovery := &Discovery{ @@ -163,6 +187,9 @@ func DiscoverOAuthRequirements(ctx context.Context, serverURL string) (*Discover discovery.Scopes = FindRequiredScopes(challenges) } + logger.Infof("discovery complete: auth_server=%s, scopes=%v, pkce=%v", + discovery.AuthorizationServer, discovery.Scopes, discovery.SupportsPKCE) + return discovery, nil } diff --git a/discovery_test.go b/discovery_test.go new file mode 100644 index 0000000..91a81d8 --- /dev/null +++ b/discovery_test.go @@ -0,0 +1,189 @@ +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// TestDiscoveryFallback_NoWWWAuthenticate verifies the critical fallback behavior +// when MCP server doesn't provide WWW-Authenticate header +// +// This tests the fix for servers like Neon that: +// - Return 401 (correct) +// - Don't provide WWW-Authenticate header (MCP spec violation) +// - Do provide /.well-known/oauth-protected-resource endpoint (RFC 9728 compliant) +func TestDiscoveryFallback_NoWWWAuthenticate(t *testing.T) { + // Mock authorization server + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/.well-known/oauth-authorization-server") { + // Use r.Host to construct URLs dynamically + baseURL := "http://" + r.Host + _ = json.NewEncoder(w).Encode(AuthorizationServerMetadata{ + Issuer: baseURL, + AuthorizationEndpoint: baseURL + "/authorize", + TokenEndpoint: baseURL + "/token", + RegistrationEndpoint: baseURL + "/register", + CodeChallengeMethodsSupported: []string{"S256"}, + }) + return + } + })) + defer authServer.Close() + + // Mock MCP server (returns 401 WITHOUT WWW-Authenticate) + mcpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/mcp" { + // Return 401 WITHOUT WWW-Authenticate header (Neon behavior) + w.WriteHeader(http.StatusUnauthorized) + return + } + if r.URL.Path == "/.well-known/oauth-protected-resource" { + // Provide resource metadata at well-known endpoint + baseURL := "http://" + r.Host + _ = json.NewEncoder(w).Encode(ProtectedResourceMetadata{ + Resource: baseURL, + AuthorizationServer: authServer.URL, + }) + return + } + })) + defer mcpServer.Close() + + // Setup logger to verify fallback triggered + logger := &testLogger{} + ctx := WithLogger(context.Background(), logger) + + // Execute discovery + discovery, err := DiscoverOAuthRequirements(ctx, mcpServer.URL+"/mcp") + // Verify no error + if err != nil { + t.Fatalf("Discovery failed: %v", err) + } + + // Verify fallback was triggered + if !logger.containsInfo("FALLBACK: trying well-known") { + t.Error("Expected fallback to well-known endpoint to be triggered") + } + if !logger.containsInfo("no WWW-Authenticate header present") { + t.Error("Expected warning about missing WWW-Authenticate header") + } + + // Verify discovery succeeded + if !discovery.RequiresOAuth { + t.Error("Expected RequiresOAuth=true") + } + if discovery.TokenEndpoint != authServer.URL+"/token" { + t.Errorf("Expected TokenEndpoint=%s, got %s", authServer.URL+"/token", discovery.TokenEndpoint) + } + if !discovery.SupportsPKCE { + t.Error("Expected SupportsPKCE=true") + } +} + +// TestDiscoveryHappyPath_WithWWWAuthenticate verifies the standard flow +// when server provides proper WWW-Authenticate header +func TestDiscoveryHappyPath_WithWWWAuthenticate(t *testing.T) { + // Mock authorization server + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/.well-known/oauth-authorization-server") { + baseURL := "http://" + r.Host + _ = json.NewEncoder(w).Encode(AuthorizationServerMetadata{ + Issuer: baseURL, + AuthorizationEndpoint: baseURL + "/authorize", + TokenEndpoint: baseURL + "/token", + CodeChallengeMethodsSupported: []string{"S256"}, + }) + return + } + })) + defer authServer.Close() + + // Mock metadata server (separate from MCP server) + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(ProtectedResourceMetadata{ + Resource: "https://api.example.com", + AuthorizationServer: authServer.URL, + Scopes: []string{"read", "write"}, + }) + })) + defer metadataServer.Close() + + // Mock MCP server (returns 401 WITH WWW-Authenticate) + mcpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/mcp" { + // Return 401 WITH WWW-Authenticate header (standard MCP behavior) + w.Header().Set("WWW-Authenticate", fmt.Sprintf("Bearer realm=\"test\", resource_metadata=\"%s\"", metadataServer.URL)) + w.WriteHeader(http.StatusUnauthorized) + return + } + })) + defer mcpServer.Close() + + // Setup logger + logger := &testLogger{} + ctx := WithLogger(context.Background(), logger) + + // Execute discovery + discovery, err := DiscoverOAuthRequirements(ctx, mcpServer.URL+"/mcp") + // Verify no error + if err != nil { + t.Fatalf("Discovery failed: %v", err) + } + + // Verify WWW-Authenticate was parsed (no fallback) + if logger.containsInfo("FALLBACK") { + t.Error("Should not use fallback when WWW-Authenticate present") + } + if !logger.containsInfo("WWW-Authenticate header present") { + t.Error("Expected WWW-Authenticate header to be detected") + } + + // Verify discovery succeeded + if !discovery.RequiresOAuth { + t.Error("Expected RequiresOAuth=true") + } + if len(discovery.Scopes) != 2 { + t.Errorf("Expected 2 scopes from metadata, got %d", len(discovery.Scopes)) + } +} + +// TestDiscoveryError_AuthServerFails verifies error handling +// when authorization server metadata cannot be fetched +func TestDiscoveryError_AuthServerFails(t *testing.T) { + // Mock MCP server (returns 401, no WWW-Authenticate) + mcpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/mcp" { + w.WriteHeader(http.StatusUnauthorized) + return + } + if r.URL.Path == "/.well-known/oauth-protected-resource" { + // Return resource metadata pointing to non-existent auth server + baseURL := "http://" + r.Host + _ = json.NewEncoder(w).Encode(ProtectedResourceMetadata{ + Resource: baseURL, + AuthorizationServer: "http://localhost:99999", // Invalid/unreachable + }) + return + } + })) + defer mcpServer.Close() + + // Execute discovery (should fail) + discovery, err := DiscoverOAuthRequirements(context.Background(), mcpServer.URL+"/mcp") + + // Verify error occurred + if err == nil { + t.Fatal("Expected error when auth server metadata fetch fails") + } + if discovery != nil { + t.Error("Expected nil discovery on error") + } + if !strings.Contains(err.Error(), "fetching authorization server metadata") { + t.Errorf("Expected auth server error, got: %v", err) + } +} diff --git a/log.go b/log.go index 542c9b8..2235e8c 100644 --- a/log.go +++ b/log.go @@ -1,107 +1,36 @@ package oauth -import ( - "context" - "fmt" - "log" - "os" -) +import "context" -// Logger is a minimal interface for structured logging -// Implementations can provide their own logger (e.g., logrus, zap, slog) +// Logger is an interface for logging during OAuth discovery +// Implementations should log with appropriate formatting and destination type Logger interface { - Infof(format string, args ...interface{}) - Warnf(format string, args ...interface{}) - Errorf(format string, args ...interface{}) + Infof(format string, args ...any) // Informational messages + Warnf(format string, args ...any) // Warnings (non-fatal issues) + Debugf(format string, args ...any) // Debug/verbose details } type contextKey struct{} var loggerKey = contextKey{} -// WithLogger returns a new context with the logger attached +// WithLogger attaches a logger to the context func WithLogger(ctx context.Context, logger Logger) context.Context { return context.WithValue(ctx, loggerKey, logger) } -// LoggerFromContext extracts the logger from context -// Returns a default logger if none is set -func LoggerFromContext(ctx context.Context) Logger { +// loggerFromContext extracts the logger from context +// Returns a noop logger if none is set (for backward compatibility) +func loggerFromContext(ctx context.Context) Logger { if logger, ok := ctx.Value(loggerKey).(Logger); ok { return logger } - return defaultLogger -} - -// defaultLogger is a simple implementation that logs to stderr with a prefix -var defaultLogger Logger = &stdLogger{ - logger: log.New(os.Stderr, "[oauth-helpers] ", log.LstdFlags), -} - -// stdLogger implements Logger using the standard log package -type stdLogger struct { - logger *log.Logger -} - -func (l *stdLogger) Infof(format string, args ...interface{}) { - l.logger.Printf("INFO: "+format, args...) -} - -func (l *stdLogger) Warnf(format string, args ...interface{}) { - l.logger.Printf("WARN: "+format, args...) -} - -func (l *stdLogger) Errorf(format string, args ...interface{}) { - l.logger.Printf("ERROR: "+format, args...) -} - -// noopLogger is a logger that does nothing (for testing or when logging is disabled) -type noopLogger struct{} - -func (noopLogger) Infof(format string, args ...interface{}) {} -func (noopLogger) Warnf(format string, args ...interface{}) {} -func (noopLogger) Errorf(format string, args ...interface{}) {} - -// NoopLogger returns a logger that does nothing -func NoopLogger() Logger { return noopLogger{} } -// Adapter for Pinata's ComponentLogger to implement our Logger interface -// This allows Pinata to pass its logger directly without wrapping - -// For convenience, add a function to wrap any interface with these methods -func WrapLogger(l interface { - Infof(format string, args ...interface{}) - Warnf(format string, args ...interface{}) - Errorf(format string, args ...interface{}) -}) Logger { - return &wrappedLogger{l: l} -} - -type wrappedLogger struct { - l interface { - Infof(format string, args ...interface{}) - Warnf(format string, args ...interface{}) - Errorf(format string, args ...interface{}) - } -} - -func (w *wrappedLogger) Infof(format string, args ...interface{}) { - w.l.Infof(format, args...) -} - -func (w *wrappedLogger) Warnf(format string, args ...interface{}) { - w.l.Warnf(format, args...) -} - -func (w *wrappedLogger) Errorf(format string, args ...interface{}) { - w.l.Errorf(format, args...) -} +// noopLogger does nothing (used when no logger is provided) +type noopLogger struct{} -// Helper to create a prefixed logger from fmt package (for quick debugging) -func NewPrefixLogger(prefix string) Logger { - return &stdLogger{ - logger: log.New(os.Stderr, fmt.Sprintf("[%s] ", prefix), log.LstdFlags), - } -} +func (noopLogger) Infof(_ string, _ ...any) {} +func (noopLogger) Warnf(_ string, _ ...any) {} +func (noopLogger) Debugf(_ string, _ ...any) {} diff --git a/testutil.go b/testutil.go new file mode 100644 index 0000000..c8e049a --- /dev/null +++ b/testutil.go @@ -0,0 +1,34 @@ +package oauth + +import ( + "fmt" + "strings" +) + +// testLogger captures log messages for test verification +type testLogger struct { + infos []string + warns []string + debugs []string +} + +func (l *testLogger) Infof(format string, args ...any) { + l.infos = append(l.infos, fmt.Sprintf(format, args...)) +} + +func (l *testLogger) Warnf(format string, args ...any) { + l.warns = append(l.warns, fmt.Sprintf(format, args...)) +} + +func (l *testLogger) Debugf(format string, args ...any) { + l.debugs = append(l.debugs, fmt.Sprintf(format, args...)) +} + +func (l *testLogger) containsInfo(substr string) bool { + for _, msg := range l.infos { + if strings.Contains(msg, substr) { + return true + } + } + return false +} diff --git a/www_authenticate.go b/www_authenticate.go index a2869e5..707f099 100644 --- a/www_authenticate.go +++ b/www_authenticate.go @@ -191,20 +191,3 @@ func FindRequiredScopes(challenges []WWWAuthenticateChallenge) []string { return scopes } - -// FindRealm extracts the realm parameter from WWW-Authenticate challenges -// -// RFC 7235 COMPLIANCE: -// - Section 2.2: Defines realm parameter format -// - Returns the first realm found across all challenges -func FindRealm(challenges []WWWAuthenticateChallenge) string { - for _, challenge := range challenges { - if challenge.Parameters == nil { - continue - } - if realm, exists := challenge.Parameters["realm"]; exists && realm != "" { - return realm - } - } - return "" -} diff --git a/www_authenticate_test.go b/www_authenticate_test.go new file mode 100644 index 0000000..2153b5e --- /dev/null +++ b/www_authenticate_test.go @@ -0,0 +1,174 @@ +package oauth + +import ( + "testing" +) + +// TestParseWWWAuthenticate_Valid verifies parsing of standard WWW-Authenticate headers +func TestParseWWWAuthenticate_Valid(t *testing.T) { + tests := []struct { + name string + header string + expectSchemes int + expectParams map[string]string + }{ + { + name: "Bearer with resource_metadata", + header: `Bearer realm="example.com", resource_metadata="https://example.com/.well-known/oauth-protected-resource"`, + expectSchemes: 1, + expectParams: map[string]string{ + "realm": "example.com", + "resource_metadata": "https://example.com/.well-known/oauth-protected-resource", + }, + }, + { + name: "Bearer with scope", + header: `Bearer realm="api", scope="read write"`, + expectSchemes: 1, + expectParams: map[string]string{ + "realm": "api", + "scope": "read write", + }, + }, + { + name: "Multiple schemes", + header: `Basic realm="web", Bearer realm="api" scope="read"`, + expectSchemes: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + challenges, err := ParseWWWAuthenticate(tt.header) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + + if len(challenges) != tt.expectSchemes { + t.Errorf("Expected %d schemes, got %d", tt.expectSchemes, len(challenges)) + } + + if tt.expectParams != nil && len(challenges) > 0 { + for key, expectedValue := range tt.expectParams { + actualValue, exists := challenges[0].Parameters[key] + if !exists { + t.Errorf("Expected parameter %s not found", key) + } + if actualValue != expectedValue { + t.Errorf("Parameter %s: expected %s, got %s", key, expectedValue, actualValue) + } + } + } + }) + } +} + +// TestParseWWWAuthenticate_Malformed verifies error handling for invalid headers +func TestParseWWWAuthenticate_Malformed(t *testing.T) { + // Empty header should return error + _, err := ParseWWWAuthenticate("") + if err == nil { + t.Error("Expected error for empty header") + } +} + +// TestFindResourceMetadataURL verifies extraction of resource_metadata URL +func TestFindResourceMetadataURL(t *testing.T) { + tests := []struct { + name string + challenges []WWWAuthenticateChallenge + expectURL string + }{ + { + name: "Found in first challenge", + challenges: []WWWAuthenticateChallenge{ + { + Scheme: "Bearer", + Parameters: map[string]string{ + "resource_metadata": "https://example.com/.well-known", + }, + }, + }, + expectURL: "https://example.com/.well-known", + }, + { + name: "No resource_metadata parameter", + challenges: []WWWAuthenticateChallenge{ + { + Scheme: "Bearer", + Parameters: map[string]string{ + "realm": "test", + }, + }, + }, + expectURL: "", + }, + { + name: "Nil challenges", + challenges: nil, + expectURL: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := FindResourceMetadataURL(tt.challenges) + if url != tt.expectURL { + t.Errorf("Expected URL %s, got %s", tt.expectURL, url) + } + }) + } +} + +// TestFindRequiredScopes verifies scope extraction from Bearer challenges +func TestFindRequiredScopes(t *testing.T) { + tests := []struct { + name string + challenges []WWWAuthenticateChallenge + expectScopes []string + }{ + { + name: "Single scope", + challenges: []WWWAuthenticateChallenge{ + { + Scheme: "Bearer", + Parameters: map[string]string{ + "scope": "read", + }, + }, + }, + expectScopes: []string{"read"}, + }, + { + name: "Multiple scopes", + challenges: []WWWAuthenticateChallenge{ + { + Scheme: "Bearer", + Parameters: map[string]string{ + "scope": "read write admin", + }, + }, + }, + expectScopes: []string{"read", "write", "admin"}, + }, + { + name: "No scopes", + challenges: []WWWAuthenticateChallenge{ + { + Scheme: "Bearer", + Parameters: map[string]string{}, + }, + }, + expectScopes: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scopes := FindRequiredScopes(tt.challenges) + if len(scopes) != len(tt.expectScopes) { + t.Errorf("Expected %d scopes, got %d", len(tt.expectScopes), len(scopes)) + } + }) + } +} From 19b29be6ce616f0c71ecf75d4a838cf1ef1168c9 Mon Sep 17 00:00:00 2001 From: Saurabh Davala Date: Thu, 9 Oct 2025 17:02:55 -0700 Subject: [PATCH 4/5] Add template, fix logs --- .github/PULL_REQUEST_TEMPLATE.md | 6 ++++++ SECURITY.md => .github/SECURITY.md | 0 discovery.go | 2 +- 3 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 .github/PULL_REQUEST_TEMPLATE.md rename SECURITY.md => .github/SECURITY.md (100%) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..0d057bd --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,6 @@ +**What I did** + +**Related issue** + + +**(not mandatory) A picture of a cute animal, if possible in relation to what you did** \ No newline at end of file diff --git a/SECURITY.md b/.github/SECURITY.md similarity index 100% rename from SECURITY.md rename to .github/SECURITY.md diff --git a/discovery.go b/discovery.go index 1771c40..1b6f9e4 100644 --- a/discovery.go +++ b/discovery.go @@ -126,7 +126,7 @@ func DiscoverOAuthRequirements(ctx context.Context, serverURL string) (*Discover } else { // No resource_metadata in WWW-Authenticate - try well-known endpoint wellKnownURL := fmt.Sprintf("%s/.well-known/oauth-protected-resource", defaultAuthServerURL) - logger.Infof("FALLBACK: trying well-known resource metadata endpoint: %s", wellKnownURL) + logger.Infof("fallback: trying well-known resource metadata endpoint: %s", wellKnownURL) resourceMetadata, resourceMetadataError = fetchOAuthProtectedResourceMetadata(ctx, client, wellKnownURL) if resourceMetadataError == nil && resourceMetadata != nil && resourceMetadata.AuthorizationServer != "" { authServerURL = resourceMetadata.AuthorizationServer From aff1630225baec30462410b4977c61ff8b915cc0 Mon Sep 17 00:00:00 2001 From: Saurabh Davala Date: Mon, 13 Oct 2025 06:33:23 -0700 Subject: [PATCH 5/5] fix comment + redirect URL --- dcr.go | 8 ++++---- discovery_test.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dcr.go b/dcr.go index 6985d48..ec66d91 100644 --- a/dcr.go +++ b/dcr.go @@ -9,6 +9,8 @@ import ( "net/http" ) +const DefaultRedirectURI = "https://mcp.docker.com/oauth/callback" + // PerformDCR performs Dynamic Client Registration with the authorization server // Returns client credentials for the registered public client // @@ -23,10 +25,8 @@ func PerformDCR(ctx context.Context, discovery *Discovery, serverName string) (* // Build DCR request for PUBLIC client registration := DCRRequest{ - ClientName: fmt.Sprintf("MCP Gateway - %s", serverName), - RedirectURIs: []string{ - "https://mcp.docker.com/oauth/callback", // mcp-oauth proxy callback only - }, + ClientName: fmt.Sprintf("MCP Gateway - %s", serverName), + RedirectURIs: []string{DefaultRedirectURI}, TokenEndpointAuthMethod: "none", // PUBLIC client (no client secret) GrantTypes: []string{"authorization_code", "refresh_token"}, ResponseTypes: []string{"code"}, diff --git a/discovery_test.go b/discovery_test.go index 91a81d8..736d5d0 100644 --- a/discovery_test.go +++ b/discovery_test.go @@ -66,7 +66,7 @@ func TestDiscoveryFallback_NoWWWAuthenticate(t *testing.T) { } // Verify fallback was triggered - if !logger.containsInfo("FALLBACK: trying well-known") { + if !logger.containsInfo("fallback: trying well-known") { t.Error("Expected fallback to well-known endpoint to be triggered") } if !logger.containsInfo("no WWW-Authenticate header present") {