Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/ghmcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv
if cfg.RepoAccessTTL != nil {
opts = append(opts, lockdown.WithTTL(*cfg.RepoAccessTTL))
}
repoAccessCache = lockdown.GetInstance(gqlClient, restClient, opts...)
repoAccessCache = lockdown.NewRepoAccessCache(gqlClient, restClient, opts...)
}

return &githubClients{
Expand Down
2 changes: 1 addition & 1 deletion pkg/github/dependencies.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ func (d *RequestDeps) GetRepoAccessCache(ctx context.Context) (*lockdown.RepoAcc
}

// Create repo access cache
instance := lockdown.GetInstance(gqlClient, restClient, d.RepoAccessOpts...)
instance := lockdown.NewRepoAccessCache(gqlClient, restClient, d.RepoAccessOpts...)
return instance, nil
}

Expand Down
19 changes: 9 additions & 10 deletions pkg/github/issues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,15 @@ func (rt *repoAccessMockTransport) RoundTrip(req *http.Request) (*http.Response,
value = repoAccessValue{isPrivate: false}
}

responseBody, err := json.Marshal(map[string]any{
"data": map[string]any{
"viewer": map[string]any{
"login": "test-viewer",
},
"repository": map[string]any{
"isPrivate": value.isPrivate,
},
},
})
data := map[string]any{}
if strings.Contains(payload.Query, "viewer") {
data["viewer"] = map[string]any{"login": "test-viewer"}
}
if strings.Contains(payload.Query, "repository") {
data["repository"] = map[string]any{"isPrivate": value.isPrivate}
}

responseBody, err := json.Marshal(map[string]any{"data": data})
if err != nil {
return nil, err
}
Expand Down
159 changes: 86 additions & 73 deletions pkg/lockdown/lockdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log/slog"
"maps"
"strings"
"sync"
"time"
Expand All @@ -15,39 +16,36 @@ import (

// RepoAccessCache caches repository metadata related to lockdown checks so that
// multiple tools can reuse the same access information safely across goroutines.
// In HTTP mode each request must construct its own instance so viewer-scoped
// lookups run under the requesting user's credentials.
type RepoAccessCache struct {
client *githubv4.Client
restClient *github.Client
mu sync.Mutex
cache *cache2go.CacheTable
ttl time.Duration
logger *slog.Logger
trustedBotLogins map[string]struct{}

viewerMu sync.Mutex
viewerLogin string
}

type repoAccessCacheEntry struct {
isPrivate bool
knownUsers map[string]bool // normalized login -> has push access
viewerLogin string
isPrivate bool
knownUsers map[string]bool // normalized login -> has push access
}

// RepoAccessInfo captures repository metadata needed for lockdown decisions.
type RepoAccessInfo struct {
IsPrivate bool
HasPushAccess bool
Comment thread
kerobbi marked this conversation as resolved.
ViewerLogin string
}

const (
defaultRepoAccessTTL = 20 * time.Minute
defaultRepoAccessCacheKey = "repo-access-cache"
)

var (
instance *RepoAccessCache
instanceMu sync.Mutex
)

// RepoAccessOption configures RepoAccessCache at construction time.
type RepoAccessOption func(*RepoAccessCache)

Expand All @@ -66,8 +64,8 @@ func WithLogger(logger *slog.Logger) RepoAccessOption {
}
}

