diff --git a/expirable/expirable_lru.go b/expirable/expirable_lru.go index 89978d6..36a7678 100644 --- a/expirable/expirable_lru.go +++ b/expirable/expirable_lru.go @@ -120,6 +120,10 @@ func (c *LRU[K, V]) Purge() { func (c *LRU[K, V]) Add(key K, value V) (evicted bool) { c.mu.Lock() defer c.mu.Unlock() + return c.addWithLock(key, value) +} + +func (c *LRU[K, V]) addWithLock(key K, value V) (evicted bool) { now := time.Now() // Check for existing item @@ -149,6 +153,11 @@ func (c *LRU[K, V]) Add(key K, value V) (evicted bool) { func (c *LRU[K, V]) Get(key K) (value V, ok bool) { c.mu.Lock() defer c.mu.Unlock() + return c.getWithLock(key) +} + +// Get looks up a key's value from the cache. +func (c *LRU[K, V]) getWithLock(key K) (value V, ok bool) { var ent *internal.Entry[K, V] if ent, ok = c.items[key]; ok { // Expired item check @@ -161,6 +170,33 @@ func (c *LRU[K, V]) Get(key K) (value V, ok bool) { return } +// ConstructorFunc provides a function signature for methods like GetOrAddFunc. +// GetOrAddFunc will invoke this function parameter to generate an item if +// no matching item is found in the LRU. This allows a new item to be +// initialized only if not found in the LRU, in a thread-safe manner. +type ConstructorFunc[V any] func() (V, error) + +// GetOrAddFunc looks up a key's value from the cache. If not present, the +// ConstructorFunc argument is executed to create an entry and add it. +func (c *LRU[K, V]) GetOrAddFunc(key K, fn ConstructorFunc[V]) (value V, added bool, evicted bool, err error) { + c.mu.Lock() + defer c.mu.Unlock() + + if existingValue, exists := c.getWithLock(key); exists { + return existingValue, added, evicted, nil + } else { + if fn != nil { + value, err = fn() + if err != nil { + return value, false, false, err + } + } + added = true + evicted = c.addWithLock(key, value) + return value, added, evicted, nil + } +} + // Contains checks if a key is in the cache, without updating the recent-ness // or deleting it for being stale. func (c *LRU[K, V]) Contains(key K) (ok bool) { diff --git a/expirable/expirable_lru_test.go b/expirable/expirable_lru_test.go index fd3b255..eb5efc0 100644 --- a/expirable/expirable_lru_test.go +++ b/expirable/expirable_lru_test.go @@ -10,6 +10,7 @@ import ( "math/big" "reflect" "sync" + "sync/atomic" "testing" "time" @@ -335,6 +336,36 @@ func TestLRUConcurrency(t *testing.T) { } } +func TestLRUGetOrAddFuncConcurrency(t *testing.T) { + lc := NewLRU[string, string](0, nil, 0) + wg := sync.WaitGroup{} + wg.Add(1000) + var evictedCount int32 + var addedCount int32 + for i := 0; i < 1000; i++ { + go func(i int) { + _, added, evicted, _ := lc.GetOrAddFunc(fmt.Sprintf("key-%d", i/10), func() (string, error) { return fmt.Sprintf("val-%d", i/10), nil }) + if evicted { + atomic.AddInt32(&evictedCount, 1) + } + if added { + atomic.AddInt32(&addedCount, 1) + } + wg.Done() + }(i) + } + wg.Wait() + if lc.Len() != 100 { + t.Fatalf("length differs from expected") + } + if addedCount != 100 { + t.Fatalf("GetOrAddFunc: unexpected added count %d", addedCount) + } + if evictedCount > 0 { + t.Fatalf("GetOrAddFunc: unexpected evicted count %d", evictedCount) + } +} + func TestLRUInvalidateAndEvict(t *testing.T) { var evicted int lc := NewLRU(-1, func(_, _ string) { evicted++ }, 0)