Authentication successful.
+Redirecting in 5 seconds...
+ +diff --git a/cmd/root.go b/cmd/root.go new file mode 100644 index 0000000..dd9ab88 --- /dev/null +++ b/cmd/root.go @@ -0,0 +1,192 @@ +package cmd + +import ( + "fmt" + "log" + "net/http" + "net/url" + + "github.com/gptscript-ai/cmd" + "github.com/obot-platform/mcp-oauth-proxy/pkg/proxy" + "github.com/obot-platform/mcp-oauth-proxy/pkg/types" + "github.com/spf13/cobra" +) + +var ( + version = "dev" + buildTime = "unknown" +) + +// RootCmd represents the base command when called without any subcommands +type RootCmd struct { + // Database configuration + DatabaseDSN string `name:"database-dsn" env:"DATABASE_DSN" usage:"Database connection string (PostgreSQL or SQLite file path). If empty, uses SQLite at data/oauth_proxy.db"` + + // OAuth Provider configuration + OAuthClientID string `name:"oauth-client-id" env:"OAUTH_CLIENT_ID" usage:"OAuth client ID from your OAuth provider" required:"true"` + OAuthClientSecret string `name:"oauth-client-secret" env:"OAUTH_CLIENT_SECRET" usage:"OAuth client secret from your OAuth provider" required:"true"` + OAuthAuthorizeURL string `name:"oauth-authorize-url" env:"OAUTH_AUTHORIZE_URL" usage:"Authorization endpoint URL from your OAuth provider (e.g., https://accounts.google.com)" required:"true"` + + // Scopes and MCP configuration + ScopesSupported string `name:"scopes-supported" env:"SCOPES_SUPPORTED" usage:"Comma-separated list of supported OAuth scopes (e.g., 'openid,profile,email')" required:"true"` + MCPServerURL string `name:"mcp-server-url" env:"MCP_SERVER_URL" usage:"URL of the MCP server to proxy requests to" required:"true"` + + // Security configuration + EncryptionKey string `name:"encryption-key" env:"ENCRYPTION_KEY" usage:"Base64-encoded 32-byte AES-256 key for encrypting sensitive data (optional)"` + + // Server configuration + Port string `name:"port" env:"PORT" usage:"Port to run the server on" default:"8080"` + Host string `name:"host" env:"HOST" usage:"Host to bind the server to" default:"localhost"` + RoutePrefix string `name:"route-prefix" env:"ROUTE_PREFIX" usage:"Optional prefix for all routes (e.g., '/oauth2')"` + + // Logging + Verbose bool `name:"verbose,v" usage:"Enable verbose logging"` + Version bool `name:"version" usage:"Show version information"` + + Mode string `name:"mode" env:"MODE" usage:"Mode to run the server in" default:"proxy"` +} + +func (c *RootCmd) Run(cobraCmd *cobra.Command, args []string) error { + if c.Version { + fmt.Printf("MCP OAuth Proxy\n") + fmt.Printf("Version: %s\n", version) + fmt.Printf("Built: %s\n", buildTime) + return nil + } + + // Configure logging + if c.Verbose { + log.SetFlags(log.LstdFlags | log.Lshortfile) + log.Println("Verbose logging enabled") + } + + // Convert CLI config to internal config format + config := &types.Config{ + DatabaseDSN: c.DatabaseDSN, + OAuthClientID: c.OAuthClientID, + OAuthClientSecret: c.OAuthClientSecret, + OAuthAuthorizeURL: c.OAuthAuthorizeURL, + ScopesSupported: c.ScopesSupported, + MCPServerURL: c.MCPServerURL, + EncryptionKey: c.EncryptionKey, + Mode: c.Mode, + RoutePrefix: c.RoutePrefix, + } + + // Validate configuration + if err := c.validateConfig(); err != nil { + return fmt.Errorf("configuration validation failed: %w", err) + } + + // Create OAuth proxy + oauthProxy, err := proxy.NewOAuthProxy(config) + if err != nil { + return fmt.Errorf("failed to create OAuth proxy: %w", err) + } + defer func() { + if err := oauthProxy.Close(); err != nil { + log.Printf("Error closing database: %v", err) + } + }() + + // Get HTTP handler + handler := oauthProxy.GetHandler() + + // Start server + address := fmt.Sprintf("%s:%s", c.Host, c.Port) + log.Printf("Starting OAuth proxy server on %s", address) + log.Printf("OAuth Provider: %s", c.OAuthAuthorizeURL) + log.Printf("MCP Server: %s", c.MCPServerURL) + log.Printf("Database: %s", c.getDatabaseType()) + + return http.ListenAndServe(address, handler) +} + +func (c *RootCmd) validateConfig() error { + if c.OAuthClientID == "" { + return fmt.Errorf("oauth-client-id is required") + } + if c.OAuthClientSecret == "" { + return fmt.Errorf("oauth-client-secret is required") + } + if c.OAuthAuthorizeURL == "" { + return fmt.Errorf("oauth-authorize-url is required") + } + if c.ScopesSupported == "" { + return fmt.Errorf("scopes-supported is required") + } + if c.MCPServerURL == "" { + return fmt.Errorf("mcp-server-url is required") + } + if c.Mode == proxy.ModeProxy { + if u, err := url.Parse(c.MCPServerURL); err != nil || u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("invalid MCP server URL: %w", err) + } else if u.Path != "" && u.Path != "/" || u.RawQuery != "" || u.Fragment != "" { + return fmt.Errorf("MCP server URL must not contain a path, query, or fragment") + } + } + return nil +} + +func (c *RootCmd) getDatabaseType() string { + if c.DatabaseDSN == "" { + return "SQLite (data/oauth_proxy.db)" + } + if len(c.DatabaseDSN) > 10 && (c.DatabaseDSN[:11] == "postgres://" || c.DatabaseDSN[:14] == "postgresql://") { + return "PostgreSQL" + } + return fmt.Sprintf("SQLite (%s)", c.DatabaseDSN) +} + +// Customizer interface implementation for additional command customization +func (c *RootCmd) Customize(cobraCmd *cobra.Command) { + cobraCmd.Use = "mcp-oauth-proxy" + cobraCmd.Short = "OAuth 2.1 proxy server for MCP (Model Context Protocol)" + cobraCmd.Long = `MCP OAuth Proxy is a comprehensive OAuth 2.1 proxy server that provides +OAuth authorization server functionality with PostgreSQL/SQLite storage. + +This proxy supports multiple OAuth providers (Google, Microsoft, GitHub) and +proxies requests to MCP servers with user context headers. + +Examples: + # Start with environment variables + export OAUTH_CLIENT_ID="your-google-client-id" + export OAUTH_CLIENT_SECRET="your-secret" + export OAUTH_AUTHORIZE_URL="https://accounts.google.com" + export SCOPES_SUPPORTED="openid,profile,email" + export MCP_SERVER_URL="http://localhost:3000" + mcp-oauth-proxy + + # Start with CLI flags + mcp-oauth-proxy \ + --oauth-client-id="your-google-client-id" \ + --oauth-client-secret="your-secret" \ + --oauth-authorize-url="https://accounts.google.com" \ + --scopes-supported="openid,profile,email" \ + --mcp-server-url="http://localhost:3000" + + # Use PostgreSQL database + mcp-oauth-proxy \ + --database-dsn="postgres://user:pass@localhost:5432/oauth_db?sslmode=disable" \ + --oauth-client-id="your-client-id" \ + # ... other required flags + +Configuration: + Configuration values are loaded in this order (later values override earlier ones): + 1. Default values + 2. Environment variables + 3. Command line flags + +Database Support: + - PostgreSQL: Full ACID compliance, recommended for production + - SQLite: Zero configuration, perfect for development and small deployments` + + cobraCmd.Version = version +} + +// Execute is the main entry point for the CLI +func Execute() error { + rootCmd := &RootCmd{} + cobraCmd := cmd.Command(rootCmd) + return cobraCmd.Execute() +} diff --git a/go.mod b/go.mod index 3a9644f..f7c580f 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,12 @@ module github.com/obot-platform/mcp-oauth-proxy go 1.25.0 require ( + github.com/golang-jwt/jwt/v5 v5.3.0 github.com/gorilla/handlers v1.5.2 + github.com/gptscript-ai/cmd v0.0.0-20250530150401-bc71fddf8070 + github.com/spf13/cobra v1.7.0 github.com/stretchr/testify v1.10.0 + golang.org/x/oauth2 v0.30.0 gorm.io/driver/postgres v1.6.0 gorm.io/driver/sqlite v1.6.0 gorm.io/gorm v1.30.1 @@ -13,6 +17,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/felixge/httpsnoop v1.0.3 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgx/v5 v5.7.5 // indirect @@ -23,8 +28,8 @@ require ( github.com/mattn/go-sqlite3 v1.14.32 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.8.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect golang.org/x/crypto v0.41.0 // indirect - golang.org/x/oauth2 v0.30.0 // indirect golang.org/x/sync v0.16.0 // indirect golang.org/x/text v0.28.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 76082aa..034505d 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,17 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk= github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE= github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w= +github.com/gptscript-ai/cmd v0.0.0-20250530150401-bc71fddf8070 h1:xm5ZZFraWFwxyE7TBEncCXArubCDZTwG6s5bpMzqhSY= +github.com/gptscript-ai/cmd v0.0.0-20250530150401-bc71fddf8070/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -30,6 +37,11 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/main.go b/main.go index 24c58a7..2893544 100644 --- a/main.go +++ b/main.go @@ -1,33 +1,15 @@ package main import ( - "log" - "net/http" + "fmt" + "os" - "github.com/obot-platform/mcp-oauth-proxy/pkg/proxy" + "github.com/obot-platform/mcp-oauth-proxy/cmd" ) func main() { - // Load configuration from environment variables - config, err := proxy.LoadConfigFromEnv() - if err != nil { - log.Fatalf("Failed to load configuration: %v", err) + if err := cmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) } - - proxy, err := proxy.NewOAuthProxy(config) - if err != nil { - log.Fatalf("Failed to create OAuth proxy: %v", err) - } - defer func() { - if err := proxy.Close(); err != nil { - log.Printf("Error closing database: %v", err) - } - }() - - // Get HTTP handler - handler := proxy.GetHandler() - - // Start server - log.Print("Starting OAuth proxy server on localhost:" + config.Port) - log.Fatal(http.ListenAndServe(":"+config.Port, handler)) } diff --git a/main_test.go b/main_test.go index 5407672..cc72f6e 100644 --- a/main_test.go +++ b/main_test.go @@ -5,11 +5,11 @@ import ( "log" "net/http" "net/http/httptest" - "os" "testing" "time" "github.com/obot-platform/mcp-oauth-proxy/pkg/proxy" + "github.com/obot-platform/mcp-oauth-proxy/pkg/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -20,49 +20,16 @@ func TestIntegrationFlow(t *testing.T) { t.Skip("Skipping integration tests in short mode") } - // Note: Using standard HTTP handlers instead of Gin - - // Set required environment variables for testing - oldVars := map[string]string{ - "OAUTH_CLIENT_ID": os.Getenv("OAUTH_CLIENT_ID"), - "OAUTH_CLIENT_SECRET": os.Getenv("OAUTH_CLIENT_SECRET"), - "OAUTH_AUTHORIZE_URL": os.Getenv("OAUTH_AUTHORIZE_URL"), - "SCOPES_SUPPORTED": os.Getenv("SCOPES_SUPPORTED"), - "MCP_SERVER_URL": os.Getenv("MCP_SERVER_URL"), - "DATABASE_DSN": os.Getenv("DATABASE_DSN"), - } - - // Set test environment variables - testEnvVars := map[string]string{ - "OAUTH_CLIENT_ID": "test_client_id", - "OAUTH_CLIENT_SECRET": "test_client_secret", - "OAUTH_AUTHORIZE_URL": "https://accounts.google.com", - "SCOPES_SUPPORTED": "openid,profile,email", - "MCP_SERVER_URL": "http://localhost:8081", - "DATABASE_DSN": os.Getenv("TEST_DATABASE_DSN"), // Use test database if available - } - - for key, value := range testEnvVars { - if value != "" { - if err := os.Setenv(key, value); err != nil { - t.Logf("Failed to set %s: %v", key, err) - } - } - } - - // Restore environment variables after test - defer func() { - for key, value := range oldVars { - if value != "" { - _ = os.Setenv(key, value) - } else { - _ = os.Unsetenv(key) - } - } - }() - // Create OAuth proxy - config, err := proxy.LoadConfigFromEnv() + config := &types.Config{ + Mode: proxy.ModeProxy, + MCPServerURL: "http://localhost:8081/", + OAuthClientID: "test_client_id", + OAuthClientSecret: "test_client_secret", + OAuthAuthorizeURL: "https://accounts.google.com", + ScopesSupported: "openid,profile,email", + } + _, err := proxy.NewOAuthProxy(config) if err != nil { log.Fatalf("Failed to load configuration: %v", err) } @@ -144,43 +111,16 @@ func TestIntegrationFlow(t *testing.T) { } func TestOAuthProxyCreation(t *testing.T) { - // Test that we can create an OAuth proxy without errors when all required env vars are set - oldVars := map[string]string{ - "OAUTH_CLIENT_ID": os.Getenv("OAUTH_CLIENT_ID"), - "OAUTH_CLIENT_SECRET": os.Getenv("OAUTH_CLIENT_SECRET"), - "OAUTH_AUTHORIZE_URL": os.Getenv("OAUTH_AUTHORIZE_URL"), - "SCOPES_SUPPORTED": os.Getenv("SCOPES_SUPPORTED"), - "MCP_SERVER_URL": os.Getenv("MCP_SERVER_URL"), - } - - // Set minimal required environment - testEnvVars := map[string]string{ - "OAUTH_CLIENT_ID": "test_client_id", - "OAUTH_CLIENT_SECRET": "test_client_secret", - "OAUTH_AUTHORIZE_URL": "https://accounts.google.com", - "SCOPES_SUPPORTED": "openid,profile,email", - "MCP_SERVER_URL": "http://localhost:8081", - } - - for key, value := range testEnvVars { - if err := os.Setenv(key, value); err != nil { - t.Fatalf("Failed to set %s: %v", key, err) - } - } - - // Restore environment variables after test - defer func() { - for key, value := range oldVars { - if value != "" { - _ = os.Setenv(key, value) - } else { - _ = os.Unsetenv(key) - } - } - }() - // Create OAuth proxy - config, err := proxy.LoadConfigFromEnv() + config := &types.Config{ + Mode: proxy.ModeProxy, + MCPServerURL: "http://localhost:8081/", + OAuthClientID: "test_client_id", + OAuthClientSecret: "test_client_secret", + OAuthAuthorizeURL: "https://accounts.google.com", + ScopesSupported: "openid,profile,email", + } + _, err := proxy.NewOAuthProxy(config) if err != nil { log.Fatalf("Failed to load configuration: %v", err) } @@ -204,36 +144,16 @@ func TestOAuthProxyStart(t *testing.T) { t.Skip("Skipping OAuth proxy start test in short mode") } - // Set minimal required environment - testEnvVars := map[string]string{ - "OAUTH_CLIENT_ID": "test_client_id", - "OAUTH_CLIENT_SECRET": "test_client_secret", - "OAUTH_AUTHORIZE_URL": "https://accounts.google.com", - "SCOPES_SUPPORTED": "openid,profile,email", - "MCP_SERVER_URL": "http://localhost:8081", - } - - oldVars := make(map[string]string) - for key, value := range testEnvVars { - oldVars[key] = os.Getenv(key) - if err := os.Setenv(key, value); err != nil { - t.Fatalf("Failed to set %s: %v", key, err) - } - } - - // Restore environment variables after test - defer func() { - for key, value := range oldVars { - if value != "" { - _ = os.Setenv(key, value) - } else { - _ = os.Unsetenv(key) - } - } - }() - // Create OAuth proxy - config, err := proxy.LoadConfigFromEnv() + config := &types.Config{ + Mode: proxy.ModeProxy, + MCPServerURL: "http://localhost:8081/", + OAuthClientID: "test_client_id", + OAuthClientSecret: "test_client_secret", + OAuthAuthorizeURL: "https://accounts.google.com", + ScopesSupported: "openid,profile,email", + } + _, err := proxy.NewOAuthProxy(config) if err != nil { log.Fatalf("Failed to load configuration: %v", err) } @@ -276,51 +196,16 @@ func TestForwardAuthIntegrationFlow(t *testing.T) { if testing.Short() { t.Skip("Skipping integration tests in short mode") } - - // Set required environment variables for forward auth testing - oldVars := map[string]string{ - "OAUTH_CLIENT_ID": os.Getenv("OAUTH_CLIENT_ID"), - "OAUTH_CLIENT_SECRET": os.Getenv("OAUTH_CLIENT_SECRET"), - "OAUTH_AUTHORIZE_URL": os.Getenv("OAUTH_AUTHORIZE_URL"), - "SCOPES_SUPPORTED": os.Getenv("SCOPES_SUPPORTED"), - "MCP_SERVER_URL": os.Getenv("MCP_SERVER_URL"), - "DATABASE_DSN": os.Getenv("DATABASE_DSN"), - "PROXY_MODE": os.Getenv("PROXY_MODE"), - "PORT": os.Getenv("PORT"), - } - - // Set test environment variables for forward auth mode - testEnvVars := map[string]string{ - "OAUTH_CLIENT_ID": "test_client_id", - "OAUTH_CLIENT_SECRET": "test_client_secret", - "OAUTH_AUTHORIZE_URL": "https://accounts.google.com", - "SCOPES_SUPPORTED": "openid,profile,email", - "PROXY_MODE": "forward_auth", - "PORT": "8082", // Different port to avoid conflicts - "DATABASE_DSN": os.Getenv("TEST_DATABASE_DSN"), // Use test database if available - } - - for key, value := range testEnvVars { - if value != "" { - if err := os.Setenv(key, value); err != nil { - t.Logf("Failed to set %s: %v", key, err) - } - } + config := &types.Config{ + Mode: proxy.ModeForwardAuth, + MCPServerURL: "http://localhost:8081/", + OAuthClientID: "test_client_id", + OAuthClientSecret: "test_client_secret", + OAuthAuthorizeURL: "https://accounts.google.com", + ScopesSupported: "openid,profile,email", + Port: "8082", } - - // Restore environment variables after test - defer func() { - for key, value := range oldVars { - if value != "" { - _ = os.Setenv(key, value) - } else { - _ = os.Unsetenv(key) - } - } - }() - - // Create OAuth proxy in forward auth mode - config, err := proxy.LoadConfigFromEnv() + _, err := proxy.NewOAuthProxy(config) if err != nil { log.Fatalf("Failed to load configuration: %v", err) } @@ -375,7 +260,7 @@ func TestForwardAuthIntegrationFlow(t *testing.T) { // Test that forward auth mode requires authorization for protected endpoints t.Run("ForwardAuthRequiresAuth", func(t *testing.T) { testPaths := []string{"/api", "/data", "/protected", "/mcp", "/test"} - + for _, path := range testPaths { t.Run("Path_"+path, func(t *testing.T) { w := httptest.NewRecorder() @@ -399,7 +284,7 @@ func TestForwardAuthIntegrationFlow(t *testing.T) { // Should get unauthorized (no proxying attempt) assert.Equal(t, http.StatusUnauthorized, w.Code) assert.Contains(t, w.Header().Get("WWW-Authenticate"), "Bearer") - + // Should not have any proxy-related error messages assert.NotContains(t, w.Body.String(), "proxy") assert.NotContains(t, w.Body.String(), "502") diff --git a/pkg/db/db.go b/pkg/db/db.go index a5fef73..612e2b6 100644 --- a/pkg/db/db.go +++ b/pkg/db/db.go @@ -295,14 +295,6 @@ func (d *Store) RevokeToken(token string) error { return result.Error } -// UpdateTokenRefreshToken updates the refresh token for an existing token -func (d *Store) UpdateTokenRefreshToken(accessToken, newRefreshToken string) error { - hashedAccessToken := hashToken(accessToken) - hashedNewRefreshToken := hashToken(newRefreshToken) - - return d.db.Model(&types.TokenData{}).Where("access_token = ?", hashedAccessToken).Update("refresh_token", hashedNewRefreshToken).Error -} - // CleanupExpiredTokens removes expired tokens and authorization codes func (d *Store) CleanupExpiredTokens() error { now := time.Now() diff --git a/pkg/db/db_test.go b/pkg/db/db_test.go index a058068..f06d20b 100644 --- a/pkg/db/db_test.go +++ b/pkg/db/db_test.go @@ -24,6 +24,7 @@ func TestDatabaseOperations(t *testing.T) { if dsn == "" { t.Skip("Skipping database tests: TEST_DATABASE_DSN is not set") } + db, err := New(dsn) if err != nil { t.Skipf("Skipping database tests: %v", err) @@ -169,10 +170,16 @@ func testTokenOperations(t *testing.T, db *Store) { grantID, err := generateRandomString(16) require.NoError(t, err) + clientID, err := generateRandomString(16) + require.NoError(t, err) + + userID, err := generateRandomString(16) + require.NoError(t, err) + grant := &types.Grant{ ID: grantID, - ClientID: "test_client_db", - UserID: "test_user_123", + ClientID: clientID, + UserID: userID, Scope: []string{"read", "write", "admin"}, Metadata: map[string]any{"provider": "test", "ip": "127.0.0.1"}, } @@ -190,8 +197,8 @@ func testTokenOperations(t *testing.T, db *Store) { tokenData := &types.TokenData{ AccessToken: accessTokenData, RefreshToken: refreshTokenData, - ClientID: "test_client_db", - UserID: "test_user_123", + ClientID: clientID, + UserID: userID, GrantID: grantID, Scope: "read write admin", ExpiresAt: time.Now().Add(1 * time.Hour), @@ -231,16 +238,6 @@ func testTokenOperations(t *testing.T, db *Store) { assert.True(t, revokedToken.Revoked) assert.NotNil(t, revokedToken.RevokedAt) - // Test updating refresh token - newRefreshTokenData, err := generateRandomString(16) - require.NoError(t, err) - err = db.UpdateTokenRefreshToken(accessTokenData, newRefreshTokenData) - require.NoError(t, err) - - updatedToken, err := db.GetToken(accessTokenData) - require.NoError(t, err) - assert.Equal(t, hashToken(newRefreshTokenData), updatedToken.RefreshToken) - // Test retrieving non-existent token _, err = db.GetToken("non_existent_token") assert.Error(t, err) diff --git a/pkg/encryption/encryption.go b/pkg/encryption/encryption.go index 6f15690..e220f52 100644 --- a/pkg/encryption/encryption.go +++ b/pkg/encryption/encryption.go @@ -143,3 +143,69 @@ func DecryptPropsIfNeeded(encryptionKey []byte, props map[string]any) (map[strin return result, nil } + +// EncryptString encrypts a string using AES-256-GCM +func EncryptString(encryptionKey []byte, plaintext string) (string, error) { + // Create AES cipher + block, err := aes.NewCipher(encryptionKey) + if err != nil { + return "", fmt.Errorf("failed to create cipher: %w", err) + } + + // Create GCM mode + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", fmt.Errorf("failed to create GCM: %w", err) + } + + // Generate random IV + iv := make([]byte, gcm.NonceSize()) + if _, err := rand.Read(iv); err != nil { + return "", fmt.Errorf("failed to generate IV: %w", err) + } + + // Encrypt the data + ciphertext := gcm.Seal(nil, iv, []byte(plaintext), nil) + + // Combine IV and ciphertext, then base64 encode + combined := append(iv, ciphertext...) + return base64.StdEncoding.EncodeToString(combined), nil +} + +// DecryptString decrypts a string using AES-256-GCM +func DecryptString(encryptionKey []byte, encryptedData string) (string, error) { + // Decode base64 data + combined, err := base64.StdEncoding.DecodeString(encryptedData) + if err != nil { + return "", fmt.Errorf("failed to decode encrypted data: %w", err) + } + + // Create AES cipher + block, err := aes.NewCipher(encryptionKey) + if err != nil { + return "", fmt.Errorf("failed to create cipher: %w", err) + } + + // Create GCM mode + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", fmt.Errorf("failed to create GCM: %w", err) + } + + // Extract IV and ciphertext + ivSize := gcm.NonceSize() + if len(combined) < ivSize { + return "", fmt.Errorf("encrypted data too short") + } + + iv := combined[:ivSize] + ciphertext := combined[ivSize:] + + // Decrypt the data + plaintext, err := gcm.Open(nil, iv, ciphertext, nil) + if err != nil { + return "", fmt.Errorf("failed to decrypt data: %w", err) + } + + return string(plaintext), nil +} diff --git a/pkg/mcpui/cookies.go b/pkg/mcpui/cookies.go new file mode 100644 index 0000000..6c31d21 --- /dev/null +++ b/pkg/mcpui/cookies.go @@ -0,0 +1,117 @@ +package mcpui + +import ( + "fmt" + "net/http" + "strings" +) + +const ( + MCPUICookieName = "mcp-ui-code" + MCPUIRefreshCookieName = "mcp-ui-refresh-code" + DefaultCookieMaxAge = 3600 // 1 hour +) + +// CookieManager handles browser cookies for MCP UI authentication +type CookieManager struct { + httpOnly bool + sameSite http.SameSite +} + +// NewCookieManager creates a new cookie manager +func NewCookieManager() *CookieManager { + return &CookieManager{ + httpOnly: true, + sameSite: http.SameSiteStrictMode, + } +} + +// isSecureRequest determines if the request is over HTTPS +func (c *CookieManager) isSecureRequest(r *http.Request) bool { + // Check if request is HTTPS + if r.TLS != nil { + return true + } + + // Check forwarded headers from reverse proxies + if r.Header.Get("X-Forwarded-Proto") == "https" { + return true + } + + if r.Header.Get("X-Forwarded-Ssl") == "on" { + return true + } + + return false +} + +// getDomain extracts the appropriate domain for cookies +func (c *CookieManager) getDomain(r *http.Request) string { + host := r.Host + + // Remove port if present + if colonIndex := strings.Index(host, ":"); colonIndex != -1 { + host = host[:colonIndex] + } + + // For localhost, don't set domain (allows cookies to work on localhost) + if host == "localhost" || host == "127.0.0.1" { + return "" + } + + return host +} + +// SetMCPUICookie sets the MCP UI authentication cookie containing the bearer token +func (c *CookieManager) SetMCPUICookie(w http.ResponseWriter, r *http.Request, bearerToken string) { + // Encode the bearer token for cookie storage + cookie := &http.Cookie{ + Name: MCPUICookieName, + Value: bearerToken, + Path: "/", + Domain: c.getDomain(r), + MaxAge: DefaultCookieMaxAge, + Secure: c.isSecureRequest(r), + HttpOnly: c.httpOnly, + SameSite: c.sameSite, + } + + http.SetCookie(w, cookie) +} + +// SetMCPUIRefreshCookie sets the refresh token cookie +func (c *CookieManager) SetMCPUIRefreshCookie(w http.ResponseWriter, r *http.Request, refreshToken string) { + // Encode the refresh token for cookie storage + cookie := &http.Cookie{ + Name: MCPUIRefreshCookieName, + Value: refreshToken, + Path: "/", + Domain: c.getDomain(r), + MaxAge: DefaultCookieMaxAge * 24 * 30, // 30 days for refresh token + Secure: c.isSecureRequest(r), + HttpOnly: c.httpOnly, + SameSite: c.sameSite, + } + + http.SetCookie(w, cookie) +} + +// GetMCPUICookie retrieves and decodes the MCP UI authentication cookie +func (c *CookieManager) GetMCPUICookie(r *http.Request) (string, error) { + cookie, err := r.Cookie(MCPUICookieName) + if err != nil { + return "", fmt.Errorf("MCP UI cookie not found: %w", err) + } + + return cookie.Value, nil +} + +// GetMCPUIRefreshCookie retrieves and decodes the MCP UI refresh cookie +func (c *CookieManager) GetMCPUIRefreshCookie(r *http.Request) (string, error) { + cookie, err := r.Cookie(MCPUIRefreshCookieName) + if err != nil { + return "", fmt.Errorf("MCP UI refresh cookie not found: %w", err) + } + + return cookie.Value, nil +} diff --git a/pkg/mcpui/jwt.go b/pkg/mcpui/jwt.go new file mode 100644 index 0000000..fc1e3e4 --- /dev/null +++ b/pkg/mcpui/jwt.go @@ -0,0 +1,166 @@ +package mcpui + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// MCPUICodeClaims represents the JWT claims for MCP UI codes +type MCPUICodeClaims struct { + BearerToken string `json:"bearer_token"` + RefreshToken string `json:"refresh_token"` + jwt.RegisteredClaims +} + +// JWTManager handles JWT operations for MCP UI codes with signing and encryption +type JWTManager struct { + signingKey []byte + encryptionKey []byte +} + +// NewJWTManager creates a new JWT manager with the given signing and encryption keys +func NewJWTManager(signingKey []byte, encryptionKey []byte) *JWTManager { + return &JWTManager{ + signingKey: signingKey, + encryptionKey: encryptionKey, + } +} + +// GenerateMCPUICode creates a signed JWT then encrypts it containing the bearer token with 1-minute expiration +func (j *JWTManager) GenerateMCPUICode(bearerToken string, refreshToken string) (string, error) { + // Create claims with 1-minute expiration + claims := MCPUICodeClaims{ + BearerToken: bearerToken, + RefreshToken: refreshToken, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Minute)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }, + } + + // Create and sign the JWT first + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signedJWT, err := token.SignedString(j.signingKey) + if err != nil { + return "", fmt.Errorf("failed to sign JWT token: %w", err) + } + + // Encrypt the signed JWT using AES-GCM + encryptedJWT, err := j.encrypt([]byte(signedJWT)) + if err != nil { + return "", fmt.Errorf("failed to encrypt signed JWT: %w", err) + } + + // Base64 encode the encrypted data for URL safety + return base64.URLEncoding.EncodeToString(encryptedJWT), nil +} + +// ValidateMCPUICode validates and extracts the bearer token from the encrypted and signed JWT +func (j *JWTManager) ValidateMCPUICode(tokenString string) (string, string, error) { + // Base64 decode the encrypted data + encryptedJWT, err := base64.URLEncoding.DecodeString(tokenString) + if err != nil { + return "", "", fmt.Errorf("failed to decode token: %w", err) + } + + // Decrypt the JWT + signedJWT, err := j.decrypt(encryptedJWT) + if err != nil { + return "", "", fmt.Errorf("failed to decrypt JWT: %w", err) + } + + // Parse and validate the signed JWT + token, err := jwt.ParseWithClaims(string(signedJWT), &MCPUICodeClaims{}, func(token *jwt.Token) (interface{}, error) { + // Validate the signing method + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return j.signingKey, nil + }) + + if err != nil { + return "", "", fmt.Errorf("failed to parse JWT token: %w", err) + } + + // Extract claims + if claims, ok := token.Claims.(*MCPUICodeClaims); ok && token.Valid { + return claims.BearerToken, claims.RefreshToken, nil + } + + return "", "", fmt.Errorf("invalid token claims") +} + +// encrypt encrypts data using AES-GCM +func (j *JWTManager) encrypt(data []byte) ([]byte, error) { + // Create AES cipher + block, err := aes.NewCipher(j.encryptionKey) + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + + // Create GCM + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create GCM: %w", err) + } + + // Generate random nonce + nonce := make([]byte, gcm.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, fmt.Errorf("failed to generate nonce: %w", err) + } + + // Encrypt data + ciphertext := gcm.Seal(nonce, nonce, data, nil) + return ciphertext, nil +} + +// decrypt decrypts data using AES-GCM +func (j *JWTManager) decrypt(encryptedData []byte) ([]byte, error) { + // Create AES cipher + block, err := aes.NewCipher(j.encryptionKey) + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + + // Create GCM + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create GCM: %w", err) + } + + // Check minimum length + nonceSize := gcm.NonceSize() + if len(encryptedData) < nonceSize { + return nil, fmt.Errorf("encrypted data too short") + } + + // Extract nonce and ciphertext + nonce := encryptedData[:nonceSize] + ciphertext := encryptedData[nonceSize:] + + // Decrypt data + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, fmt.Errorf("failed to decrypt: %w", err) + } + + return plaintext, nil +} + +// EncodeKey encodes a key to base64 for storage +func EncodeKey(key []byte) string { + return base64.StdEncoding.EncodeToString(key) +} + +// DecodeKey decodes a base64-encoded key +func DecodeKey(encodedKey string) ([]byte, error) { + return base64.StdEncoding.DecodeString(encodedKey) +} diff --git a/pkg/mcpui/jwt_test.go b/pkg/mcpui/jwt_test.go new file mode 100644 index 0000000..e2ab65f --- /dev/null +++ b/pkg/mcpui/jwt_test.go @@ -0,0 +1,85 @@ +package mcpui + +import ( + "crypto/rand" + "testing" +) + +func TestJWTManagerEncryption(t *testing.T) { + // Generate a test encryption key + key := make([]byte, 32) + if _, err := rand.Read(key); err != nil { + t.Fatalf("Failed to generate test key: %v", err) + } + + // Create JWT manager (using same key for signing and encryption) + jwtManager := NewJWTManager(key, key) + + // Test bearer token + testBearerToken := "user123:grant456:secret789" + testRefreshToken := "refresh123" + + // Generate encrypted JWT + encryptedJWT, err := jwtManager.GenerateMCPUICode(testBearerToken, testRefreshToken) + if err != nil { + t.Fatalf("Failed to generate MCP UI code: %v", err) + } + + // JWT should not be empty + if encryptedJWT == "" { + t.Fatal("Generated JWT is empty") + } + + // JWT should not contain the bearer token in plaintext + if string(encryptedJWT) == testBearerToken { + t.Fatal("JWT contains plaintext bearer token") + } + + // Validate and decrypt the JWT + extractedBearerToken, extractedRefreshToken, err := jwtManager.ValidateMCPUICode(encryptedJWT) + if err != nil { + t.Fatalf("Failed to validate MCP UI code: %v", err) + } + + // Extracted token should match original + if extractedBearerToken != testBearerToken { + t.Fatalf("Extracted bearer token doesn't match: expected %s, got %s", testBearerToken, extractedBearerToken) + } + + // Extracted refresh token should match original + if extractedRefreshToken != testRefreshToken { + t.Fatalf("Extracted refresh token doesn't match: expected %s, got %s", testRefreshToken, extractedRefreshToken) + } +} + +func TestJWTManagerWrongKey(t *testing.T) { + // Generate two different keys + key1 := make([]byte, 32) + key2 := make([]byte, 32) + if _, err := rand.Read(key1); err != nil { + t.Fatalf("Failed to generate test key1: %v", err) + } + if _, err := rand.Read(key2); err != nil { + t.Fatalf("Failed to generate test key2: %v", err) + } + + // Create JWT managers with different keys + jwtManager1 := NewJWTManager(key1, key1) + jwtManager2 := NewJWTManager(key2, key2) + + // Test bearer token + testBearerToken := "user123:grant456:secret789" + testRefreshToken := "refresh123" + + // Generate JWT with first manager + encryptedJWT, err := jwtManager1.GenerateMCPUICode(testBearerToken, testRefreshToken) + if err != nil { + t.Fatalf("Failed to generate MCP UI code: %v", err) + } + + // Try to validate with second manager (wrong key) - should fail + _, _, err = jwtManager2.ValidateMCPUICode(encryptedJWT) + if err == nil { + t.Fatal("Expected validation to fail with wrong key, but it succeeded") + } +} diff --git a/pkg/mcpui/manager.go b/pkg/mcpui/manager.go new file mode 100644 index 0000000..64ca9fb --- /dev/null +++ b/pkg/mcpui/manager.go @@ -0,0 +1,127 @@ +package mcpui + +import ( + "log" + "net/http" + + "github.com/obot-platform/mcp-oauth-proxy/pkg/providers" + "github.com/obot-platform/mcp-oauth-proxy/pkg/tokens" + "github.com/obot-platform/mcp-oauth-proxy/pkg/types" +) + +// Database interface for MCP UI operations +type Database interface { + GetToken(accessToken string) (*types.TokenData, error) + GetTokenByRefreshToken(refreshToken string) (*types.TokenData, error) +} + +// Manager handles MCP UI authentication flow +type Manager struct { + jwtManager *JWTManager + cookieManager *CookieManager + tokenManager *tokens.TokenManager + providers *providers.Manager + providerName string + clientID string + clientSecret string + encryptionKey []byte + db Database +} + +// NewManager creates a new MCP UI manager +func NewManager(encryptionKey []byte, tokenManager *tokens.TokenManager, providers *providers.Manager, providerName, clientID, clientSecret string, db Database) *Manager { + return &Manager{ + jwtManager: NewJWTManager(encryptionKey, encryptionKey), // Use same key for signing and encryption + cookieManager: NewCookieManager(), + tokenManager: tokenManager, + providers: providers, + providerName: providerName, + clientID: clientID, + clientSecret: clientSecret, + encryptionKey: encryptionKey, + db: db, + } +} + +// HandleMCPUIRequest processes requests with mcp-ui-code parameter +func (m *Manager) HandleMCPUIRequest(w http.ResponseWriter, r *http.Request) (string, bool) { + // Check for mcp-ui-code parameter + mcpUICode := r.URL.Query().Get(MCPUICookieName) + if mcpUICode == "" { + return "", false + } + + // Validate and extract bearer token from JWT + bearerToken, refreshToken, err := m.jwtManager.ValidateMCPUICode(mcpUICode) + if err != nil { + log.Printf("Invalid MCP UI code: %v", err) + // JWT expired or invalid, need to initiate OAuth flow + return "", false + } + + // Set cookies with the bearer token + m.cookieManager.SetMCPUICookie(w, r, bearerToken) + if refreshToken != "" { + m.cookieManager.SetMCPUIRefreshCookie(w, r, refreshToken) + } + + log.Printf("Successfully set MCP UI cookies from JWT") + return bearerToken, true +} + +// CheckCookieAuth checks if request has valid cookie authentication +func (m *Manager) CheckCookieAuth(r *http.Request) (string, bool) { + // Try to get bearer token from cookie + bearerToken, err := m.cookieManager.GetMCPUICookie(r) + if err != nil { + return "", false + } + + // Validate the bearer token + _, err = m.tokenManager.ValidateAccessToken(bearerToken) + if err != nil { + log.Printf("Bearer token from cookie is invalid: %v", err) + // Try to refresh the token + refreshedToken, refreshed := m.tryRefreshToken(r) + if refreshed { + return refreshedToken, true + } + return "", false + } + + return bearerToken, true +} + +// tryRefreshToken attempts to refresh an expired token using the refresh cookie +func (m *Manager) tryRefreshToken(r *http.Request) (string, bool) { + // Get refresh token from cookie + refreshToken, err := m.cookieManager.GetMCPUIRefreshCookie(r) + if err != nil { + log.Printf("No refresh token available: %v", err) + return "", false + } + + // Get provider for token refresh + provider, err := m.providers.GetProvider(m.providerName) + if err != nil { + log.Printf("Failed to get provider for token refresh: %v", err) + return "", false + } + + // Refresh the token + _, err = provider.RefreshToken(r.Context(), refreshToken, m.clientID, m.clientSecret) + if err != nil { + log.Printf("Failed to refresh token: %v", err) + return "", false + } + + // For now, return empty as we need database integration to update grants + // This will be implemented when integrating with the main proxy + log.Printf("Token refresh successful but grant update not implemented in standalone manager") + return "", false +} + +// GenerateMCPUICodeForDownstream creates a JWT for sending to downstream MCP server +func (m *Manager) GenerateMCPUICodeForDownstream(bearerToken, refreshToken string) (string, error) { + return m.jwtManager.GenerateMCPUICode(bearerToken, refreshToken) +} diff --git a/pkg/oauth/authorize/authorize.go b/pkg/oauth/authorize/authorize.go index 098e878..5fc66ac 100644 --- a/pkg/oauth/authorize/authorize.go +++ b/pkg/oauth/authorize/authorize.go @@ -24,15 +24,17 @@ type Handler struct { scopesSupported []string clientID string clientSecret string + routePrefix string } -func NewHandler(db AuthorizationStore, provider providers.Provider, scopesSupported []string, clientID, clientSecret string) http.Handler { +func NewHandler(db AuthorizationStore, provider providers.Provider, scopesSupported []string, clientID, clientSecret, routePrefix string) http.Handler { return &Handler{ db: db, provider: provider, scopesSupported: scopesSupported, clientID: clientID, clientSecret: clientSecret, + routePrefix: routePrefix, } } @@ -126,6 +128,11 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { "code_challenge_method": authReq.CodeChallengeMethod, } + // Add redirect parameter if present for post-auth redirect + if rd := params.Get("rd"); rd != "" { + authData["rd"] = rd + } + if err := p.db.StoreAuthRequest(stateKey, authData); err != nil { handlerutils.JSON(w, http.StatusInternalServerError, types.OAuthError{ Error: "server_error", @@ -134,7 +141,7 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - redirectURI := fmt.Sprintf("%s/callback", handlerutils.GetBaseURL(r)) + redirectURI := fmt.Sprintf("%s%s/callback", handlerutils.GetBaseURL(r), p.routePrefix) // Generate authorization URL with the provider authURL := p.provider.GetAuthorizationURL( diff --git a/pkg/oauth/callback/callback.go b/pkg/oauth/callback/callback.go index 28f318f..2e5714f 100644 --- a/pkg/oauth/callback/callback.go +++ b/pkg/oauth/callback/callback.go @@ -10,6 +10,7 @@ import ( "github.com/obot-platform/mcp-oauth-proxy/pkg/encryption" "github.com/obot-platform/mcp-oauth-proxy/pkg/handlerutils" + "github.com/obot-platform/mcp-oauth-proxy/pkg/mcpui" "github.com/obot-platform/mcp-oauth-proxy/pkg/providers" "github.com/obot-platform/mcp-oauth-proxy/pkg/types" ) @@ -19,6 +20,7 @@ type Store interface { StoreAuthCode(code, grantID, userID string) error GetAuthRequest(key string) (map[string]any, error) DeleteAuthRequest(key string) error + StoreToken(token *types.TokenData) error } type Handler struct { @@ -27,15 +29,24 @@ type Handler struct { encryptionKey []byte clientID string clientSecret string + mcpUIManager MCPUIManager + routePrefix string } -func NewHandler(db Store, provider providers.Provider, encryptionKey []byte, clientID, clientSecret string) http.Handler { +// MCPUIManager interface for generating JWT tokens +type MCPUIManager interface { + GenerateMCPUICodeForDownstream(bearerToken, refreshToken string) (string, error) +} + +func NewHandler(db Store, provider providers.Provider, encryptionKey []byte, clientID, clientSecret, routePrefix string, mcpUIManager MCPUIManager) http.Handler { return &Handler{ db: db, provider: provider, encryptionKey: encryptionKey, clientID: clientID, clientSecret: clientSecret, + mcpUIManager: mcpUIManager, + routePrefix: routePrefix, } } @@ -59,6 +70,34 @@ func getStringFromMap(data map[string]any, key string) string { return "" } +// setMCPUISessionCookies sets secure HttpOnly session cookies for MCP UI authentication +func (p *Handler) setMCPUISessionCookies(w http.ResponseWriter, r *http.Request, accessToken, refreshToken string) { + // Determine if request is secure + secure := r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" + + // Set access token cookie (1 hour) + http.SetCookie(w, &http.Cookie{ + Name: mcpui.MCPUICookieName, + Value: accessToken, + Path: "/", + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteStrictMode, + MaxAge: 3600, // 1 hour + }) + + // Set refresh token cookie (30 days) + http.SetCookie(w, &http.Cookie{ + Name: mcpui.MCPUIRefreshCookieName, + Value: refreshToken, + Path: "/", + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteStrictMode, + MaxAge: 30 * 24 * 3600, // 30 days + }) +} + func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Handle OAuth callback from external providers code := r.URL.Query().Get("code") @@ -113,7 +152,7 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { }() // Get provider credentials - redirectURI := fmt.Sprintf("%s/callback", handlerutils.GetBaseURL(r)) + redirectURI := fmt.Sprintf("%s%s/callback", handlerutils.GetBaseURL(r), p.routePrefix) // Exchange code for tokens tokenInfo, err := p.provider.ExchangeCodeForToken(r.Context(), code, p.clientID, p.clientSecret, redirectURI) @@ -220,6 +259,49 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if rdValue := getStringFromMap(authData, "rd"); rdValue != "" { + // This is an MCP UI flow - you should be able to issue session cookies + log.Printf("🔄 Processing MCP UI OAuth callback") + + // Generate internal application tokens (separate from OAuth provider tokens) + accessTokenSecret := encryption.GenerateRandomString(32) + accessToken := fmt.Sprintf("%s:%s:%s", userInfo.ID, grantID, accessTokenSecret) + + mcpUIRefreshTokenSecret := encryption.GenerateRandomString(32) + mcpUIRefreshToken := fmt.Sprintf("%s:%s:%s", userInfo.ID, grantID, mcpUIRefreshTokenSecret) + + // Store internal tokens in database + tokenData := &types.TokenData{ + AccessToken: accessToken, + RefreshToken: mcpUIRefreshToken, + ClientID: authReq.ClientID, + UserID: userInfo.ID, + GrantID: grantID, + Scope: authReq.Scope, + ExpiresAt: time.Now().Add(1 * time.Hour), // 1 hour for access token + CreatedAt: time.Now(), + Revoked: false, + } + + if err := p.db.StoreToken(tokenData); err != nil { + handlerutils.JSON(w, http.StatusInternalServerError, types.OAuthError{ + Error: "server_error", + ErrorDescription: "Failed to store tokens", + }) + return + } + + // Set secure HttpOnly session cookies + p.setMCPUISessionCookies(w, r, accessToken, mcpUIRefreshToken) + + // Redirect to success page with original path as parameter + baseURL := handlerutils.GetBaseURL(r) + successURL := fmt.Sprintf("%s%s/auth/mcp-ui/success?rd=%s", baseURL, p.routePrefix, url.QueryEscape(rdValue)) + + http.Redirect(w, r, successURL, http.StatusFound) + return + } + // Build the redirect URL back to the client redirectURL := authReq.RedirectURI diff --git a/pkg/oauth/success/success.go b/pkg/oauth/success/success.go new file mode 100644 index 0000000..a567022 --- /dev/null +++ b/pkg/oauth/success/success.go @@ -0,0 +1,146 @@ +package success + +import ( + "fmt" + "html/template" + "net/http" + "net/url" + "strings" + + "github.com/obot-platform/mcp-oauth-proxy/pkg/handlerutils" +) + +// Handler handles the MCP UI authentication success page +type Handler struct{} + +// NewHandler creates a new success page handler +func NewHandler() http.Handler { + return &Handler{} +} + +// successPageTemplate is the HTML template for the success page +const successPageTemplate = ` + +
+ + +