Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CachingClientFactory: lock by client cache key #716

Merged
merged 3 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions internal/vault/client_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ type cachingClientFactory struct {
mu sync.RWMutex
onceDoWatcher sync.Once
callbackHandlerCancel context.CancelFunc
// clientLocksLock is a lock for the clientLocks map.
clientLocksLock sync.RWMutex
// clientLocks is a map of cache keys to locks that allow for concurrent access
// to the client factory's cache.
clientLocks map[ClientCacheKey]*sync.RWMutex
// encClientLock is a lock for the encryption client. It is used to ensure that
// only one encryption client is created. This is necessary because the
// encryption client is not stored in the cache.
encClientLock sync.RWMutex
}

// Start method for cachingClientFactory starts the lifetime watcher handler.
Expand Down Expand Up @@ -207,6 +216,8 @@ func (m *cachingClientFactory) onClientEvict(ctx context.Context, client ctrlcli
logger.Info("Pruned storage", "count", count)
}
}

m.removeClientLock(cacheKey)
}

// Restore will attempt to restore a Client from storage. If storage is not enabled then no restoration will take place.
Expand Down Expand Up @@ -270,6 +281,23 @@ func (m *cachingClientFactory) isDisabled() bool {
return m.shutDown
}

func (m *cachingClientFactory) clientLock(cacheKey ClientCacheKey) (*sync.RWMutex, bool) {
m.clientLocksLock.Lock()
defer m.clientLocksLock.Unlock()
lock, ok := m.clientLocks[cacheKey]
if !ok {
lock = &sync.RWMutex{}
m.clientLocks[cacheKey] = lock
}
return lock, ok
}

func (m *cachingClientFactory) removeClientLock(cacheKey ClientCacheKey) {
m.clientLocksLock.Lock()
defer m.clientLocksLock.Unlock()
delete(m.clientLocks, cacheKey)
}

// Get is meant to be called for all resources that require access to Vault.
// It will attempt to fetch a Client from the in-memory cache for the provided Object.
// On a cache miss, an attempt at restoration from storage will be made, if a restoration attempt fails,
Expand All @@ -278,8 +306,6 @@ func (m *cachingClientFactory) isDisabled() bool {
//
// Supported types for obj are: VaultDynamicSecret, VaultStaticSecret. VaultPKISecret
func (m *cachingClientFactory) Get(ctx context.Context, client ctrlclient.Client, obj ctrlclient.Object) (Client, error) {
m.mu.Lock()
defer m.mu.Unlock()
benashz marked this conversation as resolved.
Show resolved Hide resolved
if m.isDisabled() {
return nil, &ClientFactoryDisabledError{}
}
Expand All @@ -306,7 +332,19 @@ func (m *cachingClientFactory) Get(ctx context.Context, client ctrlclient.Client
return nil, errs
}

m.mu.RLock()
defer m.mu.RUnlock()

lock, cachedLock := m.clientLock(cacheKey)
lock.Lock()
defer lock.Unlock()

logger = logger.WithValues("cacheKey", cacheKey)
logger.V(consts.LogLevelDebug).Info("Got lock",
"numLocks", len(m.clientLocks),
"cachedLock", cachedLock,
)

logger.V(consts.LogLevelDebug).Info("Get Client")
ns, err := common.GetVaultNamespace(obj)
if err != nil {
Expand Down Expand Up @@ -538,6 +576,9 @@ func (m *cachingClientFactory) cacheClient(ctx context.Context, c Client, persis
// The result is cached in the ClientCache for future needs. This should only ever be need if the ClientCacheStorage
// has enforceEncryption enabled.
func (m *cachingClientFactory) storageEncryptionClient(ctx context.Context, client ctrlclient.Client) (Client, error) {
m.encClientLock.Lock()
defer m.encClientLock.Unlock()

cached := m.clientCacheKeyEncrypt != ""
if !cached {
m.logger.Info("Setting up Vault Client for storage encryption",
Expand Down Expand Up @@ -673,6 +714,7 @@ func NewCachingClientFactory(ctx context.Context, client ctrlclient.Client, cach
ctrlClient: client,
callbackHandlerCh: make(chan Client),
encryptionRequired: config.StorageConfig.EnforceEncryption,
clientLocks: make(map[ClientCacheKey]*sync.RWMutex, config.ClientCacheSize),
logger: zap.New().WithName("clientCacheFactory").WithValues(
"persist", config.Persist,
"enforceEncryption", config.StorageConfig.EnforceEncryption,
Expand Down
100 changes: 100 additions & 0 deletions internal/vault/client_factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -74,3 +75,102 @@ func Test_cachingClientFactory_RegisterClientCallbackHandler(t *testing.T) {
})
}
}

func Test_cachingClientFactory_clientLocks(t *testing.T) {
t.Parallel()

tests := []struct {
name string
cacheKey ClientCacheKey
tryLockCount int
wantInLocks bool
clientLocks map[ClientCacheKey]*sync.RWMutex
}{
{
name: "single-new",
cacheKey: ClientCacheKey("single"),
tryLockCount: 1,
wantInLocks: false,
},
{
name: "single-existing",
cacheKey: ClientCacheKey("single-existing"),
clientLocks: map[ClientCacheKey]*sync.RWMutex{
ClientCacheKey("single-existing"): {},
},
tryLockCount: 1,
wantInLocks: true,
},
{
name: "concurrent-new",
cacheKey: ClientCacheKey("concurrent-new"),
tryLockCount: 10,
wantInLocks: false,
},
{
name: "concurrent-existing",
cacheKey: ClientCacheKey("concurrent-existing"),
clientLocks: map[ClientCacheKey]*sync.RWMutex{
ClientCacheKey("concurrent-existing"): {},
},
tryLockCount: 10,
wantInLocks: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Greater(t, tt.tryLockCount, 0, "no test tryLockCount provided")

if tt.clientLocks == nil {
tt.clientLocks = make(map[ClientCacheKey]*sync.RWMutex)
}

m := &cachingClientFactory{
clientLocks: tt.clientLocks,
}

got, inLocks := m.clientLock(tt.cacheKey)
if !tt.wantInLocks {
assert.Equal(t, got, tt.clientLocks[tt.cacheKey])
}
require.Equal(t, tt.wantInLocks, inLocks)

// holdLockDuration is the duration each locker will hold the lock for after it
// is acquired.
holdLockDuration := 2 * time.Millisecond
// ctxTimeout is the total time to wait for all lockers to acquire the lock once.
ctxTimeout := time.Duration(tt.tryLockCount) * (holdLockDuration * 2)
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout)
go func() {
defer cancel()
time.Sleep(ctxTimeout)
}()

wg := sync.WaitGroup{}
wg.Add(tt.tryLockCount)
for i := 0; i < tt.tryLockCount; i++ {
go func(ctx context.Context) {
defer wg.Done()
lck, _ := m.clientLock(tt.cacheKey)
lck.Lock()
defer lck.Unlock()
assert.Equal(t, got, lck)

lockTimer := time.NewTimer(holdLockDuration)
defer lockTimer.Stop()
select {
case <-lockTimer.C:
return
case <-ctx.Done():
assert.NoError(t, ctx.Err(), "timeout waiting for lock")
return
}
}(ctx)
}
wg.Wait()

assert.NoError(t, ctx.Err(),
"context timeout waiting for all lockers")
})
}
}
Loading