Skip to content

Commit

Permalink
CLI-1338: Rotate refresh tokens (#1036)
Browse files Browse the repository at this point in the history
* refresh refresh tokens

* try normal sso login if token refresh fails

* err if refresh_token is missing

* add refresh token rotation to unit tests
  • Loading branch information
brianstrauch committed Oct 11, 2021
1 parent 1afa3f5 commit 06a198d
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 39 deletions.
24 changes: 13 additions & 11 deletions internal/pkg/auth/auth_token_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@ func NewAuthTokenHandler(logger *log.Logger) AuthTokenHandler {
return &AuthTokenHandlerImpl{logger}
}

// Second string returned is refresh token if the user performs SSO login
func (a *AuthTokenHandlerImpl) GetCCloudTokens(client *ccloud.Client, credentials *Credentials, noBrowser bool) (string, string, error) {
if credentials.IsSSO {
// SSO password is the refresh token, if not present then user must perform SSO login, if present then refresh token automatically obtains a new token
// For an SSO user, the "Password" field may contain a refresh token. If one exists, try to obtain a new token.
if credentials.Password != "" {
token, err := a.refreshCCloudSSOToken(client, credentials.Password)
return token, "", err
} else {
return a.getCCloudSSOToken(client, noBrowser, credentials.Username)
if token, refreshToken, err := a.refreshCCloudSSOToken(client, credentials.Password); err == nil {
return token, refreshToken, nil
}
}
return a.getCCloudSSOToken(client, noBrowser, credentials.Username)
}

client.HttpClient.Timeout = 30 * time.Second
token, err := client.Auth.Login(context.Background(), "", credentials.Username, credentials.Password)
return token, "", err
Expand Down Expand Up @@ -81,16 +81,18 @@ func (a *AuthTokenHandlerImpl) getCCloudUserSSO(client *ccloud.Client, email str
return "", nil
}

func (a *AuthTokenHandlerImpl) refreshCCloudSSOToken(client *ccloud.Client, refreshToken string) (string, error) {
idToken, err := sso.GetNewIDTokenFromRefreshToken(client.BaseURL, refreshToken, a.logger)
func (a *AuthTokenHandlerImpl) refreshCCloudSSOToken(client *ccloud.Client, refreshToken string) (string, string, error) {
idToken, refreshToken, err := sso.RefreshTokens(client.BaseURL, refreshToken, a.logger)
if err != nil {
return "", err
return "", "", err
}

token, err := client.Auth.Login(context.Background(), idToken, "", "")
if err != nil {
return "", err
return "", "", err
}
return token, nil

return token, refreshToken, err
}

func (a *AuthTokenHandlerImpl) GetConfluentToken(mdsClient *mds.APIClient, credentials *Credentials) (string, error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/pkg/errors/error_message.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ const (
GenerateRandomSSOProviderErrorMsg = "unable to generate random bytes for SSO provider state"
GenerateRandomCodeVerifierErrorMsg = "unable to generate random bytes for code verifier"
ComputeHashErrorMsg = "unable to compute hash for code challenge"
MissingIDTokenFieldErrorMsg = "oauth token response body did not contain id_token field"
FmtMissingOAuthFieldErrorMsg = `oauth token response body did not contain field "%s"`
ConstructOAuthRequestErrorMsg = "failed to construct oauth token request"
UnmarshalOAuthTokenErrorMsg = "failed to unmarshal response body in oauth token request"

Expand Down
13 changes: 7 additions & 6 deletions internal/pkg/sso/auth_sso_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,16 @@ func Login(authURL string, noBrowser bool, auth0ConnectionName string, logger *l
return state.SSOProviderIDToken, state.SSOProviderRefreshToken, nil
}

func GetNewIDTokenFromRefreshToken(authURL string, refreshToken string, logger *log.Logger) (idToken string, err error) {
func RefreshTokens(authURL string, refreshToken string, logger *log.Logger) (string, string, error) {
state, err := newState(authURL, false, logger)
if err != nil {
return "", err
return "", "", err
}
state.SSOProviderRefreshToken = refreshToken
err = state.refreshOAuthToken()
if err != nil {
return "", err

if err := state.refreshOAuthToken(); err != nil {
return "", "", err
}
return state.SSOProviderIDToken, err

return state.SSOProviderIDToken, state.SSOProviderRefreshToken, nil
}
36 changes: 20 additions & 16 deletions internal/pkg/sso/auth_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,46 +130,50 @@ func (s *authState) generateCodes() error {
return nil
}

// GetOAuthToken exchanges the obtained authorization code for an auth0/ID token from the SSO provider
// getOAuthToken exchanges the obtained authorization code for an auth0/ID token from the SSO provider
func (s *authState) getOAuthToken() error {
payload := strings.NewReader("grant_type=authorization_code" +
"&client_id=" + s.SSOProviderClientID +
"&code_verifier=" + s.CodeVerifier +
"&code=" + s.SSOProviderAuthenticationCode +
"&redirect_uri=" + s.SSOProviderCallbackUrl)

data, err := s.getOAuthTokenResponse(payload)
if err != nil {
return err
}
token, ok := data["id_token"]
if ok {
s.SSOProviderIDToken = token.(string)
} else {
return errors.New(errors.MissingIDTokenFieldErrorMsg)
}
refreshToken, ok := data["refresh_token"]
if ok {
s.SSOProviderRefreshToken = refreshToken.(string)
}
return nil

return s.saveOAuthTokenResponse(data)
}

// GetOAuthToken exchanges the obtained authorization code for an auth0/ID token from the SSO provider
// refreshOAuthToken exchanges the refresh token for an auth0/ID token from the SSO provider
func (s *authState) refreshOAuthToken() error {
payload := strings.NewReader("grant_type=refresh_token" +
"&client_id=" + s.SSOProviderClientID +
"&refresh_token=" + s.SSOProviderRefreshToken +
"&redirect_uri=" + s.SSOProviderCallbackUrl)

data, err := s.getOAuthTokenResponse(payload)
if err != nil {
return err
}
token, ok := data["id_token"]
if ok {

return s.saveOAuthTokenResponse(data)
}

func (s *authState) saveOAuthTokenResponse(data map[string]interface{}) error {
if token, ok := data["id_token"]; ok {
s.SSOProviderIDToken = token.(string)
} else {
return errors.New(errors.MissingIDTokenFieldErrorMsg)
return errors.Errorf(errors.FmtMissingOAuthFieldErrorMsg, "id_token")
}

if token, ok := data["refresh_token"]; ok {
s.SSOProviderRefreshToken = token.(string)
} else {
return errors.Errorf(errors.FmtMissingOAuthFieldErrorMsg, "refresh_token")
}

return nil
}

Expand Down
17 changes: 12 additions & 5 deletions internal/pkg/sso/auth_state_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sso

import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -175,6 +176,8 @@ func TestGetAuthorizationUrl(t *testing.T) {
}

func TestGetOAuthToken(t *testing.T) {
mockRefreshToken := "foo"

state, _ := newState("https://devel.cpdev.cloud", false, log.New())

expectedUri := "/oauth/token"
Expand All @@ -194,7 +197,7 @@ func TestGetOAuthToken(t *testing.T) {
require.Equal(t, expectedPayload, string(body))

// mock response
_, err = rw.Write([]byte(`{"id_token": "` + mockIDToken + `"}`))
_, err = rw.Write([]byte(fmt.Sprintf(`{"id_token": "%s", "refresh_token": "%s"}`, mockIDToken, mockRefreshToken)))
require.NoError(t, err)
}))
defer server.Close()
Expand All @@ -210,13 +213,16 @@ func TestGetOAuthToken(t *testing.T) {
}

func TestRefreshOAuthToken(t *testing.T) {
mockRefreshToken1 := "foo"
mockRefreshToken2 := "bar"

state, _ := newState("https://devel.cpdev.cloud", false, log.New())
mockRefreshToken := "bar"
state.SSOProviderRefreshToken = mockRefreshToken
state.SSOProviderRefreshToken = mockRefreshToken1

expectedUri := "/oauth/token"
expectedPayload := "grant_type=refresh_token" +
"&client_id=" + state.SSOProviderClientID +
"&refresh_token=" + mockRefreshToken +
"&refresh_token=" + state.SSOProviderRefreshToken +
"&redirect_uri=" + state.SSOProviderCallbackUrl

mockIDToken := "foobar"
Expand All @@ -229,7 +235,7 @@ func TestRefreshOAuthToken(t *testing.T) {
require.Equal(t, expectedPayload, string(body))

// mock response
_, err = rw.Write([]byte(`{"id_token": "` + mockIDToken + `"}`))
_, err = rw.Write([]byte(fmt.Sprintf(`{"id_token": "%s", "refresh_token": "%s"}`, mockIDToken, mockRefreshToken2)))
require.NoError(t, err)
}))
defer server.Close()
Expand All @@ -242,4 +248,5 @@ func TestRefreshOAuthToken(t *testing.T) {
require.NoError(t, err)

require.Equal(t, mockIDToken, state.SSOProviderIDToken)
require.Equal(t, mockRefreshToken2, state.SSOProviderRefreshToken)
}

0 comments on commit 06a198d

Please sign in to comment.