Skip to content

Commit

Permalink
CachingClientFactory: lock by client cache key (#716)
Browse files Browse the repository at this point in the history
Previously all calls to Get() were serialized using a single lock. This
approach does not scale as the number of Vault client grows. With this
change locking is done by ClientCacheKey, thereby reducing the overall
contention for Vault clients.
  • Loading branch information
benashz committed May 15, 2024
1 parent cd8dee6 commit 3ad6d40
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 2 deletions.
46 changes: 44 additions & 2 deletions internal/vault/client_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ type cachingClientFactory struct {
mu sync.RWMutex
onceDoWatcher sync.Once
callbackHandlerCancel context.CancelFunc
// clientLocksLock is a lock for the clientLocks map.
clientLocksLock sync.RWMutex
// clientLocks is a map of cache keys to locks that allow for concurrent access
// to the client factory's cache.
clientLocks map[ClientCacheKey]*sync.RWMutex
// encClientLock is a lock for the encryption client. It is used to ensure that
// only one encryption client is created. This is necessary because the
// encryption client is not stored in the cache.
encClientLock sync.RWMutex
}

// Start method for cachingClientFactory starts the lifetime watcher handler.
Expand Down Expand Up @@ -215,6 +224,8 @@ func (m *cachingClientFactory) onClientEvict(ctx context.Context, client ctrlcli
logger.Info("Pruned storage", "count", count)
}
}

m.removeClientLock(cacheKey)
}

// Restore will attempt to restore a Client from storage. If storage is not enabled then no restoration will take place.
Expand Down Expand Up @@ -278,6 +289,23 @@ func (m *cachingClientFactory) isDisabled() bool {
return m.shutDown
}

func (m *cachingClientFactory) clientLock(cacheKey ClientCacheKey) (*sync.RWMutex, bool) {
m.clientLocksLock.Lock()
defer m.clientLocksLock.Unlock()
lock, ok := m.clientLocks[cacheKey]
if !ok {
lock = &sync.RWMutex{}
m.clientLocks[cacheKey] = lock
}
return lock, ok
}

func (m *cachingClientFactory) removeClientLock(cacheKey ClientCacheKey) {
m.clientLocksLock.Lock()
defer m.clientLocksLock.Unlock()
delete(m.clientLocks, cacheKey)
}

