From 5b452ab58a060918bd25ca0e0ff2bc6c56de85ff Mon Sep 17 00:00:00 2001 From: "Marcelo E. Magallon" Date: Wed, 2 Apr 2025 15:16:37 -0600 Subject: [PATCH] Add WithCache options for NewKeyRetriever and NewTokenExchangeClient It might be worth offering the option to set a custom cache, possibly with different policies or with a different caching store. This is almost supported, as cache.Cache is an interface, so it's possible to reimplement it. This could be done using a DefaultKeyRetrieverOption function to pass to NewKeyRetriever that can replace the cache field `c`. The problem is that the field is not exported, so it's not possible to implement the option outside of the package, and the type does not offer a function to set the private field (such a function would make the option unnecessary). The same goes for the TokenExchangeClient. Add those options in this change. It needs to take into consideration the fact that cache.NewLocalCache ends up starting a goroutine to do clean up tasks. While the goroutine should stop because the cache has been replaced, it's better to not start it in the first place. Signed-off-by: Marcelo E. Magallon --- authn/jwks.go | 28 +++++++-- authn/jwks_test.go | 110 +++++++++++++++++++++++++++++++++++ authn/token_exchange.go | 23 +++++++- authn/token_exchange_test.go | 36 +++++++++++- 4 files changed, 187 insertions(+), 10 deletions(-) diff --git a/authn/jwks.go b/authn/jwks.go index 96f28467..589eb295 100644 --- a/authn/jwks.go +++ b/authn/jwks.go @@ -27,6 +27,12 @@ func WithHTTPClientKeyRetrieverOpt(client *http.Client) DefaultKeyRetrieverOptio } } +func WithKeyRetrieverCache(cache cache.Cache) DefaultKeyRetrieverOption { + return func(c *DefaultKeyRetriever) { + c.c = cache + } +} + const ( cacheTTL = 10 * time.Minute cacheCleanupInterval = 10 * time.Minute @@ -34,11 +40,8 @@ const ( func NewKeyRetriever(cfg KeyRetrieverConfig, opt ...DefaultKeyRetrieverOption) *DefaultKeyRetriever { s := &DefaultKeyRetriever{ - cfg: cfg, - c: cache.NewLocalCache(cache.Config{ - Expiry: cacheTTL, - CleanupInterval: cacheCleanupInterval, - }), + cfg: cfg, + c: nil, // See below. client: http.DefaultClient, s: &singleflight.Group{}, } @@ -46,6 +49,21 @@ func NewKeyRetriever(cfg KeyRetrieverConfig, opt ...DefaultKeyRetrieverOption) * for _, o := range opt { o(s) } + + // If the options did not set the cache, create a new local cache. + // + // This has to be done this way because the cache that is created by + // the cache.NewLocalCache function spawns a goroutine that cannot be + // trivially stopped. It is set up to stop when the object is garbage + // collected, but in the general case, the calling code will not have + // control over that. + if s.c == nil { + s.c = cache.NewLocalCache(cache.Config{ + Expiry: cacheTTL, + CleanupInterval: cacheCleanupInterval, + }) + } + return s } diff --git a/authn/jwks_test.go b/authn/jwks_test.go index 40bf8b73..2a431a7c 100644 --- a/authn/jwks_test.go +++ b/authn/jwks_test.go @@ -3,12 +3,16 @@ package authn import ( "context" "encoding/json" + "errors" "fmt" "net/http" "net/http/httptest" + "sync" "testing" + "time" "github.com/go-jose/go-jose/v3" + "github.com/grafana/authlib/cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -36,6 +40,9 @@ func TestDefaultKeyRetriever_Get(t *testing.T) { SigningKeysURL: server.URL, }) + require.NotNil(t, service) + require.NotNil(t, service.c) + t.Run("should fetched key if not cached", func(t *testing.T) { key, err := service.Get(context.Background(), firstKeyID) require.NoError(t, err) @@ -66,3 +73,106 @@ func TestDefaultKeyRetriever_Get(t *testing.T) { } }) } + +func TestWithKeyRetrieverCache(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write(keys()) + })) + + t.Cleanup(func() { + server.CloseClientConnections() + server.Close() + }) + + tc := &testCache{data: make(map[string][]byte)} + + // Create a new retriever with the test cache. + service := NewKeyRetriever(KeyRetrieverConfig{ + SigningKeysURL: server.URL, + }, WithKeyRetrieverCache(tc)) + + require.NotNil(t, service, "service should not be nil") + require.NotNil(t, service.c, "there should be a cache") + require.Equal(t, tc, service.c, "the cache should be the one passed in the options") + + // Validate that the key is not in the cache + data, err := tc.Get(context.Background(), firstKeyID) + require.Error(t, err, "the initial cache should be empty") + require.Nil(t, data, "the initial cache should be empty") + + // The cache is empty, so the implementation should fetch the key. + key, err := service.Get(context.Background(), firstKeyID) + require.NoError(t, err, "getting a key not present in the cache should not return an error") + require.NotNil(t, key, "Get should return a key") + assert.Equal(t, firstKeyID, key.KeyID, "the key should match the one requested") + + // If the implementation called the cache, the data should be there now. + data, err = tc.Get(context.Background(), firstKeyID) + require.NoError(t, err, "the cache should have the key now") + require.NotNil(t, data, "the cache should have the key now") + + // Decode the data to validate that it matches the key. We know the + // entries in the cache are JSON-encoded keys. + var jwk jose.JSONWebKey + require.NoError(t, json.Unmarshal(data, &jwk), "the data should be valid JSON") + require.Equal(t, firstKeyID, jwk.KeyID, "the key id should match the one requested") + + // Remove the key from the cache; the implementation should still return the key. + err = tc.Delete(context.Background(), firstKeyID) + require.NoError(t, err, "deleting the key from the cache should not return an error") + + key, err = service.Get(context.Background(), firstKeyID) + require.NoError(t, err, "getting a key not present in the cache should not return an error") + require.NotNil(t, key, "Get should return a key") + assert.Equal(t, firstKeyID, key.KeyID, "the key should match the one requested") + + // Retrieve an invalid key; the implementation should return an error. + key, err = service.Get(context.Background(), "invalid") + require.ErrorIs(t, err, ErrInvalidSigningKey) + require.Nil(t, key) + + // The implementation adds invalid keys to the cache to prevent re-fetching. + data, err = tc.Get(context.Background(), "invalid") + require.NoError(t, err, "the cache should have the invalid key now") + require.NotNil(t, data, "the cache should have the invalid key now") + require.Empty(t, data, "the cache should have the invalid key now") +} + +// testCache implements the Cache interface for testing purposes. +type testCache struct { + mu sync.Mutex + data map[string][]byte +} + +var _ cache.Cache = (*testCache)(nil) + +func (cache *testCache) Get(ctx context.Context, key string) ([]byte, error) { + cache.mu.Lock() + defer cache.mu.Unlock() + + item, ok := cache.data[key] + if !ok { + return nil, errors.New("not found") + } + + return item, nil +} + +func (cache *testCache) Set(ctx context.Context, key string, value []byte, expire time.Duration) error { + cache.mu.Lock() + defer cache.mu.Unlock() + + cache.data[key] = value + + return nil +} + +func (cache *testCache) Delete(ctx context.Context, key string) error { + cache.mu.Lock() + defer cache.mu.Unlock() + + delete(cache.data, key) + + return nil +} diff --git a/authn/token_exchange.go b/authn/token_exchange.go index bbd6faf3..b4e13b6c 100644 --- a/authn/token_exchange.go +++ b/authn/token_exchange.go @@ -34,6 +34,12 @@ func WithHTTPClient(client *http.Client) ExchangeClientOpts { } } +func WithTokenExchangeClientCache(cache cache.Cache) ExchangeClientOpts { + return func(c *TokenExchangeClient) { + c.cache = cache + } +} + func NewTokenExchangeClient(cfg TokenExchangeConfig, opts ...ExchangeClientOpts) (*TokenExchangeClient, error) { if cfg.Token == "" { return nil, fmt.Errorf("%w: missing required token", ErrMissingConfig) @@ -44,9 +50,7 @@ func NewTokenExchangeClient(cfg TokenExchangeConfig, opts ...ExchangeClientOpts) } c := &TokenExchangeClient{ - cache: cache.NewLocalCache(cache.Config{ - CleanupInterval: 5 * time.Minute, - }), + cache: nil, // See below. cfg: cfg, singlef: singleflight.Group{}, } @@ -59,6 +63,19 @@ func NewTokenExchangeClient(cfg TokenExchangeConfig, opts ...ExchangeClientOpts) c.client = httpclient.New() } + // If the options did not set the cache, create a new local cache. + // + // This has to be done this way because the cache that is created by + // the cache.NewLocalCache function spawns a goroutine that cannot be + // trivially stopped. It is set up to stop when the object is garbage + // collected, but in the general case, the calling code will not have + // control over that. + if c.cache == nil { + c.cache = cache.NewLocalCache(cache.Config{ + CleanupInterval: 5 * time.Minute, + }) + } + return c, nil } diff --git a/authn/token_exchange_test.go b/authn/token_exchange_test.go index 8aed9ae0..c91e765c 100644 --- a/authn/token_exchange_test.go +++ b/authn/token_exchange_test.go @@ -43,11 +43,11 @@ func TestNewTokenExchangeClient(t *testing.T) { func Test_TokenExchangeClient_Exchange(t *testing.T) { expiresIn := 10 * time.Minute - setup := func(srv *httptest.Server) *TokenExchangeClient { + setup := func(srv *httptest.Server, opts ...ExchangeClientOpts) *TokenExchangeClient { c, err := NewTokenExchangeClient(TokenExchangeConfig{ Token: "some-token", TokenExchangeURL: srv.URL, - }) + }, opts...) require.NoError(t, err) return c } @@ -183,6 +183,38 @@ func Test_TokenExchangeClient_Exchange(t *testing.T) { expectedExpiry := time.Now().Add(time.Duration(expiresIn) * time.Second) require.InDelta(t, expectedExpiry.Unix(), claims.Expiry.Time().Unix(), 1) }) + + t.Run("should use an alternate cache if provided", func(t *testing.T) { + testcache := &testCache{data: make(map[string][]byte)} + + var calls int + c := setup(httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls++ + require.Equal(t, r.Header.Get("Authorization"), "Bearer some-token") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data": {"token": "` + signAccessToken(t, expiresIn) + `"}}`)) + bytes.NewBuffer([]byte(`{}`)) + json.NewEncoder(&bytes.Buffer{}) + })), WithTokenExchangeClientCache(testcache)) + + tokenToBeExchanged := signAccessToken(t, expiresIn) + + res1, err := c.Exchange(context.Background(), TokenExchangeRequest{Namespace: "*", Audiences: []string{"some-service"}, SubjectToken: tokenToBeExchanged}) + assert.NoError(t, err) + assert.NotNil(t, res1) + require.Equal(t, 1, calls) + require.Len(t, testcache.data, 1) + + // same namespace and audiences should load token from cache + res2, err := c.Exchange(context.Background(), TokenExchangeRequest{Namespace: "*", Audiences: []string{"some-service"}, SubjectToken: tokenToBeExchanged}) + assert.NoError(t, err) + assert.NotNil(t, res2) + require.Equal(t, 1, calls) + require.Len(t, testcache.data, 1) + require.Equal(t, res1, res2) + + // This is only testing that the cache is used, so we do not repeat the other cases here. + }) } func signAccessToken(t *testing.T, expiresIn time.Duration) string {