diff --git a/go.mod b/go.mod index 02b9ad252..8d5b1b274 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/mark3labs/mcp-go v0.36.0 github.com/microcosm-cc/bluemonday v1.0.27 github.com/migueleliasweb/go-github-mock v1.3.0 + github.com/muesli/cache2go v0.0.0-20221011235721-518229cd8021 github.com/spf13/cobra v1.10.1 github.com/spf13/viper v1.21.0 github.com/stretchr/testify v1.11.1 @@ -37,7 +38,7 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 - github.com/google/go-querystring v1.1.0 + github.com/google/go-querystring v1.1.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect diff --git a/go.sum b/go.sum index 1ac8b7606..0ff7b51fa 100644 --- a/go.sum +++ b/go.sum @@ -63,6 +63,8 @@ github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwX github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA= github.com/migueleliasweb/go-github-mock v1.3.0 h1:2sVP9JEMB2ubQw1IKto3/fzF51oFC6eVWOOFDgQoq88= github.com/migueleliasweb/go-github-mock v1.3.0/go.mod h1:ipQhV8fTcj/G6m7BKzin08GaJ/3B5/SonRAkgrk0zCY= +github.com/muesli/cache2go v0.0.0-20221011235721-518229cd8021 h1:31Y+Yu373ymebRdJN1cWLLooHH8xAr0MhKTEJGV/87g= +github.com/muesli/cache2go v0.0.0-20221011235721-518229cd8021/go.mod h1:WERUkUryfUWlrHnFSO/BEUZ+7Ns8aZy7iVOGewxKzcc= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= diff --git a/pkg/lockdown/lockdown.go b/pkg/lockdown/lockdown.go index ddecca16d..e41ba74b7 100644 --- a/pkg/lockdown/lockdown.go +++ b/pkg/lockdown/lockdown.go @@ -6,17 +6,21 @@ import ( "log/slog" "strings" "sync" + "sync/atomic" "time" + "github.com/muesli/cache2go" "github.com/shurcooL/githubv4" ) +var cacheNameCounter atomic.Uint64 + // RepoAccessCache caches repository metadata related to lockdown checks so that // multiple tools can reuse the same access information safely across goroutines. type RepoAccessCache struct { client *githubv4.Client mu sync.Mutex - cache map[string]*repoAccessCacheEntry + cache *cache2go.CacheTable ttl time.Duration logger *slog.Logger } @@ -25,7 +29,6 @@ type repoAccessCacheEntry struct { isPrivate bool knownUsers map[string]bool // normalized login -> has push access ready bool - timer *time.Timer } const defaultRepoAccessTTL = 5 * time.Minute @@ -51,9 +54,11 @@ func WithLogger(logger *slog.Logger) RepoAccessOption { // NewRepoAccessCache returns a cache bound to the provided GitHub GraphQL // client. The cache is safe for concurrent use. func NewRepoAccessCache(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessCache { + // Use a unique cache name for each instance to avoid sharing state between tests + cacheName := fmt.Sprintf("repoAccess-%d", cacheNameCounter.Add(1)) c := &RepoAccessCache{ client: client, - cache: make(map[string]*repoAccessCacheEntry), + cache: cache2go.Cache(cacheName), ttl: defaultRepoAccessTTL, } for _, opt := range opts { @@ -72,8 +77,19 @@ func (c *RepoAccessCache) SetTTL(ttl time.Duration) { defer c.mu.Unlock() c.ttl = ttl c.logInfo("repo access cache TTL updated", "ttl", ttl) - for key, entry := range c.cache { - entry.scheduleExpiry(c, key) + + // Collect all current entries + entries := make(map[interface{}]*repoAccessCacheEntry) + c.cache.Foreach(func(key interface{}, item *cache2go.CacheItem) { + entries[key] = item.Data().(*repoAccessCacheEntry) + }) + + // Flush the cache + c.cache.Flush() + + // Re-add all entries with the new TTL + for key, entry := range entries { + c.cache.Add(key, ttl, entry) } } @@ -103,69 +119,46 @@ func (c *RepoAccessCache) GetRepoAccessInfo(ctx context.Context, username, owner userKey := strings.ToLower(username) c.mu.Lock() defer c.mu.Unlock() - entry := c.ensureEntry(key) - if entry.ready { - if cachedHasPush, known := entry.knownUsers[userKey]; known { - entry.scheduleExpiry(c, key) - c.logDebug("repo access cache hit", "owner", owner, "repo", repo, "user", username) - cachedPrivate := entry.isPrivate - return cachedPrivate, cachedHasPush, nil + + // Try to get entry from cache - this will keep the item alive if it exists + cacheItem, err := c.cache.Value(key) + if err == nil { + entry := cacheItem.Data().(*repoAccessCacheEntry) + if entry.ready { + if cachedHasPush, known := entry.knownUsers[userKey]; known { + c.logDebug("repo access cache hit", "owner", owner, "repo", repo, "user", username) + return entry.isPrivate, cachedHasPush, nil + } } + // Entry exists but user not in knownUsers, need to query } c.logDebug("repo access cache miss", "owner", owner, "repo", repo, "user", username) - isPrivate, hasPush, err := c.queryRepoAccessInfo(ctx, username, owner, repo) - if err != nil { - return false, false, err + isPrivate, hasPush, queryErr := c.queryRepoAccessInfo(ctx, username, owner, repo) + if queryErr != nil { + return false, false, queryErr } - entry = c.ensureEntry(key) - entry.ready = true - entry.isPrivate = isPrivate - entry.knownUsers[userKey] = hasPush - entry.scheduleExpiry(c, key) - - return isPrivate, hasPush, nil -} - -func (c *RepoAccessCache) ensureEntry(key string) *repoAccessCacheEntry { - if c.cache == nil { - c.cache = make(map[string]*repoAccessCacheEntry) - } - entry, ok := c.cache[key] - if !ok { + // Get or create entry - don't use Value() here to avoid keeping alive unnecessarily + var entry *repoAccessCacheEntry + if err == nil && cacheItem != nil { + // Entry already existed, just update it + entry = cacheItem.Data().(*repoAccessCacheEntry) + } else { + // Create new entry entry = &repoAccessCacheEntry{ knownUsers: make(map[string]bool), } - c.cache[key] = entry } - return entry -} - -func (entry *repoAccessCacheEntry) scheduleExpiry(c *RepoAccessCache, key string) { - if entry.timer != nil { - entry.timer.Stop() - entry.timer = nil - } - - dur := c.ttl - if dur <= 0 { - return - } - - owner, repo := splitKey(key) - entry.timer = time.AfterFunc(dur, func() { - c.mu.Lock() - defer c.mu.Unlock() - - current, ok := c.cache[key] - if !ok || current != entry { - return - } + + entry.ready = true + entry.isPrivate = isPrivate + entry.knownUsers[userKey] = hasPush + + // Add or update the entry in cache with TTL + c.cache.Add(key, c.ttl, entry) - delete(c.cache, key) - c.logDebug("repo access cache entry evicted", "owner", owner, "repo", repo) - }) + return isPrivate, hasPush, nil } func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, owner, repo string) (bool, bool, error) {