diff --git a/authn/jwks.go b/authn/jwks.go index 96f2846..589eb29 100644 --- a/authn/jwks.go +++ b/authn/jwks.go @@ -27,6 +27,12 @@ func WithHTTPClientKeyRetrieverOpt(client *http.Client) DefaultKeyRetrieverOptio } } +func WithKeyRetrieverCache(cache cache.Cache) DefaultKeyRetrieverOption { + return func(c *DefaultKeyRetriever) { + c.c = cache + } +} + const ( cacheTTL = 10 * time.Minute cacheCleanupInterval = 10 * time.Minute @@ -34,11 +40,8 @@ const ( func NewKeyRetriever(cfg KeyRetrieverConfig, opt ...DefaultKeyRetrieverOption) *DefaultKeyRetriever { s := &DefaultKeyRetriever{ - cfg: cfg, - c: cache.NewLocalCache(cache.Config{ - Expiry: cacheTTL, - CleanupInterval: cacheCleanupInterval, - }), + cfg: cfg, + c: nil, // See below. client: http.DefaultClient, s: &singleflight.Group{}, } @@ -46,6 +49,21 @@ func NewKeyRetriever(cfg KeyRetrieverConfig, opt ...DefaultKeyRetrieverOption) * for _, o := range opt { o(s) } + + // If the options did not set the cache, create a new local cache. + // + // This has to be done this way because the cache that is created by + // the cache.NewLocalCache function spawns a goroutine that cannot be + // trivially stopped. It is set up to stop when the object is garbage + // collected, but in the general case, the calling code will not have + // control over that. + if s.c == nil { + s.c = cache.NewLocalCache(cache.Config{ + Expiry: cacheTTL, + CleanupInterval: cacheCleanupInterval, + }) + } + return s } diff --git a/authn/jwks_test.go b/authn/jwks_test.go index 40bf8b7..2a431a7 100644 --- a/authn/jwks_test.go +++ b/authn/jwks_test.go @@ -3,12 +3,16 @@ package authn import ( "context" "encoding/json" + "errors" "fmt" "net/http" "net/http/httptest" + "sync" "testing" + "time" "github.com/go-jose/go-jose/v3" + "github.com/grafana/authlib/cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -36,6 +40,9 @@ func TestDefaultKeyRetriever_Get(t *testing.T) { SigningKeysURL: server.URL, }) + require.NotNil(t, service) + require.NotNil(t, service.c) + t.Run("should fetched key if not cached", func(t *testing.T) { key, err := service.Get(context.Background(), firstKeyID) require.NoError(t, err) @@ -66,3 +73,106 @@ func TestDefaultKeyRetriever_Get(t *testing.T) { } }) } + +func TestWithKeyRetrieverCache(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write(keys()) + })) + + t.Cleanup(func() { + server.CloseClientConnections() + server.Close() + }) + + tc := &testCache{data: make(map[string][]byte)} + + // Create a new retriever with the test cache. + service := NewKeyRetriever(KeyRetrieverConfig{ + SigningKeysURL: server.URL, + }, WithKeyRetrieverCache(tc)) + + require.NotNil(t, service, "service should not be nil") + require.NotNil(t, service.c, "there should be a cache") + require.Equal(t, tc, service.c, "the cache should be the one passed in the options") + + // Validate that the key is not in the cache + data, err := tc.Get(context.Background(), firstKeyID) + require.Error(t, err, "the initial cache should be empty") + require.Nil(t, data, "the initial cache should be empty") + + // The cache is empty, so the implementation should fetch the key. + key, err := service.Get(context.Background(), firstKeyID) + require.NoError(t, err, "getting a key not present in the cache should not return an error") + require.NotNil(t, key, "Get should return a key") + assert.Equal(t, firstKeyID, key.KeyID, "the key should match the one requested") + + // If the implementation called the cache, the data should be there now. + data, err = tc.Get(context.Background(), firstKeyID) + require.NoError(t, err, "the cache should have the key now") + require.NotNil(t, data, "the cache should have the key now") + + // Decode the data to validate that it matches the key. We know the + // entries in the cache are JSON-encoded keys. + var jwk jose.JSONWebKey + require.NoError(t, json.Unmarshal(data, &jwk), "the data should be valid JSON") + require.Equal(t, firstKeyID, jwk.KeyID, "the key id should match the one requested") + + // Remove the key from the cache; the implementation should still return the key. + err = tc.Delete(context.Background(), firstKeyID) + require.NoError(t, err, "deleting the key from the cache should not return an error") + + key, err = service.Get(context.Background(), firstKeyID) + require.NoError(t, err, "getting a key not present in the cache should not return an error") + require.NotNil(t, key, "Get should return a key") + assert.Equal(t, firstKeyID, key.KeyID, "the key should match the one requested") + + // Retrieve an invalid key; the implementation should return an error. + key, err = service.Get(context.Background(), "invalid") + require.ErrorIs(t, err, ErrInvalidSigningKey) + require.Nil(t, key) + + // The implementation adds invalid keys to the cache to prevent re-fetching. + data, err = tc.Get(context.Background(), "invalid") + require.NoError(t, err, "the cache should have the invalid key now") + require.NotNil(t, data, "the cache should have the invalid key now") + require.Empty(t, data, "the cache should have the invalid key now") +} + +// testCache implements the Cache interface for testing purposes. +type testCache struct { + mu sync.Mutex + data map[string][]byte +} + +var _ cache.Cache = (*testCache)(nil) + +func (cache *testCache) Get(ctx context.Context, key string) ([]byte, error) { + cache.mu.Lock() + defer cache.mu.Unlock() + + item, ok := cache.data[key] + if !ok { + return nil, errors.New("not found") + } + + return item, nil +} + +func (cache *testCache) Set(ctx context.Context, key string, value []byte, expire time.Duration) error { + cache.mu.Lock() + defer cache.mu.Unlock() + + cache.data[key] = value + + return nil +} + +func (cache *testCache) Delete(ctx context.Context, key string) error { + cache.mu.Lock() + defer cache.mu.Unlock() + + delete(cache.data, key) + + return nil +} diff --git a/authn/token_exchange.go b/authn/token_exchange.go index bbd6faf..b4e13b6 100644 --- a/authn/token_exchange.go +++ b/authn/token_exchange.go @@ -34,6 +34,12 @@ func WithHTTPClient(client *http.Client) ExchangeClientOpts { } } +func WithTokenExchangeClientCache(cache cache.Cache) ExchangeClientOpts { + return func(c *TokenExchangeClient) { + c.cache = cache + } +} + func NewTokenExchangeClient(cfg TokenExchangeConfig, opts ...ExchangeClientOpts) (*TokenExchangeClient, error) { if cfg.Token == "" { return nil, fmt.Errorf("%w: missing required token", ErrMissingConfig) @@ -44,9 +50,7 @@ func NewTokenExchangeClient(cfg TokenExchangeConfig, opts ...ExchangeClientOpts) } c := &TokenExchangeClient{ - cache: cache.NewLocalCache(cache.Config{ - CleanupInterval: 5 * time.Minute, - }), + cache: nil, // See below. cfg: cfg, singlef: singleflight.Group{}, } @@ -59,6 +63,19 @@ func NewTokenExchangeClient(cfg TokenExchangeConfig, opts ...ExchangeClientOpts) c.client = httpclient.New() } + // If the options did not set the cache, create a new local cache. + // + // This has to be done this way because the cache that is created by + // the cache.NewLocalCache function spawns a goroutine that cannot be + // trivially stopped. It is set up to stop when the object is garbage + // collected, but in the general case, the calling code will not have + // control over that. + if c.cache == nil { + c.cache = cache.NewLocalCache(cache.Config{ + CleanupInterval: 5 * time.Minute, + }) + } + return c, nil } diff --git a/authn/token_exchange_test.go b/authn/token_exchange_test.go index 8aed9ae..c91e765 100644 --- a/authn/token_exchange_test.go +++ b/authn/token_exchange_test.go @@ -43,11 +43,11 @@ func TestNewTokenExchangeClient(t *testing.T) { func Test_TokenExchangeClient_Exchange(t *testing.T) { expiresIn := 10 * time.Minute - setup := func(srv *httptest.Server) *TokenExchangeClient { + setup := func(srv *httptest.Server, opts ...ExchangeClientOpts) *TokenExchangeClient { c, err := NewTokenExchangeClient(TokenExchangeConfig{ Token: "some-token", TokenExchangeURL: srv.URL, - }) + }, opts...) require.NoError(t, err) return c } @@ -183,6 +183,38 @@ func Test_TokenExchangeClient_Exchange(t *testing.T) { expectedExpiry := time.Now().Add(time.Duration(expiresIn) * time.Second) require.InDelta(t, expectedExpiry.Unix(), claims.Expiry.Time().Unix(), 1) }) + + t.Run("should use an alternate cache if provided", func(t *testing.T) { + testcache := &testCache{data: make(map[string][]byte)} + + var calls int + c := setup(httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls++ + require.Equal(t, r.Header.Get("Authorization"), "Bearer some-token") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data": {"token": "` + signAccessToken(t, expiresIn) + `"}}`)) + bytes.NewBuffer([]byte(`{}`)) + json.NewEncoder(&bytes.Buffer{}) + })), WithTokenExchangeClientCache(testcache)) + + tokenToBeExchanged := signAccessToken(t, expiresIn) + + res1, err := c.Exchange(context.Background(), TokenExchangeRequest{Namespace: "*", Audiences: []string{"some-service"}, SubjectToken: tokenToBeExchanged}) + assert.NoError(t, err) + assert.NotNil(t, res1) + require.Equal(t, 1, calls) + require.Len(t, testcache.data, 1) + + // same namespace and audiences should load token from cache + res2, err := c.Exchange(context.Background(), TokenExchangeRequest{Namespace: "*", Audiences: []string{"some-service"}, SubjectToken: tokenToBeExchanged}) + assert.NoError(t, err) + assert.NotNil(t, res2) + require.Equal(t, 1, calls) + require.Len(t, testcache.data, 1) + require.Equal(t, res1, res2) + + // This is only testing that the cache is used, so we do not repeat the other cases here. + }) } func signAccessToken(t *testing.T, expiresIn time.Duration) string {