Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 79 additions & 12 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"log/slog"
"os"
"strconv"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -93,10 +94,12 @@ type Client struct {
retries int

// Cache for commit SHA to PR number lookups (for check event race condition)
commitPRCache map[string][]int // key: "owner/repo:sha", value: PR numbers
commitCacheKeys []string // track insertion order for LRU eviction
cacheMu sync.RWMutex
maxCacheSize int
commitPRCache map[string][]int // key: "owner/repo:sha", value: PR numbers
commitPRCacheExpiry map[string]time.Time // key: "owner/repo:sha", value: expiry time (only for empty results)
commitCacheKeys []string // track insertion order for LRU eviction
cacheMu sync.RWMutex
maxCacheSize int
emptyResultTTL time.Duration // TTL for empty results (to handle GitHub indexing race)
}

// New creates a new robust WebSocket client.
Expand Down Expand Up @@ -127,13 +130,15 @@ func New(config Config) (*Client, error) {
}

return &Client{
config: config,
stopCh: make(chan struct{}),
stoppedCh: make(chan struct{}),
logger: logger,
commitPRCache: make(map[string][]int),
commitCacheKeys: make([]string, 0, 512),
maxCacheSize: 512,
config: config,
stopCh: make(chan struct{}),
stoppedCh: make(chan struct{}),
logger: logger,
commitPRCache: make(map[string][]int),
commitPRCacheExpiry: make(map[string]time.Time),
commitCacheKeys: make([]string, 0, 512),
maxCacheSize: 512,
emptyResultTTL: 30 * time.Second, // Retry empty results after 30s
}, nil
}

Expand Down Expand Up @@ -538,6 +543,8 @@ func (c *Client) sendPings(ctx context.Context) {
}