// WithCacheName overrides the cache table name used for storing entries. This option is intended for tests
// that need isolated cache instances.
// WithCacheName overrides the cache table name used for storing entries.
// Use this to isolate cache entries between tenants or in tests.
func WithCacheName(name string) RepoAccessOption {
return func(c *RepoAccessCache) {
if name != "" {
Expand All @@ -76,25 +74,8 @@ func WithCacheName(name string) RepoAccessOption {
}
}

// GetInstance returns the singleton instance of RepoAccessCache.
// It initializes the instance on first call with the provided client and options.
// Subsequent calls ignore the client and options parameters and return the existing instance.
// This is the preferred way to access the cache in production code.
func GetInstance(client *githubv4.Client, restClient *github.Client, opts ...RepoAccessOption) *RepoAccessCache {
instanceMu.Lock()
defer instanceMu.Unlock()
if instance == nil {
instance = newRepoAccessCache(client, restClient, opts...)
}
return instance
}

// NewRepoAccessCache creates a standalone cache instance, used for tests.
// NewRepoAccessCache creates a RepoAccessCache bound to the supplied clients.
func NewRepoAccessCache(client *githubv4.Client, restClient *github.Client, opts ...RepoAccessOption) *RepoAccessCache {
Comment thread
kerobbi marked this conversation as resolved.
return newRepoAccessCache(client, restClient, opts...)
}

func newRepoAccessCache(client *githubv4.Client, restClient *github.Client, opts ...RepoAccessOption) *RepoAccessCache {
c := &RepoAccessCache{
client: client,
restClient: restClient,
Expand All @@ -113,13 +94,6 @@ func newRepoAccessCache(client *githubv4.Client, restClient *github.Client, opts
return c
}

// SetLogger updates the logger used for cache diagnostics.
func (c *RepoAccessCache) SetLogger(logger *slog.Logger) {
c.mu.Lock()
c.logger = logger
c.mu.Unlock()
}

// CacheStats summarizes cache activity counters.
Comment thread
kerobbi marked this conversation as resolved.
type CacheStats struct {
Hits int64
Expand Down Expand Up @@ -150,10 +124,55 @@ func (c *RepoAccessCache) IsSafeContent(ctx context.Context, username, owner, re
c.logDebug(ctx, fmt.Sprintf("evaluated repo access for user %s to %s/%s for content filtering, result: hasPushAccess=%t, isPrivate=%t",
username, owner, repo, repoInfo.HasPushAccess, repoInfo.IsPrivate))

if repoInfo.IsPrivate || repoInfo.ViewerLogin == strings.ToLower(username) {
if repoInfo.IsPrivate {
return true, nil
}
if repoInfo.HasPushAccess {
return true, nil
}
return repoInfo.HasPushAccess, nil

viewerLogin, err := c.viewerLoginFor(ctx)
if err != nil {
return false, err
}
return viewerLogin == strings.ToLower(username), nil
}

func (c *RepoAccessCache) viewerLoginFor(ctx context.Context) (string, error) {
c.viewerMu.Lock()
defer c.viewerMu.Unlock()
if c.viewerLogin != "" {
return c.viewerLogin, nil
}
if c.client == nil {
return "", fmt.Errorf("nil GraphQL client")
}
var query struct {
Viewer struct {
Login githubv4.String
}
}
if err := c.client.Query(ctx, &query, nil); err != nil {
return "", fmt.Errorf("failed to query viewer login: %w", err)
}
login := strings.ToLower(string(query.Viewer.Login))
if login == "" {
return "", fmt.Errorf("viewer login returned empty")
}
c.viewerLogin = login
return c.viewerLogin, nil
}

// setViewerLogin seeds the cached viewer login from a piggy-backed query response.
func (c *RepoAccessCache) setViewerLogin(login string) {
if login == "" {
return
}
c.viewerMu.Lock()
defer c.viewerMu.Unlock()
if c.viewerLogin == "" {
c.viewerLogin = strings.ToLower(login)
}
}

func (c *RepoAccessCache) getRepoAccessInfo(ctx context.Context, username, owner, repo string) (RepoAccessInfo, error) {
Expand All @@ -163,19 +182,16 @@ func (c *RepoAccessCache) getRepoAccessInfo(ctx context.Context, username, owner

key := cacheKey(owner, repo)
userKey := strings.ToLower(username)
c.mu.Lock()
defer c.mu.Unlock()

// Try to get entry from cache - this will keep the item alive if it exists
cacheItem, err := c.cache.Value(key)
if err == nil {
// Entries are immutable once added: the cache table is shared across instances,
// so we publish a fresh entry with a cloned knownUsers map on every miss.
if cacheItem, err := c.cache.Value(key); err == nil {
entry := cacheItem.Data().(*repoAccessCacheEntry)
if cachedHasPush, known := entry.knownUsers[userKey]; known {
c.logDebug(ctx, fmt.Sprintf("repo access cache hit for user %s to %s/%s", username, owner, repo))
return RepoAccessInfo{
IsPrivate: entry.isPrivate,
HasPushAccess: cachedHasPush,
ViewerLogin: entry.viewerLogin,
}, nil
}

Expand All @@ -186,41 +202,48 @@ func (c *RepoAccessCache) getRepoAccessInfo(ctx context.Context, username, owner
return RepoAccessInfo{}, pushErr
}

entry.knownUsers[userKey] = hasPush
c.cache.Add(key, c.ttl, entry)
users := make(map[string]bool, len(entry.knownUsers)+1)
maps.Copy(users, entry.knownUsers)
users[userKey] = hasPush
c.cache.Add(key, c.ttl, &repoAccessCacheEntry{
isPrivate: entry.isPrivate,
knownUsers: users,
})

return RepoAccessInfo{
IsPrivate: entry.isPrivate,
HasPushAccess: hasPush,
ViewerLogin: entry.viewerLogin,
}, nil
}

c.logDebug(ctx, fmt.Sprintf("repo access cache miss for user %s to %s/%s", username, owner, repo))

info, queryErr := c.queryRepoAccessInfo(ctx, username, owner, repo)
isPrivate, viewerLogin, queryErr := c.queryRepoAccessInfo(ctx, owner, repo)
if queryErr != nil {
return RepoAccessInfo{}, queryErr
}
c.setViewerLogin(viewerLogin)

// Create new entry
entry := &repoAccessCacheEntry{
knownUsers: map[string]bool{userKey: info.HasPushAccess},
isPrivate: info.IsPrivate,
viewerLogin: info.ViewerLogin,
hasPush, pushErr := c.checkPushAccess(ctx, username, owner, repo)
if pushErr != nil {
return RepoAccessInfo{}, pushErr
}
c.cache.Add(key, c.ttl, entry)

c.cache.Add(key, c.ttl, &repoAccessCacheEntry{
knownUsers: map[string]bool{userKey: hasPush},
isPrivate: isPrivate,
})

return RepoAccessInfo{
IsPrivate: entry.isPrivate,
HasPushAccess: entry.knownUsers[userKey],
ViewerLogin: entry.viewerLogin,
IsPrivate: isPrivate,
HasPushAccess: hasPush,
}, nil
}

func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, owner, repo string) (RepoAccessInfo, error) {
// queryRepoAccessInfo fetches repository visibility and the viewer login in a single GraphQL round-trip.
func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, owner, repo string) (bool, string, error) {
if c.client == nil {
return RepoAccessInfo{}, fmt.Errorf("nil GraphQL client")
return false, "", fmt.Errorf("nil GraphQL client")
}

var query struct {
Expand All @@ -238,22 +261,12 @@ func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, own
}

if err := c.client.Query(ctx, &query, variables); err != nil {
return RepoAccessInfo{}, fmt.Errorf("failed to query repository metadata: %w", err)
}

hasPush, err := c.checkPushAccess(ctx, username, owner, repo)
if err != nil {
return RepoAccessInfo{}, err
return false, "", fmt.Errorf("failed to query repository metadata: %w", err)
}

c.logDebug(ctx, fmt.Sprintf("queried repo access info for user %s to %s/%s: isPrivate=%t, hasPushAccess=%t, viewerLogin=%s",
username, owner, repo, bool(query.Repository.IsPrivate), hasPush, query.Viewer.Login))
c.logDebug(ctx, fmt.Sprintf("queried repo access info for %s/%s: isPrivate=%t", owner, repo, bool(query.Repository.IsPrivate)))

return RepoAccessInfo{
IsPrivate: bool(query.Repository.IsPrivate),
HasPushAccess: hasPush,
ViewerLogin: string(query.Viewer.Login),
}, nil
return bool(query.Repository.IsPrivate), string(query.Viewer.Login), nil
}

// checkPushAccess checks if the user has push access to the repository via the REST permission endpoint.
Expand Down
Loading
Loading