diff --git a/src/jetstream/auth.go b/src/jetstream/auth.go index 24c3eb9585..34d864b8d0 100644 --- a/src/jetstream/auth.go +++ b/src/jetstream/auth.go @@ -384,6 +384,7 @@ func (p *portalProxy) DoLoginToCNSIwithConsoleUAAtoken(c echo.Context, theCNSIre } if uaaUrl.String() == p.GetConfig().ConsoleConfig.UAAEndpoint.String() { // CNSI UAA server matches Console UAA server + uaaToken.LinkedGUID = uaaToken.TokenGUID err = p.setCNSITokenRecord(theCNSIrecord.GUID, u.UserGUID, uaaToken) return err } else { @@ -688,26 +689,6 @@ func (p *portalProxy) InitEndpointTokenRecord(expiry int64, authTok string, refr return tokenRecord } -func (p *portalProxy) removed_saveCNSIToken(cnsiID string, u interfaces.JWTUserTokenInfo, authTok string, refreshTok string, disconnect bool) (interfaces.TokenRecord, error) { - log.Debug("saveCNSIToken") - - tokenRecord := interfaces.TokenRecord{ - AuthToken: authTok, - RefreshToken: refreshTok, - TokenExpiry: u.TokenExpiry, - Disconnected: disconnect, - AuthType: interfaces.AuthTypeOAuth2, - } - - err := p.setCNSITokenRecord(cnsiID, u.UserGUID, tokenRecord) - if err != nil { - log.Errorf("%v", err) - return interfaces.TokenRecord{}, err - } - - return tokenRecord, nil -} - func (p *portalProxy) deleteCNSIToken(cnsiID string, userGUID string) error { log.Debug("deleteCNSIToken") diff --git a/src/jetstream/auth_test.go b/src/jetstream/auth_test.go index d89844750b..22928de41a 100644 --- a/src/jetstream/auth_test.go +++ b/src/jetstream/auth_test.go @@ -19,7 +19,7 @@ import ( ) const ( - findUAATokenSql = `SELECT auth_token, refresh_token, token_expiry, auth_type, meta_data FROM tokens .*` + findUAATokenSql = `SELECT token_guid, auth_token, refresh_token, token_expiry, auth_type, meta_data FROM tokens .*` ) func TestLoginToUAA(t *testing.T) { @@ -549,9 +549,10 @@ func TestVerifySession(t *testing.T) { t.Error(errors.New("Unable to mock/stub user in session object.")) } + mockTokenGUID := "mock-token-guid" encryptedUAAToken, _ := crypto.EncryptToken(pp.Config.EncryptionKeyInBytes, mockUAAToken) - expectedTokensRow := sqlmock.NewRows([]string{"auth_token", "refresh_token", "token_expiry", "auth_type", "meta_data"}). - AddRow(encryptedUAAToken, encryptedUAAToken, mockTokenExpiry, "oauth", "") + expectedTokensRow := sqlmock.NewRows([]string{"token_guid", "auth_token", "refresh_token", "token_expiry", "auth_type", "meta_data"}). + AddRow(mockTokenGUID, encryptedUAAToken, encryptedUAAToken, mockTokenExpiry, "oauth", "") mock.ExpectQuery(selectAnyFromTokens). WithArgs(mockUserGUID). @@ -561,8 +562,8 @@ func TestVerifySession(t *testing.T) { AddRow(mockProxyVersion) mock.ExpectQuery(getDbVersion).WillReturnRows(expectVersionRow) - rs := sqlmock.NewRows([]string{"auth_token", "refresh_token", "token_expiry", "auth_type", "meta_data"}). - AddRow(encryptedUAAToken, encryptedUAAToken, mockTokenExpiry, "oauth", "") + rs := sqlmock.NewRows([]string{"token_guid", "auth_token", "refresh_token", "token_expiry", "auth_type", "meta_data"}). + AddRow(mockTokenGUID, encryptedUAAToken, encryptedUAAToken, mockTokenExpiry, "oauth", "") mock.ExpectQuery(findUAATokenSql). WillReturnRows(rs) diff --git a/src/jetstream/cnsi.go b/src/jetstream/cnsi.go index 8c90471b10..3eb770e376 100644 --- a/src/jetstream/cnsi.go +++ b/src/jetstream/cnsi.go @@ -350,7 +350,7 @@ func (p *portalProxy) GetCNSITokenRecord(cnsiGUID string, userGUID string) (inte } func (p *portalProxy) GetCNSITokenRecordWithDisconnected(cnsiGUID string, userGUID string) (interfaces.TokenRecord, bool) { - log.Debug("GetCNSITokenRecord") + log.Debug("GetCNSITokenRecordWithDisconnected") tokenRepo, err := tokens.NewPgsqlTokenRepository(p.DatabaseConnectionPool) if err != nil { return interfaces.TokenRecord{}, false @@ -380,6 +380,25 @@ func (p *portalProxy) ListEndpointsByUser(userGUID string) ([]*interfaces.Connec return cnsiList, nil } +// Uopdate the Access Token, Refresh Token and Token Expiry for a token +func (p *portalProxy) updateTokenAuth(userGUID string, t interfaces.TokenRecord) error { + log.Debug("updateTokenAuth") + tokenRepo, err := tokens.NewPgsqlTokenRepository(p.DatabaseConnectionPool) + if err != nil { + log.Errorf(dbReferenceError, err) + return fmt.Errorf(dbReferenceError, err) + } + + err = tokenRepo.UpdateTokenAuth(userGUID, t, p.Config.EncryptionKeyInBytes) + if err != nil { + msg := "Unable to update Token: %v" + log.Errorf(msg, err) + return fmt.Errorf(msg, err) + } + + return nil +} + func (p *portalProxy) setCNSITokenRecord(cnsiGUID string, userGUID string, t interfaces.TokenRecord) error { log.Debug("setCNSITokenRecord") tokenRepo, err := tokens.NewPgsqlTokenRepository(p.DatabaseConnectionPool) diff --git a/src/jetstream/datastore/20180824092600_LinkedTokens.go b/src/jetstream/datastore/20180824092600_LinkedTokens.go new file mode 100644 index 0000000000..957aa4b4d0 --- /dev/null +++ b/src/jetstream/datastore/20180824092600_LinkedTokens.go @@ -0,0 +1,42 @@ +package datastore + +import ( + "database/sql" + + "bitbucket.org/liamstask/goose/lib/goose" +) + +func init() { + RegisterMigration(20180813110300, "LinkedTokens", func(txn *sql.Tx, conf *goose.DBConf) error { + + addTokenID := "ALTER TABLE tokens ADD token_guid VARCHAR(36) DEFAULT 'default-token'" + _, err := txn.Exec(addTokenID) + if err != nil { + return err + } + + addLinkedTokens := "ALTER TABLE tokens ADD linked_token VARCHAR(36)" + _, err = txn.Exec(addLinkedTokens) + if err != nil { + return err + } + + // Ensure any existing tokens have an ID + + // For UAA tokens, use the user id + ensureUAATokenID := "UPDATE tokens SET token_guid=user_guid WHERE token_guid IS NULL AND token_type='uaa'" + _, err = txn.Exec(ensureUAATokenID) + if err != nil { + return err + } + + // For CNSI tokens, use the cnsi guid + ensureCNSITokenID := "UPDATE tokens SET token_guid=cnsi_guid WHERE token_guid IS NULL" + _, err = txn.Exec(ensureCNSITokenID) + if err != nil { + return err + } + + return nil + }) +} diff --git a/src/jetstream/mock_server_test.go b/src/jetstream/mock_server_test.go index 5502316a44..f052f13a1b 100644 --- a/src/jetstream/mock_server_test.go +++ b/src/jetstream/mock_server_test.go @@ -65,6 +65,7 @@ const mockCFGUID = "some-cf-guid-1234" const mockCEGUID = "some-hce-guid-1234" const mockUserGUID = "asd-gjfg-bob" const mockAdminGUID = tokens.SystemSharedUserGuid +const mockTokenGUID = "mock-token-guid" const mockURLString = "http://localhost:9999/some/fake/url/" @@ -170,15 +171,15 @@ func expectCFAndCERows() sqlmock.Rows { } func expectTokenRow() sqlmock.Rows { - return sqlmock.NewRows([]string{"auth_token", "refresh_token", "token_expiry", "disconnected", "auth_type", "meta_data", "user_guid"}). - AddRow(mockUAAToken, mockUAAToken, mockTokenExpiry, false, "OAuth2", "", mockUserGUID) + return sqlmock.NewRows([]string{"token_guid", "auth_token", "refresh_token", "token_expiry", "disconnected", "auth_type", "meta_data", "user_guid", "linked_token"}). + AddRow(mockTokenGUID, mockUAAToken, mockUAAToken, mockTokenExpiry, false, "OAuth2", "", mockUserGUID, nil) } func expectEncryptedTokenRow(mockEncryptionKey []byte) sqlmock.Rows { encryptedUaaToken, _ := crypto.EncryptToken(mockEncryptionKey, mockUAAToken) - return sqlmock.NewRows([]string{"auth_token", "refresh_token", "token_expiry", "disconnected", "auth_type", "meta_data", "user_guid"}). - AddRow(encryptedUaaToken, encryptedUaaToken, mockTokenExpiry, false, "OAuth2", "", mockUserGUID) + return sqlmock.NewRows([]string{"token_guid", "auth_token", "refresh_token", "token_expiry", "disconnected", "auth_type", "meta_data", "user_guid", "linked_token"}). + AddRow(mockTokenGUID, encryptedUaaToken, encryptedUaaToken, mockTokenExpiry, false, "OAuth2", "", mockUserGUID, nil) } func setupHTTPTest(req *http.Request) (*httptest.ResponseRecorder, *echo.Echo, echo.Context, *portalProxy, *sql.DB, sqlmock.Sqlmock) { diff --git a/src/jetstream/oauth_requests.go b/src/jetstream/oauth_requests.go index 5ae8f4df3c..1ffbf4b900 100644 --- a/src/jetstream/oauth_requests.go +++ b/src/jetstream/oauth_requests.go @@ -89,9 +89,10 @@ func (p *portalProxy) RefreshOAuthToken(skipSSLValidation bool, cnsiGUID, userGU u.UserGUID = userGUID tokenRecord := p.InitEndpointTokenRecord(u.TokenExpiry, uaaRes.AccessToken, uaaRes.RefreshToken, userToken.Disconnected) - err = p.setCNSITokenRecord(cnsiGUID, userGUID, tokenRecord) + tokenRecord.TokenGUID = userToken.TokenGUID + err = p.updateTokenAuth(userGUID, tokenRecord) if err != nil { - return t, fmt.Errorf("Couldn't save new token: %v", err) + return t, fmt.Errorf("Couldn't update token: %v", err) } return tokenRecord, nil diff --git a/src/jetstream/oauth_requests_test.go b/src/jetstream/oauth_requests_test.go index 70900e641d..2a006f653b 100644 --- a/src/jetstream/oauth_requests_test.go +++ b/src/jetstream/oauth_requests_test.go @@ -82,6 +82,8 @@ func TestDoOauthFlowRequestWithValidToken(t *testing.T) { TokenExpiry: tokenExpiration, } + mockTokenGUID := "mock-token-guid" + // set up the database expectation for pp.setCNSITokenRecord mock.ExpectQuery(selectAnyFromTokens). WithArgs(mockCNSIGUID, mockUserGUID). @@ -97,8 +99,8 @@ func TestDoOauthFlowRequestWithValidToken(t *testing.T) { // p.getCNSIRequestRecords(cnsiRequest) -> // p.getCNSITokenRecord(r.GUID, r.UserGUID) -> // tokenRepo.FindCNSIToken(cnsiGUID, userGUID) - expectedCNSITokenRow := sqlmock.NewRows([]string{"auth_token", "refresh_token", "token_expiry", "disconnected", "auth_type", "meta_data", "user_guid"}). - AddRow(encryptedToken, encryptedToken, tokenExpiration, false, "OAuth2", "", mockUserGUID) + expectedCNSITokenRow := sqlmock.NewRows([]string{"token_guid", "auth_token", "refresh_token", "token_expiry", "disconnected", "auth_type", "meta_data", "user_guid", "linked_token"}). + AddRow(mockTokenGUID, encryptedToken, encryptedToken, tokenExpiration, false, "OAuth2", "", mockUserGUID, nil) mock.ExpectQuery(selectAnyFromTokens). WithArgs(mockCNSIGUID, mockUserGUID, mockAdminGUID). WillReturnRows(expectedCNSITokenRow) @@ -204,6 +206,8 @@ func TestDoOauthFlowRequestWithExpiredToken(t *testing.T) { TokenExpiry: tokenExpiration, } + mockTokenGUID := "mock-token-guid" + _, _, _, pp, db, mock := setupHTTPTest(req) defer db.Close() encryptedUAAToken, _ := crypto.EncryptToken(pp.Config.EncryptionKeyInBytes, mockUAAToken) @@ -227,8 +231,8 @@ func TestDoOauthFlowRequestWithExpiredToken(t *testing.T) { // p.getCNSIRequestRecords(cnsiRequest) -> // p.getCNSITokenRecord(r.GUID, r.UserGUID) -> // tokenRepo.FindCNSIToken(cnsiGUID, userGUID) - expectedCNSITokenRow := sqlmock.NewRows([]string{"auth_token", "refresh_token", "token_expiry", "disconnected", "auth_type", "meta_data", "user_guid"}). - AddRow(encryptedUAAToken, encryptedUAAToken, tokenExpiration, false, "OAuth2", "", mockUserGUID) + expectedCNSITokenRow := sqlmock.NewRows([]string{"token_guid", "auth_token", "refresh_token", "token_expiry", "disconnected", "auth_type", "meta_data", "user_guid", "linked_token"}). + AddRow(mockTokenGUID, encryptedUAAToken, encryptedUAAToken, tokenExpiration, false, "OAuth2", "", mockUserGUID, nil) mock.ExpectQuery(selectAnyFromTokens). WithArgs(mockCNSIGUID, mockUserGUID, mockAdminGUID). WillReturnRows(expectedCNSITokenRow) @@ -240,19 +244,14 @@ func TestDoOauthFlowRequestWithExpiredToken(t *testing.T) { WithArgs(mockCNSIGUID). WillReturnRows(expectedCNSIRecordRow) - expectedCNSITokenRecordRow := sqlmock.NewRows([]string{"auth_token", "refresh_token", "token_expiry", "disconnected", "auth_type", "meta_data", "user_guid"}). - AddRow(encryptedUAAToken, encryptedUAAToken, tokenExpiration, false, "OAuth2", "", mockUserGUID) + expectedCNSITokenRecordRow := sqlmock.NewRows([]string{"token_guid", "auth_token", "refresh_token", "token_expiry", "disconnected", "auth_type", "meta_data", "user_guid", "linked_token"}). + AddRow(mockTokenGUID, encryptedUAAToken, encryptedUAAToken, tokenExpiration, false, "OAuth2", "", mockUserGUID, nil) mock.ExpectQuery(selectAnyFromTokens). WithArgs(mockCNSIGUID, mockUserGUID, mockAdminGUID). WillReturnRows(expectedCNSITokenRecordRow) - mock.ExpectQuery(selectAnyFromTokens). - WithArgs(mockCNSIGUID, mockUserGUID). - WillReturnRows(sqlmock.NewRows([]string{"COUNT(*)"}).AddRow("0")) - - // Expect the INSERT - mock.ExpectExec(insertIntoTokens). - //WithArgs(mockCNSIGUID, mockUserGUID, "cnsi", encryptedUAAToken, encryptedUAAToken, mockTokenRecord.TokenExpiry). + // A token refresh attempt will be made - which is just an update + mock.ExpectExec(updateTokens). WillReturnResult(sqlmock.NewResult(1, 1)) // @@ -370,8 +369,8 @@ func TestDoOauthFlowRequestWithFailedRefreshMethod(t *testing.T) { // p.getCNSIRequestRecords(cnsiRequest) -> // p.getCNSITokenRecord(r.GUID, r.UserGUID) -> // tokenRepo.FindCNSIToken(cnsiGUID, userGUID) - expectedCNSITokenRow := sqlmock.NewRows([]string{"auth_token", "refresh_token", "token_expiry", "disconnected", "auth_type", "meta_data", "user_guid"}). - AddRow(encryptedUAAToken, encryptedUAAToken, tokenExpiration, false, "OAuth2", "", mockUserGUID) + expectedCNSITokenRow := sqlmock.NewRows([]string{"auth_token", "refresh_token", "token_expiry", "disconnected", "auth_type", "meta_data", "user_guid", "linked_token"}). + AddRow(encryptedUAAToken, encryptedUAAToken, tokenExpiration, false, "OAuth2", "", mockUserGUID, nil) mock.ExpectQuery(selectAnyFromTokens). WithArgs(mockCNSIGUID, mockUserGUID, mockAdminGUID). WillReturnRows(expectedCNSITokenRow) diff --git a/src/jetstream/passthrough.go b/src/jetstream/passthrough.go index a031a6722e..8e80ad1064 100644 --- a/src/jetstream/passthrough.go +++ b/src/jetstream/passthrough.go @@ -234,7 +234,7 @@ func (p *portalProxy) ProxyRequest(c echo.Context, uri *url.URL) (map[string]*in if shouldPassthrough { if len(cnsiList) > 1 { - err := errors.New("Requested passthrough to multiple CNSIs. Only single CNSI passthroughs are supported.") + err := errors.New("Requested passthrough to multiple CNSIs. Only single CNSI passthroughs are supported") return nil, echo.NewHTTPError(http.StatusBadRequest, err.Error()) } } diff --git a/src/jetstream/repository/interfaces/structs.go b/src/jetstream/repository/interfaces/structs.go index e764c16f39..6bd7586948 100644 --- a/src/jetstream/repository/interfaces/structs.go +++ b/src/jetstream/repository/interfaces/structs.go @@ -65,8 +65,9 @@ type EndpointTokenRecord struct { LoggingEndpoint string } -//TODO this could be moved back to tokens subpackage, and extensions could import it? +// TokenRecord repsrents and endpoint or uaa token type TokenRecord struct { + TokenGUID string AuthToken string RefreshToken string TokenExpiry int64 @@ -74,6 +75,7 @@ type TokenRecord struct { AuthType string Metadata string SystemShared bool + LinkedGUID string // Indicates the GUID of the token that this token is linked to (if any) } type CFInfo struct { diff --git a/src/jetstream/repository/tokens/pgsql_tokens.go b/src/jetstream/repository/tokens/pgsql_tokens.go index 88547a8d1a..1487e4139e 100644 --- a/src/jetstream/repository/tokens/pgsql_tokens.go +++ b/src/jetstream/repository/tokens/pgsql_tokens.go @@ -8,10 +8,11 @@ import ( "github.com/cloudfoundry-incubator/stratos/src/jetstream/datastore" "github.com/cloudfoundry-incubator/stratos/src/jetstream/repository/crypto" "github.com/cloudfoundry-incubator/stratos/src/jetstream/repository/interfaces" + uuid "github.com/satori/go.uuid" log "github.com/sirupsen/logrus" ) -var findAuthToken = `SELECT auth_token, refresh_token, token_expiry, auth_type, meta_data +var findAuthToken = `SELECT token_guid, auth_token, refresh_token, token_expiry, auth_type, meta_data FROM tokens WHERE token_type = 'uaa' AND user_guid = $1` @@ -19,18 +20,26 @@ var countAuthTokens = `SELECT COUNT(*) FROM tokens WHERE token_type = 'uaa' AND user_guid = $1` -var insertAuthToken = `INSERT INTO tokens (user_guid, token_type, auth_token, refresh_token, token_expiry) - VALUES ($1, $2, $3, $4, $5)` +var insertAuthToken = `INSERT INTO tokens (token_guid, user_guid, token_type, auth_token, refresh_token, token_expiry) + VALUES ($1, $2, $3, $4, $5, $6)` var updateAuthToken = `UPDATE tokens SET auth_token = $1, refresh_token = $2, token_expiry = $3 WHERE user_guid = $4 AND token_type = $5` -var findCNSIToken = `SELECT auth_token, refresh_token, token_expiry, disconnected, auth_type, meta_data, user_guid +var getToken = `SELECT token_guid, auth_token, refresh_token, token_expiry, disconnected, auth_type, meta_data, user_guid, linked_token + FROM tokens + WHERE user_guid = $1 AND token_guid = $2` + +var getTokenConnected = `SELECT token_guid, auth_token, refresh_token, token_expiry, disconnected, auth_type, meta_data, user_guid, linked_token + FROM tokens + WHERE user_guid = $1 AND token_guid = $2 AND disconnected = '0'` + +var findCNSIToken = `SELECT token_guid, auth_token, refresh_token, token_expiry, disconnected, auth_type, meta_data, user_guid, linked_token FROM tokens WHERE cnsi_guid = $1 AND (user_guid = $2 OR user_guid = $3) AND token_type = 'cnsi'` -var findCNSITokenConnected = `SELECT auth_token, refresh_token, token_expiry, disconnected, auth_type, meta_data, user_guid +var findCNSITokenConnected = `SELECT token_guid, auth_token, refresh_token, token_expiry, disconnected, auth_type, meta_data, user_guid, linked_token FROM tokens WHERE cnsi_guid = $1 AND (user_guid = $2 OR user_guid = $3) AND token_type = 'cnsi' AND disconnected = '0'` @@ -38,18 +47,18 @@ var countCNSITokens = `SELECT COUNT(*) FROM tokens WHERE cnsi_guid=$1 AND user_guid = $2 AND token_type = 'cnsi'` -var insertCNSIToken = `INSERT INTO tokens (cnsi_guid, user_guid, token_type, auth_token, refresh_token, token_expiry, disconnected, auth_type, meta_data) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)` +var insertCNSIToken = `INSERT INTO tokens (token_guid, cnsi_guid, user_guid, token_type, auth_token, refresh_token, token_expiry, disconnected, auth_type, meta_data, linked_token) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)` var updateCNSIToken = `UPDATE tokens - SET auth_token = $1, refresh_token = $2, token_expiry = $3, disconnected = $4, meta_data = $5 - WHERE cnsi_guid = $6 AND user_guid = $7 AND token_type = $8 AND auth_type = $9` + SET auth_token = $1, refresh_token = $2, token_expiry = $3, disconnected = $4, meta_data = $5, linked_token = $6 + WHERE cnsi_guid = $7 AND user_guid = $8 AND token_type = $9 AND auth_type = $10` var deleteCNSIToken = `DELETE FROM tokens WHERE token_type = 'cnsi' AND cnsi_guid = $1 AND user_guid = $2` var deleteCNSITokens = `DELETE FROM tokens WHERE token_type = 'cnsi' AND cnsi_guid = $1` -// TODO (wchrisjohnson) We need to adjust several calls ^ to accept a list of items (guids) as input +var updateToken = `UPDATE tokens SET auth_token = $1, refresh_token = $2, token_expiry = $3 WHERE token_guid = $7 AND user_guid = $8` // PgsqlTokenRepository is a PostgreSQL-backed token repository type PgsqlTokenRepository struct { @@ -118,7 +127,8 @@ func (p *PgsqlTokenRepository) SaveAuthToken(userGUID string, tr interfaces.Toke case 0: log.Debug("Performing INSERT of encrypted tokens") - if _, err := p.db.Exec(insertAuthToken, userGUID, "uaa", ciphertextAuthToken, + tokenGUID := uuid.NewV4().String() + if _, err := p.db.Exec(insertAuthToken, tokenGUID, userGUID, "uaa", ciphertextAuthToken, ciphertextRefreshToken, tr.TokenExpiry); err != nil { msg := "Unable to INSERT UAA token: %v" log.Debugf(msg, err) @@ -154,6 +164,7 @@ func (p *PgsqlTokenRepository) FindAuthToken(userGUID string, encryptionKey []by // temp vars to retrieve db data var ( + tokenGUID sql.NullString ciphertextAuthToken []byte ciphertextRefreshToken []byte tokenExpiry sql.NullInt64 @@ -162,7 +173,7 @@ func (p *PgsqlTokenRepository) FindAuthToken(userGUID string, encryptionKey []by ) // Get the UAA record from the db - err := p.db.QueryRow(findAuthToken, userGUID).Scan(&ciphertextAuthToken, &ciphertextRefreshToken, &tokenExpiry, &authType, &metadata) + err := p.db.QueryRow(findAuthToken, userGUID).Scan(&tokenGUID, &ciphertextAuthToken, &ciphertextRefreshToken, &tokenExpiry, &authType, &metadata) if err != nil { msg := "Unable to Find UAA token: %v" log.Debugf(msg, err) @@ -183,6 +194,9 @@ func (p *PgsqlTokenRepository) FindAuthToken(userGUID string, encryptionKey []by // Build a new TokenRecord based on the decrypted tokens tr := new(interfaces.TokenRecord) + if tokenGUID.Valid { + tr.TokenGUID = tokenGUID.String + } tr.AuthToken = plaintextAuthToken tr.RefreshToken = plaintextRefreshToken if tokenExpiry.Valid { @@ -192,7 +206,6 @@ func (p *PgsqlTokenRepository) FindAuthToken(userGUID string, encryptionKey []by if metadata.Valid { tr.Metadata = metadata.String } - return *tr, nil } @@ -217,13 +230,28 @@ func (p *PgsqlTokenRepository) SaveCNSIToken(cnsiGUID string, userGUID string, t return errors.New(msg) } + var ciphertextAuthToken, ciphertextRefreshToken []byte + var err error + + var linkedToken sql.NullString + + // Linked token? + if tr.LinkedGUID == "" { + linkedToken = sql.NullString{} + } else { + tr.AuthToken = "LINKED TOKEN" + tr.RefreshToken = "LINKED TOKEN" + linkedToken = sql.NullString{ + String: tr.LinkedGUID, + Valid: true, + } + } + log.Debug("Encrypting Auth Token") - ciphertextAuthToken, err := crypto.EncryptToken(encryptionKey, tr.AuthToken) + ciphertextAuthToken, err = crypto.EncryptToken(encryptionKey, tr.AuthToken) if err != nil { return err } - - var ciphertextRefreshToken []byte if tr.RefreshToken != "" { log.Debug("Encrypting Refresh Token") ciphertextRefreshToken, err = crypto.EncryptToken(encryptionKey, tr.RefreshToken) @@ -241,9 +269,9 @@ func (p *PgsqlTokenRepository) SaveCNSIToken(cnsiGUID string, userGUID string, t switch count { case 0: - - if _, insertErr := p.db.Exec(insertCNSIToken, cnsiGUID, userGUID, "cnsi", ciphertextAuthToken, - ciphertextRefreshToken, tr.TokenExpiry, tr.Disconnected, tr.AuthType, tr.Metadata); insertErr != nil { + tokenGUID := uuid.NewV4().String() + if _, insertErr := p.db.Exec(insertCNSIToken, tokenGUID, cnsiGUID, userGUID, "cnsi", ciphertextAuthToken, + ciphertextRefreshToken, tr.TokenExpiry, tr.Disconnected, tr.AuthType, tr.Metadata, linkedToken); insertErr != nil { msg := "Unable to INSERT CNSI token: %v" log.Debugf(msg, insertErr) @@ -256,7 +284,7 @@ func (p *PgsqlTokenRepository) SaveCNSIToken(cnsiGUID string, userGUID string, t log.Debug("Existing CNSI token found - attempting update.") result, err := p.db.Exec(updateCNSIToken, ciphertextAuthToken, ciphertextRefreshToken, tr.TokenExpiry, - tr.Disconnected, tr.Metadata, cnsiGUID, userGUID, "cnsi", tr.AuthType) + tr.Disconnected, tr.Metadata, linkedToken, cnsiGUID, userGUID, "cnsi", tr.AuthType) if err != nil { msg := "Unable to UPDATE CNSI token: %v" log.Debugf(msg, err) @@ -308,6 +336,7 @@ func (p *PgsqlTokenRepository) findCNSIToken(cnsiGUID string, userGUID string, e // temp vars to retrieve db data var ( + tokenGUID sql.NullString ciphertextAuthToken []byte ciphertextRefreshToken []byte tokenExpiry sql.NullInt64 @@ -315,13 +344,14 @@ func (p *PgsqlTokenRepository) findCNSIToken(cnsiGUID string, userGUID string, e authType string metadata sql.NullString tokenUserGUID sql.NullString + linkedTokenGUID sql.NullString ) var err error if includeDisconnected { - err = p.db.QueryRow(findCNSIToken, cnsiGUID, userGUID, SystemSharedUserGuid).Scan(&ciphertextAuthToken, &ciphertextRefreshToken, &tokenExpiry, &disconnected, &authType, &metadata, &tokenUserGUID) + err = p.db.QueryRow(findCNSIToken, cnsiGUID, userGUID, SystemSharedUserGuid).Scan(&tokenGUID, &ciphertextAuthToken, &ciphertextRefreshToken, &tokenExpiry, &disconnected, &authType, &metadata, &tokenUserGUID, &linkedTokenGUID) } else { - err = p.db.QueryRow(findCNSITokenConnected, cnsiGUID, userGUID, SystemSharedUserGuid).Scan(&ciphertextAuthToken, &ciphertextRefreshToken, &tokenExpiry, &disconnected, &authType, &metadata, &tokenUserGUID) + err = p.db.QueryRow(findCNSITokenConnected, cnsiGUID, userGUID, SystemSharedUserGuid).Scan(&tokenGUID, &ciphertextAuthToken, &ciphertextRefreshToken, &tokenExpiry, &disconnected, &authType, &metadata, &tokenUserGUID, &linkedTokenGUID) } if err != nil { @@ -334,6 +364,26 @@ func (p *PgsqlTokenRepository) findCNSIToken(cnsiGUID string, userGUID string, e return interfaces.TokenRecord{}, fmt.Errorf(msg, err) } + // If this token is linked - fetch that token and use it instead + // Currently we don't recurse - we only support one level of linked token - you can't link to another linked token + if linkedTokenGUID.Valid { + if includeDisconnected { + err = p.db.QueryRow(getToken, userGUID, linkedTokenGUID.String).Scan(&tokenGUID, &ciphertextAuthToken, &ciphertextRefreshToken, &tokenExpiry, &disconnected, &authType, &metadata, &tokenUserGUID, &linkedTokenGUID) + } else { + err = p.db.QueryRow(getTokenConnected, userGUID, linkedTokenGUID.String).Scan(&tokenGUID, &ciphertextAuthToken, &ciphertextRefreshToken, &tokenExpiry, &disconnected, &authType, &metadata, &tokenUserGUID, &linkedTokenGUID) + } + + if err != nil { + msg := "Unable to Find CNSI token: %v" + if err == sql.ErrNoRows { + log.Debugf(msg, err) + } else { + log.Errorf(msg, err) + } + return interfaces.TokenRecord{}, fmt.Errorf(msg, err) + } + } + log.Debug("Decrypting Auth Token") plaintextAuthToken, err := crypto.DecryptToken(encryptionKey, ciphertextAuthToken) if err != nil { @@ -348,6 +398,9 @@ func (p *PgsqlTokenRepository) findCNSIToken(cnsiGUID string, userGUID string, e // Build a new TokenRecord based on the decrypted tokens tr := new(interfaces.TokenRecord) + if tokenGUID.Valid { + tr.TokenGUID = tokenGUID.String + } tr.AuthToken = plaintextAuthToken tr.RefreshToken = plaintextRefreshToken if tokenExpiry.Valid { @@ -361,6 +414,9 @@ func (p *PgsqlTokenRepository) findCNSIToken(cnsiGUID string, userGUID string, e if tokenUserGUID.Valid { tr.SystemShared = tokenUserGUID.String == SystemSharedUserGuid } + if linkedTokenGUID.Valid { + tr.LinkedGUID = linkedTokenGUID.String + } return *tr, nil } @@ -407,3 +463,82 @@ func (p *PgsqlTokenRepository) DeleteCNSITokens(cnsiGUID string) error { return nil } + +// UpdateTokenAuth - Update a token's auth data +func (p *PgsqlTokenRepository) UpdateTokenAuth(userGUID string, tr interfaces.TokenRecord, encryptionKey []byte) error { + log.Debug("UpdateTokenAuth") + + if userGUID == "" { + msg := "Unable to save Token without a valid User GUID." + log.Debug(msg) + return errors.New(msg) + } + + if tr.AuthToken == "" { + msg := "Unable to save Token without a valid Auth Token." + log.Debug(msg) + return errors.New(msg) + } + + if tr.RefreshToken == "" { + msg := "Unable to save Token without a valid Refresh Token." + log.Debug(msg) + return errors.New(msg) + } + + var ciphertextAuthToken, ciphertextRefreshToken []byte + var err error + + var tokenGUID string + + // Linked token? if so, update the linked token + if tr.LinkedGUID == "" { + tokenGUID = tr.TokenGUID + } else { + tokenGUID = tr.LinkedGUID + } + + if tr.RefreshToken == "" { + msg := "Unable to save Token without a valid Token GUID" + return errors.New(msg) + } + + log.Infof("Updating token %s", tokenGUID) + + log.Debug("Encrypting Auth Token") + ciphertextAuthToken, err = crypto.EncryptToken(encryptionKey, tr.AuthToken) + if err != nil { + return err + } + if tr.RefreshToken != "" { + log.Debug("Encrypting Refresh Token") + ciphertextRefreshToken, err = crypto.EncryptToken(encryptionKey, tr.RefreshToken) + if err != nil { + return err + } + } + + result, err := p.db.Exec(updateToken, ciphertextAuthToken, ciphertextRefreshToken, tr.TokenExpiry, tokenGUID, userGUID) + if err != nil { + msg := "Unable to UPDATE token: %v" + log.Debugf(msg, err) + return fmt.Errorf(msg, err) + } + + rowsUpdates, err := result.RowsAffected() + if err != nil { + return errors.New("Unable to UPDATE token: could not determine number of rows that were updated") + } + + if rowsUpdates < 1 { + return errors.New("Unable to UPDATE token: no rows were updated") + } + + if rowsUpdates > 1 { + log.Warn("UPDATE token: More than 1 row was updated (expected only 1)") + } + + log.Debug("Token UPDATE complete") + + return nil +} diff --git a/src/jetstream/repository/tokens/pgsql_tokens_test.go b/src/jetstream/repository/tokens/pgsql_tokens_test.go index c7591d6dbf..73dd1ae5a2 100644 --- a/src/jetstream/repository/tokens/pgsql_tokens_test.go +++ b/src/jetstream/repository/tokens/pgsql_tokens_test.go @@ -15,13 +15,14 @@ import ( const ( mockUAAToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6ImxlZ2FjeS10b2tlbi1rZXkiLCJ0eXAiOiJKV1QifQ.eyJqdGkiOiI2ZGIyYTI5NGYyYWE0OGNlYjI1NDgzMDk4ZDNjY2Q3YyIsInN1YiI6Ijg4YmNlYWE1LWJkY2UtNDdiOC04MmYzLTRhZmMxNGYyNjZmOSIsInNjb3BlIjpbIm9wZW5pZCIsInNjaW0ucmVhZCIsImNsb3VkX2NvbnRyb2xsZXIuYWRtaW4iLCJ1YWEudXNlciIsImNsb3VkX2NvbnRyb2xsZXIucmVhZCIsInBhc3N3b3JkLndyaXRlIiwicm91dGluZy5yb3V0ZXJfZ3JvdXBzLnJlYWQiLCJjbG91ZF9jb250cm9sbGVyLndyaXRlIiwiZG9wcGxlci5maXJlaG9zZSIsInNjaW0ud3JpdGUiXSwiY2xpZW50X2lkIjoiY2YiLCJjaWQiOiJjZiIsImF6cCI6ImNmIiwiZ3JhbnRfdHlwZSI6InBhc3N3b3JkIiwidXNlcl9pZCI6Ijg4YmNlYWE1LWJkY2UtNDdiOC04MmYzLTRhZmMxNGYyNjZmOSIsIm9yaWdpbiI6InVhYSIsInVzZXJfbmFtZSI6ImFkbWluIiwiZW1haWwiOiJhZG1pbiIsImF1dGhfdGltZSI6MTQ2Nzc2OTgxNiwicmV2X3NpZyI6IjE0MGUwMjZiIiwiaWF0IjoxNDY3NzY5ODE2LCJleHAiOjE0Njc3NzA0MTYsImlzcyI6Imh0dHBzOi8vdWFhLmV4YW1wbGUuY29tL29hdXRoL3Rva2VuIiwiemlkIjoidWFhIiwiYXVkIjpbImNmIiwib3BlbmlkIiwic2NpbSIsImNsb3VkX2NvbnRyb2xsZXIiLCJ1YWEiLCJwYXNzd29yZCIsInJvdXRpbmcucm91dGVyX2dyb3VwcyIsImRvcHBsZXIiXX0.q2u0JX42Qiwr0ZsBU5Y6bF74_0URWmmBYTLf8l7of_6huFoMkyqvirEYcbYbATt6Hz2zcN6xlXcInALxQ6nt6Jk01kZHRNYfuu6QziLHHw2o_dJWk9iipiermUze7BvSGtU_JXx45BSBNVFxvRxG9Yv54Lwa9FvyhMSmK3CI5S8NtVDchzrsH3sMsIjlTAb-L7sch-OOQ7ncWH1JoGMtw8sTbiaHvfNJQclSq8Ro11NUtRHiWeGFFxYIerzKO-TrSpDojFJrYVuK1m0YPmBDa_dY3cneRuppagRIn8oI0VFHF8BckrIqNCHvOMoVz6uzHebo9LK7H5z5SluxJ2vYUgPiHE_Tyo-7gELnNSy8qL4Bk9yTxNseeGiq13TSTGOtNnbrv1eq4ZeW7eafseLceKIZH2QZlXVzwd_aWbuKRv9ApDwy4AcSbpM0XtU89IjUEDoOf3IDWV2YZTZkEaXZ52Mhztb1O_IVpHyyks88P67RoANFt83MnCai9U3stCX45LEsg9oz2djrVnfHDzRNQVlg9hKJYbxsa2R5tpnftjhz-hfpsoPRxBkJDKM2islyd-gLqHtsERiZEoifu93VRE0Jvk6vaCNdStw7y4mq73Co6ykNUYA78SlT9lCwDJRQHTJiDWg33EeKpXne8joZbElwrKNcv93X1qxxvmp1wXQ bearer eyJhbGciOiJSUzI1NiIsImtpZCI6ImxlZ2FjeS10b2tlbi1rZXkiLCJ0eXAiOiJKV1QifQ.eyJqdGkiOiI2ZGIyYTI5NGYyYWE0OGNlYjI1NDgzMDk4ZDNjY2Q3Yy1yIiwic3ViIjoiODhiY2VhYTUtYmRjZS00N2I4LTgyZjMtNGFmYzE0ZjI2NmY5Iiwic2NvcGUiOlsib3BlbmlkIiwic2NpbS5yZWFkIiwiY2xvdWRfY29udHJvbGxlci5hZG1pbiIsInVhYS51c2VyIiwiY2xvdWRfY29udHJvbGxlci5yZWFkIiwicGFzc3dvcmQud3JpdGUiLCJyb3V0aW5nLnJvdXRlcl9ncm91cHMucmVhZCIsImNsb3VkX2NvbnRyb2xsZXIud3JpdGUiLCJkb3BwbGVyLmZpcmVob3NlIiwic2NpbS53cml0ZSJdLCJpYXQiOjE0Njc3Njk4MTYsImV4cCI6MTQ3MDM2MTgxNiwiY2lkIjoiY2YiLCJjbGllbnRfaWQiOiJjZiIsImlzcyI6Imh0dHBzOi8vdWFhLmV4YW1wbGUuY29tL29hdXRoL3Rva2VuIiwiemlkIjoidWFhIiwiZ3JhbnRfdHlwZSI6InBhc3N3b3JkIiwidXNlcl9uYW1lIjoiYWRtaW4iLCJvcmlnaW4iOiJ1YWEiLCJ1c2VyX2lkIjoiODhiY2VhYTUtYmRjZS00N2I4LTgyZjMtNGFmYzE0ZjI2NmY5IiwicmV2X3NpZyI6IjE0MGUwMjZiIiwiYXVkIjpbImNmIiwib3BlbmlkIiwic2NpbSIsImNsb3VkX2NvbnRyb2xsZXIiLCJ1YWEiLCJwYXNzd29yZCIsInJvdXRpbmcucm91dGVyX2dyb3VwcyIsImRvcHBsZXIiXX0.K5M_isGkEBAN_MaXqkVvJfHG86rGIUkDgsHaFnoKOA1x5FNC4APDvhImWJZ8zbFHhXT3PYHTyeSf_HQaFDFUHFvGZUhSSry2ID4kdU5kRyZ-y3ydkv2mq32BlUQBSC9ap0r5vFTv7BY1yf2EcDaKGe4v4ODMhTm2SIkdTyk2ZcLXHIucS0xgSZdjgxNqh3pnKtmcFkw72-CyREW4_2Nbvn_7U2UNUCb2SeAuWmYaNAOkuGveB8jAhg9ftTrxn5GNtNe1sdVycm51X1O0dGPt_rLbwkRDCdNpm0La_xzLqZEl60_YUqwo33eOChFgqXB5y_0Pzs9gD__uExrIXYIgMsltFELXryyRUDKTTHZEEw1bnLTbQfF-GAnS0E0CaTU_kcDVqDYcqfh0TCcr7nGCEozExMPm3J0OGUSP3FQAD5mDICsKKcSIi_kIjggkJ87tuNAY6QOW1WzBoRizXJVS4jb3QOnrii2LmH786qBYJMX0nH__JRYEU-HWLi_OGXVTo03Pe9QcB8qJvbu2DGRfQdBfjhvgt2AItY4voJnZcjwT29q144C5wvJ2_W8cUzNY-Xw_tN_fWK4LWCu6KRNLVLO2MNbl0aOfkvb1U5NZJUpUUC2jG3cZM2c8232YNFKVjdjbf-Mlx17OxOYQ5XtG5BiSEj7BA6s5hWftUXEUchg` mockCNSIToken = mockUAAToken + mockTokenGUID = "mock-token-guid" mockUserGuid = "foo-bar" mockCNSIGuid = "foo-bar" countTokensSql = `SELECT COUNT` insertTokenSql = `INSERT INTO tokens` updateUAATokenSql = `UPDATE tokens` - findTokenSql = `SELECT auth_token, refresh_token, token_expiry, disconnected, auth_type, meta_data, user_guid FROM tokens .*` - findUAATokenSql = `SELECT auth_token, refresh_token, token_expiry, auth_type, meta_data FROM tokens WHERE token_type = 'uaa' AND .*` + findTokenSql = `SELECT token_guid, auth_token, refresh_token, token_expiry, disconnected, auth_type, meta_data, user_guid, linked_token FROM tokens .*` + findUAATokenSql = `SELECT token_guid, auth_token, refresh_token, token_expiry, auth_type, meta_data FROM tokens WHERE token_type = 'uaa' AND .*` deleteFromTokensSql = `DELETE FROM tokens` ) @@ -107,7 +108,7 @@ func TestSaveUAATokens(t *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"0"})) mock.ExpectExec(insertTokenSql). - WithArgs(mockUserGuid, "uaa", sqlmock.AnyArg(), sqlmock.AnyArg(), tokenRecord.TokenExpiry). + WithArgs(sqlmock.AnyArg(), mockUserGuid, "uaa", sqlmock.AnyArg(), sqlmock.AnyArg(), tokenRecord.TokenExpiry). WillReturnResult(sqlmock.NewResult(1, 1)) err := repository.SaveAuthToken(mockUserGuid, tokenRecord, mockEncryptionKey) @@ -217,7 +218,7 @@ func TestSaveCNSITokens(t *testing.T) { WillReturnRows(rs) mock.ExpectExec(insertTokenSql). - WithArgs(mockCNSIGuid, mockUserGuid, "cnsi", sqlmock.AnyArg(), sqlmock.AnyArg(), tokenRecord.TokenExpiry, false, "", ""). + WithArgs(sqlmock.AnyArg(), mockCNSIGuid, mockUserGuid, "cnsi", sqlmock.AnyArg(), sqlmock.AnyArg(), tokenRecord.TokenExpiry, false, "", "", sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(1, 1)) err := repository.SaveCNSIToken(mockCNSIGuid, mockUserGuid, tokenRecord, mockEncryptionKey) @@ -292,8 +293,8 @@ func TestFindUAATokens(t *testing.T) { Convey("Success case", func() { - rs := sqlmock.NewRows([]string{"auth_token", "refresh_token", "token_expiry", "auth_type", "meta_data"}). - AddRow(mockUAAToken, mockUAAToken, mockTokenExpiry, "oauth", "") + rs := sqlmock.NewRows([]string{"token_guid", "auth_token", "refresh_token", "token_expiry", "auth_type", "meta_data"}). + AddRow(mockTokenGUID, mockUAAToken, mockUAAToken, mockTokenExpiry, "oauth", "") mock.ExpectQuery(findUAATokenSql). WillReturnRows(rs) @@ -355,8 +356,8 @@ func TestFindCNSITokens(t *testing.T) { }) Convey("Success case", func() { - rs := sqlmock.NewRows([]string{"auth_token", "refresh_token", "token_expiry", "disconnected", "auth_type", "meta_data", "user_guid"}). - AddRow(mockUAAToken, mockUAAToken, mockTokenExpiry, false, "oauth", "", mockUserGuid) + rs := sqlmock.NewRows([]string{"token_guid", "auth_token", "refresh_token", "token_expiry", "disconnected", "auth_type", "meta_data", "user_guid", "linked_token"}). + AddRow(mockTokenGUID, mockUAAToken, mockUAAToken, mockTokenExpiry, false, "oauth", "", mockUserGuid, nil) mock.ExpectQuery(findTokenSql). WillReturnRows(rs) diff --git a/src/jetstream/repository/tokens/tokens.go b/src/jetstream/repository/tokens/tokens.go index 0d7b0b62ef..05384df430 100644 --- a/src/jetstream/repository/tokens/tokens.go +++ b/src/jetstream/repository/tokens/tokens.go @@ -21,4 +21,7 @@ type Repository interface { DeleteCNSIToken(cnsiGUID string, userGUID string) error DeleteCNSITokens(cnsiGUID string) error SaveCNSIToken(cnsiGUID string, userGUID string, tokenRecord interfaces.TokenRecord, encryptionKey []byte) error + + // Update a token's auth data + UpdateTokenAuth(userGUID string, tokenRecord interfaces.TokenRecord, encryptionKey []byte) error } diff --git a/src/jetstream/setup_console.go b/src/jetstream/setup_console.go index 9083117add..9e597228ae 100644 --- a/src/jetstream/setup_console.go +++ b/src/jetstream/setup_console.go @@ -136,7 +136,7 @@ func (p *portalProxy) setupConsoleUpdate(c echo.Context) error { "Console configuration data storage failed due to %s", err) } c.NoContent(http.StatusOK) - log.Infof("Console has been setup with the following settings: %+v", consoleConfig) + log.Infof("Updated Stratos setup") return nil } @@ -202,7 +202,6 @@ func (p *portalProxy) SaveConsoleConfig(consoleConfig *interfaces.ConsoleConfig, consoleRepo = consoleRepoInterface.(console_config.Repository) } - log.Infof("Console has been setup with the following settings: %+v", consoleConfig) err := consoleRepo.SaveConsoleConfig(consoleConfig) if err != nil { log.Printf("Failed to store Console Config: %+v", err) @@ -214,6 +213,8 @@ func (p *portalProxy) SaveConsoleConfig(consoleConfig *interfaces.ConsoleConfig, log.Printf("Failed to store Console Config: %+v", err) return fmt.Errorf("Failed to store Console Config: %+v", err) } + + log.Info("Stratos setup has been stored") return nil }