From 6d423a4c7e21f53f5b8c5f15ae109971dfdb2f90 Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Sat, 13 Sep 2025 16:52:36 -0700 Subject: [PATCH] enhance: add middleware mode --- pkg/oauth/callback/callback.go | 3 ++ pkg/oauth/validate/validatetoken.go | 58 +++++++++++++++++++---------- pkg/providers/generic.go | 44 +++++++++++----------- pkg/providers/provider.go | 13 +++++-- pkg/proxy/proxy.go | 44 +++++++++++----------- pkg/ratelimit/ratelimiter.go | 8 +++- pkg/types/types.go | 1 + 7 files changed, 103 insertions(+), 68 deletions(-) diff --git a/pkg/oauth/callback/callback.go b/pkg/oauth/callback/callback.go index 2e5714f..504c030 100644 --- a/pkg/oauth/callback/callback.go +++ b/pkg/oauth/callback/callback.go @@ -1,6 +1,7 @@ package callback import ( + "encoding/json" "fmt" "log" "net/http" @@ -198,6 +199,8 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if needsUserInfo { sensitiveProps["email"] = userInfo.Email sensitiveProps["name"] = userInfo.Name + infoJSON, _ := json.Marshal(userInfo) + sensitiveProps["info"] = string(infoJSON) } // Initialize props map diff --git a/pkg/oauth/validate/validatetoken.go b/pkg/oauth/validate/validatetoken.go index ad9bfc5..417c012 100644 --- a/pkg/oauth/validate/validatetoken.go +++ b/pkg/oauth/validate/validatetoken.go @@ -19,15 +19,16 @@ import ( ) type TokenValidator struct { - tokenManager *tokens.TokenManager - encryptionKey []byte - mcpUIManager *mcpui.Manager // Optional MCP UI manager for JWT handling - db TokenStore // Database for refresh operations - provider providers.Provider // OAuth provider for generating auth URLs - clientID string // OAuth client ID - clientSecret string // OAuth client secret - scopesSupported []string // Supported OAuth scopes - routePrefix string + tokenManager *tokens.TokenManager + encryptionKey []byte + mcpUIManager *mcpui.Manager // Optional MCP UI manager for JWT handling + db TokenStore // Database for refresh operations + provider providers.Provider // OAuth provider for generating auth URLs + clientID string // OAuth client ID + clientSecret string // OAuth client secret + scopesSupported []string // Supported OAuth scopes + routePrefix string + requiredAuthPaths []string } // TokenStore interface for database operations needed by validator @@ -39,17 +40,18 @@ type TokenStore interface { StoreAuthRequest(key string, data map[string]any) error } -func NewTokenValidator(tokenManager *tokens.TokenManager, mcpUIManager *mcpui.Manager, encryptionKey []byte, db TokenStore, provider providers.Provider, clientID, clientSecret string, scopesSupported []string, routePrefix string) *TokenValidator { +func NewTokenValidator(tokenManager *tokens.TokenManager, mcpUIManager *mcpui.Manager, encryptionKey []byte, db TokenStore, provider providers.Provider, clientID, clientSecret string, scopesSupported []string, routePrefix string, requiredAuthPaths []string) *TokenValidator { return &TokenValidator{ - mcpUIManager: mcpUIManager, - tokenManager: tokenManager, - encryptionKey: encryptionKey, - db: db, - provider: provider, - clientID: clientID, - clientSecret: clientSecret, - scopesSupported: scopesSupported, - routePrefix: routePrefix, + mcpUIManager: mcpUIManager, + tokenManager: tokenManager, + encryptionKey: encryptionKey, + db: db, + provider: provider, + clientID: clientID, + clientSecret: clientSecret, + scopesSupported: scopesSupported, + routePrefix: routePrefix, + requiredAuthPaths: requiredAuthPaths, } } @@ -181,6 +183,21 @@ func (p *TokenValidator) setCookiesForRefresh(w http.ResponseWriter, r *http.Req func (p *TokenValidator) WithTokenValidation(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") + if authHeader == "" && len(p.requiredAuthPaths) > 0 { + matches := false + for _, path := range p.requiredAuthPaths { + if strings.HasPrefix(r.URL.Path, path) { + matches = true + break + } + } + if !matches { + // Not a protected path, skip validation + next.ServeHTTP(w, r) + return + } + } + if authHeader == "" { // Try cookie-based authentication with refresh capability var bearerTokenFromCookie string @@ -351,7 +368,8 @@ func (p *TokenValidator) handleOauthFlow(w http.ResponseWriter, r *http.Request) } func GetTokenInfo(r *http.Request) *tokens.TokenInfo { - return r.Context().Value(tokenInfoKey{}).(*tokens.TokenInfo) + v, _ := r.Context().Value(tokenInfoKey{}).(*tokens.TokenInfo) + return v } type tokenInfoKey struct{} diff --git a/pkg/providers/generic.go b/pkg/providers/generic.go index 7d5f544..e550492 100644 --- a/pkg/providers/generic.go +++ b/pkg/providers/generic.go @@ -175,24 +175,25 @@ func (p *GenericProvider) GetUserInfo(ctx context.Context, accessToken string) ( return nil, fmt.Errorf("failed to decode user info response: %w", err) } - var userInfo *UserInfo - if p.metadata.UserinfoEndpoint == "https://api.github.com/user" { - userInfo = &UserInfo{ - ID: getString(userInfoResp, "login"), - Email: getString(userInfoResp, "email"), - Name: getString(userInfoResp, "name"), - } - } else { - userInfo = &UserInfo{ - ID: getString(userInfoResp, "sub"), - Email: getString(userInfoResp, "email"), - Name: getString(userInfoResp, "name"), - } + userInfo := &UserInfo{ + ID: getString(userInfoResp, "id"), + Sub: getString(userInfoResp, "sub"), + Login: getString(userInfoResp, "login"), + Email: getString(userInfoResp, "email"), + EmailVerified: getBool(userInfoResp, "email_verified"), + Name: getString(userInfoResp, "name"), + Picture: getString(userInfoResp, "picture"), + GivenName: getString(userInfoResp, "given_name"), + FamilyName: getString(userInfoResp, "family_name"), + Locale: getString(userInfoResp, "locale"), } - // If sub is not available, try other common ID fields if userInfo.ID == "" { - userInfo.ID = getString(userInfoResp, "id") + userInfo.ID = userInfo.Sub + } + + if userInfo.ID == "" && p.metadata.UserinfoEndpoint == "https://api.github.com/user" { + userInfo.ID = userInfo.Login } return userInfo, nil @@ -231,12 +232,13 @@ func (p *GenericProvider) GetName() string { return "generic" } +func getBool(m map[string]any, key string) bool { + b, _ := m[key].(bool) + return b +} + // Helper functions func getString(m map[string]any, key string) string { - if val, ok := m[key]; ok { - if str, ok := val.(string); ok { - return str - } - } - return "" + str, _ := m[key].(string) + return str } diff --git a/pkg/providers/provider.go b/pkg/providers/provider.go index 91e54f8..9067f0e 100644 --- a/pkg/providers/provider.go +++ b/pkg/providers/provider.go @@ -8,9 +8,16 @@ import ( // UserInfo represents user information from OAuth provider type UserInfo struct { - ID string `json:"id"` - Email string `json:"email"` - Name string `json:"name"` + ID string `json:"id"` + Sub string `json:"sub"` + Login string `json:"login"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Name string `json:"name"` + GivenName string `json:"given_name"` + FamilyName string `json:"family_name"` + Picture string `json:"picture"` + Locale string `json:"locale"` } // TokenInfo represents token information from OAuth provider diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 56fc57a..e027c0d 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -5,13 +5,13 @@ import ( "encoding/base64" "fmt" "log" + "maps" "net/http" "net/http/httputil" "net/url" "os" "strconv" "strings" - "sync" "time" "github.com/gorilla/handlers" @@ -43,7 +43,6 @@ type OAuthProxy struct { provider string encryptionKey []byte resourceName string - lock sync.Mutex config *types.Config ctx context.Context @@ -53,6 +52,7 @@ type OAuthProxy struct { const ( ModeProxy = "proxy" ModeForwardAuth = "forward_auth" + Middleware = "middleware" ) func NewOAuthProxy(config *types.Config) (*OAuthProxy, error) { @@ -206,7 +206,7 @@ func (p *OAuthProxy) Start(ctx context.Context) error { return nil } -func (p *OAuthProxy) SetupRoutes(mux *http.ServeMux) { +func (p *OAuthProxy) SetupRoutes(mux *http.ServeMux, next http.Handler) { provider, err := p.providers.GetProvider(p.provider) if err != nil { log.Fatalf("Failed to get provider: %v", err) @@ -216,7 +216,7 @@ func (p *OAuthProxy) SetupRoutes(mux *http.ServeMux) { tokenHandler := token.NewHandler(p.db) callbackHandler := callback.NewHandler(p.db, provider, p.encryptionKey, p.GetOAuthClientID(), p.GetOAuthClientSecret(), p.config.RoutePrefix, p.mcpUIManager) revokeHandler := revoke.NewHandler(p.db) - tokenValidator := validate.NewTokenValidator(p.tokenManager, p.mcpUIManager, p.encryptionKey, p.db, provider, p.GetOAuthClientID(), p.GetOAuthClientSecret(), p.metadata.ScopesSupported, p.config.RoutePrefix) + tokenValidator := validate.NewTokenValidator(p.tokenManager, p.mcpUIManager, p.encryptionKey, p.db, provider, p.GetOAuthClientID(), p.GetOAuthClientSecret(), p.metadata.ScopesSupported, p.config.RoutePrefix, p.config.RequiredAuthPaths) successHandler := success.NewHandler() // Get route prefix from config @@ -239,13 +239,15 @@ func (p *OAuthProxy) SetupRoutes(mux *http.ServeMux) { mux.HandleFunc("GET "+prefix+"/auth/mcp-ui/success", p.withCORS(p.withRateLimit(successHandler))) // Protect everything else - mux.HandleFunc(prefix+"/{path...}", p.withCORS(p.withRateLimit(tokenValidator.WithTokenValidation(p.mcpProxyHandler)))) + mux.HandleFunc(prefix+"/{path...}", p.withCORS(p.withRateLimit(tokenValidator.WithTokenValidation(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.mcpProxyHandler(w, r, next) + }))))) } // GetHandler returns an http.Handler for the OAuth proxy func (p *OAuthProxy) GetHandler() http.Handler { mux := http.NewServeMux() - p.SetupRoutes(mux) + p.SetupRoutes(mux, nil) // Wrap with logging middleware loggedHandler := handlers.LoggingHandler(os.Stdout, mux) @@ -335,20 +337,17 @@ func (p *OAuthProxy) protectedResourceMetadataHandler(w http.ResponseWriter, r * handlerutils.JSON(w, http.StatusOK, metadata) } -func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) { +func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request, next http.Handler) { tokenInfo := validate.GetTokenInfo(r) path := r.PathValue("path") // Check if the access token is expired and refresh if needed - if tokenInfo.Props != nil { + if tokenInfo != nil && tokenInfo.Props != nil { if _, ok := tokenInfo.Props["access_token"].(string); ok { // Check if token is expired (with a 5-minute buffer) expiresAt, ok := tokenInfo.Props["expires_at"].(float64) if ok && expiresAt > 0 { if time.Now().Add(5 * time.Minute).After(time.Unix(int64(expiresAt), 0)) { - // when refreshing token, we need to lock the database to avoid race conditions - // otherwise we could get save the old access token into the database when another refresh process is running - p.lock.Lock() log.Printf("Access token is expired or will expire soon, attempting to refresh") // Get the refresh token @@ -359,7 +358,6 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) { "error": "invalid_token", "error_description": "Access token expired and no refresh token available", }) - p.lock.Unlock() return } @@ -371,7 +369,6 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) { "error": "server_error", "error_description": "Failed to refresh token", }) - p.lock.Unlock() return } @@ -384,7 +381,6 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) { "error": "server_error", "error_description": "OAuth credentials not configured", }) - p.lock.Unlock() return } @@ -396,7 +392,6 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) { "error": "invalid_token", "error_description": "Failed to refresh access token", }) - p.lock.Unlock() return } @@ -407,14 +402,11 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) { "error": "server_error", "error_description": "Failed to update grant with new token", }) - p.lock.Unlock() return } // Update the token info with the new access token for the current request tokenInfo.Props["access_token"] = newTokenInfo.AccessToken - p.lock.Unlock() - log.Printf("Successfully refreshed access token") } } @@ -422,6 +414,8 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) { } switch p.config.Mode { + case Middleware: + next.ServeHTTP(w, r) case ModeForwardAuth: setHeaders(w.Header(), tokenInfo.Props) case ModeProxy: @@ -508,13 +502,17 @@ func (p *OAuthProxy) updateGrant(grantID, userID string, oldTokenInfo *tokens.To return fmt.Errorf("failed to get grant: %w", err) } - // Prepare sensitive props data - sensitiveProps := map[string]any{ - "access_token": newTokenInfo.AccessToken, - "refresh_token": newTokenInfo.RefreshToken, - "expires_at": newTokenInfo.Expiry.Unix(), + sensitiveProps := map[string]any{} + if oldTokenInfo.Props != nil { + // keep all the old props, that include a lot of the user info + maps.Copy(sensitiveProps, oldTokenInfo.Props) } + // Prepare sensitive props data + sensitiveProps["access_token"] = newTokenInfo.AccessToken + sensitiveProps["refresh_token"] = newTokenInfo.RefreshToken + sensitiveProps["expires_at"] = newTokenInfo.Expiry.Unix() + // Add existing user info if available if grant.Props != nil { if email, ok := grant.Props["email"].(string); ok { diff --git a/pkg/ratelimit/ratelimiter.go b/pkg/ratelimit/ratelimiter.go index 360be27..70e94c5 100644 --- a/pkg/ratelimit/ratelimiter.go +++ b/pkg/ratelimit/ratelimiter.go @@ -1,10 +1,14 @@ package ratelimit -import "time" +import ( + "sync" + "time" +) // RateLimiter simple in-memory rate limiter type RateLimiter struct { requests map[string][]time.Time + lock sync.Mutex window time.Duration max int } @@ -18,6 +22,8 @@ func NewRateLimiter(window time.Duration, max int) *RateLimiter { } func (rl *RateLimiter) Allow(key string) bool { + rl.lock.Lock() + defer rl.lock.Unlock() now := time.Now() windowStart := now.Add(-rl.window) diff --git a/pkg/types/types.go b/pkg/types/types.go index 595f1d4..50a9365 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -16,6 +16,7 @@ type Config struct { MCPServerURL string Mode string RoutePrefix string + RequiredAuthPaths []string } // TokenData represents stored token data for OAuth 2.1 compliance