diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 15b1efc10..970d230ab 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -62,7 +62,7 @@ type MCPServerConfig struct { const stdioServerLogPrefix = "stdioserver" -func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { +func NewMCPServer(cfg MCPServerConfig, logger *slog.Logger) (*server.MCPServer, error) { apiHost, err := parseAPIHost(cfg.Host) if err != nil { return nil, fmt.Errorf("failed to parse API host: %w", err) @@ -88,6 +88,9 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { if cfg.RepoAccessTTL != nil { repoAccessOpts = append(repoAccessOpts, lockdown.WithTTL(*cfg.RepoAccessTTL)) } + + repoAccessLogger := logger.With("component", "lockdown") + repoAccessOpts = append(repoAccessOpts, lockdown.WithLogger(repoAccessLogger)) var repoAccessCache *lockdown.RepoAccessCache if cfg.LockdownMode { repoAccessCache = lockdown.GetInstance(gqlClient, repoAccessOpts...) @@ -273,7 +276,7 @@ func RunStdioServer(cfg StdioServerConfig) error { ContentWindowSize: cfg.ContentWindowSize, LockdownMode: cfg.LockdownMode, RepoAccessTTL: cfg.RepoAccessCacheTTL, - }) + }, logger) if err != nil { return fmt.Errorf("failed to create MCP server: %w", err) } diff --git a/pkg/lockdown/lockdown.go b/pkg/lockdown/lockdown.go index 4c3500440..80eca07f8 100644 --- a/pkg/lockdown/lockdown.go +++ b/pkg/lockdown/lockdown.go @@ -15,11 +15,12 @@ import ( // RepoAccessCache caches repository metadata related to lockdown checks so that // multiple tools can reuse the same access information safely across goroutines. type RepoAccessCache struct { - client *githubv4.Client - mu sync.Mutex - cache *cache2go.CacheTable - ttl time.Duration - logger *slog.Logger + client *githubv4.Client + mu sync.Mutex + cache *cache2go.CacheTable + ttl time.Duration + logger *slog.Logger + trustedBotLogins map[string]struct{} } type repoAccessCacheEntry struct { @@ -85,6 +86,9 @@ func GetInstance(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessC client: client, cache: cache2go.Cache(defaultRepoAccessCacheKey), ttl: defaultRepoAccessTTL, + trustedBotLogins: map[string]struct{}{ + "copilot": {}, + }, } for _, opt := range opts { if opt != nil { @@ -109,13 +113,22 @@ type CacheStats struct { Evictions int64 } +// IsSafeContent determines if the specified user can safely access the requested repository content. +// Safe access applies when any of the following is true: +// - the content was created by a trusted bot; +// - the author currently has push access to the repository; +// - the repository is private; +// - the content was created by the viewer. func (c *RepoAccessCache) IsSafeContent(ctx context.Context, username, owner, repo string) (bool, error) { repoInfo, err := c.getRepoAccessInfo(ctx, username, owner, repo) if err != nil { - c.logDebug("error checking repo access info for content filtering", "owner", owner, "repo", repo, "user", username, "error", err) return false, err } - if repoInfo.IsPrivate || repoInfo.ViewerLogin == username { + + c.logDebug(ctx, fmt.Sprintf("evaluated repo access for user %s to %s/%s for content filtering, result: hasPushAccess=%t, isPrivate=%t", + username, owner, repo, repoInfo.HasPushAccess, repoInfo.IsPrivate)) + + if c.isTrustedBot(username) || repoInfo.IsPrivate || repoInfo.ViewerLogin == strings.ToLower(username) { return true, nil } return repoInfo.HasPushAccess, nil @@ -136,22 +149,26 @@ func (c *RepoAccessCache) getRepoAccessInfo(ctx context.Context, username, owner if err == nil { entry := cacheItem.Data().(*repoAccessCacheEntry) if cachedHasPush, known := entry.knownUsers[userKey]; known { - c.logDebug("repo access cache hit", "owner", owner, "repo", repo, "user", username) + c.logDebug(ctx, fmt.Sprintf("repo access cache hit for user %s to %s/%s", username, owner, repo)) return RepoAccessInfo{ IsPrivate: entry.isPrivate, HasPushAccess: cachedHasPush, ViewerLogin: entry.viewerLogin, }, nil } - c.logDebug("known users cache miss", "owner", owner, "repo", repo, "user", username) + + c.logDebug(ctx, "known users cache miss, fetching from graphql API") + info, queryErr := c.queryRepoAccessInfo(ctx, username, owner, repo) if queryErr != nil { return RepoAccessInfo{}, queryErr } + entry.knownUsers[userKey] = info.HasPushAccess entry.viewerLogin = info.ViewerLogin entry.isPrivate = info.IsPrivate c.cache.Add(key, c.ttl, entry) + return RepoAccessInfo{ IsPrivate: entry.isPrivate, HasPushAccess: entry.knownUsers[userKey], @@ -159,7 +176,7 @@ func (c *RepoAccessCache) getRepoAccessInfo(ctx context.Context, username, owner }, nil } - c.logDebug("repo access cache miss", "owner", owner, "repo", repo, "user", username) + c.logDebug(ctx, fmt.Sprintf("repo access cache miss for user %s to %s/%s", username, owner, repo)) info, queryErr := c.queryRepoAccessInfo(ctx, username, owner, repo) if queryErr != nil { @@ -223,6 +240,9 @@ func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, own } } + c.logDebug(ctx, fmt.Sprintf("queried repo access info for user %s to %s/%s: isPrivate=%t, hasPushAccess=%t, viewerLogin=%s", + username, owner, repo, bool(query.Repository.IsPrivate), hasPush, query.Viewer.Login)) + return RepoAccessInfo{ IsPrivate: bool(query.Repository.IsPrivate), HasPushAccess: hasPush, @@ -230,12 +250,25 @@ func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, own }, nil } -func cacheKey(owner, repo string) string { - return fmt.Sprintf("%s/%s", strings.ToLower(owner), strings.ToLower(repo)) +func (c *RepoAccessCache) log(ctx context.Context, level slog.Level, msg string, attrs ...slog.Attr) { + if c == nil || c.logger == nil { + return + } + if !c.logger.Enabled(ctx, level) { + return + } + c.logger.LogAttrs(ctx, level, msg, attrs...) } -func (c *RepoAccessCache) logDebug(msg string, args ...any) { - if c != nil && c.logger != nil { - c.logger.Debug(msg, args...) - } +func (c *RepoAccessCache) logDebug(ctx context.Context, msg string, attrs ...slog.Attr) { + c.log(ctx, slog.LevelDebug, msg, attrs...) +} + +func (c *RepoAccessCache) isTrustedBot(username string) bool { + _, ok := c.trustedBotLogins[strings.ToLower(username)] + return ok +} + +func cacheKey(owner, repo string) string { + return fmt.Sprintf("%s/%s", strings.ToLower(owner), strings.ToLower(repo)) }