// Get is meant to be called for all resources that require access to Vault.
// It will attempt to fetch a Client from the in-memory cache for the provided Object.
// On a cache miss, an attempt at restoration from storage will be made, if a restoration attempt fails,
Expand All @@ -286,8 +314,6 @@ func (m *cachingClientFactory) isDisabled() bool {
//
// Supported types for obj are: VaultDynamicSecret, VaultStaticSecret. VaultPKISecret
func (m *cachingClientFactory) Get(ctx context.Context, client ctrlclient.Client, obj ctrlclient.Object) (Client, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.isDisabled() {
return nil, &ClientFactoryDisabledError{}
}
Expand All @@ -314,7 +340,19 @@ func (m *cachingClientFactory) Get(ctx context.Context, client ctrlclient.Client
return nil, errs
}

m.mu.RLock()
defer m.mu.RUnlock()

lock, cachedLock := m.clientLock(cacheKey)
lock.Lock()
defer lock.Unlock()

logger = logger.WithValues("cacheKey", cacheKey)
logger.V(consts.LogLevelDebug).Info("Got lock",
"numLocks", len(m.clientLocks),
"cachedLock", cachedLock,
)

logger.V(consts.LogLevelDebug).Info("Get Client")
ns, err := common.GetVaultNamespace(obj)
if err != nil {
Expand Down Expand Up @@ -546,6 +584,9 @@ func (m *cachingClientFactory) cacheClient(ctx context.Context, c Client, persis
// The result is cached in the ClientCache for future needs. This should only ever be need if the ClientCacheStorage
// has enforceEncryption enabled.
func (m *cachingClientFactory) storageEncryptionClient(ctx context.Context, client ctrlclient.Client) (Client, error) {
m.encClientLock.Lock()
defer m.encClientLock.Unlock()

cached := m.clientCacheKeyEncrypt != ""
if !cached {
m.logger.Info("Setting up Vault Client for storage encryption",
Expand Down Expand Up @@ -681,6 +722,7 @@ func NewCachingClientFactory(ctx context.Context, client ctrlclient.Client, cach
ctrlClient: client,
callbackHandlerCh: make(chan Client),
encryptionRequired: config.StorageConfig.EnforceEncryption,
clientLocks: make(map[ClientCacheKey]*sync.RWMutex, config.ClientCacheSize),
logger: zap.New().WithName("clientCacheFactory").WithValues(
"persist", config.Persist,
"enforceEncryption", config.StorageConfig.EnforceEncryption,
Expand Down
100 changes: 100 additions & 0 deletions internal/vault/client_factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -74,3 +75,102 @@ func Test_cachingClientFactory_RegisterClientCallbackHandler(t *testing.T) {
})
}
}

func Test_cachingClientFactory_clientLocks(t *testing.T) {
t.Parallel()

tests := []struct {
name string
cacheKey ClientCacheKey
tryLockCount int
wantInLocks bool
clientLocks map[ClientCacheKey]*sync.RWMutex
}{
{
name: "single-new",
cacheKey: ClientCacheKey("single"),
tryLockCount: 1,
wantInLocks: false,
},
{
name: "single-existing",
cacheKey: ClientCacheKey("single-existing"),
clientLocks: map[ClientCacheKey]*sync.RWMutex{
ClientCacheKey("single-existing"): {},
},
tryLockCount: 1,
wantInLocks: true,
},
{
name: "concurrent-new",
cacheKey: ClientCacheKey("concurrent-new"),
tryLockCount: 10,
wantInLocks: false,
},
{
name: "concurrent-existing",
cacheKey: ClientCacheKey("concurrent-existing"),
clientLocks: map[ClientCacheKey]*sync.RWMutex{
ClientCacheKey("concurrent-existing"): {},
},
tryLockCount: 10,
wantInLocks: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Greater(t, tt.tryLockCount, 0, "no test tryLockCount provided")

if tt.clientLocks == nil {
tt.clientLocks = make(map[ClientCacheKey]*sync.RWMutex)
}

m := &cachingClientFactory{
clientLocks: tt.clientLocks,
}

got, inLocks := m.clientLock(tt.cacheKey)
if !tt.wantInLocks {
assert.Equal(t, got, tt.clientLocks[tt.cacheKey])
}
require.Equal(t, tt.wantInLocks, inLocks)

// holdLockDuration is the duration each locker will hold the lock for after it
// is acquired.
holdLockDuration := 2 * time.Millisecond
// ctxTimeout is the total time to wait for all lockers to acquire the lock once.
ctxTimeout := time.Duration(tt.tryLockCount) * (holdLockDuration * 2)
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout)
go func() {
defer cancel()
time.Sleep(ctxTimeout)
}()

wg := sync.WaitGroup{}
wg.Add(tt.tryLockCount)
for i := 0; i < tt.tryLockCount; i++ {
go func(ctx context.Context) {
defer wg.Done()
lck, _ := m.clientLock(tt.cacheKey)
lck.Lock()
defer lck.Unlock()
assert.Equal(t, got, lck)

lockTimer := time.NewTimer(holdLockDuration)
defer lockTimer.Stop()
select {
case <-lockTimer.C:
return
case <-ctx.Done():
assert.NoError(t, ctx.Err(), "timeout waiting for lock")
return
}
}(ctx)
}
wg.Wait()

assert.NoError(t, ctx.Err(),
"context timeout waiting for all lockers")
})
}
}

0 comments on commit 3ad6d40

Please sign in to comment.