diff --git a/.gitignore b/.gitignore index cd61cc9..4efce2d 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,6 @@ mock-oauth2-server # env file .env + +# Server binary +server diff --git a/cmd/server/main.go b/cmd/server/main.go index 38cf077..3c69130 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -46,12 +46,12 @@ func main() { // Determine the base URL for OpenID Connect configuration baseURL := host - + // If host flag is not provided, check for IssuerURL in config (from MOCK_ISSUER_URL env var) if baseURL == "" { baseURL = cfg.IssuerURL } - + // If neither host flag nor MOCK_ISSUER_URL env var is provided, use localhost if baseURL == "" { baseURL = fmt.Sprintf("http://localhost:%d", serverPort) @@ -70,14 +70,17 @@ func main() { // Set up routes mux.Handle("/authorize", &handlers.AuthorizeHandler{Store: memoryStore}) - mux.Handle("/token", handlers.NewTokenHandler(memoryStore)) + mux.Handle("/token", handlers.NewTokenHandlerWithIssuer(memoryStore, baseURL)) mux.Handle("/userinfo", &handlers.UserInfoHandler{Store: memoryStore}) mux.Handle("/config", handlers.NewConfigHandler(memoryStore, defaultUser)) mux.Handle("/version", handlers.NewVersionHandler()) - + // Add OpenID Connect Discovery endpoint mux.Handle("/.well-known/openid-configuration", handlers.NewOpenIDConfigHandler(baseURL)) + // Add JWKS endpoint + mux.Handle("/jwks", handlers.NewJWKSHandler()) + // Start the server with the custom ServeMux startServer(serverPort, mux) } diff --git a/go.mod b/go.mod index 3689db8..270b1d7 100644 --- a/go.mod +++ b/go.mod @@ -8,3 +8,5 @@ require ( github.com/google/uuid v1.6.0 golang.org/x/oauth2 v0.28.0 ) + +require github.com/golang-jwt/jwt/v5 v5.3.0 // indirect diff --git a/go.sum b/go.sum index c500842..a995bd8 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 39f54fe..24edf1b 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -84,10 +84,10 @@ func TestGetConfig(t *testing.T) { if retrievedConfig.Port != 8082 || retrievedConfig.MockUserEmail != "getconfig@example.com" || retrievedConfig.MockUserName != "GetConfig User" || - retrievedConfig.MockTokenExpiry != 1800 || + retrievedConfig.MockTokenExpiry != 1800 || retrievedConfig.IssuerURL != "http://getconfig-mock-oauth2:8082" { - t.Errorf("expected retrievedConfig to match updated config, got Port: %d, MockUserEmail: %s, MockUserName: %s, MockTokenExpiry: %d, IssuerURL: %s", + t.Errorf("expected retrievedConfig to match updated config, got Port: %d, MockUserEmail: %s, MockUserName: %s, MockTokenExpiry: %d, IssuerURL: %s", retrievedConfig.Port, retrievedConfig.MockUserEmail, retrievedConfig.MockUserName, retrievedConfig.MockTokenExpiry, retrievedConfig.IssuerURL) } } diff --git a/internal/handlers/jwks.go b/internal/handlers/jwks.go new file mode 100644 index 0000000..4158df2 --- /dev/null +++ b/internal/handlers/jwks.go @@ -0,0 +1,31 @@ +package handlers + +import ( + "encoding/json" + "net/http" + + "github.com/chrisw-dev/golang-mock-oauth2-server/internal/jwt" +) + +// JWKSHandler handles requests for JSON Web Key Set +type JWKSHandler struct{} + +// NewJWKSHandler creates a new JWKS handler +func NewJWKSHandler() *JWKSHandler { + return &JWKSHandler{} +} + +// ServeHTTP handles HTTP requests for JWKS +func (h *JWKSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + jwks, err := jwt.GetJWKS() + if err != nil { + http.Error(w, "Error generating JWKS", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(jwks); err != nil { + http.Error(w, "Error encoding JWKS", http.StatusInternalServerError) + return + } +} diff --git a/internal/handlers/jwks_test.go b/internal/handlers/jwks_test.go new file mode 100644 index 0000000..3fa8edc --- /dev/null +++ b/internal/handlers/jwks_test.go @@ -0,0 +1,72 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestJWKSHandler(t *testing.T) { + handler := NewJWKSHandler() + + req := httptest.NewRequest("GET", "/jwks", nil) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code) + } + + // Check content type + contentType := rr.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("Expected content type application/json, got %s", contentType) + } + + // Parse the response + var jwks map[string]interface{} + err := json.NewDecoder(rr.Body).Decode(&jwks) + if err != nil { + t.Fatalf("Failed to decode JWKS response: %v", err) + } + + // Verify JWKS structure + keys, ok := jwks["keys"].([]interface{}) + if !ok { + t.Fatal("JWKS should have a 'keys' array") + } + + if len(keys) == 0 { + t.Error("JWKS keys array should not be empty") + } + + // Check the first key + key, ok := keys[0].(map[string]interface{}) + if !ok { + t.Fatal("Key should be a map") + } + + requiredFields := []string{"kty", "use", "kid", "alg", "n", "e"} + for _, field := range requiredFields { + if _, exists := key[field]; !exists { + t.Errorf("Key should have field %s", field) + } + } + + // Verify key type is RSA + if key["kty"] != "RSA" { + t.Errorf("Expected kty to be RSA, got %v", key["kty"]) + } + + // Verify algorithm is RS256 + if key["alg"] != "RS256" { + t.Errorf("Expected alg to be RS256, got %v", key["alg"]) + } + + // Verify use is sig + if key["use"] != "sig" { + t.Errorf("Expected use to be sig, got %v", key["use"]) + } +} diff --git a/internal/handlers/token.go b/internal/handlers/token.go index 3d88e4e..22bb1a7 100644 --- a/internal/handlers/token.go +++ b/internal/handlers/token.go @@ -4,21 +4,33 @@ import ( "encoding/json" "log" "net/http" + "strings" "time" + "github.com/chrisw-dev/golang-mock-oauth2-server/internal/jwt" "github.com/chrisw-dev/golang-mock-oauth2-server/internal/models" "github.com/chrisw-dev/golang-mock-oauth2-server/internal/store" ) // TokenHandler handles OAuth2 token exchange requests type TokenHandler struct { - store store.Store + store store.Store + issuerURL string } // NewTokenHandler creates a new TokenHandler with the given store func NewTokenHandler(store store.Store) *TokenHandler { return &TokenHandler{ - store: store, + store: store, + issuerURL: "http://localhost:8080", // default issuer + } +} + +// NewTokenHandlerWithIssuer creates a new TokenHandler with the given store and issuer URL +func NewTokenHandlerWithIssuer(store store.Store, issuerURL string) *TokenHandler { + return &TokenHandler{ + store: store, + issuerURL: issuerURL, } } @@ -69,12 +81,26 @@ func (h *TokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // Generate token response + accessToken, err := generateAccessToken(h.issuerURL, clientID, authRequest.Scope) + if err != nil { + log.Printf("Error generating access token: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + idToken, err := generateIDToken(h.issuerURL, clientID) + if err != nil { + log.Printf("Error generating ID token: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + tokenResponse := models.TokenResponse{ - AccessToken: generateAccessToken(clientID), + AccessToken: accessToken, TokenType: "Bearer", ExpiresIn: 3600, RefreshToken: generateRefreshToken(clientID), - IDToken: generateIDToken(clientID), + IDToken: idToken, } // Store the token in the store for future validation @@ -94,8 +120,17 @@ func (h *TokenHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // Helper function to generate a mock access token -func generateAccessToken(clientID string) string { - return "mock-access-token-" + clientID + "-" + time.Now().Format("20060102150405") +func generateAccessToken(issuerURL, clientID, scope string) (string, error) { + // Parse scopes from the scope string + scopes := strings.Fields(scope) + if len(scopes) == 0 { + scopes = []string{"openid"} + } + + // Generate a subject ID based on client ID + sub := "user-" + clientID + + return jwt.GenerateAccessToken(issuerURL, clientID, sub, scopes) } // Helper function to generate a mock refresh token @@ -104,6 +139,9 @@ func generateRefreshToken(clientID string) string { } // Helper function to generate a mock ID token -func generateIDToken(clientID string) string { - return "mock-id-token-" + clientID + "-" + time.Now().Format("20060102150405") +func generateIDToken(issuerURL, clientID string) (string, error) { + // Generate a subject ID based on client ID + sub := "user-" + clientID + + return jwt.GenerateIDToken(issuerURL, clientID, sub) } diff --git a/internal/handlers/token_test.go b/internal/handlers/token_test.go index 6939cf1..12d874c 100644 --- a/internal/handlers/token_test.go +++ b/internal/handlers/token_test.go @@ -8,8 +8,10 @@ import ( "strings" "testing" + "github.com/chrisw-dev/golang-mock-oauth2-server/internal/jwt" "github.com/chrisw-dev/golang-mock-oauth2-server/internal/models" "github.com/chrisw-dev/golang-mock-oauth2-server/internal/store" + jwtlib "github.com/golang-jwt/jwt/v5" ) func TestTokenHandler(t *testing.T) { @@ -61,4 +63,45 @@ func TestTokenHandler(t *testing.T) { if response.TokenType != "Bearer" { t.Errorf("Expected token_type to be 'Bearer', got '%s'", response.TokenType) } + + // Verify that ID token is a valid JWT + if response.IDToken == "" { + t.Errorf("Expected ID token to be present") + } + + // Parse ID token to verify it's a valid JWT + parser := jwtlib.NewParser() + idToken, _, err := parser.ParseUnverified(response.IDToken, jwtlib.MapClaims{}) + if err != nil { + t.Errorf("Failed to parse ID token as JWT: %v", err) + } + + // Check that the token has standard JWT headers + if _, ok := idToken.Header["alg"]; !ok { + t.Error("ID token should have 'alg' header") + } + + if _, ok := idToken.Header["kid"]; !ok { + t.Error("ID token should have 'kid' header") + } + + // Verify the ID token using our JWT package + claims, err := jwt.VerifyToken(response.IDToken) + if err != nil { + t.Errorf("Failed to verify ID token: %v", err) + } + + if claims == nil { + t.Error("ID token claims should not be nil") + } + + // Verify that access token is a valid JWT + accessToken, _, err := parser.ParseUnverified(response.AccessToken, jwtlib.MapClaims{}) + if err != nil { + t.Errorf("Failed to parse access token as JWT: %v", err) + } + + if _, ok := accessToken.Header["alg"]; !ok { + t.Error("Access token should have 'alg' header") + } } diff --git a/internal/jwt/jwt.go b/internal/jwt/jwt.go new file mode 100644 index 0000000..3b9153b --- /dev/null +++ b/internal/jwt/jwt.go @@ -0,0 +1,185 @@ +package jwt + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "encoding/json" + "fmt" + "math/big" + "sync" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +var ( + privateKey *rsa.PrivateKey + publicKey *rsa.PublicKey + keyID string + once sync.Once +) + +// InitKeys initializes the RSA key pair for JWT signing +func InitKeys() error { + var err error + once.Do(func() { + // Generate RSA key pair + privateKey, err = rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return + } + publicKey = &privateKey.PublicKey + keyID = "mock-key-1" + }) + return err +} + +// GenerateIDToken creates a signed JWT ID token +func GenerateIDToken(issuer, clientID, sub string) (string, error) { + if privateKey == nil { + if err := InitKeys(); err != nil { + return "", err + } + } + + now := time.Now() + claims := jwt.MapClaims{ + "iss": issuer, + "sub": sub, + "aud": clientID, + "exp": now.Add(time.Hour).Unix(), + "iat": now.Unix(), + "nonce": generateNonce(), + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = keyID + + return token.SignedString(privateKey) +} + +// GenerateAccessToken creates a signed JWT access token +func GenerateAccessToken(issuer, clientID, sub string, scopes []string) (string, error) { + if privateKey == nil { + if err := InitKeys(); err != nil { + return "", err + } + } + + now := time.Now() + claims := jwt.MapClaims{ + "iss": issuer, + "sub": sub, + "aud": clientID, + "exp": now.Add(time.Hour).Unix(), + "iat": now.Unix(), + "scope": scopes, + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = keyID + + return token.SignedString(privateKey) +} + +// GetJWKS returns the JSON Web Key Set +func GetJWKS() (map[string]interface{}, error) { + if publicKey == nil { + if err := InitKeys(); err != nil { + return nil, err + } + } + + // Encode the public key components + nBytes := publicKey.N.Bytes() + eBytes := big.NewInt(int64(publicKey.E)).Bytes() + + n := base64.RawURLEncoding.EncodeToString(nBytes) + e := base64.RawURLEncoding.EncodeToString(eBytes) + + jwk := map[string]interface{}{ + "kty": "RSA", + "use": "sig", + "kid": keyID, + "alg": "RS256", + "n": n, + "e": e, + } + + jwks := map[string]interface{}{ + "keys": []interface{}{jwk}, + } + + return jwks, nil +} + +// VerifyToken verifies a JWT token and returns the claims +func VerifyToken(tokenString string) (jwt.MapClaims, error) { + if publicKey == nil { + if err := InitKeys(); err != nil { + return nil, err + } + } + + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + // Validate signing method + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return publicKey, nil + }) + + if err != nil { + return nil, err + } + + if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { + return claims, nil + } + + return nil, fmt.Errorf("invalid token") +} + +// generateNonce generates a random nonce for the token +func generateNonce() string { + b := make([]byte, 16) + rand.Read(b) + return base64.RawURLEncoding.EncodeToString(b) +} + +// GetPublicKey returns the public key (for testing purposes) +func GetPublicKey() (*rsa.PublicKey, error) { + if publicKey == nil { + if err := InitKeys(); err != nil { + return nil, err + } + } + return publicKey, nil +} + +// GetPublicKeyPEM returns the public key in PEM format +func GetPublicKeyPEM() (string, error) { + if publicKey == nil { + if err := InitKeys(); err != nil { + return "", err + } + } + + pubKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey) + if err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(pubKeyBytes), nil +} + +// MarshalJWKS returns the JWKS as a JSON byte array +func MarshalJWKS() ([]byte, error) { + jwks, err := GetJWKS() + if err != nil { + return nil, err + } + return json.Marshal(jwks) +} diff --git a/internal/jwt/jwt_test.go b/internal/jwt/jwt_test.go new file mode 100644 index 0000000..24a7897 --- /dev/null +++ b/internal/jwt/jwt_test.go @@ -0,0 +1,211 @@ +package jwt + +import ( + "testing" + + "github.com/golang-jwt/jwt/v5" +) + +func TestInitKeys(t *testing.T) { + err := InitKeys() + if err != nil { + t.Fatalf("Failed to initialize keys: %v", err) + } + + if privateKey == nil { + t.Error("Private key should not be nil after initialization") + } + + if publicKey == nil { + t.Error("Public key should not be nil after initialization") + } + + if keyID == "" { + t.Error("Key ID should not be empty after initialization") + } +} + +func TestGenerateIDToken(t *testing.T) { + err := InitKeys() + if err != nil { + t.Fatalf("Failed to initialize keys: %v", err) + } + + issuer := "http://localhost:8080" + clientID := "test-client" + sub := "user-123" + + tokenString, err := GenerateIDToken(issuer, clientID, sub) + if err != nil { + t.Fatalf("Failed to generate ID token: %v", err) + } + + if tokenString == "" { + t.Error("Token string should not be empty") + } + + // Verify the token can be parsed + claims, err := VerifyToken(tokenString) + if err != nil { + t.Fatalf("Failed to verify token: %v", err) + } + + // Check claims + if claims["iss"] != issuer { + t.Errorf("Expected issuer %s, got %v", issuer, claims["iss"]) + } + + if claims["sub"] != sub { + t.Errorf("Expected subject %s, got %v", sub, claims["sub"]) + } + + if claims["aud"] != clientID { + t.Errorf("Expected audience %s, got %v", clientID, claims["aud"]) + } +} + +func TestGenerateAccessToken(t *testing.T) { + err := InitKeys() + if err != nil { + t.Fatalf("Failed to initialize keys: %v", err) + } + + issuer := "http://localhost:8080" + clientID := "test-client" + sub := "user-123" + scopes := []string{"openid", "email", "profile"} + + tokenString, err := GenerateAccessToken(issuer, clientID, sub, scopes) + if err != nil { + t.Fatalf("Failed to generate access token: %v", err) + } + + if tokenString == "" { + t.Error("Token string should not be empty") + } + + // Verify the token can be parsed + claims, err := VerifyToken(tokenString) + if err != nil { + t.Fatalf("Failed to verify token: %v", err) + } + + // Check claims + if claims["iss"] != issuer { + t.Errorf("Expected issuer %s, got %v", issuer, claims["iss"]) + } + + if claims["sub"] != sub { + t.Errorf("Expected subject %s, got %v", sub, claims["sub"]) + } +} + +func TestGetJWKS(t *testing.T) { + err := InitKeys() + if err != nil { + t.Fatalf("Failed to initialize keys: %v", err) + } + + jwks, err := GetJWKS() + if err != nil { + t.Fatalf("Failed to get JWKS: %v", err) + } + + if jwks == nil { + t.Error("JWKS should not be nil") + } + + keys, ok := jwks["keys"].([]interface{}) + if !ok { + t.Fatal("JWKS should have a 'keys' array") + } + + if len(keys) == 0 { + t.Error("JWKS keys array should not be empty") + } + + // Check the first key + key, ok := keys[0].(map[string]interface{}) + if !ok { + t.Fatal("Key should be a map") + } + + if key["kty"] != "RSA" { + t.Errorf("Expected kty to be RSA, got %v", key["kty"]) + } + + if key["use"] != "sig" { + t.Errorf("Expected use to be sig, got %v", key["use"]) + } + + if key["alg"] != "RS256" { + t.Errorf("Expected alg to be RS256, got %v", key["alg"]) + } +} + +func TestVerifyToken(t *testing.T) { + err := InitKeys() + if err != nil { + t.Fatalf("Failed to initialize keys: %v", err) + } + + // Generate a valid token + issuer := "http://localhost:8080" + clientID := "test-client" + sub := "user-123" + + tokenString, err := GenerateIDToken(issuer, clientID, sub) + if err != nil { + t.Fatalf("Failed to generate ID token: %v", err) + } + + // Verify the token + claims, err := VerifyToken(tokenString) + if err != nil { + t.Fatalf("Failed to verify token: %v", err) + } + + if claims == nil { + t.Error("Claims should not be nil") + } + + // Test with invalid token + invalidToken := "invalid.token.string" + _, err = VerifyToken(invalidToken) + if err == nil { + t.Error("Expected error for invalid token") + } +} + +func TestTokenFormat(t *testing.T) { + err := InitKeys() + if err != nil { + t.Fatalf("Failed to initialize keys: %v", err) + } + + issuer := "http://localhost:8080" + clientID := "test-client" + sub := "user-123" + + tokenString, err := GenerateIDToken(issuer, clientID, sub) + if err != nil { + t.Fatalf("Failed to generate ID token: %v", err) + } + + // Parse token to check it has 3 parts (header.payload.signature) + parser := jwt.NewParser() + token, _, err := parser.ParseUnverified(tokenString, jwt.MapClaims{}) + if err != nil { + t.Fatalf("Failed to parse token: %v", err) + } + + // Check that the token has the kid header + if kid, ok := token.Header["kid"].(string); !ok || kid == "" { + t.Error("Token should have a kid header") + } + + // Check that the algorithm is RS256 + if alg, ok := token.Header["alg"].(string); !ok || alg != "RS256" { + t.Errorf("Expected algorithm RS256, got %v", token.Header["alg"]) + } +} diff --git a/pkg/oauth/google.go b/pkg/oauth/google.go index 6d2cb74..3a44988 100644 --- a/pkg/oauth/google.go +++ b/pkg/oauth/google.go @@ -4,17 +4,22 @@ import ( "net/url" "time" + "github.com/chrisw-dev/golang-mock-oauth2-server/internal/jwt" "github.com/chrisw-dev/golang-mock-oauth2-server/internal/store" ) // GoogleProvider implements the Provider interface for Google OAuth2 type GoogleProvider struct { - Store *store.MemoryStore + Store *store.MemoryStore + IssuerURL string } // NewGoogleProvider creates a new Google OAuth2 provider instance func NewGoogleProvider(store *store.MemoryStore) *GoogleProvider { - return &GoogleProvider{Store: store} + return &GoogleProvider{ + Store: store, + IssuerURL: "http://localhost:8080", + } } // GenerateAuthURL creates an authorization URL for the OAuth2 flow @@ -43,12 +48,26 @@ func (p *GoogleProvider) ExchangeCodeForToken(code string) (map[string]interface return nil, &Error{Code: "invalid_grant", Description: "Authorization code expired"} } + // Generate proper JWT tokens + sub := "user-" + authRequest.ClientID + scopes := []string{"openid", "email", "profile"} + + accessToken, err := jwt.GenerateAccessToken(p.IssuerURL, authRequest.ClientID, sub, scopes) + if err != nil { + return nil, &Error{Code: "server_error", Description: "Failed to generate access token"} + } + + idToken, err := jwt.GenerateIDToken(p.IssuerURL, authRequest.ClientID, sub) + if err != nil { + return nil, &Error{Code: "server_error", Description: "Failed to generate ID token"} + } + token := map[string]interface{}{ - "access_token": "mock-access-token", + "access_token": accessToken, "token_type": "Bearer", "expires_in": 3600, "refresh_token": "mock-refresh-token", - "id_token": "mock-id-token", + "id_token": idToken, } return token, nil diff --git a/pkg/oauth/google_test.go b/pkg/oauth/google_test.go index 13e8d69..8bb6e5f 100644 --- a/pkg/oauth/google_test.go +++ b/pkg/oauth/google_test.go @@ -1,12 +1,13 @@ package oauth import ( - "reflect" + "strings" "testing" "time" "github.com/chrisw-dev/golang-mock-oauth2-server/internal/models" "github.com/chrisw-dev/golang-mock-oauth2-server/internal/store" + jwtlib "github.com/golang-jwt/jwt/v5" ) func TestGoogleProvider_GenerateAuthURL(t *testing.T) { @@ -33,34 +34,24 @@ func TestGoogleProvider_ExchangeCodeForToken(t *testing.T) { store.StoreAuthCode(code, authRequest) tests := []struct { - name string - code string - expectedError string - expectedResult map[string]interface{} + name string + code string + expectedError string }{ { name: "Valid code", code: "valid-code", expectedError: "", - expectedResult: map[string]interface{}{ - "access_token": "mock-access-token", - "token_type": "Bearer", - "expires_in": 3600, - "refresh_token": "mock-refresh-token", - "id_token": "mock-id-token", - }, }, { - name: "Invalid code", - code: "invalid-code", - expectedError: "invalid_grant: Invalid authorization code", - expectedResult: nil, + name: "Invalid code", + code: "invalid-code", + expectedError: "invalid_grant: Invalid authorization code", }, { - name: "Expired code", - code: "expired-code", - expectedError: "invalid_grant: Authorization code expired", - expectedResult: nil, + name: "Expired code", + code: "expired-code", + expectedError: "invalid_grant: Authorization code expired", }, } @@ -80,8 +71,60 @@ func TestGoogleProvider_ExchangeCodeForToken(t *testing.T) { if err == nil || err.Error() != tt.expectedError { t.Errorf("expected error %s, got %v", tt.expectedError, err) } - } else if !reflect.DeepEqual(result, tt.expectedResult) { - t.Errorf("expected result %+v, got %+v", tt.expectedResult, result) + } else { + // Check that the result contains the expected fields + if result == nil { + t.Error("expected result to be non-nil") + return + } + + // Check token_type + if result["token_type"] != "Bearer" { + t.Errorf("expected token_type to be Bearer, got %v", result["token_type"]) + } + + // Check expires_in + if result["expires_in"] != 3600 { + t.Errorf("expected expires_in to be 3600, got %v", result["expires_in"]) + } + + // Check that access_token is a valid JWT + accessToken, ok := result["access_token"].(string) + if !ok || accessToken == "" { + t.Error("access_token should be a non-empty string") + } else { + // Verify it's a JWT (has 3 parts separated by dots) + parts := strings.Split(accessToken, ".") + if len(parts) != 3 { + t.Errorf("access_token should be a JWT with 3 parts, got %d parts", len(parts)) + } + + // Parse to verify it's a valid JWT + parser := jwtlib.NewParser() + _, _, err := parser.ParseUnverified(accessToken, jwtlib.MapClaims{}) + if err != nil { + t.Errorf("access_token should be a valid JWT: %v", err) + } + } + + // Check that id_token is a valid JWT + idToken, ok := result["id_token"].(string) + if !ok || idToken == "" { + t.Error("id_token should be a non-empty string") + } else { + // Verify it's a JWT (has 3 parts separated by dots) + parts := strings.Split(idToken, ".") + if len(parts) != 3 { + t.Errorf("id_token should be a JWT with 3 parts, got %d parts", len(parts)) + } + + // Parse to verify it's a valid JWT + parser := jwtlib.NewParser() + _, _, err := parser.ParseUnverified(idToken, jwtlib.MapClaims{}) + if err != nil { + t.Errorf("id_token should be a valid JWT: %v", err) + } + } } }) } @@ -97,32 +140,19 @@ func TestGoogleProvider_GetUserInfo(t *testing.T) { store.StoreAuthCode("test-client", &models.AuthRequest{ClientID: "test-client"}) tests := []struct { - name string - token string - expectedError string - expectedResult map[string]interface{} + name string + token string + expectedError string }{ { name: "Valid token", token: "valid-token", expectedError: "", - expectedResult: map[string]interface{}{ - "sub": "test-client", - "name": "Test User", - "given_name": "Test", - "family_name": "User", - "email": "test-client@example.com", - "email_verified": true, - "picture": "https://example.com/photo.jpg", - "locale": "", - "hd": "", - }, }, { - name: "Invalid token", - token: "invalid-token", - expectedError: "invalid_token: Invalid access token", - expectedResult: nil, + name: "Invalid token", + token: "invalid-token", + expectedError: "invalid_token: Invalid access token", }, } @@ -134,8 +164,20 @@ func TestGoogleProvider_GetUserInfo(t *testing.T) { if err == nil || err.Error() != tt.expectedError { t.Errorf("expected error %s, got %v", tt.expectedError, err) } - } else if !reflect.DeepEqual(result, tt.expectedResult) { - t.Errorf("expected result %+v, got %+v", tt.expectedResult, result) + } else { + // Check that the result contains the expected user info fields + if result == nil { + t.Error("expected result to be non-nil") + return + } + + // Check that required fields are present + if result["sub"] == "" { + t.Error("expected sub to be non-empty") + } + if result["email"] == "" { + t.Error("expected email to be non-empty") + } } }) }