diff --git a/server/handlers.go b/server/handlers.go index ef264dfee7..0d38121b1d 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -467,6 +467,8 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe // Only valid for implicit and hybrid flows. idToken string idTokenExpiry time.Time + + accessToken = storage.NewID() ) for _, responseType := range authReq.ResponseTypes { @@ -502,7 +504,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe case responseTypeIDToken: implicitOrHybrid = true var err error - idToken, idTokenExpiry, err = s.newIDToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce) + idToken, idTokenExpiry, err = s.newIDToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, accessToken) if err != nil { s.logger.Errorf("failed to create ID token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -513,7 +515,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe if implicitOrHybrid { v := url.Values{} - v.Set("access_token", storage.NewID()) + v.Set("access_token", accessToken) v.Set("token_type", "bearer") v.Set("state", authReq.State) if idToken != "" { @@ -623,7 +625,8 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s return } - idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce) + accessToken := storage.NewID() + idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken) if err != nil { s.logger.Errorf("failed to create ID token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -674,7 +677,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s return } } - s.writeAccessToken(w, idToken, refreshToken, expiry) + s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry) } // handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6 @@ -787,7 +790,8 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie Groups: ident.Groups, } - idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce) + accessToken := storage.NewID() + idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken) if err != nil { s.logger.Errorf("failed to create ID token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -826,10 +830,10 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } - s.writeAccessToken(w, idToken, rawNewToken, expiry) + s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry) } -func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, refreshToken string, expiry time.Time) { +func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) { // TODO(ericchiang): figure out an access token story and support the user info // endpoint. For now use a random value so no one depends on the access_token // holding a specific structure. @@ -840,7 +844,7 @@ func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, refreshToken s RefreshToken string `json:"refresh_token,omitempty"` IDToken string `json:"id_token"` }{ - storage.NewID(), + accessToken, "bearer", int(expiry.Sub(s.now()).Seconds()), refreshToken, diff --git a/server/oauth2.go b/server/oauth2.go index adbf1eedab..18c554ec2e 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -1,14 +1,25 @@ package server import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/sha256" + "crypto/sha512" + "encoding/base64" "encoding/json" + "errors" "fmt" + "hash" + "io" "net/http" "net/url" "strconv" "strings" "time" + jose "gopkg.in/square/go-jose.v2" + "github.com/coreos/dex/connector" "github.com/coreos/dex/storage" ) @@ -125,6 +136,88 @@ func parseScopes(scopes []string) connector.Scopes { return s } +// Determine the signature algorithm for a JWT. +func signatureAlgorithm(jwk *jose.JSONWebKey) (alg jose.SignatureAlgorithm, err error) { + if jwk.Key == nil { + return alg, errors.New("no signing key") + } + switch key := jwk.Key.(type) { + case *rsa.PrivateKey: + // Because OIDC mandates that we support RS256, we always return that + // value. In the future, we might want to make this configurable on a + // per client basis. For example allowing PS256 or ECDSA variants. + // + // See https://github.com/coreos/dex/issues/692 + return jose.RS256, nil + case *ecdsa.PrivateKey: + // We don't actually support ECDSA keys yet, but they're tested for + // in case we want to in the future. + // + // These values are prescribed depending on the ECDSA key type. We + // can't return different values. + switch key.Params() { + case elliptic.P256().Params(): + return jose.ES256, nil + case elliptic.P384().Params(): + return jose.ES384, nil + case elliptic.P521().Params(): + return jose.ES512, nil + default: + return alg, errors.New("unsupported ecdsa curve") + } + default: + return alg, fmt.Errorf("unsupported signing key type %T", key) + } +} + +func signPayload(key *jose.JSONWebKey, alg jose.SignatureAlgorithm, payload []byte) (jws string, err error) { + signingKey := jose.SigningKey{Key: key, Algorithm: alg} + + signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{}) + if err != nil { + return "", fmt.Errorf("new signier: %v", err) + } + signature, err := signer.Sign(payload) + if err != nil { + return "", fmt.Errorf("signing payload: %v", err) + } + return signature.CompactSerialize() +} + +// The hash algorithm for the at_hash is detemrined by the signing +// algorithm used for the id_token. From the spec: +// +// ...the hash algorithm used is the hash algorithm used in the alg Header +// Parameter of the ID Token's JOSE Header. For instance, if the alg is RS256, +// hash the access_token value with SHA-256 +// +// https://openid.net/specs/openid-connect-core-1_0.html#ImplicitIDToken +var hashForSigAlg = map[jose.SignatureAlgorithm]func() hash.Hash{ + jose.RS256: sha256.New, + jose.RS384: sha512.New384, + jose.RS512: sha512.New, + jose.ES256: sha256.New, + jose.ES384: sha512.New384, + jose.ES512: sha512.New, +} + +// Compute an at_hash from a raw access token and a signature algorithm +// +// See: https://openid.net/specs/openid-connect-core-1_0.html#ImplicitIDToken +func accessTokenHash(alg jose.SignatureAlgorithm, accessToken string) (string, error) { + newHash, ok := hashForSigAlg[alg] + if !ok { + return "", fmt.Errorf("unsupported signature algorithm: %s", alg) + } + + hash := newHash() + if _, err := io.WriteString(hash, accessToken); err != nil { + return "", fmt.Errorf("computing hash: %v", err) + } + sum := hash.Sum(nil) + return base64.RawURLEncoding.EncodeToString(sum[:len(sum)/2]), nil +} + type audience []string func (a audience) MarshalJSON() ([]byte, error) { @@ -143,6 +236,8 @@ type idTokenClaims struct { AuthorizingParty string `json:"azp,omitempty"` Nonce string `json:"nonce,omitempty"` + AccessTokenHash string `json:"at_hash,omitempty"` + Email string `json:"email,omitempty"` EmailVerified *bool `json:"email_verified,omitempty"` @@ -151,7 +246,22 @@ type idTokenClaims struct { Name string `json:"name,omitempty"` } -func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce string) (idToken string, expiry time.Time, err error) { +func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken string) (idToken string, expiry time.Time, err error) { + keys, err := s.storage.GetKeys() + if err != nil { + s.logger.Errorf("Failed to get keys: %v", err) + return "", expiry, err + } + + signingKey := keys.SigningKey + if signingKey == nil { + return "", expiry, fmt.Errorf("no key to sign payload with") + } + signingAlg, err := signatureAlgorithm(signingKey) + if err != nil { + return "", expiry, err + } + issuedAt := s.now() expiry = issuedAt.Add(s.idTokensValidFor) @@ -163,6 +273,15 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str IssuedAt: issuedAt.Unix(), } + if accessToken != "" { + atHash, err := accessTokenHash(signingAlg, accessToken) + if err != nil { + s.logger.Errorf("error computing at_hash: %v", err) + return "", expiry, fmt.Errorf("error computing at_hash: %v", err) + } + tok.AccessTokenHash = atHash + } + for _, scope := range scopes { switch { case scope == scopeEmail: @@ -175,6 +294,8 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str default: peerID, ok := parseCrossClientScope(scope) if !ok { + // Ignore unknown scopes. These are already validated during the + // initial auth request. continue } isTrusted, err := s.validateCrossClientTrust(clientID, peerID) @@ -188,9 +309,14 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str tok.Audience = append(tok.Audience, peerID) } } + if len(tok.Audience) == 0 { + // Client didn't ask for cross client audience. Set the current + // client as the audience. tok.Audience = audience{clientID} } else { + // Client asked for cross client audience. The current client + // becomes the authorizing party. tok.AuthorizingParty = clientID } @@ -199,12 +325,7 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str return "", expiry, fmt.Errorf("could not serialize claims: %v", err) } - keys, err := s.storage.GetKeys() - if err != nil { - s.logger.Errorf("Failed to get keys: %v", err) - return "", expiry, err - } - if idToken, err = keys.Sign(payload); err != nil { + if idToken, err = signPayload(signingKey, signingAlg, payload); err != nil { return "", expiry, fmt.Errorf("failed to sign payload: %v", err) } return idToken, expiry, nil diff --git a/server/oauth2_test.go b/server/oauth2_test.go index 807de8366c..7f9a449da4 100644 --- a/server/oauth2_test.go +++ b/server/oauth2_test.go @@ -6,6 +6,8 @@ import ( "net/url" "testing" + jose "gopkg.in/square/go-jose.v2" + "github.com/coreos/dex/storage" ) @@ -148,3 +150,20 @@ func TestParseAuthorizationRequest(t *testing.T) { }() } } + +const ( + // at_hash value and access_token returned by Google. + googleAccessTokenHash = "piwt8oCH-K2D9pXlaS1Y-w" + googleAccessToken = "ya29.CjHSA1l5WUn8xZ6HanHFzzdHdbXm-14rxnC7JHch9eFIsZkQEGoWzaYG4o7k5f6BnPLj" + googleSigningAlg = jose.RS256 +) + +func TestAccessTokenHash(t *testing.T) { + atHash, err := accessTokenHash(googleSigningAlg, googleAccessToken) + if err != nil { + t.Fatal(err) + } + if atHash != googleAccessTokenHash { + t.Errorf("expected %q got %q", googleAccessTokenHash, atHash) + } +} diff --git a/server/server_test.go b/server/server_test.go index d848076fcc..b438279c35 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -26,6 +26,7 @@ import ( "golang.org/x/crypto/bcrypt" "golang.org/x/net/context" "golang.org/x/oauth2" + jose "gopkg.in/square/go-jose.v2" "github.com/coreos/dex/connector" "github.com/coreos/dex/connector/mock" @@ -221,6 +222,38 @@ func TestOAuth2CodeFlow(t *testing.T) { return nil }, }, + { + name: "verify at_hash", + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { + rawIDToken, ok := token.Extra("id_token").(string) + if !ok { + return fmt.Errorf("no id token found") + } + idToken, err := p.Verifier().Verify(ctx, rawIDToken) + if err != nil { + return fmt.Errorf("failed to verify id token: %v", err) + } + + var claims struct { + AtHash string `json:"at_hash"` + } + if err := idToken.Claims(&claims); err != nil { + return fmt.Errorf("failed to decode raw claims: %v", err) + } + if claims.AtHash == "" { + return errors.New("no at_hash value in id_token") + } + wantAtHash, err := accessTokenHash(jose.RS256, token.AccessToken) + if err != nil { + return fmt.Errorf("computed expected at hash: %v", err) + } + if wantAtHash != claims.AtHash { + return fmt.Errorf("expected at_hash=%q got=%q", wantAtHash, claims.AtHash) + } + + return nil + }, + }, { name: "refresh token", handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { diff --git a/storage/storage.go b/storage/storage.go index 47f5dcc656..3d27e6f72a 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -1,13 +1,9 @@ package storage import ( - "crypto/ecdsa" - "crypto/elliptic" "crypto/rand" - "crypto/rsa" "encoding/base32" "errors" - "fmt" "io" "strings" "time" @@ -288,38 +284,3 @@ type Keys struct { // For caching purposes, implementations MUST NOT update keys before this time. NextRotation time.Time } - -// Sign creates a JWT using the signing key. -func (k Keys) Sign(payload []byte) (jws string, err error) { - if k.SigningKey == nil { - return "", fmt.Errorf("no key to sign payload with") - } - signingKey := jose.SigningKey{Key: k.SigningKey} - - switch key := k.SigningKey.Key.(type) { - case *rsa.PrivateKey: - // TODO(ericchiang): Allow different cryptographic hashes. - signingKey.Algorithm = jose.RS256 - case *ecdsa.PrivateKey: - switch key.Params() { - case elliptic.P256().Params(): - signingKey.Algorithm = jose.ES256 - case elliptic.P384().Params(): - signingKey.Algorithm = jose.ES384 - case elliptic.P521().Params(): - signingKey.Algorithm = jose.ES512 - default: - return "", errors.New("unsupported ecdsa curve") - } - } - - signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{}) - if err != nil { - return "", fmt.Errorf("new signier: %v", err) - } - signature, err := signer.Sign(payload) - if err != nil { - return "", fmt.Errorf("signing payload: %v", err) - } - return signature.CompactSerialize() -}