Skip to content

Commit

Permalink
Caching the OAuth tokens in Application Gateway (#15924)
Browse files Browse the repository at this point in the history
* add clientSecret to cache key

* cache key, first approach

* cache key by SHA on mTLS OAuth

* invalidate mTLS token cache

* fix nil pointer, doc update, secret modify

* delete unnecessary files

* generate client cert with incorrect ca

* bump images

* apply review sugestions

* revert clientid
  • Loading branch information
mvshao committed Nov 2, 2022
1 parent 381302b commit d965167
Show file tree
Hide file tree
Showing 20 changed files with 155 additions and 96 deletions.
Expand Up @@ -175,6 +175,7 @@ func getOAuthWithCertCredentials(secret map[string][]byte, url string) (*authori

return &authorization.OAuthWithCert{
ClientID: string(secret[ClientIDKey]),
ClientSecret: string(secret[ClientSecretKey]),
Certificate: secret[CertificateKey],
PrivateKey: secret[PrivateKeyKey],
URL: url,
Expand Down
Expand Up @@ -64,7 +64,7 @@ func TestExternalAuthStrategy(t *testing.T) {
t.Run("should call Invalidate method on the provided strategy", func(t *testing.T) {
// given
oauthClientMock := &mocks.Client{}
oauthClientMock.On("InvalidateTokenCache", "clientId", "www.example.com/token").Return("token", nil).Once()
oauthClientMock.On("InvalidateTokenCache", "clientId", "clientSecret", "www.example.com/token").Return("token", nil).Once()

oauthStrategy := newOAuthStrategy(oauthClientMock, "clientId", "clientSecret", "www.example.com/token", nil)

Expand Down
@@ -1,7 +1,6 @@
package authorization

import (
"crypto/tls"
"net/http"

"github.com/kyma-project/kyma/components/central-application-gateway/pkg/authorization/oauth"
Expand Down Expand Up @@ -29,9 +28,10 @@ type StrategyFactory interface {
type OAuthClient interface {
// GetToken obtains OAuth token
GetToken(clientID string, clientSecret string, authURL string, headers, queryParameters *map[string][]string, skipTLSVerification bool) (string, apperrors.AppError)
GetTokenMTLS(clientID, authURL string, cert tls.Certificate, headers, queryParameters *map[string][]string, skipTLSVerification bool) (string, apperrors.AppError)
GetTokenMTLS(clientID, authURL string, certificate, privateKey []byte, headers, queryParameters *map[string][]string, skipVerify bool) (string, apperrors.AppError)
// InvalidateTokenCache resets internal token cache
InvalidateTokenCache(clientID string, authURL string)
InvalidateTokenCache(clientID string, clientSecret string, authURL string)
InvalidateTokenCacheMTLS(clientID, authURL string, certificate, privateKey []byte)
}

type authorizationStrategyFactory struct {
Expand All @@ -47,7 +47,7 @@ func (asf authorizationStrategyFactory) create(c *Credentials) Strategy {
if c != nil && c.OAuth != nil {
return newOAuthStrategy(asf.oauthClient, c.OAuth.ClientID, c.OAuth.ClientSecret, c.OAuth.URL, c.OAuth.RequestParameters)
} else if c != nil && c.OAuthWithCert != nil {
oAuthStrategy := newOAuthWithCertStrategy(asf.oauthClient, c.OAuthWithCert.ClientID, c.OAuthWithCert.Certificate, c.OAuthWithCert.PrivateKey, c.OAuthWithCert.URL, c.OAuthWithCert.RequestParameters)
oAuthStrategy := newOAuthWithCertStrategy(asf.oauthClient, c.OAuthWithCert.ClientID, c.OAuthWithCert.ClientSecret, c.OAuthWithCert.Certificate, c.OAuthWithCert.PrivateKey, c.OAuthWithCert.URL, c.OAuthWithCert.RequestParameters)
return &oAuthStrategy
} else if c != nil && c.BasicAuth != nil {
return newBasicAuthStrategy(c.BasicAuth.Username, c.BasicAuth.Password)
Expand Down
Expand Up @@ -2,6 +2,7 @@ package authorization

import (
"crypto/tls"
"github.com/kyma-project/kyma/components/central-application-gateway/pkg/authorization/testconsts"
"net/http"
"testing"

Expand Down Expand Up @@ -147,19 +148,17 @@ func TestStrategyFactory(t *testing.T) {

t.Run("should create oauth with cert strategy", func(t *testing.T) {
// given
pair, err := tls.X509KeyPair(certificate, privateKey)
require.NoError(t, err)

oauthClientMock := &oauthMocks.Client{}
oauthClientMock.On("GetTokenMTLS", "clientId", "www.example.com/token", pair, (*map[string][]string)(nil), (*map[string][]string)(nil), false).Return("token", nil)
oauthClientMock.On("GetTokenMTLS", "clientId", "www.example.com/token", []byte(testconsts.Certificate), []byte(testconsts.PrivateKey), (*map[string][]string)(nil), (*map[string][]string)(nil), false).Return("token", nil)

factory := authorizationStrategyFactory{oauthClient: oauthClientMock}
credentials := &Credentials{
OAuthWithCert: &OAuthWithCert{
ClientID: "clientId",
Certificate: certificate,
PrivateKey: privateKey,
URL: "www.example.com/token",
ClientID: "clientId",
ClientSecret: "clientSecret",
Certificate: certificate,
PrivateKey: privateKey,
URL: "www.example.com/token",
},
}

Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Expand Up @@ -47,6 +47,7 @@ type CertificateGen struct {
type OAuthWithCert struct {
URL string
ClientID string
ClientSecret string
Certificate []byte
PrivateKey []byte
RequestParameters *RequestParameters
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Expand Up @@ -2,8 +2,11 @@ package oauth

import (
"context"
"crypto/sha256"
"crypto/tls"
"encoding/hex"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
Expand All @@ -27,8 +30,9 @@ type oauthResponse struct {
//go:generate mockery --name=Client
type Client interface {
GetToken(clientID, clientSecret, authURL string, headers, queryParameters *map[string][]string, skipVerify bool) (string, apperrors.AppError)
GetTokenMTLS(clientID, authURL string, cert tls.Certificate, headers, queryParameters *map[string][]string, skipVerify bool) (string, apperrors.AppError)
InvalidateTokenCache(clientID string, authURL string)
GetTokenMTLS(clientID, authURL string, certificate, privateKey []byte, headers, queryParameters *map[string][]string, skipVerify bool) (string, apperrors.AppError)
InvalidateTokenCache(clientID string, clientSecret string, authURL string)
InvalidateTokenCacheMTLS(clientID, authURL string, certificate, privateKey []byte)
}

type client struct {
Expand All @@ -44,7 +48,7 @@ func NewOauthClient(timeoutDuration int, tokenCache tokencache.TokenCache) Clien
}

func (c *client) GetToken(clientID, clientSecret, authURL string, headers, queryParameters *map[string][]string, skipVerify bool) (string, apperrors.AppError) {
token, found := c.tokenCache.Get(c.makeOAuthTokenCacheKey(clientID, authURL))
token, found := c.tokenCache.Get(c.makeOAuthTokenCacheKey(clientID, clientSecret, authURL))
if found {
return token, nil
}
Expand All @@ -54,34 +58,56 @@ func (c *client) GetToken(clientID, clientSecret, authURL string, headers, query
return "", err
}

c.tokenCache.Add(c.makeOAuthTokenCacheKey(clientID, authURL), tokenResponse.AccessToken, tokenResponse.ExpiresIn)
c.tokenCache.Add(c.makeOAuthTokenCacheKey(clientID, clientSecret, authURL), tokenResponse.AccessToken, tokenResponse.ExpiresIn)

return tokenResponse.AccessToken, nil
}

func (c *client) GetTokenMTLS(clientID, authURL string, cert tls.Certificate, headers, queryParameters *map[string][]string, skipVerify bool) (string, apperrors.AppError) {
token, found := c.tokenCache.Get(c.makeOAuthTokenCacheKey(clientID, authURL))
func (c *client) GetTokenMTLS(clientID, authURL string, certificate, privateKey []byte, headers, queryParameters *map[string][]string, skipVerify bool) (string, apperrors.AppError) {
token, found := c.tokenCache.Get(c.makeMTLSOAuthTokenCacheKey(clientID, authURL, certificate, privateKey))
if found {
return token, nil
}

tokenResponse, err := c.requestTokenMTLS(clientID, authURL, cert, headers, queryParameters, skipVerify)
cert, err := tls.X509KeyPair(certificate, privateKey)
if err != nil {
return "", err
return "", apperrors.Internal("Failed to prepare certificate, %s", err.Error())
}

tokenResponse, requestError := c.requestTokenMTLS(clientID, authURL, cert, headers, queryParameters, skipVerify)
if err != nil {
return "", requestError
}

if tokenResponse == nil {
return "", apperrors.Internal("Failed to fetch token, possible certificate problem")
}

c.tokenCache.Add(c.makeOAuthTokenCacheKey(clientID, authURL), tokenResponse.AccessToken, tokenResponse.ExpiresIn)
c.tokenCache.Add(c.makeMTLSOAuthTokenCacheKey(clientID, authURL, certificate, privateKey), tokenResponse.AccessToken, tokenResponse.ExpiresIn)

return tokenResponse.AccessToken, nil
}

func (c *client) InvalidateTokenCache(clientID, authURL string) {
c.tokenCache.Remove(c.makeOAuthTokenCacheKey(clientID, authURL))
func (c *client) InvalidateTokenCache(clientID, clientSecret, authURL string) {
c.tokenCache.Remove(c.makeOAuthTokenCacheKey(clientID, clientSecret, authURL))
}

func (c *client) InvalidateTokenCacheMTLS(clientID, authURL string, certificate, privateKey []byte) {
c.tokenCache.Remove(c.makeMTLSOAuthTokenCacheKey(clientID, authURL, certificate, privateKey))
}

// to avoid case of single clientID and different endpoints for MTLS and standard oauth
func (c *client) makeOAuthTokenCacheKey(clientID, authURL string) string {
return clientID + authURL
func (c *client) makeOAuthTokenCacheKey(clientID, clientSecret, authURL string) string {
return clientID + clientSecret + authURL
}

func (c *client) makeMTLSOAuthTokenCacheKey(clientID, authURL string, certificate, privateKey []byte) string {
certificateSha := sha256.Sum256(certificate)
keySha := sha256.Sum256(privateKey)

hashedCertificate := hex.EncodeToString(certificateSha[:])
hashedKey := hex.EncodeToString(keySha[:])
return fmt.Sprintf("%v-%v-%v-%v", clientID, hashedCertificate, hashedKey, authURL)
}

func (c *client) requestToken(clientID, clientSecret, authURL string, headers, queryParameters *map[string][]string, skipVerify bool) (*oauthResponse, apperrors.AppError) {
Expand Down

0 comments on commit d965167

Please sign in to comment.