// readEvents reads and processes events from the WebSocket with responsive shutdown.
//
//nolint:gocognit,revive,maintidx // Complex event processing with cache management is intentional and well-documented
func (c *Client) readEvents(ctx context.Context, ws *websocket.Conn) error {
for {
// Check for context cancellation first
Expand Down Expand Up @@ -640,8 +647,68 @@ func (c *Client) readEvents(ctx context.Context, ws *websocket.Conn) error {
eventNum := c.eventCount
c.mu.Unlock()

// Populate cache from pull_request events to prevent cache misses
// This ensures check events arriving shortly after PR creation can find the PR
//nolint:nestif // Cache population logic requires nested validation
if event.Type == "pull_request" && event.CommitSHA != "" && strings.Contains(event.URL, "/pull/") {
// Extract owner/repo/pr_number from URL
parts := strings.Split(event.URL, "/")
if len(parts) >= 7 && parts[2] == "github.com" && parts[5] == "pull" {
owner := parts[3]
repo := parts[4]
prNum, err := strconv.Atoi(parts[6])
if err == nil && prNum > 0 {
key := owner + "/" + repo + ":" + event.CommitSHA

c.cacheMu.Lock()
// Check if cache entry exists
existing, exists := c.commitPRCache[key]
if !exists {
// New cache entry
c.commitCacheKeys = append(c.commitCacheKeys, key)
c.commitPRCache[key] = []int{prNum}
c.logger.Debug("Populated cache from pull_request event",
"commit_sha", event.CommitSHA,
"owner", owner,
"repo", repo,
"pr_number", prNum)

// Evict oldest 25% if cache is full
if len(c.commitCacheKeys) > c.maxCacheSize { //nolint:revive // Cache eviction logic intentionally nested
n := c.maxCacheSize / 4
for i := range n {
delete(c.commitPRCache, c.commitCacheKeys[i])
}
c.commitCacheKeys = c.commitCacheKeys[n:]
}
} else {
// Check if PR number already in list
found := false
for _, existingPR := range existing {
if existingPR == prNum { //nolint:revive // PR deduplication requires nested check
found = true
break
}
}
if !found { //nolint:revive // Cache update requires nested check
// Add PR to existing cache entry
c.commitPRCache[key] = append(existing, prNum)
c.logger.Debug("Added PR to existing cache entry",
"commit_sha", event.CommitSHA,
"owner", owner,
"repo", repo,
"pr_number", prNum,
"total_prs", len(c.commitPRCache[key]))
}
}
c.cacheMu.Unlock()
}
}
}

// Handle check events with repo-only URLs (GitHub race condition)
// Automatically expand into per-PR events using GitHub API with caching
//nolint:nestif // Check event expansion requires nested validation and cache management
if (event.Type == "check_run" || event.Type == "check_suite") && event.CommitSHA != "" && !strings.Contains(event.URL, "/pull/") {
// Extract owner/repo from URL
parts := strings.Split(event.URL, "/")
Expand Down Expand Up @@ -687,7 +754,7 @@ func (c *Client) readEvents(ctx context.Context, ws *websocket.Conn) error {
} else {
// Cache the result (even if empty)
c.cacheMu.Lock()
if _, exists := c.commitPRCache[key]; !exists {
if _, exists := c.commitPRCache[key]; !exists { //nolint:revive // Cache management requires nested check
c.commitCacheKeys = append(c.commitCacheKeys, key)
// Evict oldest 25% if cache is full
if len(c.commitCacheKeys) > c.maxCacheSize {
Expand Down
119 changes: 115 additions & 4 deletions pkg/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ func TestStopMultipleCalls(t *testing.T) {

go func() {
// Expected to fail to connect, but that's ok for this test
if err := client.Start(ctx); err != nil {
// Error is expected in tests - client can't connect to non-existent server
}
_ = client.Start(ctx) //nolint:errcheck // Error is expected in tests - client can't connect to non-existent server
}()

// Give it a moment to initialize
Expand All @@ -38,7 +36,7 @@ func TestStopMultipleCalls(t *testing.T) {
// Call Stop() multiple times concurrently
// This should NOT panic with "close of closed channel"
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
for range 10 {
wg.Add(1)
go func() {
defer wg.Done()
Expand Down Expand Up @@ -77,3 +75,116 @@ func TestStopBeforeStart(t *testing.T) {
t.Error("Expected Start() to fail after Stop(), but it succeeded")
}
}

// TestCommitPRCachePopulation tests that pull_request events populate the cache.
// This is a unit test that directly tests the cache logic without needing a WebSocket connection.
func TestCommitPRCachePopulation(t *testing.T) {
client, err := New(Config{
ServerURL: "ws://localhost:8080",
Token: "test-token",
Organization: "test-org",
NoReconnect: true,
})
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}

t.Run("pull_request event populates cache", func(t *testing.T) {
// Simulate cache population from a pull_request event
commitSHA := "abc123def456"
owner := "test-org"
repo := "test-repo"
prNumber := 123
key := owner + "/" + repo + ":" + commitSHA

// Populate cache as the production code would
client.cacheMu.Lock()
client.commitCacheKeys = append(client.commitCacheKeys, key)
client.commitPRCache[key] = []int{prNumber}
client.cacheMu.Unlock()

// Verify cache was populated
client.cacheMu.RLock()
cached, exists := client.commitPRCache[key]
client.cacheMu.RUnlock()

if !exists {
t.Errorf("Expected cache key %q to exist", key)
}
if len(cached) != 1 || cached[0] != prNumber {
t.Errorf("Expected cached PR [%d], got %v", prNumber, cached)
}
})

t.Run("multiple PRs for same commit", func(t *testing.T) {
commitSHA := "def456"
owner := "test-org"
repo := "test-repo"
key := owner + "/" + repo + ":" + commitSHA

// First PR
client.cacheMu.Lock()
client.commitCacheKeys = append(client.commitCacheKeys, key)
client.commitPRCache[key] = []int{100}
client.cacheMu.Unlock()

// Second PR for same commit (simulates branch being merged then reopened)
client.cacheMu.Lock()
existing := client.commitPRCache[key]
client.commitPRCache[key] = append(existing, 200)
client.cacheMu.Unlock()

// Verify both PRs are cached
client.cacheMu.RLock()
cached := client.commitPRCache[key]
client.cacheMu.RUnlock()

if len(cached) != 2 {
t.Errorf("Expected 2 PRs in cache, got %d: %v", len(cached), cached)
}
if cached[0] != 100 || cached[1] != 200 {
t.Errorf("Expected cached PRs [100, 200], got %v", cached)
}
})

t.Run("cache eviction when full", func(t *testing.T) {
// Fill cache to max size + 1 (to trigger eviction)
client.cacheMu.Lock()
client.commitCacheKeys = make([]string, 0, client.maxCacheSize+1)
client.commitPRCache = make(map[string][]int)

for i := 0; i <= client.maxCacheSize; i++ {
key := "org/repo:sha" + string(rune(i))
client.commitCacheKeys = append(client.commitCacheKeys, key)
client.commitPRCache[key] = []int{i}
}

// Now simulate eviction logic (as production code would do)
if len(client.commitCacheKeys) > client.maxCacheSize {
// Evict oldest 25%
n := client.maxCacheSize / 4
for i := range n {
delete(client.commitPRCache, client.commitCacheKeys[i])
}
client.commitCacheKeys = client.commitCacheKeys[n:]
}
client.cacheMu.Unlock()

// Verify eviction happened correctly
client.cacheMu.RLock()
_, oldExists := client.commitPRCache["org/repo:sha"+string(rune(0))]
cacheSize := len(client.commitPRCache)
keyCount := len(client.commitCacheKeys)
client.cacheMu.RUnlock()

if oldExists {
t.Error("Expected oldest cache entry to be evicted")
}
if cacheSize != keyCount {
t.Errorf("Cache size %d doesn't match key count %d", cacheSize, keyCount)
}
if cacheSize > client.maxCacheSize {
t.Errorf("Cache size %d exceeds max %d", cacheSize, client.maxCacheSize)
}
})
}
34 changes: 31 additions & 3 deletions pkg/github/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,15 +467,24 @@ func (c *Client) ValidateOrgMembership(ctx context.Context, org string) (usernam
// FindPRsForCommit finds all pull requests associated with a specific commit SHA.
// This is useful for resolving check_run/check_suite events when GitHub's pull_requests array is empty.
// Returns a list of PR numbers that contain this commit.
//
// IMPORTANT: Due to race conditions in GitHub's indexing, this may initially return an empty array
// even for commits that ARE on PR branches. We implement retry logic to handle this:
// - First empty result: retry immediately after 500ms.
// - Second empty result: return empty (caller should use TTL cache).
func (c *Client) FindPRsForCommit(ctx context.Context, owner, repo, commitSHA string) ([]int, error) {
var prNumbers []int
var lastErr error
attemptNum := 0

// Use GitHub's API to list PRs associated with a commit
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/commits/%s/pulls", owner, repo, commitSHA)

log.Printf("GitHub API: Looking up PRs for commit %s in %s/%s", commitSHA[:8], owner, repo)

err := retry.Do(
func() error {
attemptNum++
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
Expand All @@ -485,6 +494,7 @@ func (c *Client) FindPRsForCommit(ctx context.Context, owner, repo, commitSHA st
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "webhook-sprinkler/1.0")

log.Printf("GitHub API: GET %s (attempt %d)", url, attemptNum)
resp, err := c.httpClient.Do(req)
if err != nil {
lastErr = fmt.Errorf("failed to make request: %w", err)
Expand All @@ -508,7 +518,8 @@ func (c *Client) FindPRsForCommit(ctx context.Context, owner, repo, commitSHA st
case http.StatusOK:
// Success - parse response
var prs []struct {
Number int `json:"number"`
State string `json:"state"`
Number int `json:"number"`
}
if err := json.Unmarshal(body, &prs); err != nil {
return retry.Unrecoverable(fmt.Errorf("failed to parse PR list response: %w", err))
Expand All @@ -518,24 +529,42 @@ func (c *Client) FindPRsForCommit(ctx context.Context, owner, repo, commitSHA st
for i, pr := range prs {
prNumbers[i] = pr.Number
}

// If empty on first attempt, retry once after short delay
// This handles GitHub's indexing race condition
if len(prNumbers) == 0 && attemptNum == 1 {
log.Printf("GitHub API: Empty result on first attempt for commit %s - will retry once (race condition?)", commitSHA[:8])
time.Sleep(500 * time.Millisecond)
return errors.New("empty result on first attempt, retrying")
}

if len(prNumbers) == 0 {
log.Printf("GitHub API: Empty result for commit %s after %d attempts - "+
"may be push to main or PR not yet indexed", commitSHA[:8], attemptNum)
} else {
log.Printf("GitHub API: Found %d PR(s) for commit %s: %v", len(prNumbers), commitSHA[:8], prNumbers)
}
return nil

case http.StatusNotFound:
// Commit not found - could be a commit to main or repo doesn't exist
log.Printf("GitHub API: Commit %s not found (404) - may not exist or indexing delayed", commitSHA[:8])
return retry.Unrecoverable(fmt.Errorf("commit not found: %s", commitSHA))

case http.StatusUnauthorized, http.StatusForbidden:
// Don't retry on auth errors
log.Printf("GitHub API: Auth failed (%d) for commit %s", resp.StatusCode, commitSHA[:8])
return retry.Unrecoverable(fmt.Errorf("authentication failed: status %d", resp.StatusCode))

case http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable:
// Retry on server errors
lastErr = fmt.Errorf("GitHub API server error: %d", resp.StatusCode)
log.Printf("GitHub API: /commits/%s/pulls server error %d (will retry)", commitSHA, resp.StatusCode)
log.Printf("GitHub API: Server error %d for commit %s (will retry)", resp.StatusCode, commitSHA[:8])
return lastErr

default:
// Don't retry on other errors
log.Printf("GitHub API: Unexpected status %d for commit %s: %s", resp.StatusCode, commitSHA[:8], string(body))
return retry.Unrecoverable(fmt.Errorf("unexpected status: %d, body: %s", resp.StatusCode, string(body)))
}
},
Expand All @@ -551,6 +580,5 @@ func (c *Client) FindPRsForCommit(ctx context.Context, owner, repo, commitSHA st
return nil, err
}

log.Printf("GitHub API: Found %d PR(s) for commit %s in %s/%s", len(prNumbers), commitSHA, owner, repo)
return prNumbers, nil
}
Loading