Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refresh token rotation #540

Merged
merged 1 commit into from
Aug 16, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
187 changes: 115 additions & 72 deletions db/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,47 +91,108 @@ func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenG
}

func (r *refreshTokenRepo) Create(userID, clientID, connectorID string, scopes []string) (string, error) {
if userID == "" {
return "", refresh.ErrorInvalidUserID
return r.create(nil, userID, clientID, connectorID, scopes)
}

func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
return r.verify(nil, clientID, token)
}

func (r *refreshTokenRepo) Revoke(userID, token string) error {
tx, err := r.begin()
if err != nil {
return err
}
if clientID == "" {
return "", refresh.ErrorInvalidClientID
defer tx.Rollback()
if err := r.revoke(tx, userID, token); err != nil {
return err
}

// TODO(yifan): Check the number of tokens given to the client-user pair.
tokenPayload, err := r.tokenGenerator.Generate()
return tx.Commit()
}

func (r *refreshTokenRepo) RenewRefreshToken(clientID, userID, oldToken string) (newRefreshToken string, err error) {
// Verify
userID, connectorID, scopes, err := r.verify(nil, clientID, oldToken)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is verify outside the transaction?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because if dex can't verify the token you save opening a transaction.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay sure

if err != nil {
return "", err
}

payloadHash, err := bcrypt.GenerateFromPassword(tokenPayload, bcrypt.DefaultCost)
// Revoke old refresh token
tx, err := r.begin()
if err != nil {
return "", err
}

record := &refreshTokenModel{
PayloadHash: payloadHash,
UserID: userID,
ClientID: clientID,
ConnectorID: connectorID,
Scopes: strings.Join(scopes, " "),
defer tx.Rollback()
if err := r.revoke(tx, userID, oldToken); err != nil {
return "", err
}

if err := r.executor(nil).Insert(record); err != nil {
// Renew refresh token
newRefreshToken, err = r.create(tx, userID, clientID, connectorID, scopes)
if err != nil {
return "", err
}

return buildToken(record.ID, tokenPayload), nil
return newRefreshToken, tx.Commit()
}

func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
func (r *refreshTokenRepo) RevokeTokensForClient(userID, clientID string) error {
q := fmt.Sprintf("DELETE FROM %s WHERE user_id = $1 AND client_id = $2", r.quote(refreshTokenTableName))
_, err := r.executor(nil).Exec(q, userID, clientID)
return err
}

func (r *refreshTokenRepo) ClientsWithRefreshTokens(userID string) ([]client.Client, error) {
q := `SELECT c.* FROM %s as c
INNER JOIN %s as r ON c.id = r.client_id WHERE r.user_id = $1;`
q = fmt.Sprintf(q, r.quote(clientTableName), r.quote(refreshTokenTableName))
var clients []clientModel
if _, err := r.executor(nil).Select(&clients, q, userID); err != nil {
return nil, err
}

c := make([]client.Client, len(clients))
for i, client := range clients {
ident, err := client.Client()
if err != nil {
return nil, err
}
c[i] = *ident
// Do not share the secret.
c[i].Credentials.Secret = ""
}

return c, nil
}

func (r *refreshTokenRepo) get(tx repo.Transaction, tokenID int64) (*refreshTokenModel, error) {
ex := r.executor(tx)
result, err := ex.Get(refreshTokenModel{}, tokenID)
if err != nil {
return nil, err
}

if result == nil {
return nil, refresh.ErrorInvalidToken
}

record, ok := result.(*refreshTokenModel)
if !ok {
log.Errorf("expected refreshTokenModel but found %v", reflect.TypeOf(result))
return nil, errors.New("unrecognized model")
}
return record, nil
}

func (r *refreshTokenRepo) verify(tx repo.Transaction, clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
tokenID, tokenPayload, err := parseToken(token)

if err != nil {
return
}

record, err := r.get(nil, tokenID)
record, err := r.get(tx, tokenID)
if err != nil {
return
}
Expand All @@ -140,6 +201,7 @@ func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID s
return "", "", nil, refresh.ErrorInvalidClientID
}

// Check if the hash of token received is the same stored in database
if err = checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
return
}
Expand All @@ -152,17 +214,46 @@ func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID s
return record.UserID, record.ConnectorID, scopes, nil
}

func (r *refreshTokenRepo) Revoke(userID, token string) error {
tokenID, tokenPayload, err := parseToken(token)
func (r *refreshTokenRepo) create(tx repo.Transaction, userID, clientID, connectorID string, scopes []string) (string, error) {
if userID == "" {
return "", refresh.ErrorInvalidUserID
}
if clientID == "" {
return "", refresh.ErrorInvalidClientID
}

// TODO(yifan): Check the number of tokens given to the client-user pair.
tokenPayload, err := r.tokenGenerator.Generate()
if err != nil {
return err
return "", err
}

tx, err := r.begin()
payloadHash, err := bcrypt.GenerateFromPassword(tokenPayload, bcrypt.DefaultCost)
if err != nil {
return "", err
}

record := &refreshTokenModel{
PayloadHash: payloadHash,
UserID: userID,
ClientID: clientID,
ConnectorID: connectorID,
Scopes: strings.Join(scopes, " "),
}

if err := r.executor(tx).Insert(record); err != nil {
return "", err
}

return buildToken(record.ID, tokenPayload), nil
}

func (r *refreshTokenRepo) revoke(tx repo.Transaction, userID, token string) error {
tokenID, tokenPayload, err := parseToken(token)
if err != nil {
return err
}
defer tx.Rollback()

exec := r.executor(tx)
record, err := r.get(tx, tokenID)
if err != nil {
Expand All @@ -185,53 +276,5 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error {
return refresh.ErrorInvalidToken
}

return tx.Commit()
}

func (r *refreshTokenRepo) RevokeTokensForClient(userID, clientID string) error {
q := fmt.Sprintf("DELETE FROM %s WHERE user_id = $1 AND client_id = $2", r.quote(refreshTokenTableName))
_, err := r.executor(nil).Exec(q, userID, clientID)
return err
}

func (r *refreshTokenRepo) ClientsWithRefreshTokens(userID string) ([]client.Client, error) {
q := `SELECT c.* FROM %s as c
INNER JOIN %s as r ON c.id = r.client_id WHERE r.user_id = $1;`
q = fmt.Sprintf(q, r.quote(clientTableName), r.quote(refreshTokenTableName))
var clients []clientModel
if _, err := r.executor(nil).Select(&clients, q, userID); err != nil {
return nil, err
}

c := make([]client.Client, len(clients))
for i, client := range clients {
ident, err := client.Client()
if err != nil {
return nil, err
}
c[i] = *ident
// Do not share the secret.
c[i].Credentials.Secret = ""
}

return c, nil
}

func (r *refreshTokenRepo) get(tx repo.Transaction, tokenID int64) (*refreshTokenModel, error) {
ex := r.executor(tx)
result, err := ex.Get(refreshTokenModel{}, tokenID)
if err != nil {
return nil, err
}

if result == nil {
return nil, refresh.ErrorInvalidToken
}

record, ok := result.(*refreshTokenModel)
if !ok {
log.Errorf("expected refreshTokenModel but found %v", reflect.TypeOf(result))
return nil, errors.New("unrecognized model")
}
return record, nil
return nil
}
3 changes: 3 additions & 0 deletions refresh/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ type RefreshTokenRepo interface {
// Revoke deletes the refresh token if the token belongs to the given userID.
Revoke(userID, token string) error

// Revoke old refresh token and generates a new one
RenewRefreshToken(clientID, userID, oldToken string) (newRefreshToken string, err error)

// RevokeTokensForClient revokes all tokens issued for the userID for the provided client.
RevokeTokensForClient(userID, clientID string) error

Expand Down
2 changes: 1 addition & 1 deletion server/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc {
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
return
}
jwt, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token)
jwt, refreshToken, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token)
if err != nil {
writeTokenError(w, err, state)
return
Expand Down
42 changes: 24 additions & 18 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ type OIDCServer interface {

ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error)

// RefreshToken takes a previously generated refresh token and returns a new ID token
// RefreshToken takes a previously generated refresh token and returns a new ID token and new refresh token
// if the token is valid.
RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error)
RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, error)

KillSession(string) error

Expand Down Expand Up @@ -567,34 +567,34 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
return jwt, refreshToken, nil
}

