Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions authn/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,43 @@ func WithHTTPClientKeyRetrieverOpt(client *http.Client) DefaultKeyRetrieverOptio
}
}

func WithKeyRetrieverCache(cache cache.Cache) DefaultKeyRetrieverOption {
return func(c *DefaultKeyRetriever) {
c.c = cache
}
}

const (
cacheTTL = 10 * time.Minute
cacheCleanupInterval = 10 * time.Minute
)

func NewKeyRetriever(cfg KeyRetrieverConfig, opt ...DefaultKeyRetrieverOption) *DefaultKeyRetriever {
s := &DefaultKeyRetriever{
cfg: cfg,
c: cache.NewLocalCache(cache.Config{
Expiry: cacheTTL,
CleanupInterval: cacheCleanupInterval,
}),
cfg: cfg,
c: nil, // See below.
client: http.DefaultClient,
s: &singleflight.Group{},
}

for _, o := range opt {
o(s)
}

// If the options did not set the cache, create a new local cache.
//
// This has to be done this way because the cache that is created by
// the cache.NewLocalCache function spawns a goroutine that cannot be
// trivially stopped. It is set up to stop when the object is garbage
// collected, but in the general case, the calling code will not have
// control over that.
if s.c == nil {
s.c = cache.NewLocalCache(cache.Config{
Expiry: cacheTTL,
CleanupInterval: cacheCleanupInterval,
})
}

return s
}

Expand Down
110 changes: 110 additions & 0 deletions authn/jwks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@ package authn
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

"github.com/go-jose/go-jose/v3"
"github.com/grafana/authlib/cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -36,6 +40,9 @@ func TestDefaultKeyRetriever_Get(t *testing.T) {
SigningKeysURL: server.URL,
})

require.NotNil(t, service)
require.NotNil(t, service.c)

t.Run("should fetched key if not cached", func(t *testing.T) {
key, err := service.Get(context.Background(), firstKeyID)
require.NoError(t, err)
Expand Down Expand Up @@ -66,3 +73,106 @@ func TestDefaultKeyRetriever_Get(t *testing.T) {
}
})
}

func TestWithKeyRetrieverCache(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(keys())
}))

t.Cleanup(func() {
server.CloseClientConnections()
server.Close()
})

tc := &testCache{data: make(map[string][]byte)}

// Create a new retriever with the test cache.
service := NewKeyRetriever(KeyRetrieverConfig{
SigningKeysURL: server.URL,
}, WithKeyRetrieverCache(tc))

require.NotNil(t, service, "service should not be nil")
require.NotNil(t, service.c, "there should be a cache")
require.Equal(t, tc, service.c, "the cache should be the one passed in the options")

// Validate that the key is not in the cache
data, err := tc.Get(context.Background(), firstKeyID)
require.Error(t, err, "the initial cache should be empty")
require.Nil(t, data, "the initial cache should be empty")

// The cache is empty, so the implementation should fetch the key.
key, err := service.Get(context.Background(), firstKeyID)
require.NoError(t, err, "getting a key not present in the cache should not return an error")
require.NotNil(t, key, "Get should return a key")
assert.Equal(t, firstKeyID, key.KeyID, "the key should match the one requested")

// If the implementation called the cache, the data should be there now.
data, err = tc.Get(context.Background(), firstKeyID)
require.NoError(t, err, "the cache should have the key now")
require.NotNil(t, data, "the cache should have the key now")

// Decode the data to validate that it matches the key. We know the
// entries in the cache are JSON-encoded keys.
var jwk jose.JSONWebKey
require.NoError(t, json.Unmarshal(data, &jwk), "the data should be valid JSON")
require.Equal(t, firstKeyID, jwk.KeyID, "the key id should match the one requested")

// Remove the key from the cache; the implementation should still return the key.
err = tc.Delete(context.Background(), firstKeyID)
require.NoError(t, err, "deleting the key from the cache should not return an error")

key, err = service.Get(context.Background(), firstKeyID)
require.NoError(t, err, "getting a key not present in the cache should not return an error")
require.NotNil(t, key, "Get should return a key")
assert.Equal(t, firstKeyID, key.KeyID, "the key should match the one requested")

// Retrieve an invalid key; the implementation should return an error.
key, err = service.Get(context.Background(), "invalid")
require.ErrorIs(t, err, ErrInvalidSigningKey)
require.Nil(t, key)

// The implementation adds invalid keys to the cache to prevent re-fetching.
data, err = tc.Get(context.Background(), "invalid")
require.NoError(t, err, "the cache should have the invalid key now")
require.NotNil(t, data, "the cache should have the invalid key now")
require.Empty(t, data, "the cache should have the invalid key now")
}

// testCache implements the Cache interface for testing purposes.
type testCache struct {
mu sync.Mutex
data map[string][]byte
}

