From fb513d23850b615b48a8fdd3e25829a68dfa7066 Mon Sep 17 00:00:00 2001 From: Thomas Stromberg Date: Fri, 24 Oct 2025 15:38:03 -0400 Subject: [PATCH 1/2] fix panic --- cmd/server/main.go | 21 +++++++++++---------- pkg/client/client.go | 30 +++++++++++++++++++++--------- pkg/security/connlimiter.go | 4 ++-- pkg/security/race_test.go | 6 +++--- pkg/srv/websocket.go | 23 +++++++---------------- pkg/webhook/handler.go | 32 +++++++++++--------------------- 6 files changed, 55 insertions(+), 61 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index 851def5..220cfb2 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -32,13 +32,10 @@ const ( minMaskHeaderLength = 20 // Minimum header length before we show full "[REDACTED]" ) -// getEnvOrDefault returns the value of the environment variable or the default if not set. -func getEnvOrDefault(key, defaultValue string) string { - if value := os.Getenv(key); value != "" { - return value - } - return defaultValue -} +// contextKey is a custom type for context keys to avoid collisions. +type contextKey string + +const reservationTokenKey contextKey = "reservation_token" var ( webhookSecret = flag.String("webhook-secret", os.Getenv("GITHUB_WEBHOOK_SECRET"), "GitHub webhook secret for signature verification") @@ -50,8 +47,12 @@ var ( maxConnsPerIP = flag.Int("max-conns-per-ip", 10, "Maximum WebSocket connections per IP") maxConnsTotal = flag.Int("max-conns-total", 1000, "Maximum total WebSocket connections") rateLimit = flag.Int("rate-limit", 100, "Maximum requests per minute per IP") - allowedEvents = flag.String("allowed-events", getEnvOrDefault("ALLOWED_WEBHOOK_EVENTS", "*"), - "Comma-separated list of allowed webhook event types (use '*' for all, default: '*')") + allowedEvents = flag.String("allowed-events", func() string { + if value := os.Getenv("ALLOWED_WEBHOOK_EVENTS"); value != "" { + return value + } + return "*" + }(), "Comma-separated list of allowed webhook event types (use '*' for all, default: '*')") debugHeaders = flag.Bool("debug-headers", false, "Log request headers for debugging (security warning: may log sensitive data)") ) @@ -244,7 +245,7 @@ func main() { } // Set reservation token in request context so websocket handler can commit it - r = r.WithContext(context.WithValue(r.Context(), "reservation_token", reservationToken)) + r = r.WithContext(context.WithValue(r.Context(), reservationTokenKey, reservationToken)) // Log successful auth and proceed to upgrade log.Printf("WebSocket UPGRADE: ip=%s duration=%v", ip, time.Since(startTime)) diff --git a/pkg/client/client.go b/pkg/client/client.go index 629d385..9ba29c3 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -85,7 +85,8 @@ type Client struct { ws *websocket.Conn stopCh chan struct{} stoppedCh chan struct{} - writeCh chan any // Channel for serializing all writes + stopOnce sync.Once // Ensures Stop() is only executed once + writeCh chan any // Channel for serializing all writes eventCount int retries int } @@ -220,16 +221,27 @@ func (c *Client) Start(ctx context.Context) error { } // Stop gracefully stops the client. +// Safe to call multiple times - only the first call will take effect. +// Also safe to call before Start() or if Start() was never called. func (c *Client) Stop() { - close(c.stopCh) - c.mu.Lock() - if c.ws != nil { - if closeErr := c.ws.Close(); closeErr != nil { - c.logger.Error("Error closing websocket on shutdown", "error", closeErr) + c.stopOnce.Do(func() { + close(c.stopCh) + c.mu.Lock() + if c.ws != nil { + if closeErr := c.ws.Close(); closeErr != nil { + c.logger.Error("Error closing websocket on shutdown", "error", closeErr) + } } - } - c.mu.Unlock() - <-c.stoppedCh + c.mu.Unlock() + + // Wait for Start() to finish, but with timeout in case Start() was never called + select { + case <-c.stoppedCh: + // Start() completed normally + case <-time.After(100 * time.Millisecond): + // Start() was never called or hasn't started yet - that's ok + } + }) } // connect establishes a WebSocket connection and handles events. diff --git a/pkg/security/connlimiter.go b/pkg/security/connlimiter.go index 787ddff..7f3bb87 100644 --- a/pkg/security/connlimiter.go +++ b/pkg/security/connlimiter.go @@ -25,8 +25,8 @@ type connectionInfo struct { // reservation represents a reserved connection slot. type reservation struct { - ip string createdAt time.Time + ip string } // ConnectionLimiter tracks connections per IP and total. @@ -185,7 +185,7 @@ func (cl *ConnectionLimiter) CancelReservation(token string) { } // CanAdd checks if a connection can be added for the given IP without actually adding it. -// DEPRECATED: This method has a TOCTOU race condition. Use Reserve() instead, which +// Deprecated: This method has a TOCTOU race condition. Use Reserve() instead, which // atomically checks and reserves a slot, preventing the race. // // This method is kept for backward compatibility and testing only. diff --git a/pkg/security/race_test.go b/pkg/security/race_test.go index d29685d..00e7b30 100644 --- a/pkg/security/race_test.go +++ b/pkg/security/race_test.go @@ -114,9 +114,9 @@ func TestConnectionLimiterTOCTOU_Documentation(t *testing.T) { ip := "192.168.1.1" var wg sync.WaitGroup - var canAddSuccess int32 // How many times CanAdd returned true - var addSuccess int32 // How many times Add actually succeeded - var addFailed int32 // How many times Add failed despite CanAdd=true + var canAddSuccess int32 // How many times CanAdd returned true + var addSuccess int32 // How many times Add actually succeeded + var addFailed int32 // How many times Add failed despite CanAdd=true // Launch many goroutines simultaneously trying to add connections // This simulates multiple HTTP handlers racing to add connections diff --git a/pkg/srv/websocket.go b/pkg/srv/websocket.go index 565527e..0f77b36 100644 --- a/pkg/srv/websocket.go +++ b/pkg/srv/websocket.go @@ -150,14 +150,6 @@ func (h *WebSocketHandler) extractGitHubToken(ws *websocket.Conn, ip string) (st return githubToken, true } -// tokenDebugInfo extracts token prefix for debug logging. -func tokenDebugInfo(token string) string { - if len(token) >= tokenPrefixLength { - return token[:tokenPrefixLength] - } - return "" -} - // errorInfo holds error response details. type errorInfo struct { code string @@ -281,7 +273,10 @@ func (*WebSocketHandler) handleAuthError( logContext string, ) error { errInfo := determineErrorInfo(err, username, orgName, userOrgs) - tokenPrefix := tokenDebugInfo(githubToken) + tokenPrefix := "" + if len(githubToken) >= tokenPrefixLength { + tokenPrefix = githubToken[:tokenPrefixLength] + } logger.Error(logContext, err, logger.Fields{ "ip": ip, @@ -411,11 +406,6 @@ type wsCloser struct { mu sync.Mutex } -// newWSCloser creates a new WebSocket closer wrapper. -func newWSCloser(ws *websocket.Conn) *wsCloser { - return &wsCloser{ws: ws} -} - // Close closes the WebSocket connection exactly once. func (wc *wsCloser) Close() error { var err error @@ -491,7 +481,7 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) { log.Printf("WebSocket Handle() got IP: %s", ip) // Wrap WebSocket with sync.Once closer to prevent double-close - wc := newWSCloser(ws) + wc := &wsCloser{ws: ws} // Ensure WebSocket is properly closed (client will be set later if connection succeeds) var client *Client @@ -508,7 +498,8 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) { }) // Get reservation token from context (set by main.go before upgrade) - reservationToken, _ := ws.Request().Context().Value("reservation_token").(string) + // Context key is a string type for package boundary crossing + reservationToken, _ := ws.Request().Context().Value("reservation_token").(string) //nolint:errcheck // Type assertion intentionally unchecked - empty string is valid default if reservationToken == "" { // No reservation token - this should not happen in production // (main.go always sets it), but handle gracefully for tests diff --git a/pkg/webhook/handler.go b/pkg/webhook/handler.go index 8009f15..1860a4b 100644 --- a/pkg/webhook/handler.go +++ b/pkg/webhook/handler.go @@ -182,7 +182,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // GitHub webhooks can fire before the pull_requests array is populated commitSHA := extractCommitSHA(eventType, payload) // Extract repo URL as fallback for org-based matching - repoURL := extractRepoURL(payload) + repoURL := "" + if repo, ok := payload["repository"].(map[string]any); ok { + if htmlURL, ok := repo["html_url"].(string); ok { + repoURL = htmlURL + } + } // If we can't extract repo URL, drop the event if repoURL == "" { @@ -299,11 +304,15 @@ func ExtractPRURL(eventType string, payload map[string]any) string { } } // Log when we can't extract PR URL from check event + payloadKeys := make([]string, 0, len(payload)) + for k := range payload { + payloadKeys = append(payloadKeys, k) + } logger.Warn("no PR URL found in check event", logger.Fields{ "event_type": eventType, "has_check_run": payload["check_run"] != nil, "has_check_suite": payload["check_suite"] != nil, - "payload_keys": getPayloadKeys(payload), + "payload_keys": payloadKeys, }) default: // For other event types, no PR URL can be extracted @@ -378,15 +387,6 @@ func extractPRFromCheckEvent(checkEvent map[string]any, payload map[string]any, return constructedURL } -// getPayloadKeys returns the keys from a payload map for logging. -func getPayloadKeys(payload map[string]any) []string { - keys := make([]string, 0, len(payload)) - for k := range payload { - keys = append(keys, k) - } - return keys -} - // getMapKeys returns the keys from a map for logging. func getMapKeys(m map[string]any) []string { keys := make([]string, 0, len(m)) @@ -417,13 +417,3 @@ func extractCommitSHA(eventType string, payload map[string]any) string { return "" } -// extractRepoURL extracts the repository HTML URL from the payload. -// This is used as a fallback when PR URL cannot be extracted (e.g., check event race condition). -func extractRepoURL(payload map[string]any) string { - if repo, ok := payload["repository"].(map[string]any); ok { - if htmlURL, ok := repo["html_url"].(string); ok { - return htmlURL - } - } - return "" -} From f65df9fbea6e7a4e0bdf7282175400c7891e3688 Mon Sep 17 00:00:00 2001 From: Thomas Stromberg Date: Mon, 27 Oct 2025 08:26:11 -0400 Subject: [PATCH 2/2] transparent test->PR lookups --- go.sum | 8 --- pkg/client/client.go | 123 ++++++++++++++++++++++++++++++++++++-- pkg/client/client_test.go | 79 ++++++++++++++++++++++++ pkg/github/client.go | 91 ++++++++++++++++++++++++++++ pkg/srv/hub.go | 9 +-- pkg/webhook/handler.go | 5 ++ 6 files changed, 298 insertions(+), 17 deletions(-) create mode 100644 pkg/client/client_test.go diff --git a/go.sum b/go.sum index a4b57c7..86e9c01 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,8 @@ github.com/codeGROOVE-dev/retry v1.2.0 h1:xYpYPX2PQZmdHwuiQAGGzsBm392xIMl4nfMEFApQnu8= github.com/codeGROOVE-dev/retry v1.2.0/go.mod h1:8OgefgV1XP7lzX2PdKlCXILsYKuz6b4ZpHa/20iLi8E= -golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= -golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= -golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I= -golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= -golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM= -golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= -golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= -golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= diff --git a/pkg/client/client.go b/pkg/client/client.go index 9ba29c3..c7ed15b 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -11,6 +11,7 @@ import ( "time" "github.com/codeGROOVE-dev/retry" + "github.com/codeGROOVE-dev/sprinkler/pkg/github" "golang.org/x/net/websocket" ) @@ -45,8 +46,9 @@ type Event struct { Timestamp time.Time `json:"timestamp"` Raw map[string]any Type string `json:"type"` - URL string `json:"url"` + URL string `json:"url"` // PR URL (or repo URL for check events with race condition) DeliveryID string `json:"delivery_id,omitempty"` + CommitSHA string `json:"commit_sha,omitempty"` // Commit SHA for check events } // Config holds the configuration for the client. @@ -89,6 +91,12 @@ type Client struct { writeCh chan any // Channel for serializing all writes eventCount int retries int + + // Cache for commit SHA to PR number lookups (for check event race condition) + commitPRCache map[string][]int // key: "owner/repo:sha", value: PR numbers + commitCacheKeys []string // track insertion order for LRU eviction + cacheMu sync.RWMutex + maxCacheSize int } // New creates a new robust WebSocket client. @@ -119,10 +127,13 @@ func New(config Config) (*Client, error) { } return &Client{ - config: config, - stopCh: make(chan struct{}), - stoppedCh: make(chan struct{}), - logger: logger, + config: config, + stopCh: make(chan struct{}), + stoppedCh: make(chan struct{}), + logger: logger, + commitPRCache: make(map[string][]int), + commitCacheKeys: make([]string, 0, 512), + maxCacheSize: 512, }, nil } @@ -620,11 +631,111 @@ func (c *Client) readEvents(ctx context.Context, ws *websocket.Conn) error { event.DeliveryID = deliveryID } + if commitSHA, ok := response["commit_sha"].(string); ok { + event.CommitSHA = commitSHA + } + c.mu.Lock() c.eventCount++ eventNum := c.eventCount c.mu.Unlock() + // Handle check events with repo-only URLs (GitHub race condition) + // Automatically expand into per-PR events using GitHub API with caching + if (event.Type == "check_run" || event.Type == "check_suite") && event.CommitSHA != "" && !strings.Contains(event.URL, "/pull/") { + // Extract owner/repo from URL + parts := strings.Split(event.URL, "/") + if len(parts) >= 5 && parts[2] == "github.com" { + owner := parts[3] + repo := parts[4] + key := owner + "/" + repo + ":" + event.CommitSHA + + // Check cache first + c.cacheMu.RLock() + cached, ok := c.commitPRCache[key] + c.cacheMu.RUnlock() + + var prs []int + if ok { + // Cache hit - return copy to prevent external modifications + prs = make([]int, len(cached)) + copy(prs, cached) + c.logger.Info("Check event with repo URL - using cached PR lookup", + "commit_sha", event.CommitSHA, + "repo_url", event.URL, + "type", event.Type, + "pr_count", len(prs), + "cache_hit", true) + } else { + // Cache miss - look up via GitHub API + c.logger.Info("Check event with repo URL - looking up PRs via GitHub API", + "commit_sha", event.CommitSHA, + "repo_url", event.URL, + "type", event.Type, + "cache_hit", false) + + gh := github.NewClient(c.config.Token) + var err error + prs, err = gh.FindPRsForCommit(ctx, owner, repo, event.CommitSHA) + if err != nil { + c.logger.Warn("Failed to look up PRs for commit", + "commit_sha", event.CommitSHA, + "owner", owner, + "repo", repo, + "error", err) + // Don't cache errors - try again next time + } else { + // Cache the result (even if empty) + c.cacheMu.Lock() + if _, exists := c.commitPRCache[key]; !exists { + c.commitCacheKeys = append(c.commitCacheKeys, key) + // Evict oldest 25% if cache is full + if len(c.commitCacheKeys) > c.maxCacheSize { + n := c.maxCacheSize / 4 + for i := range n { + delete(c.commitPRCache, c.commitCacheKeys[i]) + } + c.commitCacheKeys = c.commitCacheKeys[n:] + } + } + // Store copy to prevent external modifications + cached := make([]int, len(prs)) + copy(cached, prs) + c.commitPRCache[key] = cached + c.cacheMu.Unlock() + + c.logger.Info("Cached PR lookup result", + "commit_sha", event.CommitSHA, + "pr_count", len(prs)) + } + } + + // Emit events for each PR found + if len(prs) > 0 { + for _, n := range prs { + e := event // Copy the event + e.URL = fmt.Sprintf("https://github.com/%s/%s/pull/%d", owner, repo, n) + + if c.config.OnEvent != nil { + c.logger.Info("Event received (expanded from commit)", + "timestamp", e.Timestamp.Format("15:04:05"), + "event_number", eventNum, + "type", e.Type, + "url", e.URL, + "commit_sha", e.CommitSHA, + "delivery_id", e.DeliveryID) + c.config.OnEvent(e) + } + } + continue // Skip the normal event handling since we expanded it + } + c.logger.Info("No PRs found for commit - may be push to main", + "commit_sha", event.CommitSHA, + "owner", owner, + "repo", repo) + } + } + // Log event if c.config.Verbose { c.logger.Info("Event received", @@ -632,6 +743,7 @@ func (c *Client) readEvents(ctx context.Context, ws *websocket.Conn) error { "timestamp", event.Timestamp.Format("15:04:05"), "type", event.Type, "url", event.URL, + "commit_sha", event.CommitSHA, "delivery_id", event.DeliveryID, "raw", event.Raw) } else { @@ -641,6 +753,7 @@ func (c *Client) readEvents(ctx context.Context, ws *websocket.Conn) error { "event_number", eventNum, "type", event.Type, "url", event.URL, + "commit_sha", event.CommitSHA, "delivery_id", event.DeliveryID) } else { c.logger.Info("Event received", diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go new file mode 100644 index 0000000..485aa7c --- /dev/null +++ b/pkg/client/client_test.go @@ -0,0 +1,79 @@ +package client + +import ( + "context" + "sync" + "testing" + "time" +) + +// TestStopMultipleCalls verifies that calling Stop() multiple times is safe +// and doesn't panic with "close of closed channel". +func TestStopMultipleCalls(t *testing.T) { + // Create a client with minimal config + client, err := New(Config{ + ServerURL: "ws://localhost:8080", + Token: "test-token", + Organization: "test-org", + NoReconnect: true, // Disable reconnect to make test faster + }) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + // Start the client in a goroutine + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + // Expected to fail to connect, but that's ok for this test + if err := client.Start(ctx); err != nil { + // Error is expected in tests - client can't connect to non-existent server + } + }() + + // Give it a moment to initialize + time.Sleep(10 * time.Millisecond) + + // Call Stop() multiple times concurrently + // This should NOT panic with "close of closed channel" + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + client.Stop() // Should be safe to call multiple times + }() + } + + // Wait for all Stop() calls to complete + wg.Wait() + + // If we get here without a panic, the test passes +} + +// TestStopBeforeStart verifies that calling Stop() before Start() is safe. +func TestStopBeforeStart(t *testing.T) { + client, err := New(Config{ + ServerURL: "ws://localhost:8080", + Token: "test-token", + Organization: "test-org", + NoReconnect: true, + }) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + // Call Stop() before Start() + client.Stop() + + // Now try to start - should exit cleanly + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err = client.Start(ctx) + // We expect either context.DeadlineExceeded or "stop requested" + if err == nil { + t.Error("Expected Start() to fail after Stop(), but it succeeded") + } +} diff --git a/pkg/github/client.go b/pkg/github/client.go index 9644190..aea8c26 100644 --- a/pkg/github/client.go +++ b/pkg/github/client.go @@ -463,3 +463,94 @@ func (c *Client) ValidateOrgMembership(ctx context.Context, org string) (usernam log.Printf("GitHub API: User is member of %d organizations: %v", len(orgNames), orgNames) return username, orgNames, errors.New("user is not a member of the requested organization") } + +// FindPRsForCommit finds all pull requests associated with a specific commit SHA. +// This is useful for resolving check_run/check_suite events when GitHub's pull_requests array is empty. +// Returns a list of PR numbers that contain this commit. +func (c *Client) FindPRsForCommit(ctx context.Context, owner, repo, commitSHA string) ([]int, error) { + var prNumbers []int + var lastErr error + + // Use GitHub's API to list PRs associated with a commit + url := fmt.Sprintf("https://api.github.com/repos/%s/%s/commits/%s/pulls", owner, repo, commitSHA) + + err := retry.Do( + func() error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token)) + req.Header.Set("Accept", "application/vnd.github.v3+json") + req.Header.Set("User-Agent", "webhook-sprinkler/1.0") + + resp, err := c.httpClient.Do(req) + if err != nil { + lastErr = fmt.Errorf("failed to make request: %w", err) + log.Printf("GitHub API request failed (will retry): %v", err) + return err // Retry on network errors + } + defer func() { + if err := resp.Body.Close(); err != nil { + log.Printf("failed to close response body: %v", err) + } + }() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1MB limit + if err != nil { + lastErr = fmt.Errorf("failed to read response: %w", err) + return err // Retry on read errors + } + + // Handle status codes + switch resp.StatusCode { + case http.StatusOK: + // Success - parse response + var prs []struct { + Number int `json:"number"` + } + if err := json.Unmarshal(body, &prs); err != nil { + return retry.Unrecoverable(fmt.Errorf("failed to parse PR list response: %w", err)) + } + + prNumbers = make([]int, len(prs)) + for i, pr := range prs { + prNumbers[i] = pr.Number + } + return nil + + case http.StatusNotFound: + // Commit not found - could be a commit to main or repo doesn't exist + return retry.Unrecoverable(fmt.Errorf("commit not found: %s", commitSHA)) + + case http.StatusUnauthorized, http.StatusForbidden: + // Don't retry on auth errors + return retry.Unrecoverable(fmt.Errorf("authentication failed: status %d", resp.StatusCode)) + + case http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable: + // Retry on server errors + lastErr = fmt.Errorf("GitHub API server error: %d", resp.StatusCode) + log.Printf("GitHub API: /commits/%s/pulls server error %d (will retry)", commitSHA, resp.StatusCode) + return lastErr + + default: + // Don't retry on other errors + return retry.Unrecoverable(fmt.Errorf("unexpected status: %d, body: %s", resp.StatusCode, string(body))) + } + }, + retry.Attempts(3), + retry.DelayType(retry.FullJitterBackoffDelay), + retry.MaxDelay(2*time.Minute), + retry.Context(ctx), + ) + if err != nil { + if lastErr != nil { + return nil, lastErr + } + return nil, err + } + + log.Printf("GitHub API: Found %d PR(s) for commit %s in %s/%s", len(prNumbers), commitSHA, owner, repo) + return prNumbers, nil +} diff --git a/pkg/srv/hub.go b/pkg/srv/hub.go index f4b251b..775776f 100644 --- a/pkg/srv/hub.go +++ b/pkg/srv/hub.go @@ -14,10 +14,11 @@ import ( // Event represents a GitHub webhook event that will be broadcast to clients. // It contains the PR URL, timestamp, event type, and delivery ID from GitHub. type Event struct { - URL string `json:"url"` // Pull request URL - Timestamp time.Time `json:"timestamp"` // When the event occurred - Type string `json:"type"` // GitHub event type (e.g., "pull_request") - DeliveryID string `json:"delivery_id,omitempty"` // GitHub webhook delivery ID (unique per webhook) + URL string `json:"url"` // Pull request URL (or repo URL for check events with race condition) + Timestamp time.Time `json:"timestamp"` // When the event occurred + Type string `json:"type"` // GitHub event type (e.g., "pull_request") + DeliveryID string `json:"delivery_id,omitempty"` // GitHub webhook delivery ID (unique per webhook) + CommitSHA string `json:"commit_sha,omitempty"` // Commit SHA for check events (used to look up PR when URL is repo-only) } // Hub manages WebSocket clients and event broadcasting. diff --git a/pkg/webhook/handler.go b/pkg/webhook/handler.go index 1860a4b..15415f1 100644 --- a/pkg/webhook/handler.go +++ b/pkg/webhook/handler.go @@ -224,6 +224,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { DeliveryID: deliveryID, } + // For check events, include commit SHA to allow PR lookup when URL is repo-only (race condition) + if eventType == "check_run" || eventType == "check_suite" { + event.CommitSHA = extractCommitSHA(eventType, payload) + } + // Get client count before broadcasting (for debugging delivery issues) clientCount := h.hub.ClientCount()