Skip to content

Commit

Permalink
Add pagination when listing Oauth clients tokens (TykTechnologies#2109)
Browse files Browse the repository at this point in the history
Added optional query parameter called `page` to the endpoint `/oauth/clients/{apiID}/{keyName}/tokens` . To prevent breaking backwards compatibility, tokens 
will only be paginated if they contain a `page` query item

Part of TykTechnologies/tyk-analytics#1101
  • Loading branch information
adelowo authored and buger committed Jun 2, 2019
1 parent 379328b commit 794af25
Show file tree
Hide file tree
Showing 3 changed files with 293 additions and 11 deletions.
52 changes: 46 additions & 6 deletions gateway/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,18 @@ func apiError(msg string) apiStatusMessage {
return apiStatusMessage{"error", msg}
}

// paginationStatus provides more information about a paginated data set
type paginationStatus struct {
PageNum int `json:"page_num"`
PageTotal int `json:"page_total"`
PageSize int `json:"page_size"`
}

type paginatedOAuthClientTokens struct {
Pagination paginationStatus
Tokens []OAuthClientToken
}

func doJSONWrite(w http.ResponseWriter, code int, obj interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
Expand Down Expand Up @@ -426,10 +438,10 @@ func handleGetDetail(sessionKey, apiID string, byHash bool) (interface{}, int) {
}
} else {
log.WithFields(logrus.Fields{
"prefix": "api",
"key": obfuscateKey(sessionKey),
"message": err,
"status": "ok",
"prefix": "api",
"key": obfuscateKey(sessionKey),
"message": err,
"status": "ok",
}).Info("Can't retrieve key quota")
}

Expand Down Expand Up @@ -1652,8 +1664,36 @@ func oAuthClientTokensHandler(w http.ResponseWriter, r *http.Request) {
return
}

// get tokens from redis
// TODO: add pagination
if p := r.URL.Query().Get("page"); p != "" {
page := 1

queryPage, err := strconv.Atoi(p)
if err == nil {
page = queryPage
}

if page <= 0 {
page = 1
}

tokens, totalPages, err := apiSpec.OAuthManager.OsinServer.Storage.GetPaginatedClientTokens(keyName, page)
if err != nil {
doJSONWrite(w, http.StatusInternalServerError, apiError("Get client tokens failed"))
return
}

doJSONWrite(w, http.StatusOK, paginatedOAuthClientTokens{
Pagination: paginationStatus{
PageSize: 100,
PageNum: page,
PageTotal: totalPages,
},
Tokens: tokens,
})

return
}

tokens, err := apiSpec.OAuthManager.OsinServer.Storage.GetClientTokens(keyName)
if err != nil {
doJSONWrite(w, http.StatusInternalServerError, apiError("Get client tokens failed"))
Expand Down
73 changes: 70 additions & 3 deletions gateway/oauth_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/base64"
"encoding/json"
"errors"
"math"
"net/http"
"time"

Expand Down Expand Up @@ -392,13 +393,14 @@ type ExtendedOsinStorageInterface interface {
// Custom getter to handle prefixing issues in Redis
GetClientNoPrefix(id string) (osin.Client, error)

GetClientTokens(id string) ([]OAuthClientToken, error)
GetPaginatedClientTokens(id string, page int) ([]OAuthClientToken, int, error)

GetExtendedClient(id string) (ExtendedOsinClientInterface, error)

// Custom getter to handle prefixing issues in Redis
GetExtendedClientNoPrefix(id string) (ExtendedOsinClientInterface, error)

GetClientTokens(id string) ([]OAuthClientToken, error)

GetClients(filter string, ignorePrefix bool) ([]ExtendedOsinClientInterface, error)

DeleteClient(id string, ignorePrefix bool) error
Expand Down Expand Up @@ -542,6 +544,67 @@ func (r *RedisOsinStorageInterface) GetClients(filter string, ignorePrefix bool)
return theseClients, nil
}

// GetPaginatedClientTokens returns all tokens associated with the given id.
// It returns the tokens, the total number of pages of the tokens after
// pagination and an error if any
func (r *RedisOsinStorageInterface) GetPaginatedClientTokens(id string, page int) ([]OAuthClientToken, int, error) {
key := prefixClientTokens + id

// use current timestamp as a start score so all expired tokens won't be picked
nowTs := time.Now().Unix()
startScore := strconv.FormatInt(nowTs, 10)

log.Info("Getting client tokens sorted list:", key)

tokens, scores, err := r.store.GetSortedSetRange(key, startScore, "+inf")
if err != nil {
return nil, 0, err
}

// clean up expired tokens in sorted set (remove all tokens with score up to current timestamp minus retention)
if config.Global().OauthTokenExpiredRetainPeriod > 0 {
cleanupStartScore := strconv.FormatInt(nowTs-int64(config.Global().OauthTokenExpiredRetainPeriod), 10)
go r.store.RemoveSortedSetRange(key, "-inf", cleanupStartScore)
}

itemsPerPage := 100

if len(tokens) == 0 {
return []OAuthClientToken{}, 0, nil
}

startIdx, endIdx := 0, itemsPerPage
if page > 1 {
startIdx = (page - 1) * itemsPerPage
endIdx += startIdx

// Make sure an "out of range" error never happens
n := len(tokens)
if endIdx > n {
endIdx = n
}

if startIdx > n {
startIdx = n
}
}

totalPages := int(math.Ceil(float64(len(tokens)) / float64(itemsPerPage)))

tokens = tokens[startIdx:endIdx]

// convert sorted set data and scores into reply struct
tokensData := make([]OAuthClientToken, len(tokens))
for i := range tokens {
tokensData[i] = OAuthClientToken{
Token: tokens[i],
Expires: int64(scores[i]), // we store expire timestamp as a score
}
}

return tokensData, totalPages, nil
}

func (r *RedisOsinStorageInterface) GetClientTokens(id string) ([]OAuthClientToken, error) {
key := prefixClientTokens + id

Expand All @@ -562,9 +625,13 @@ func (r *RedisOsinStorageInterface) GetClientTokens(id string) ([]OAuthClientTok
go r.store.RemoveSortedSetRange(key, "-inf", cleanupStartScore)
}

if len(tokens) == 0 {
return []OAuthClientToken{}, nil
}

// convert sorted set data and scores into reply struct
tokensData := make([]OAuthClientToken, len(tokens))
for i := 0; i < len(tokensData); i++ {
for i := range tokens {
tokensData[i] = OAuthClientToken{
Token: tokens[i],
Expires: int64(scores[i]), // we store expire timestamp as a score
Expand Down
179 changes: 177 additions & 2 deletions gateway/oauth_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,180 @@ func getAuthCode(t *testing.T, ts *Test) map[string]string {
return response
}

func TestGetPaginatedClientTokens(t *testing.T) {
globalConf := config.Global()
// set tokens to be expired after 1 second
globalConf.OauthTokenExpire = 1
// cleanup tokens older than 3 seconds
globalConf.OauthTokenExpiredRetainPeriod = 3
config.SetGlobal(globalConf)

defer resetTestConfig()

ts := StartTest()
defer ts.Close()

spec := loadTestOAuthSpec()

clientID := uuid.NewV4().String()
createTestOAuthClient(spec, clientID)

// make eight tokens
tokensID := map[string]bool{}
t.Run("Send eight token requests", func(t *testing.T) {
param := make(url.Values)
param.Set("response_type", "token")
param.Set("redirect_uri", authRedirectUri)
param.Set("client_id", clientID)
param.Set("client_secret", authClientSecret)
param.Set("key_rules", keyRules)

headers := map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
}

for i := 0; i < 110; i++ {
resp, err := ts.Run(t, test.TestCase{
Path: "/APIID/tyk/oauth/authorize-client/",
Data: param.Encode(),
AdminAuth: true,
Headers: headers,
Method: http.MethodPost,
Code: http.StatusOK,
})
if err != nil {
t.Error(err)
}

response := map[string]interface{}{}
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
t.Fatal(err)
}

// save tokens for future check
tokensID[response["access_token"].(string)] = true
}
})

// get list of tokens
t.Run("Get list of tokens without page query", func(t *testing.T) {
resp, err := ts.Run(t, test.TestCase{
// Defaults to fetching all tokens since we don't have
// the page query here
Path: fmt.Sprintf("/tyk/oauth/clients/999999/%s/tokens", clientID),
AdminAuth: true,
Method: http.MethodGet,
Code: http.StatusOK,
})
if err != nil {
t.Error(err)
}

tokensResp := []OAuthClientToken{}
if err := json.NewDecoder(resp.Body).Decode(&tokensResp); err != nil {
t.Fatal(err)
}

// check response
if n := 110; len(tokensResp) != n {
t.Errorf("Wrong number of tokens received. Expected: %d. Got: %d", n, len(tokensResp))
}

for _, token := range tokensResp {
if !tokensID[token.Token] {
t.Errorf("Token %s is not found in expected result. Expecting: %v", token.Token, tokensID)
}
}
})

t.Run("Get list of tokens with a page query param lesser than 0", func(t *testing.T) {
resp, err := ts.Run(t, test.TestCase{
// strconv#Atoi successfully parses a negative integer
// so make sure it is being reset to the first page
Path: fmt.Sprintf("/tyk/oauth/clients/999999/%s/tokens?page=-4", clientID),
AdminAuth: true,
Method: http.MethodGet,
Code: http.StatusOK,
})
if err != nil {
t.Error(err)
}

tokensResp := paginatedOAuthClientTokens{}
if err := json.NewDecoder(resp.Body).Decode(&tokensResp); err != nil {
t.Fatal(err)
}

// check response
if n := 100; len(tokensResp.Tokens) != n {
t.Errorf("Wrong number of tokens received. Expected: %d. Got: %d", n, len(tokensResp.Tokens))
}

for _, token := range tokensResp.Tokens {
if !tokensID[token.Token] {
t.Errorf("Token %s is not found in expected result. Expecting: %v", token.Token, tokensID)
}
}

// Also inspect the pagination data information
if tokensResp.Pagination.PageNum != 1 {
t.Errorf("Paginated data should default to the first page if a negative integer is provided. Expected %d. Got %d", 1, tokensResp.Pagination.PageNum)
}
})

t.Run("Get list of tokens with ?page=2", func(t *testing.T) {
resp, err := ts.Run(t, test.TestCase{
Path: fmt.Sprintf("/tyk/oauth/clients/999999/%s/tokens?page=2", clientID),
AdminAuth: true,
Method: http.MethodGet,
Code: http.StatusOK,
})
if err != nil {
t.Error(err)
}

tokensResp := paginatedOAuthClientTokens{}
if err := json.NewDecoder(resp.Body).Decode(&tokensResp); err != nil {
t.Fatal(err)
}

// check response
if n := 10; len(tokensResp.Tokens) != n {
t.Errorf("Wrong number of tokens received. Expected: %d. Got: %d", n, len(tokensResp.Tokens))
}

for _, token := range tokensResp.Tokens {
if !tokensID[token.Token] {
t.Errorf("Token %s is not found in expected result. Expecting: %v", token.Token, tokensID)
}
}
})

t.Run("Get list of tokens after they expire", func(t *testing.T) {
// sleep to wait until tokens expire
time.Sleep(2 * time.Second)

resp, err := ts.Run(t, test.TestCase{
Path: fmt.Sprintf("/tyk/oauth/clients/999999/%s/tokens", clientID),
AdminAuth: true,
Method: http.MethodGet,
Code: http.StatusOK,
})
if err != nil {
t.Error(err)
}

// check response
tokensResp := []OAuthClientToken{}
if err := json.NewDecoder(resp.Body).Decode(&tokensResp); err != nil {
t.Fatal(err)
}
if len(tokensResp) > 0 {
t.Errorf("Wrong number of tokens received. Expected 0 - all tokens expired. Got: %d", len(tokensResp))
}
})
}

func TestGetClientTokens(t *testing.T) {
t.Run("Without hashing", func(t *testing.T) {
testGetClientTokens(t, false)
Expand Down Expand Up @@ -650,9 +824,10 @@ func testGetClientTokens(t *testing.T, hashed bool) {
}

// check response
if len(tokensResp) != len(tokensID) {
t.Errorf("Wrong number of tokens received. Expected: %d. Got: %d", len(tokensID), len(tokensResp))
if n := len(tokensID); len(tokensResp) != n {
t.Errorf("Wrong number of tokens received. Expected: %d. Got: %d", n, len(tokensResp))
}

for _, token := range tokensResp {
if !tokensID[token.Token] {
t.Errorf("Token %s is not found in expected result. Expecting: %v", token.Token, tokensID)
Expand Down

0 comments on commit 794af25

Please sign in to comment.