func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error) {
func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, string, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the raw Claims ever used or do they immediately call Encode()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raw claims are encoded immediately here. But the interface OIDCServer uses jose.JWT in more places. Do you prefer to change it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is fine. Again, it's something we can clean up later.

ok, err := s.ClientManager.Authenticate(creds)
if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}
if !ok {
log.Errorf("Failed to Authenticate client %s", creds.ID)
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient)
}

userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
switch err {
case nil:
break
case refresh.ErrorInvalidToken:
return nil, oauth2.NewError(oauth2.ErrorInvalidRequest)
return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest)
case refresh.ErrorInvalidClientID:
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient)
default:
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}

if len(scopes) == 0 {
scopes = rtScopes
} else {
if !rtScopes.Contains(scopes) {
return nil, oauth2.NewError(oauth2.ErrorInvalidRequest)
return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest)
}
}

Expand All @@ -603,27 +603,27 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
// The error can be user.ErrorNotFound, but we are not deleting
// user at this moment, so this shouldn't happen.
log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}

var groups []string
if rtScopes.HasScope(scope.ScopeGroups) {
conn, ok := s.connector(connectorID)
if !ok {
log.Errorf("refresh token contained invalid connector ID (%s)", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}

grouper, ok := conn.(connector.GroupsConnector)
if !ok {
log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}

remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID)
if err != nil {
log.Errorf("failed to get remote identities: %v", err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}
remoteIdentity, ok := func() (user.RemoteIdentity, bool) {
for _, ri := range remoteIdentities {
Expand All @@ -635,18 +635,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
}()
if !ok {
log.Errorf("failed to get remote identity for connector %s", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}
if groups, err = grouper.Groups(remoteIdentity.ID); err != nil {
log.Errorf("failed to get groups for refresh token: %v", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}
}

signer, err := s.KeyManager.Signer()
if err != nil {
log.Errorf("Failed to refresh ID token: %v", err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}

now := time.Now()
Expand All @@ -666,12 +666,18 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
jwt, err := jose.NewSignedJWT(claims, signer)
if err != nil {
log.Errorf("Failed to generate ID token: %v", err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}

refreshToken, err := s.RefreshTokenRepo.RenewRefreshToken(creds.ID, userID, token)
if err != nil {
log.Errorf("Failed to generate new refresh token: %v", err)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}

log.Infof("New token sent: clientID=%s", creds.ID)

return jwt, nil
return jwt, refreshToken, nil
}

func (s *Server) CrossClientAuthAllowed(requestingClientID, authorizingClientID string) (bool, error) {
Expand Down