var _ cache.Cache = (*testCache)(nil)

func (cache *testCache) Get(ctx context.Context, key string) ([]byte, error) {
cache.mu.Lock()
defer cache.mu.Unlock()

item, ok := cache.data[key]
if !ok {
return nil, errors.New("not found")
}

return item, nil
}

func (cache *testCache) Set(ctx context.Context, key string, value []byte, expire time.Duration) error {
cache.mu.Lock()
defer cache.mu.Unlock()

cache.data[key] = value

return nil
}

func (cache *testCache) Delete(ctx context.Context, key string) error {
cache.mu.Lock()
defer cache.mu.Unlock()

delete(cache.data, key)

return nil
}
23 changes: 20 additions & 3 deletions authn/token_exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ func WithHTTPClient(client *http.Client) ExchangeClientOpts {
}
}

func WithTokenExchangeClientCache(cache cache.Cache) ExchangeClientOpts {
return func(c *TokenExchangeClient) {
c.cache = cache
}
}

func NewTokenExchangeClient(cfg TokenExchangeConfig, opts ...ExchangeClientOpts) (*TokenExchangeClient, error) {
if cfg.Token == "" {
return nil, fmt.Errorf("%w: missing required token", ErrMissingConfig)
Expand All @@ -44,9 +50,7 @@ func NewTokenExchangeClient(cfg TokenExchangeConfig, opts ...ExchangeClientOpts)
}

c := &TokenExchangeClient{
cache: cache.NewLocalCache(cache.Config{
CleanupInterval: 5 * time.Minute,
}),
cache: nil, // See below.
cfg: cfg,
singlef: singleflight.Group{},
}
Expand All @@ -59,6 +63,19 @@ func NewTokenExchangeClient(cfg TokenExchangeConfig, opts ...ExchangeClientOpts)
c.client = httpclient.New()
}

// If the options did not set the cache, create a new local cache.
//
// This has to be done this way because the cache that is created by
// the cache.NewLocalCache function spawns a goroutine that cannot be
// trivially stopped. It is set up to stop when the object is garbage
// collected, but in the general case, the calling code will not have
// control over that.
if c.cache == nil {
c.cache = cache.NewLocalCache(cache.Config{
CleanupInterval: 5 * time.Minute,
})
}

return c, nil

}
Expand Down
36 changes: 34 additions & 2 deletions authn/token_exchange_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ func TestNewTokenExchangeClient(t *testing.T) {

func Test_TokenExchangeClient_Exchange(t *testing.T) {
expiresIn := 10 * time.Minute
setup := func(srv *httptest.Server) *TokenExchangeClient {
setup := func(srv *httptest.Server, opts ...ExchangeClientOpts) *TokenExchangeClient {
c, err := NewTokenExchangeClient(TokenExchangeConfig{
Token: "some-token",
TokenExchangeURL: srv.URL,
})
}, opts...)
require.NoError(t, err)
return c
}
Expand Down Expand Up @@ -183,6 +183,38 @@ func Test_TokenExchangeClient_Exchange(t *testing.T) {
expectedExpiry := time.Now().Add(time.Duration(expiresIn) * time.Second)
require.InDelta(t, expectedExpiry.Unix(), claims.Expiry.Time().Unix(), 1)
})

t.Run("should use an alternate cache if provided", func(t *testing.T) {
testcache := &testCache{data: make(map[string][]byte)}

var calls int
c := setup(httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
calls++
require.Equal(t, r.Header.Get("Authorization"), "Bearer some-token")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data": {"token": "` + signAccessToken(t, expiresIn) + `"}}`))
bytes.NewBuffer([]byte(`{}`))
json.NewEncoder(&bytes.Buffer{})
})), WithTokenExchangeClientCache(testcache))

tokenToBeExchanged := signAccessToken(t, expiresIn)

res1, err := c.Exchange(context.Background(), TokenExchangeRequest{Namespace: "*", Audiences: []string{"some-service"}, SubjectToken: tokenToBeExchanged})
assert.NoError(t, err)
assert.NotNil(t, res1)
require.Equal(t, 1, calls)
require.Len(t, testcache.data, 1)

// same namespace and audiences should load token from cache
res2, err := c.Exchange(context.Background(), TokenExchangeRequest{Namespace: "*", Audiences: []string{"some-service"}, SubjectToken: tokenToBeExchanged})
assert.NoError(t, err)
assert.NotNil(t, res2)
require.Equal(t, 1, calls)
require.Len(t, testcache.data, 1)
require.Equal(t, res1, res2)

// This is only testing that the cache is used, so we do not repeat the other cases here.
})
}

func signAccessToken(t *testing.T, expiresIn time.Duration) string {
Expand Down
Loading