Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 90 additions & 30 deletions internal/app/oauth_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,14 @@ func (s *OAuthService) GetAuthorizationURL(ctx context.Context, input Authorizat
return nil, ErrProviderDisabled
}

// Generate state token for CSRF protection
state, err := s.generateState(input.Provider, input.FinalRedirect)
// Generate state token with PKCE for CSRF + code interception protection
state, codeVerifier, err := s.generateState(input.Provider, input.FinalRedirect)
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
}

// Build authorization URL based on provider
authURL, err := s.buildAuthorizationURL(input.Provider, providerConfig, input.RedirectURI, state)
// Build authorization URL with PKCE challenge
authURL, err := s.buildAuthorizationURL(input.Provider, providerConfig, input.RedirectURI, state, codeVerifier)
if err != nil {
return nil, fmt.Errorf("failed to build authorization URL: %w", err)
}
Expand Down Expand Up @@ -183,14 +183,14 @@ func (s *OAuthService) HandleCallback(ctx context.Context, input CallbackInput)
return nil, ErrProviderDisabled
}

// Validate state
finalRedirect, err := s.validateState(input.State, input.Provider)
// Validate state and extract PKCE verifier
finalRedirect, codeVerifier, err := s.validateState(input.State, input.Provider)
if err != nil {
return nil, ErrInvalidState
}

// Exchange code for tokens
tokens, err := s.exchangeCode(ctx, input.Provider, providerConfig, input.Code, input.RedirectURI)
// Exchange code for tokens (includes PKCE code_verifier)
tokens, err := s.exchangeCode(ctx, input.Provider, providerConfig, input.Code, input.RedirectURI, codeVerifier)
if err != nil {
s.logger.Error("failed to exchange OAuth code", "provider", input.Provider, "error", err)
return nil, ErrOAuthExchangeFailed
Expand All @@ -203,6 +203,11 @@ func (s *OAuthService) HandleCallback(ctx context.Context, input CallbackInput)
return nil, ErrOAuthUserInfoFailed
}

// SECURITY: Require email from OAuth provider
if userInfo.Email == "" {
return nil, fmt.Errorf("OAuth provider did not return an email address")
}

// Find or create user
u, err := s.findOrCreateUser(ctx, userInfo, input.Provider)
if err != nil {
Expand Down Expand Up @@ -261,33 +266,53 @@ func (s *OAuthService) getProviderConfig(provider OAuthProvider) *config.OAuthPr
return nil
}

// generateState generates a signed state token.
func (s *OAuthService) generateState(provider OAuthProvider, finalRedirect string) (string, error) {
// generatePKCE generates a PKCE code verifier and challenge (RFC 7636).
func generatePKCE() (verifier, challenge string, err error) {
verifierBytes := make([]byte, 32)
if _, err := rand.Read(verifierBytes); err != nil {
return "", "", err
}
verifier = base64.RawURLEncoding.EncodeToString(verifierBytes)

hash := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(hash[:])
return verifier, challenge, nil
}

// generateState generates a signed state token with PKCE verifier.
func (s *OAuthService) generateState(provider OAuthProvider, finalRedirect string) (state string, codeVerifier string, err error) {
// Generate random bytes
randomBytes := make([]byte, 16)
if _, err := rand.Read(randomBytes); err != nil {
return "", err
return "", "", err
}

// Generate PKCE
verifier, _, pkceErr := generatePKCE()
if pkceErr != nil {
return "", "", pkceErr
}

// Create state data
// Create state data (includes PKCE verifier for callback verification)
stateData := map[string]interface{}{
"provider": string(provider),
"final_redirect": finalRedirect,
"code_verifier": verifier,
"random": base64.URLEncoding.EncodeToString(randomBytes),
"exp": time.Now().Add(s.config.StateDuration).Unix(),
}

// Encode state data
stateJSON, err := json.Marshal(stateData)
if err != nil {
return "", err
return "", "", err
}

// Sign the state
stateBase64 := base64.URLEncoding.EncodeToString(stateJSON)
signature := s.signState(stateBase64)

return stateBase64 + "." + signature, nil
return stateBase64 + "." + signature, verifier, nil
}

// signState creates an HMAC signature for the state.
Expand All @@ -301,49 +326,50 @@ func (s *OAuthService) signState(data string) string {
return base64.URLEncoding.EncodeToString(h.Sum(nil))
}

// validateState validates and decodes the state token.
func (s *OAuthService) validateState(state string, expectedProvider OAuthProvider) (string, error) {
// validateState validates and decodes the state token. Returns finalRedirect and codeVerifier.
func (s *OAuthService) validateState(state string, expectedProvider OAuthProvider) (string, string, error) {
parts := strings.SplitN(state, ".", 2)
if len(parts) != 2 {
return "", errors.New("invalid state format")
return "", "", errors.New("invalid state format")
}

stateData, signature := parts[0], parts[1]

// Verify signature
expectedSig := s.signState(stateData)
if !hmac.Equal([]byte(signature), []byte(expectedSig)) {
return "", errors.New("invalid state signature")
return "", "", errors.New("invalid state signature")
}

// Decode state data
stateJSON, err := base64.URLEncoding.DecodeString(stateData)
if err != nil {
return "", errors.New("invalid state encoding")
return "", "", errors.New("invalid state encoding")
}

var data map[string]interface{}
if err := json.Unmarshal(stateJSON, &data); err != nil {
return "", errors.New("invalid state JSON")
return "", "", errors.New("invalid state JSON")
}

// Check expiration
expFloat, ok := data["exp"].(float64)
if !ok {
return "", errors.New("invalid state expiration")
return "", "", errors.New("invalid state expiration")
}
if time.Now().Unix() > int64(expFloat) {
return "", errors.New("state expired")
return "", "", errors.New("state expired")
}

// Check provider
provider, ok := data["provider"].(string)
if !ok || provider != string(expectedProvider) {
return "", errors.New("provider mismatch")
return "", "", errors.New("provider mismatch")
}

finalRedirect, _ := data["final_redirect"].(string)
return finalRedirect, nil
verifier, _ := data["code_verifier"].(string)
return finalRedirect, verifier, nil
}

// OAuth token response.
Expand All @@ -356,7 +382,7 @@ type oauthTokens struct {
}

// exchangeCode exchanges the authorization code for tokens.
func (s *OAuthService) exchangeCode(ctx context.Context, provider OAuthProvider, cfg *config.OAuthProviderConfig, code, redirectURI string) (*oauthTokens, error) {
func (s *OAuthService) exchangeCode(ctx context.Context, provider OAuthProvider, cfg *config.OAuthProviderConfig, code, redirectURI, codeVerifier string) (*oauthTokens, error) {
tokenURL := s.getTokenURL(provider)

data := url.Values{}
Expand All @@ -366,6 +392,11 @@ func (s *OAuthService) exchangeCode(ctx context.Context, provider OAuthProvider,
data.Set("redirect_uri", redirectURI)
data.Set("grant_type", "authorization_code")

// PKCE: Include code_verifier (RFC 7636)
if codeVerifier != "" {
data.Set("code_verifier", codeVerifier)
}

req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, err
Expand Down Expand Up @@ -439,15 +470,21 @@ func (s *OAuthService) getGoogleUserInfo(ctx context.Context, accessToken string
}

var data struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Picture string `json:"picture"`
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Picture string `json:"picture"`
EmailVerified bool `json:"verified_email"` // Google uses "verified_email" in v2
}
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return nil, err
}

// SECURITY: Only accept verified emails from Google
if !data.EmailVerified {
return nil, fmt.Errorf("email not verified by Google")
}

return &OAuthUserInfo{
ID: data.ID,
Email: data.Email,
Expand Down Expand Up @@ -601,6 +638,21 @@ func (s *OAuthService) findOrCreateUser(ctx context.Context, userInfo *OAuthUser
// Try to find existing user by email
existingUser, err := s.userRepo.GetByEmail(ctx, userInfo.Email)
if err == nil && existingUser != nil {
// SECURITY: Verify auth provider matches to prevent account takeover.
// A local user with a password cannot be logged in via OAuth.
existingProvider := existingUser.AuthProvider()
expectedProvider := provider.ToAuthProvider()

if existingProvider != expectedProvider && existingProvider == user.AuthProviderLocal {
if existingUser.PasswordHash() != nil {
s.logger.Warn("OAuth login blocked: email exists with local auth",
"email", userInfo.Email,
"oauth_provider", provider,
)
return nil, fmt.Errorf("this email is registered with password login")
}
}

// Update last login
existingUser.UpdateLastLogin()
if err := s.userRepo.Update(ctx, existingUser); err != nil {
Expand Down Expand Up @@ -679,7 +731,7 @@ func (s *OAuthService) createSession(ctx context.Context, u *user.User) (*Sessio
}

// buildAuthorizationURL builds the OAuth authorization URL.
func (s *OAuthService) buildAuthorizationURL(provider OAuthProvider, cfg *config.OAuthProviderConfig, redirectURI, state string) (string, error) {
func (s *OAuthService) buildAuthorizationURL(provider OAuthProvider, cfg *config.OAuthProviderConfig, redirectURI, state, codeVerifier string) (string, error) {
authURL := s.getAuthURL(provider)

params := url.Values{}
Expand All @@ -692,6 +744,14 @@ func (s *OAuthService) buildAuthorizationURL(provider OAuthProvider, cfg *config
params.Set("scope", strings.Join(cfg.Scopes, " "))
}

// PKCE: Add code_challenge (RFC 7636)
if codeVerifier != "" {
challengeHash := sha256.Sum256([]byte(codeVerifier))
codeChallenge := base64.RawURLEncoding.EncodeToString(challengeHash[:])
params.Set("code_challenge", codeChallenge)
params.Set("code_challenge_method", "S256")
}

// Provider-specific parameters
switch provider {
case OAuthProviderGoogle:
Expand Down
5 changes: 4 additions & 1 deletion internal/app/session_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,10 @@ func (s *SessionService) RevokeSession(ctx context.Context, userID, sessionID st
return fmt.Errorf("failed to update session: %w", err)
}

// Revoke all refresh tokens for this session
// Revoke all refresh tokens for this session.
// Note: Not in same DB transaction as session revoke above.
// Race window is <1ms. Even if a refresh token is used in this window,
// the session check in ExchangeToken will reject it (session is already revoked).
if err := s.refreshTokenRepo.RevokeBySessionID(ctx, sid); err != nil {
s.logger.Error("failed to revoke refresh tokens", "error", err)
}
Expand Down
46 changes: 37 additions & 9 deletions internal/app/sso_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ func (s *SSOService) GenerateAuthorizeURL(ctx context.Context, input SSOAuthoriz
}
_ = clientSecret // Just validating decryption works

// Generate state token containing org slug + provider
state, err := s.generateState(input.OrgSlug, input.Provider)
// Generate state token with nonce for CSRF + replay protection
state, nonce, err := s.generateState(input.OrgSlug, input.Provider)
if err != nil {
return nil, fmt.Errorf("generate state: %w", err)
}
Expand All @@ -219,6 +219,7 @@ func (s *SSOService) GenerateAuthorizeURL(ctx context.Context, input SSOAuthoriz
params.Set("redirect_uri", input.RedirectURI)
params.Set("state", state)
params.Set("response_type", "code")
params.Set("nonce", nonce) // ID token replay prevention

if len(ip.Scopes()) > 0 {
params.Set("scope", strings.Join(ip.Scopes(), " "))
Expand Down Expand Up @@ -371,29 +372,37 @@ func (s *SSOService) HandleCallback(ctx context.Context, input SSOCallbackInput)
}, nil
}

// generateState generates a signed state token containing org slug and provider.
func (s *SSOService) generateState(orgSlug, provider string) (string, error) {
// generateState generates a signed state token containing org slug, provider, and nonce.
func (s *SSOService) generateState(orgSlug, provider string) (state string, nonce string, err error) {
randomBytes := make([]byte, 16)
if _, err := rand.Read(randomBytes); err != nil {
return "", err
return "", "", err
}

// Generate nonce for ID token replay prevention
nonceBytes := make([]byte, 16)
if _, err := rand.Read(nonceBytes); err != nil {
return "", "", err
}
nonce = base64.RawURLEncoding.EncodeToString(nonceBytes)

stateData := map[string]interface{}{
"org": orgSlug,
"provider": provider,
"nonce": nonce,
"random": base64.URLEncoding.EncodeToString(randomBytes),
"exp": time.Now().Add(10 * time.Minute).Unix(),
}

stateJSON, err := json.Marshal(stateData)
if err != nil {
return "", err
stateJSON, marshalErr := json.Marshal(stateData)
if marshalErr != nil {
return "", "", marshalErr
}

stateBase64 := base64.URLEncoding.EncodeToString(stateJSON)
signature := s.signState(stateBase64)

return stateBase64 + "." + signature, nil
return stateBase64 + "." + signature, nonce, nil
}

// signState creates an HMAC signature for the state.
Expand Down Expand Up @@ -593,6 +602,25 @@ func (s *SSOService) findOrCreateUser(ctx context.Context, userInfo *SSOUserInfo
// Try to find existing user by email
existingUser, err := s.userRepo.GetByEmail(ctx, userInfo.Email)
if err == nil && existingUser != nil {
// SECURITY: Verify auth provider matches to prevent account takeover.
// A local user cannot be logged in via SSO (and vice versa) unless
// the auth provider matches or the user was created by this SSO provider.
existingProvider := existingUser.AuthProvider()
expectedProvider := s.mapAuthProvider(provider)

if existingProvider != expectedProvider && existingProvider != user.AuthProviderOIDC {
// Allow local users to be "upgraded" to SSO only if they have no password set
// (i.e., they were invited but haven't set a password yet).
if existingProvider == user.AuthProviderLocal && existingUser.PasswordHash() != nil {
s.logger.Warn("SSO login blocked: email exists with different auth provider",
"email", userInfo.Email,
"existing_provider", existingProvider,
"sso_provider", expectedProvider,
)
return nil, fmt.Errorf("%w: this email is registered with a different login method", ErrSSODomainNotAllowed)
}
}

existingUser.UpdateLastLogin()
if updateErr := s.userRepo.Update(ctx, existingUser); updateErr != nil {
s.logger.Warn("failed to update last login", "error", updateErr)
Expand Down
Loading
Loading