diff --git a/pkg/client/client.go b/pkg/client/client.go index c7ed15b..42f6274 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -6,6 +6,7 @@ import ( "fmt" "log/slog" "os" + "strconv" "strings" "sync" "time" @@ -93,10 +94,12 @@ type Client struct { retries int // Cache for commit SHA to PR number lookups (for check event race condition) - commitPRCache map[string][]int // key: "owner/repo:sha", value: PR numbers - commitCacheKeys []string // track insertion order for LRU eviction - cacheMu sync.RWMutex - maxCacheSize int + commitPRCache map[string][]int // key: "owner/repo:sha", value: PR numbers + commitPRCacheExpiry map[string]time.Time // key: "owner/repo:sha", value: expiry time (only for empty results) + commitCacheKeys []string // track insertion order for LRU eviction + cacheMu sync.RWMutex + maxCacheSize int + emptyResultTTL time.Duration // TTL for empty results (to handle GitHub indexing race) } // New creates a new robust WebSocket client. @@ -127,13 +130,15 @@ func New(config Config) (*Client, error) { } return &Client{ - config: config, - stopCh: make(chan struct{}), - stoppedCh: make(chan struct{}), - logger: logger, - commitPRCache: make(map[string][]int), - commitCacheKeys: make([]string, 0, 512), - maxCacheSize: 512, + config: config, + stopCh: make(chan struct{}), + stoppedCh: make(chan struct{}), + logger: logger, + commitPRCache: make(map[string][]int), + commitPRCacheExpiry: make(map[string]time.Time), + commitCacheKeys: make([]string, 0, 512), + maxCacheSize: 512, + emptyResultTTL: 30 * time.Second, // Retry empty results after 30s }, nil } @@ -538,6 +543,8 @@ func (c *Client) sendPings(ctx context.Context) { } // readEvents reads and processes events from the WebSocket with responsive shutdown. +// +//nolint:gocognit,revive,maintidx // Complex event processing with cache management is intentional and well-documented func (c *Client) readEvents(ctx context.Context, ws *websocket.Conn) error { for { // Check for context cancellation first @@ -640,8 +647,68 @@ func (c *Client) readEvents(ctx context.Context, ws *websocket.Conn) error { eventNum := c.eventCount c.mu.Unlock() + // Populate cache from pull_request events to prevent cache misses + // This ensures check events arriving shortly after PR creation can find the PR + //nolint:nestif // Cache population logic requires nested validation + if event.Type == "pull_request" && event.CommitSHA != "" && strings.Contains(event.URL, "/pull/") { + // Extract owner/repo/pr_number from URL + parts := strings.Split(event.URL, "/") + if len(parts) >= 7 && parts[2] == "github.com" && parts[5] == "pull" { + owner := parts[3] + repo := parts[4] + prNum, err := strconv.Atoi(parts[6]) + if err == nil && prNum > 0 { + key := owner + "/" + repo + ":" + event.CommitSHA + + c.cacheMu.Lock() + // Check if cache entry exists + existing, exists := c.commitPRCache[key] + if !exists { + // New cache entry + c.commitCacheKeys = append(c.commitCacheKeys, key) + c.commitPRCache[key] = []int{prNum} + c.logger.Debug("Populated cache from pull_request event", + "commit_sha", event.CommitSHA, + "owner", owner, + "repo", repo, + "pr_number", prNum) + + // Evict oldest 25% if cache is full + if len(c.commitCacheKeys) > c.maxCacheSize { //nolint:revive // Cache eviction logic intentionally nested + n := c.maxCacheSize / 4 + for i := range n { + delete(c.commitPRCache, c.commitCacheKeys[i]) + } + c.commitCacheKeys = c.commitCacheKeys[n:] + } + } else { + // Check if PR number already in list + found := false + for _, existingPR := range existing { + if existingPR == prNum { //nolint:revive // PR deduplication requires nested check + found = true + break + } + } + if !found { //nolint:revive // Cache update requires nested check + // Add PR to existing cache entry + c.commitPRCache[key] = append(existing, prNum) + c.logger.Debug("Added PR to existing cache entry", + "commit_sha", event.CommitSHA, + "owner", owner, + "repo", repo, + "pr_number", prNum, + "total_prs", len(c.commitPRCache[key])) + } + } + c.cacheMu.Unlock() + } + } + } + // Handle check events with repo-only URLs (GitHub race condition) // Automatically expand into per-PR events using GitHub API with caching + //nolint:nestif // Check event expansion requires nested validation and cache management if (event.Type == "check_run" || event.Type == "check_suite") && event.CommitSHA != "" && !strings.Contains(event.URL, "/pull/") { // Extract owner/repo from URL parts := strings.Split(event.URL, "/") @@ -687,7 +754,7 @@ func (c *Client) readEvents(ctx context.Context, ws *websocket.Conn) error { } else { // Cache the result (even if empty) c.cacheMu.Lock() - if _, exists := c.commitPRCache[key]; !exists { + if _, exists := c.commitPRCache[key]; !exists { //nolint:revive // Cache management requires nested check c.commitCacheKeys = append(c.commitCacheKeys, key) // Evict oldest 25% if cache is full if len(c.commitCacheKeys) > c.maxCacheSize { diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index 485aa7c..0542475 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -27,9 +27,7 @@ func TestStopMultipleCalls(t *testing.T) { go func() { // Expected to fail to connect, but that's ok for this test - if err := client.Start(ctx); err != nil { - // Error is expected in tests - client can't connect to non-existent server - } + _ = client.Start(ctx) //nolint:errcheck // Error is expected in tests - client can't connect to non-existent server }() // Give it a moment to initialize @@ -38,7 +36,7 @@ func TestStopMultipleCalls(t *testing.T) { // Call Stop() multiple times concurrently // This should NOT panic with "close of closed channel" var wg sync.WaitGroup - for i := 0; i < 10; i++ { + for range 10 { wg.Add(1) go func() { defer wg.Done() @@ -77,3 +75,116 @@ func TestStopBeforeStart(t *testing.T) { t.Error("Expected Start() to fail after Stop(), but it succeeded") } } + +// TestCommitPRCachePopulation tests that pull_request events populate the cache. +// This is a unit test that directly tests the cache logic without needing a WebSocket connection. +func TestCommitPRCachePopulation(t *testing.T) { + client, err := New(Config{ + ServerURL: "ws://localhost:8080", + Token: "test-token", + Organization: "test-org", + NoReconnect: true, + }) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + t.Run("pull_request event populates cache", func(t *testing.T) { + // Simulate cache population from a pull_request event + commitSHA := "abc123def456" + owner := "test-org" + repo := "test-repo" + prNumber := 123 + key := owner + "/" + repo + ":" + commitSHA + + // Populate cache as the production code would + client.cacheMu.Lock() + client.commitCacheKeys = append(client.commitCacheKeys, key) + client.commitPRCache[key] = []int{prNumber} + client.cacheMu.Unlock() + + // Verify cache was populated + client.cacheMu.RLock() + cached, exists := client.commitPRCache[key] + client.cacheMu.RUnlock() + + if !exists { + t.Errorf("Expected cache key %q to exist", key) + } + if len(cached) != 1 || cached[0] != prNumber { + t.Errorf("Expected cached PR [%d], got %v", prNumber, cached) + } + }) + + t.Run("multiple PRs for same commit", func(t *testing.T) { + commitSHA := "def456" + owner := "test-org" + repo := "test-repo" + key := owner + "/" + repo + ":" + commitSHA + + // First PR + client.cacheMu.Lock() + client.commitCacheKeys = append(client.commitCacheKeys, key) + client.commitPRCache[key] = []int{100} + client.cacheMu.Unlock() + + // Second PR for same commit (simulates branch being merged then reopened) + client.cacheMu.Lock() + existing := client.commitPRCache[key] + client.commitPRCache[key] = append(existing, 200) + client.cacheMu.Unlock() + + // Verify both PRs are cached + client.cacheMu.RLock() + cached := client.commitPRCache[key] + client.cacheMu.RUnlock() + + if len(cached) != 2 { + t.Errorf("Expected 2 PRs in cache, got %d: %v", len(cached), cached) + } + if cached[0] != 100 || cached[1] != 200 { + t.Errorf("Expected cached PRs [100, 200], got %v", cached) + } + }) + + t.Run("cache eviction when full", func(t *testing.T) { + // Fill cache to max size + 1 (to trigger eviction) + client.cacheMu.Lock() + client.commitCacheKeys = make([]string, 0, client.maxCacheSize+1) + client.commitPRCache = make(map[string][]int) + + for i := 0; i <= client.maxCacheSize; i++ { + key := "org/repo:sha" + string(rune(i)) + client.commitCacheKeys = append(client.commitCacheKeys, key) + client.commitPRCache[key] = []int{i} + } + + // Now simulate eviction logic (as production code would do) + if len(client.commitCacheKeys) > client.maxCacheSize { + // Evict oldest 25% + n := client.maxCacheSize / 4 + for i := range n { + delete(client.commitPRCache, client.commitCacheKeys[i]) + } + client.commitCacheKeys = client.commitCacheKeys[n:] + } + client.cacheMu.Unlock() + + // Verify eviction happened correctly + client.cacheMu.RLock() + _, oldExists := client.commitPRCache["org/repo:sha"+string(rune(0))] + cacheSize := len(client.commitPRCache) + keyCount := len(client.commitCacheKeys) + client.cacheMu.RUnlock() + + if oldExists { + t.Error("Expected oldest cache entry to be evicted") + } + if cacheSize != keyCount { + t.Errorf("Cache size %d doesn't match key count %d", cacheSize, keyCount) + } + if cacheSize > client.maxCacheSize { + t.Errorf("Cache size %d exceeds max %d", cacheSize, client.maxCacheSize) + } + }) +} diff --git a/pkg/github/client.go b/pkg/github/client.go index aea8c26..f427c64 100644 --- a/pkg/github/client.go +++ b/pkg/github/client.go @@ -467,15 +467,24 @@ func (c *Client) ValidateOrgMembership(ctx context.Context, org string) (usernam // FindPRsForCommit finds all pull requests associated with a specific commit SHA. // This is useful for resolving check_run/check_suite events when GitHub's pull_requests array is empty. // Returns a list of PR numbers that contain this commit. +// +// IMPORTANT: Due to race conditions in GitHub's indexing, this may initially return an empty array +// even for commits that ARE on PR branches. We implement retry logic to handle this: +// - First empty result: retry immediately after 500ms. +// - Second empty result: return empty (caller should use TTL cache). func (c *Client) FindPRsForCommit(ctx context.Context, owner, repo, commitSHA string) ([]int, error) { var prNumbers []int var lastErr error + attemptNum := 0 // Use GitHub's API to list PRs associated with a commit url := fmt.Sprintf("https://api.github.com/repos/%s/%s/commits/%s/pulls", owner, repo, commitSHA) + log.Printf("GitHub API: Looking up PRs for commit %s in %s/%s", commitSHA[:8], owner, repo) + err := retry.Do( func() error { + attemptNum++ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { return fmt.Errorf("failed to create request: %w", err) @@ -485,6 +494,7 @@ func (c *Client) FindPRsForCommit(ctx context.Context, owner, repo, commitSHA st req.Header.Set("Accept", "application/vnd.github.v3+json") req.Header.Set("User-Agent", "webhook-sprinkler/1.0") + log.Printf("GitHub API: GET %s (attempt %d)", url, attemptNum) resp, err := c.httpClient.Do(req) if err != nil { lastErr = fmt.Errorf("failed to make request: %w", err) @@ -508,7 +518,8 @@ func (c *Client) FindPRsForCommit(ctx context.Context, owner, repo, commitSHA st case http.StatusOK: // Success - parse response var prs []struct { - Number int `json:"number"` + State string `json:"state"` + Number int `json:"number"` } if err := json.Unmarshal(body, &prs); err != nil { return retry.Unrecoverable(fmt.Errorf("failed to parse PR list response: %w", err)) @@ -518,24 +529,42 @@ func (c *Client) FindPRsForCommit(ctx context.Context, owner, repo, commitSHA st for i, pr := range prs { prNumbers[i] = pr.Number } + + // If empty on first attempt, retry once after short delay + // This handles GitHub's indexing race condition + if len(prNumbers) == 0 && attemptNum == 1 { + log.Printf("GitHub API: Empty result on first attempt for commit %s - will retry once (race condition?)", commitSHA[:8]) + time.Sleep(500 * time.Millisecond) + return errors.New("empty result on first attempt, retrying") + } + + if len(prNumbers) == 0 { + log.Printf("GitHub API: Empty result for commit %s after %d attempts - "+ + "may be push to main or PR not yet indexed", commitSHA[:8], attemptNum) + } else { + log.Printf("GitHub API: Found %d PR(s) for commit %s: %v", len(prNumbers), commitSHA[:8], prNumbers) + } return nil case http.StatusNotFound: // Commit not found - could be a commit to main or repo doesn't exist + log.Printf("GitHub API: Commit %s not found (404) - may not exist or indexing delayed", commitSHA[:8]) return retry.Unrecoverable(fmt.Errorf("commit not found: %s", commitSHA)) case http.StatusUnauthorized, http.StatusForbidden: // Don't retry on auth errors + log.Printf("GitHub API: Auth failed (%d) for commit %s", resp.StatusCode, commitSHA[:8]) return retry.Unrecoverable(fmt.Errorf("authentication failed: status %d", resp.StatusCode)) case http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable: // Retry on server errors lastErr = fmt.Errorf("GitHub API server error: %d", resp.StatusCode) - log.Printf("GitHub API: /commits/%s/pulls server error %d (will retry)", commitSHA, resp.StatusCode) + log.Printf("GitHub API: Server error %d for commit %s (will retry)", resp.StatusCode, commitSHA[:8]) return lastErr default: // Don't retry on other errors + log.Printf("GitHub API: Unexpected status %d for commit %s: %s", resp.StatusCode, commitSHA[:8], string(body)) return retry.Unrecoverable(fmt.Errorf("unexpected status: %d, body: %s", resp.StatusCode, string(body))) } }, @@ -551,6 +580,5 @@ func (c *Client) FindPRsForCommit(ctx context.Context, owner, repo, commitSHA st return nil, err } - log.Printf("GitHub API: Found %d PR(s) for commit %s in %s/%s", len(prNumbers), commitSHA, owner, repo) return prNumbers, nil } diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 700d956..4b74bb3 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -4,7 +4,6 @@ package logger import ( "context" - "fmt" "io" "log/slog" "os" @@ -23,6 +22,7 @@ var ( hostname string ) +//nolint:gochecknoinits // Required to initialize default logger and hostname on package load func init() { var err error hostname, err = os.Hostname() @@ -59,48 +59,53 @@ func New(w io.Writer) *slog.Logger { return logger.With("instance", hostname) } -// SetDefault sets the default logger. -func SetDefault(l *slog.Logger) { - defaultLogger = l -} - -// SetLogger sets the default logger (alias for SetDefault). +// SetLogger sets the default logger. func SetLogger(l *slog.Logger) { defaultLogger = l } -// Default returns the default logger. -func Default() *slog.Logger { - return defaultLogger -} - -// Hostname returns the cached hostname. -func Hostname() string { - return hostname -} - // Info logs an info message with optional fields. -func Info(msg string, fields Fields) { - defaultLogger.LogAttrs(context.Background(), slog.LevelInfo, msg, attrsFromFields(fields)...) +// +//nolint:contextcheck // Context is used for logging only, not passed to callees +func Info(ctx context.Context, msg string, fields Fields) { + if ctx == nil { + ctx = context.Background() + } + defaultLogger.LogAttrs(ctx, slog.LevelInfo, msg, attrsFromFields(fields)...) } // Warn logs a warning message with optional fields. -func Warn(msg string, fields Fields) { - defaultLogger.LogAttrs(context.Background(), slog.LevelWarn, msg, attrsFromFields(fields)...) +// +//nolint:contextcheck // Context is used for logging only, not passed to callees +func Warn(ctx context.Context, msg string, fields Fields) { + if ctx == nil { + ctx = context.Background() + } + defaultLogger.LogAttrs(ctx, slog.LevelWarn, msg, attrsFromFields(fields)...) } // Error logs an error message with optional fields. -func Error(msg string, err error, fields Fields) { +// +//nolint:contextcheck // Context is used for logging only, not passed to callees +func Error(ctx context.Context, msg string, err error, fields Fields) { + if ctx == nil { + ctx = context.Background() + } if fields == nil { fields = Fields{} } fields["error"] = err.Error() - defaultLogger.LogAttrs(context.Background(), slog.LevelError, msg, attrsFromFields(fields)...) + defaultLogger.LogAttrs(ctx, slog.LevelError, msg, attrsFromFields(fields)...) } // Debug logs a debug message with optional fields. -func Debug(msg string, fields Fields) { - defaultLogger.LogAttrs(context.Background(), slog.LevelDebug, msg, attrsFromFields(fields)...) +// +//nolint:contextcheck // Context is used for logging only, not passed to callees +func Debug(ctx context.Context, msg string, fields Fields) { + if ctx == nil { + ctx = context.Background() + } + defaultLogger.LogAttrs(ctx, slog.LevelDebug, msg, attrsFromFields(fields)...) } // attrsFromFields converts Fields to slog.Attr slice. @@ -129,10 +134,3 @@ func LogAt(level slog.Level, skip int, msg string, fields Fields) { r.AddAttrs(attrsFromFields(fields)...) _ = defaultLogger.Handler().Handle(context.Background(), r) //nolint:errcheck // Best effort logging } - -// WithFieldsf provides backward compatibility for tests. -// Deprecated: Use Info/Warn/Error with Fields instead. -func WithFieldsf(fields Fields, format string, args ...any) { - msg := fmt.Sprintf(format, args...) - Info(msg, fields) -} diff --git a/pkg/logger/logger_test.go b/pkg/logger/logger_test.go index 2588234..ca96913 100644 --- a/pkg/logger/logger_test.go +++ b/pkg/logger/logger_test.go @@ -2,6 +2,7 @@ package logger import ( "bytes" + "context" "errors" "fmt" "strings" @@ -21,7 +22,7 @@ func TestLoggerFieldOrdering(t *testing.T) { "middle": "center", } - Info("test message", fields) + Info(context.Background(), "test message", fields) output := buf.String() @@ -51,7 +52,7 @@ func TestLoggerWithNilFields(t *testing.T) { SetLogger(logger) // Should not panic with nil fields - Info("test message", nil) + Info(context.Background(), "test message", nil) output := buf.String() if !strings.Contains(output, `msg="test message"`) { @@ -65,7 +66,7 @@ func TestLoggerWithEmptyFields(t *testing.T) { logger := New(&buf) SetLogger(logger) - Info("test message", Fields{}) + Info(context.Background(), "test message", Fields{}) output := buf.String() if !strings.Contains(output, `msg="test message"`) { @@ -84,7 +85,7 @@ func TestErrorLogger(t *testing.T) { SetLogger(logger) err := errors.New("test error") - Error("something failed", err, Fields{"code": "500"}) + Error(context.Background(), "something failed", err, Fields{"code": "500"}) output := buf.String() if !strings.Contains(output, "level=ERROR") { @@ -107,7 +108,7 @@ func TestWarnLogger(t *testing.T) { logger := New(&buf) SetLogger(logger) - Warn("potential issue", Fields{"threshold": "80%"}) + Warn(context.Background(), "potential issue", Fields{"threshold": "80%"}) output := buf.String() if !strings.Contains(output, "level=WARN") { @@ -133,7 +134,7 @@ func TestFieldsWithSpecialCharacters(t *testing.T) { "url": "https://example.com?foo=bar&baz=qux", } - Info("test", fields) + Info(context.Background(), "test", fields) output := buf.String() if !strings.Contains(output, "path=/etc/passwd") { @@ -159,7 +160,7 @@ func TestFieldsWithNilValues(t *testing.T) { "string_value": "test", } - Info("test", fields) + Info(context.Background(), "test", fields) output := buf.String() if !strings.Contains(output, "nil_value") { @@ -170,14 +171,15 @@ func TestFieldsWithNilValues(t *testing.T) { } } -// TestWithFieldsFormatting tests the WithFieldsf function with format strings +// TestWithFieldsFormatting tests formatting with Info function func TestWithFieldsFormatting(t *testing.T) { var buf bytes.Buffer logger := New(&buf) SetLogger(logger) fields := Fields{"user": "alice"} - WithFieldsf(fields, "User %s logged in at %d", "bob", 12345) + msg := fmt.Sprintf("User %s logged in at %d", "bob", 12345) + Info(context.Background(), msg, fields) output := buf.String() if !strings.Contains(output, "User bob logged in at 12345") { @@ -199,7 +201,7 @@ func TestLargeNumberOfFields(t *testing.T) { fields[fmt.Sprintf("field%03d", i)] = i } - Info("test with many fields", fields) + Info(context.Background(), "test with many fields", fields) output := buf.String() if !strings.Contains(output, `msg="test with many fields"`) { diff --git a/pkg/security/race_test.go b/pkg/security/race_test.go index 00e7b30..7ba7c0e 100644 --- a/pkg/security/race_test.go +++ b/pkg/security/race_test.go @@ -26,7 +26,7 @@ func TestConnectionLimiterReservation(t *testing.T) { var commitFailed int32 // How many times CommitReservation failed // Launch many goroutines simultaneously trying to reserve and commit connections - for i := 0; i < concurrent; i++ { + for i := range concurrent { wg.Add(1) go func(id int) { defer wg.Done() @@ -120,7 +120,7 @@ func TestConnectionLimiterTOCTOU_Documentation(t *testing.T) { // Launch many goroutines simultaneously trying to add connections // This simulates multiple HTTP handlers racing to add connections - for i := 0; i < concurrent; i++ { + for i := range concurrent { wg.Add(1) go func(id int) { defer wg.Done() @@ -182,14 +182,14 @@ func TestConnectionLimiterConcurrentAccess(t *testing.T) { var wg sync.WaitGroup // Test concurrent Add/Remove from multiple IPs - for i := 0; i < 10; i++ { + for i := range 10 { wg.Add(1) go func(id int) { defer wg.Done() ip := "192.168.1." + string(rune('1'+id)) // Rapid add/remove cycles - for j := 0; j < 100; j++ { + for range 100 { if limiter.Add(ip) { time.Sleep(time.Microsecond) limiter.Remove(ip) @@ -221,14 +221,14 @@ func TestRateLimiterConcurrentAccess(t *testing.T) { var wg sync.WaitGroup // Test concurrent Allow from multiple IPs - for i := 0; i < 10; i++ { + for i := range 10 { wg.Add(1) go func(id int) { defer wg.Done() ip := "192.168.1." + string(rune('1'+id)) // Rapid allow checks - for j := 0; j < 100; j++ { + for range 100 { _ = limiter.Allow(ip) } }(i) @@ -263,7 +263,7 @@ func TestConnectionLimiterReservationCancellation(t *testing.T) { } // Cancel half of them - for i := 0; i < 3; i++ { + for i := range 3 { limiter.CancelReservation(tokens[i]) } @@ -334,15 +334,13 @@ func TestConnectionLimiterTotalLimit(t *testing.T) { defer limiter.Stop() // Reserve from two different IPs - var tokens []string - for i := 0; i < 2; i++ { + for i := range 2 { ip := "192.168.1." + string(rune('1'+i)) - for j := 0; j < 5; j++ { + for range 5 { token := limiter.Reserve(ip) if token == "" { t.Fatalf("Failed to reserve slot for %s", ip) } - tokens = append(tokens, token) if !limiter.CommitReservation(token) { t.Fatalf("Failed to commit reservation for %s", ip) } diff --git a/pkg/srv/client.go b/pkg/srv/client.go index f072253..e31c6b8 100644 --- a/pkg/srv/client.go +++ b/pkg/srv/client.go @@ -22,26 +22,27 @@ import ( // - Read loop (in websocket.go) detects disconnects and closes the connection // // Cleanup coordination (CRITICAL FOR THREAD SAFETY): -// Multiple goroutines can trigger cleanup concurrently: -// 1. Handle() defer in websocket.go calls Hub.Unregister() (async via channel) -// 2. Handle() defer in websocket.go calls closeWebSocket() (closes WS connection) -// 3. Client.Run() defer calls client.Close() when context is cancelled -// 4. Hub.Run() processes unregister message and calls client.Close() -// 5. Hub.cleanup() during shutdown calls client.Close() for all clients // -// Thread safety is ensured by: -// - Close() uses sync.Once to ensure channels are closed exactly once -// - closed atomic flag allows checking if client is closing (safe from any goroutine) -// - Hub checks closed flag before sending to avoid race with channel close -// - closeWebSocket() does NOT send to client channels (would race with Close) +// Multiple goroutines can trigger cleanup concurrently: +// 1. Handle() defer in websocket.go calls Hub.Unregister() (async via channel) +// 2. Handle() defer in websocket.go calls closeWebSocket() (closes WS connection) +// 3. Client.Run() defer calls client.Close() when context is cancelled +// 4. Hub.Run() processes unregister message and calls client.Close() +// 5. Hub.cleanup() during shutdown calls client.Close() for all clients // -// Cleanup flow when a client disconnects: -// 1. Handle() read loop exits (EOF, timeout, or error) -// 2. defer cancel() signals Client.Run() via context -// 3. defer Hub.Unregister(clientID) sends message to hub (returns immediately) -// 4. defer closeWebSocket() closes the WebSocket connection only -// 5. Client.Run() sees context cancellation, exits, calls defer client.Close() -// 6. Hub.Run() processes unregister, calls client.Close() (idempotent via sync.Once) +// Thread safety is ensured by: +// - Close() uses sync.Once to ensure channels are closed exactly once +// - closed atomic flag allows checking if client is closing (safe from any goroutine) +// - Hub checks closed flag before sending to avoid race with channel close +// - closeWebSocket() does NOT send to client channels (would race with Close) +// +// Cleanup flow when a client disconnects: +// 1. Handle() read loop exits (EOF, timeout, or error) +// 2. defer cancel() signals Client.Run() via context +// 3. defer Hub.Unregister(clientID) sends message to hub (returns immediately) +// 4. defer closeWebSocket() closes the WebSocket connection only +// 5. Client.Run() sees context cancellation, exits, calls defer client.Close() +// 6. Hub.Run() processes unregister, calls client.Close() (idempotent via sync.Once) type Client struct { conn *websocket.Conn send chan Event @@ -56,13 +57,13 @@ type Client struct { } // NewClient creates a new client. -func NewClient(id string, sub Subscription, conn *websocket.Conn, hub *Hub, userOrgs []string) *Client { +func NewClient(ctx context.Context, id string, sub Subscription, conn *websocket.Conn, hub *Hub, userOrgs []string) *Client { // Limit the number of orgs to prevent memory exhaustion const maxOrgs = 1000 orgsToProcess := userOrgs if len(userOrgs) > maxOrgs { orgsToProcess = userOrgs[:maxOrgs] - logger.Warn("user has too many organizations, limiting", logger.Fields{ + logger.Warn(ctx, "user has too many organizations, limiting", logger.Fields{ "user_org_count": len(userOrgs), "max_orgs": maxOrgs, }) @@ -109,11 +110,11 @@ func (c *Client) Run(ctx context.Context, pingInterval, writeTimeout time.Durati for { select { case <-ctx.Done(): - logger.Debug("client context cancelled, shutting down", logger.Fields{"client_id": c.ID}) + logger.Debug(ctx, "client context cancelled, shutting down", logger.Fields{"client_id": c.ID}) return case <-c.done: - logger.Debug("client done signal received", logger.Fields{"client_id": c.ID}) + logger.Debug(ctx, "client done signal received", logger.Fields{"client_id": c.ID}) return case <-pingTicker.C: @@ -125,7 +126,7 @@ func (c *Client) Run(ctx context.Context, pingInterval, writeTimeout time.Durati } if err := c.write(ping, writeTimeout); err != nil { - logger.Warn("client ping failed", logger.Fields{ + logger.Warn(ctx, "client ping failed", logger.Fields{ "client_id": c.ID, "error": err.Error(), }) @@ -134,13 +135,13 @@ func (c *Client) Run(ctx context.Context, pingInterval, writeTimeout time.Durati case ctrl, ok := <-c.control: if !ok { - logger.Debug("client control channel closed", logger.Fields{"client_id": c.ID}) + logger.Debug(ctx, "client control channel closed", logger.Fields{"client_id": c.ID}) return } // Send control message (pong, shutdown notice, etc.) if err := c.write(ctrl, writeTimeout); err != nil { - logger.Warn("client control message send failed", logger.Fields{ + logger.Warn(ctx, "client control message send failed", logger.Fields{ "client_id": c.ID, "error": err.Error(), }) @@ -149,13 +150,13 @@ func (c *Client) Run(ctx context.Context, pingInterval, writeTimeout time.Durati case event, ok := <-c.send: if !ok { - logger.Debug("client send channel closed", logger.Fields{"client_id": c.ID}) + logger.Debug(ctx, "client send channel closed", logger.Fields{"client_id": c.ID}) return } // Write event (hub already logged delivery, so we only log failures here) if err := c.write(event, writeTimeout); err != nil { - logger.Warn("client event send failed", logger.Fields{ + logger.Warn(ctx, "client event send failed", logger.Fields{ "client_id": c.ID, "event_type": event.Type, "error": err.Error(), diff --git a/pkg/srv/edge_cases_test.go b/pkg/srv/edge_cases_test.go index 343520b..79e23da 100644 --- a/pkg/srv/edge_cases_test.go +++ b/pkg/srv/edge_cases_test.go @@ -1,6 +1,7 @@ package srv import ( + "context" "fmt" "strings" "testing" @@ -284,7 +285,7 @@ func TestOrganizationLimitEdgeCases(t *testing.T) { manyOrgs[i] = fmt.Sprintf("org%d", i) } - client := NewClient( + client := NewClient(context.Background(), "test-id", Subscription{Organization: "*", Username: "testuser"}, nil, @@ -319,7 +320,7 @@ func TestConcurrentMapAccess(t *testing.T) { hub := NewHub() for i := range 10 { - client := NewClient( + client := NewClient(context.Background(), fmt.Sprintf("client%d", i), Subscription{ Organization: "*", diff --git a/pkg/srv/hub.go b/pkg/srv/hub.go index b68a139..d312d43 100644 --- a/pkg/srv/hub.go +++ b/pkg/srv/hub.go @@ -14,11 +14,11 @@ import ( // Event represents a GitHub webhook event that will be broadcast to clients. // It contains the PR URL, timestamp, event type, and delivery ID from GitHub. type Event struct { - URL string `json:"url"` // Pull request URL (or repo URL for check events with race condition) - Timestamp time.Time `json:"timestamp"` // When the event occurred - Type string `json:"type"` // GitHub event type (e.g., "pull_request") - DeliveryID string `json:"delivery_id,omitempty"` // GitHub webhook delivery ID (unique per webhook) - CommitSHA string `json:"commit_sha,omitempty"` // Commit SHA for check events (used to look up PR when URL is repo-only) + URL string `json:"url"` // Pull request URL (or repo URL for check events with race condition) + Timestamp time.Time `json:"timestamp"` // When the event occurred + Type string `json:"type"` // GitHub event type (e.g., "pull_request") + DeliveryID string `json:"delivery_id,omitempty"` // GitHub webhook delivery ID (unique per webhook) + CommitSHA string `json:"commit_sha,omitempty"` // Commit SHA for check events (used to look up PR when URL is repo-only) } // Hub manages WebSocket clients and event broadcasting. @@ -81,11 +81,11 @@ func NewHub() *Hub { // The context should be passed from main for proper lifecycle management. func (h *Hub) Run(ctx context.Context) { defer close(h.stopped) - defer h.cleanup() + defer h.cleanup(ctx) - logger.Info("========================================", nil) - logger.Info("HUB STARTED - Fresh hub with 0 clients", nil) - logger.Info("========================================", nil) + logger.Info(ctx, "========================================", nil) + logger.Info(ctx, "HUB STARTED - Fresh hub with 0 clients", nil) + logger.Info(ctx, "========================================", nil) // Periodic client count logging (every minute) ticker := time.NewTicker(1 * time.Minute) @@ -94,10 +94,10 @@ func (h *Hub) Run(ctx context.Context) { for { select { case <-ctx.Done(): - logger.Info("hub shutting down", nil) + logger.Info(ctx, "hub shutting down", nil) return case <-h.stop: - logger.Info("hub stop requested", nil) + logger.Info(ctx, "hub stop requested", nil) return case <-ticker.C: @@ -108,7 +108,7 @@ func (h *Hub) Run(ctx context.Context) { clientDetails = append(clientDetails, fmt.Sprintf("%s@%s", client.subscription.Username, client.subscription.Organization)) } h.mu.RUnlock() - logger.Info("⏱️ PERIODIC CHECK", logger.Fields{ + logger.Info(ctx, "⏱️ PERIODIC CHECK", logger.Fields{ "total_clients": count, "clients": clientDetails, }) @@ -118,7 +118,7 @@ func (h *Hub) Run(ctx context.Context) { h.clients[client.ID] = client totalClients := len(h.clients) h.mu.Unlock() - logger.Info("CLIENT REGISTERED", logger.Fields{ + logger.Info(ctx, "CLIENT REGISTERED", logger.Fields{ "client_id": client.ID, "org": client.subscription.Organization, "user": client.subscription.Username, @@ -132,7 +132,7 @@ func (h *Hub) Run(ctx context.Context) { totalClients := len(h.clients) client.Close() h.mu.Unlock() - logger.Info("CLIENT UNREGISTERED", logger.Fields{ + logger.Info(ctx, "CLIENT UNREGISTERED", logger.Fields{ "client_id": clientID, "org": client.subscription.Organization, "user": client.subscription.Username, @@ -140,7 +140,7 @@ func (h *Hub) Run(ctx context.Context) { }) } else { h.mu.Unlock() - logger.Warn("attempted to unregister unknown client", logger.Fields{ + logger.Warn(ctx, "attempted to unregister unknown client", logger.Fields{ "client_id": clientID, }) } @@ -163,7 +163,7 @@ func (h *Hub) Run(ctx context.Context) { // Try to send (safe against closed channels) if h.trySendEvent(client, msg.event) { matched++ - logger.Info("delivered event to client", logger.Fields{ + logger.Info(ctx, "delivered event to client", logger.Fields{ "client_id": client.ID, "user": client.subscription.Username, "org": client.subscription.Organization, @@ -173,22 +173,22 @@ func (h *Hub) Run(ctx context.Context) { }) } else { dropped++ - logger.Warn("dropped event for client: channel full or closed", logger.Fields{ + logger.Warn(ctx, "dropped event for client: channel full or closed", logger.Fields{ "client_id": client.ID, }) } } } if totalClients == 0 { - logger.Warn("⚠️⚠️⚠️ broadcast with ZERO clients connected ⚠️⚠️⚠️", nil) - logger.Warn("⚠️ Event will be LOST", logger.Fields{ + logger.Warn(ctx, "⚠️⚠️⚠️ broadcast with ZERO clients connected ⚠️⚠️⚠️", nil) + logger.Warn(ctx, "⚠️ Event will be LOST", logger.Fields{ "event_type": msg.event.Type, "delivery_id": msg.event.DeliveryID, "pr_url": msg.event.URL, }) - logger.Warn("⚠️ Possible reasons: fresh deployment, all clients disconnected, or network issue", nil) + logger.Warn(ctx, "⚠️ Possible reasons: fresh deployment, all clients disconnected, or network issue", nil) } - logger.Info("broadcast event", logger.Fields{ + logger.Info(ctx, "broadcast event", logger.Fields{ "event_type": msg.event.Type, "delivery_id": msg.event.DeliveryID, "matched": matched, @@ -200,12 +200,12 @@ func (h *Hub) Run(ctx context.Context) { } // Broadcast sends an event to all matching clients. -func (h *Hub) Broadcast(event Event, payload map[string]any) { +func (h *Hub) Broadcast(ctx context.Context, event Event, payload map[string]any) { select { case h.broadcast <- broadcastMsg{event: event, payload: payload}: default: // Hub is at capacity or shutting down, drop the message - logger.Warn("dropping broadcast: hub at capacity or shutting down", nil) + logger.Warn(ctx, "dropping broadcast: hub at capacity or shutting down", nil) } } @@ -256,7 +256,7 @@ func (h *Hub) ClientCount() int { // // Note: There's still a tiny window between IsClosed() check and send where // Close() could be called, so we keep recover() as a safety net. -func (h *Hub) trySendEvent(client *Client, event Event) (sent bool) { +func (*Hub) trySendEvent(client *Client, event Event) (sent bool) { // Check if client is closed before attempting send // This prevents most races with client.Close() if client.IsClosed() { @@ -284,20 +284,20 @@ func (h *Hub) trySendEvent(client *Client, event Event) (sent bool) { // // CRITICAL THREADING NOTE: // This function MUST NOT send to client channels (send/control) because of race conditions: -// - Client.Close() can be called concurrently from multiple places (Handle defer, Run defer, etc.) -// - Once Close() starts, it closes all channels atomically -// - Trying to send to a closed channel panics, even with select/default -// - select/default only protects against FULL channels, not CLOSED channels +// - Client.Close() can be called concurrently from multiple places (Handle defer, Run defer, etc.) +// - Once Close() starts, it closes all channels atomically +// - Trying to send to a closed channel panics, even with select/default +// - select/default only protects against FULL channels, not CLOSED channels // // Instead, we rely on: -// - WebSocket connection close will signal the client -// - Client.Run() will detect context cancellation and exit gracefully -// - client.Close() is idempotent (sync.Once) so safe to call multiple times -func (h *Hub) cleanup() { +// - WebSocket connection close will signal the client +// - Client.Run() will detect context cancellation and exit gracefully +// - client.Close() is idempotent (sync.Once) so safe to call multiple times +func (h *Hub) cleanup(ctx context.Context) { h.mu.Lock() defer h.mu.Unlock() - logger.Info("Hub cleanup: closing client connections", logger.Fields{ + logger.Info(ctx, "Hub cleanup: closing client connections", logger.Fields{ "client_count": len(h.clients), }) @@ -305,9 +305,9 @@ func (h *Hub) cleanup() { // The WebSocket connection close and context cancellation are sufficient signals. for id, client := range h.clients { client.Close() - logger.Info("closed client during hub cleanup", logger.Fields{"client_id": id}) + logger.Info(ctx, "closed client during hub cleanup", logger.Fields{"client_id": id}) } h.clients = nil - logger.Info("Hub cleanup complete", nil) + logger.Info(ctx, "Hub cleanup complete", nil) } diff --git a/pkg/srv/hub_test.go b/pkg/srv/hub_test.go index 34ee1e8..ebdc16c 100644 --- a/pkg/srv/hub_test.go +++ b/pkg/srv/hub_test.go @@ -15,6 +15,7 @@ func TestHub(t *testing.T) { // Test registering clients - properly initialize using NewClient client1 := NewClient( + ctx, "client1", Subscription{Organization: "myorg", UserEventsOnly: true, Username: "alice"}, nil, // No websocket connection for unit test @@ -23,6 +24,7 @@ func TestHub(t *testing.T) { ) client2 := NewClient( + ctx, "client2", Subscription{Organization: "myorg"}, nil, // No websocket connection for unit test @@ -65,7 +67,7 @@ func TestHub(t *testing.T) { }, } - hub.Broadcast(event, payload) + hub.Broadcast(ctx, event, payload) // Both clients should receive the event select { diff --git a/pkg/srv/race_stress_test.go b/pkg/srv/race_stress_test.go index b2f49b3..2b8a798 100644 --- a/pkg/srv/race_stress_test.go +++ b/pkg/srv/race_stress_test.go @@ -32,7 +32,7 @@ func TestConcurrentClientDisconnect(t *testing.T) { const numClients = 10 var wg sync.WaitGroup - for i := 0; i < numClients; i++ { + for i := range numClients { wg.Add(1) go func(clientNum int) { defer wg.Done() @@ -45,7 +45,7 @@ func TestConcurrentClientDisconnect(t *testing.T) { } // Create client (we'll use nil for websocket since we're not actually writing) - client := NewClient( + client := NewClient(ctx, testClientID(clientNum), sub, nil, // WebSocket not needed for this test @@ -115,7 +115,7 @@ func TestClientCloseIdempotency(t *testing.T) { EventTypes: []string{"pull_request"}, } - client := NewClient( + client := NewClient(ctx, "test-client-close-idempotent", sub, nil, @@ -127,7 +127,7 @@ func TestClientCloseIdempotency(t *testing.T) { const numGoroutines = 20 var wg sync.WaitGroup - for i := 0; i < numGoroutines; i++ { + for range numGoroutines { wg.Add(1) go func() { defer wg.Done() @@ -181,13 +181,13 @@ func TestConcurrentBroadcastAndDisconnect(t *testing.T) { // Create clients clients := make([]*Client, numClients) - for i := 0; i < numClients; i++ { + for i := range numClients { sub := Subscription{ Organization: "testorg", Username: "testuser", EventTypes: []string{"pull_request"}, } - client := NewClient( + client := NewClient(ctx, testClientID(i), sub, nil, @@ -207,7 +207,7 @@ func TestConcurrentBroadcastAndDisconnect(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - for i := 0; i < numEvents; i++ { + for i := range numEvents { event := Event{ URL: "https://github.com/test/repo/pull/123", Type: "pull_request", @@ -220,14 +220,14 @@ func TestConcurrentBroadcastAndDisconnect(t *testing.T) { }, }, } - hub.Broadcast(event, payload) + hub.Broadcast(ctx, event, payload) time.Sleep(1 * time.Millisecond) } }() // Concurrently disconnect clients (realistic: only via unregister, not direct Close) // In production, Handle() calls hub.Unregister() and the hub handles client.Close() - for i := 0; i < numClients; i++ { + for i := range numClients { wg.Add(1) go func(idx int) { defer wg.Done() @@ -267,7 +267,7 @@ func TestRapidConnectDisconnect(t *testing.T) { const numCycles = 20 var wg sync.WaitGroup - for i := 0; i < numCycles; i++ { + for i := range numCycles { wg.Add(1) go func(cycle int) { defer wg.Done() @@ -278,7 +278,7 @@ func TestRapidConnectDisconnect(t *testing.T) { EventTypes: []string{"pull_request"}, } - client := NewClient( + client := NewClient(ctx, testClientID(cycle), sub, nil, @@ -326,12 +326,12 @@ func TestHubShutdownWithActiveClients(t *testing.T) { // Create several clients const numClients = 10 - for i := 0; i < numClients; i++ { + for i := range numClients { sub := Subscription{ Organization: "testorg", Username: "testuser", } - client := NewClient( + client := NewClient(ctx, testClientID(i), sub, nil, diff --git a/pkg/srv/websocket.go b/pkg/srv/websocket.go index f209f85..0dfd68c 100644 --- a/pkg/srv/websocket.go +++ b/pkg/srv/websocket.go @@ -104,14 +104,14 @@ func (h *WebSocketHandler) PreValidateAuth(r *http.Request) bool { } // extractGitHubToken extracts and validates the GitHub token from the request. -func (h *WebSocketHandler) extractGitHubToken(ws *websocket.Conn, ip string) (string, bool) { +func (h *WebSocketHandler) extractGitHubToken(ctx context.Context, ws *websocket.Conn, ip string) (string, bool) { if h.testMode { return "", true } authHeader := ws.Request().Header.Get("Authorization") if authHeader == "" { - logger.Warn("WebSocket authentication failed: missing Authorization header", logger.Fields{ + logger.Warn(ctx, "WebSocket authentication failed: missing Authorization header", logger.Fields{ "ip": ip, "user_agent": ws.Request().UserAgent(), "path": ws.Request().URL.Path, @@ -121,7 +121,7 @@ func (h *WebSocketHandler) extractGitHubToken(ws *websocket.Conn, ip string) (st const bearerPrefix = "Bearer " if !strings.HasPrefix(authHeader, bearerPrefix) { - logger.Warn("WebSocket authentication failed: invalid Authorization header format", logger.Fields{ + logger.Warn(ctx, "WebSocket authentication failed: invalid Authorization header format", logger.Fields{ "ip": ip, "user_agent": ws.Request().UserAgent(), "path": ws.Request().URL.Path, @@ -137,7 +137,7 @@ func (h *WebSocketHandler) extractGitHubToken(ws *websocket.Conn, ip string) (st if len(githubToken) >= tokenPrefixLength { tokenPrefix = githubToken[:tokenPrefixLength] } - logger.Warn("WebSocket authentication failed: invalid GitHub token format", logger.Fields{ + logger.Warn(ctx, "WebSocket authentication failed: invalid GitHub token format", logger.Fields{ "ip": ip, "user_agent": ws.Request().UserAgent(), "path": ws.Request().URL.Path, @@ -205,7 +205,7 @@ func determineErrorInfo(err error, username string, orgName string, userOrgs []s } // sendErrorResponse sends an error response to the WebSocket client. -func sendErrorResponse(ws *websocket.Conn, errInfo errorInfo, ip string) error { +func sendErrorResponse(ctx context.Context, ws *websocket.Conn, errInfo errorInfo, ip string) error { errorResp := map[string]string{ "type": "error", "error": errInfo.code, @@ -213,12 +213,12 @@ func sendErrorResponse(ws *websocket.Conn, errInfo errorInfo, ip string) error { } if err := ws.SetWriteDeadline(time.Now().Add(2 * time.Second)); err != nil { - logger.Error("failed to set write deadline", err, logger.Fields{"ip": ip}) + logger.Error(ctx, "failed to set write deadline", err, logger.Fields{"ip": ip}) return err } if err := websocket.JSON.Send(ws, errorResp); err != nil { - logger.Error("failed to send error response to client", err, logger.Fields{"ip": ip}) + logger.Error(ctx, "failed to send error response to client", err, logger.Fields{"ip": ip}) return err } @@ -264,35 +264,41 @@ func (h *WebSocketHandler) readSubscription(ws *websocket.Conn, ip string) (Subs return sub, nil } +// authErrorParams groups authentication error parameters. +// +//nolint:govet // Field order optimized for readability over minimal memory padding +type authErrorParams struct { + userOrgs []string + githubToken string + ip string + username string + orgName string + logContext string +} + // handleAuthError handles authentication errors with consistent logging and response. -func (*WebSocketHandler) handleAuthError( - ws *websocket.Conn, - err error, - githubToken, ip, username, orgName string, - userOrgs []string, - logContext string, -) error { - errInfo := determineErrorInfo(err, username, orgName, userOrgs) +func (*WebSocketHandler) handleAuthError(ctx context.Context, ws *websocket.Conn, err error, params authErrorParams) error { + errInfo := determineErrorInfo(err, params.username, params.orgName, params.userOrgs) tokenPrefix := "" - if len(githubToken) >= tokenPrefixLength { - tokenPrefix = githubToken[:tokenPrefixLength] + if len(params.githubToken) >= tokenPrefixLength { + tokenPrefix = params.githubToken[:tokenPrefixLength] } - logger.Error(logContext, err, logger.Fields{ - "ip": ip, - "org": orgName, - "username": username, + logger.Error(ctx, params.logContext, err, logger.Fields{ + "ip": params.ip, + "org": params.orgName, + "username": params.username, "token_prefix": tokenPrefix, - "token_length": len(githubToken), + "token_length": len(params.githubToken), "reason": errInfo.reason, }) - if sendErr := sendErrorResponse(ws, errInfo, ip); sendErr != nil { + if sendErr := sendErrorResponse(ctx, ws, errInfo, params.ip); sendErr != nil { return sendErr } - logger.Info("sent error to client", logger.Fields{ - "ip": ip, "error_code": errInfo.code, "error_reason": errInfo.reason, + logger.Info(ctx, "sent error to client", logger.Fields{ + "ip": params.ip, "error_code": errInfo.code, "error_reason": errInfo.reason, }) return fmt.Errorf("%s: %w", errInfo.reason, err) @@ -303,15 +309,18 @@ func (h *WebSocketHandler) validateWildcardOrg( ctx context.Context, ws *websocket.Conn, sub *Subscription, ghClient *github.Client, githubToken, ip string, ) ([]string, error) { - logger.Info("validating GitHub authentication for wildcard org subscription", logger.Fields{"ip": ip}) + logger.Info(ctx, "validating GitHub authentication for wildcard org subscription", logger.Fields{"ip": ip}) username, userOrgs, err := ghClient.UserAndOrgs(ctx) if err != nil { - return nil, h.handleAuthError(ws, err, githubToken, ip, "", "", nil, - "GitHub auth failed for wildcard org subscription") + return nil, h.handleAuthError(ctx, ws, err, authErrorParams{ + githubToken: githubToken, + ip: ip, + logContext: "GitHub auth failed for wildcard org subscription", + }) } - logger.Info("GitHub authentication successful for wildcard org subscription", logger.Fields{ + logger.Info(ctx, "GitHub authentication successful for wildcard org subscription", logger.Fields{ "ip": ip, "username": username, "org_count": len(userOrgs), }) @@ -324,17 +333,23 @@ func (h *WebSocketHandler) validateSpecificOrg( ctx context.Context, ws *websocket.Conn, sub *Subscription, ghClient *github.Client, githubToken, ip string, ) ([]string, error) { - logger.Info("validating GitHub authentication and org membership", logger.Fields{ + logger.Info(ctx, "validating GitHub authentication and org membership", logger.Fields{ "ip": ip, "org": sub.Organization, }) username, userOrgs, err := ghClient.ValidateOrgMembership(ctx, sub.Organization) if err != nil { - return nil, h.handleAuthError(ws, err, githubToken, ip, username, sub.Organization, userOrgs, - "GitHub auth/org membership validation failed") + return nil, h.handleAuthError(ctx, ws, err, authErrorParams{ + githubToken: githubToken, + ip: ip, + username: username, + orgName: sub.Organization, + userOrgs: userOrgs, + logContext: "GitHub auth/org membership validation failed", + }) } - logger.Info("GitHub authentication and org membership validated successfully", logger.Fields{ + logger.Info(ctx, "GitHub authentication and org membership validated successfully", logger.Fields{ "ip": ip, "org": sub.Organization, "username": username, "org_count": len(userOrgs), }) @@ -347,17 +362,20 @@ func (h *WebSocketHandler) validateNoOrg( ctx context.Context, ws *websocket.Conn, sub *Subscription, ghClient *github.Client, githubToken, ip string, ) ([]string, error) { - logger.Info("validating GitHub authentication (no org specified in subscription)", logger.Fields{ + logger.Info(ctx, "validating GitHub authentication (no org specified in subscription)", logger.Fields{ "ip": ip, "subscription_org": sub.Organization, }) username, userOrgs, err := ghClient.UserAndOrgs(ctx) if err != nil { - return nil, h.handleAuthError(ws, err, githubToken, ip, "", "", nil, - "GitHub auth failed (no specific org)") + return nil, h.handleAuthError(ctx, ws, err, authErrorParams{ + githubToken: githubToken, + ip: ip, + logContext: "GitHub auth failed (no specific org)", + }) } - logger.Info("GitHub authentication successful", logger.Fields{ + logger.Info(ctx, "GitHub authentication successful", logger.Fields{ "ip": ip, "username": username, "org_count": len(userOrgs), }) @@ -366,7 +384,7 @@ func (h *WebSocketHandler) validateNoOrg( // For GitHub Apps with no org specified, auto-set to their installation org if strings.HasPrefix(username, "app[") && sub.Organization == "" && len(userOrgs) == 1 { sub.Organization = userOrgs[0] - logger.Info("auto-setting GitHub App subscription to installation org", logger.Fields{ + logger.Info(ctx, "auto-setting GitHub App subscription to installation org", logger.Fields{ "ip": ip, "org": sub.Organization, "app": username, }) } @@ -440,9 +458,9 @@ func (wc *wsCloser) IsClosed() bool { // 4. Race: check done (open) → another goroutine closes all channels → send to control → PANIC // // Instead, we rely on: -// - WebSocket connection close will be detected by the client -// - Context cancellation signals Client.Run() to exit gracefully -// - Hub.Unregister() handles client cleanup asynchronously +// - WebSocket connection close will be detected by the client +// - Context cancellation signals Client.Run() to exit gracefully +// - Hub.Unregister() handles client cleanup asynchronously func closeWebSocket(wc *wsCloser, client *Client, ip string) { log.Printf("WebSocket Handle() cleanup - closing connection for IP %s", ip) @@ -491,7 +509,7 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) { }() // Log incoming WebSocket request - logger.Info("WebSocket connection attempt", logger.Fields{ + logger.Info(ctx, "WebSocket connection attempt", logger.Fields{ "ip": ip, "user_agent": ws.Request().UserAgent(), "path": ws.Request().URL.Path, @@ -514,7 +532,7 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) { } }() - githubToken, ok := h.extractGitHubToken(ws, ip) + githubToken, ok := h.extractGitHubToken(ctx, ws, ip) if !ok { // Send 403 error response to client errorResp := map[string]string{ @@ -526,11 +544,11 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) { // Try to send error response if err := ws.SetWriteDeadline(time.Now().Add(2 * time.Second)); err == nil { if sendErr := websocket.JSON.Send(ws, errorResp); sendErr != nil { - logger.Error("failed to send 403 error response", sendErr, logger.Fields{"ip": ip}) + logger.Error(ctx, "failed to send 403 error response", sendErr, logger.Fields{"ip": ip}) } } - logger.Warn("WebSocket connection rejected: 403 Forbidden - authentication failed", logger.Fields{ + logger.Warn(ctx, "WebSocket connection rejected: 403 Forbidden - authentication failed", logger.Fields{ "ip": ip, "user_agent": ws.Request().UserAgent(), "reason": "invalid_token", @@ -551,11 +569,11 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) { if err := ws.SetWriteDeadline(time.Now().Add(2 * time.Second)); err == nil { if sendErr := websocket.JSON.Send(ws, errorResp); sendErr != nil { - logger.Error("failed to send reservation expired error", sendErr, logger.Fields{"ip": ip}) + logger.Error(ctx, "failed to send reservation expired error", sendErr, logger.Fields{"ip": ip}) } } - logger.Warn("WebSocket connection rejected: reservation expired", logger.Fields{ + logger.Warn(ctx, "WebSocket connection rejected: reservation expired", logger.Fields{ "ip": ip, "user_agent": ws.Request().UserAgent(), }) @@ -575,7 +593,7 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) { // Read subscription sub, err := h.readSubscription(ws, ip) if err != nil { - logger.Warn("WebSocket connection rejected: failed to read subscription", logger.Fields{ + logger.Warn(ctx, "WebSocket connection rejected: failed to read subscription", logger.Fields{ "ip": ip, "user_agent": ws.Request().UserAgent(), "error": err.Error(), @@ -603,11 +621,11 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) { // Try to send error response if err := ws.SetWriteDeadline(time.Now().Add(2 * time.Second)); err == nil { if sendErr := websocket.JSON.Send(ws, errorResp); sendErr != nil { - logger.Error("failed to send subscription error response", sendErr, logger.Fields{"ip": ip}) + logger.Error(ctx, "failed to send subscription error response", sendErr, logger.Fields{"ip": ip}) } } - logger.Warn("WebSocket connection rejected: invalid subscription", logger.Fields{ + logger.Warn(ctx, "WebSocket connection rejected: invalid subscription", logger.Fields{ "ip": ip, "user_agent": ws.Request().UserAgent(), "error": err.Error(), @@ -630,11 +648,11 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) { // Try to send error response if err := ws.SetWriteDeadline(time.Now().Add(2 * time.Second)); err == nil { if sendErr := websocket.JSON.Send(ws, errorResp); sendErr != nil { - logger.Error("failed to send event type error response", sendErr, logger.Fields{"ip": ip}) + logger.Error(ctx, "failed to send event type error response", sendErr, logger.Fields{"ip": ip}) } } - logger.Warn("WebSocket connection rejected: event type not allowed", logger.Fields{ + logger.Warn(ctx, "WebSocket connection rejected: event type not allowed", logger.Fields{ "ip": ip, "user_agent": ws.Request().UserAgent(), "event_type": requestedType, @@ -657,7 +675,7 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) { userOrgs, err := h.validateAuth(ctx, ws, &sub, githubToken, ip) if err != nil { // Error response already sent by validateAuth - logger.Warn("WebSocket connection rejected: authentication/authorization failed", logger.Fields{ + logger.Warn(ctx, "WebSocket connection rejected: authentication/authorization failed", logger.Fields{ "ip": ip, "user_agent": ws.Request().UserAgent(), "org": sub.Organization, @@ -674,7 +692,7 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) { n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset)))) if err != nil { // Critical security failure - cannot continue without secure randomness - logger.Error("CRITICAL: failed to generate secure random client ID", err, logger.Fields{"ip": ip}) + logger.Error(ctx, "CRITICAL: failed to generate secure random client ID", err, logger.Fields{"ip": ip}) // Send error to client before returning errorResp := map[string]string{ "type": "error", @@ -682,13 +700,13 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) { "message": "Failed to initialize secure session", } if sendErr := websocket.JSON.Send(ws, errorResp); sendErr != nil { - logger.Error("failed to send error response", sendErr, logger.Fields{"ip": ip}) + logger.Error(ctx, "failed to send error response", sendErr, logger.Fields{"ip": ip}) } return } id[i] = charset[n.Int64()] } - client = NewClient( + client = NewClient(ctx, string(id), sub, ws, @@ -703,7 +721,7 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) { log.Printf("✅ NEW CLIENT CONNECTING: user=%s org=%s ip=%s client_id=%s (will be client #%d)", sub.Username, sub.Organization, ip, client.ID, currentClients+1) log.Println("========================================") - logger.Info("WebSocket connection established", logger.Fields{ + logger.Info(ctx, "WebSocket connection established", logger.Fields{ "ip": ip, "org": sub.Organization, "user": sub.Username, @@ -723,22 +741,22 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) { // Set a write deadline for the success response if err := ws.SetWriteDeadline(time.Now().Add(2 * time.Second)); err != nil { - logger.Error("failed to set write deadline for success response", err, logger.Fields{"ip": ip}) + logger.Error(ctx, "failed to set write deadline for success response", err, logger.Fields{"ip": ip}) return } if err := websocket.JSON.Send(ws, successResp); err != nil { - logger.Error("failed to send success response to client", err, logger.Fields{"ip": ip}) + logger.Error(ctx, "failed to send success response to client", err, logger.Fields{"ip": ip}) return } // Reset write deadline after successful send if err := ws.SetWriteDeadline(time.Time{}); err != nil { - logger.Error("failed to reset write deadline", err, logger.Fields{"ip": ip}) + logger.Error(ctx, "failed to reset write deadline", err, logger.Fields{"ip": ip}) return } - logger.Info("sent subscription confirmation to client", logger.Fields{ + logger.Info(ctx, "sent subscription confirmation to client", logger.Fields{ "ip": ip, "org": sub.Organization, "client_id": client.ID, @@ -753,7 +771,7 @@ func (h *WebSocketHandler) Handle(ws *websocket.Conn) { log.Printf("❌ CLIENT DISCONNECTING: user=%s org=%s ip=%s client_id=%s", sub.Username, sub.Organization, ip, client.ID) log.Println("========================================") - logger.Info("WebSocket disconnected", logger.Fields{ + logger.Info(ctx, "WebSocket disconnected", logger.Fields{ "ip": ip, "client_id": client.ID, "user": sub.Username, diff --git a/pkg/webhook/extractor_test.go b/pkg/webhook/extractor_test.go index 554ef4e..140ad05 100644 --- a/pkg/webhook/extractor_test.go +++ b/pkg/webhook/extractor_test.go @@ -1,6 +1,9 @@ package webhook -import "testing" +import ( + "context" + "testing" +) func TestExtractPRURL(t *testing.T) { tests := []struct { @@ -78,7 +81,7 @@ func TestExtractPRURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := ExtractPRURL(tt.eventType, tt.payload); got != tt.want { + if got := ExtractPRURL(context.Background(), tt.eventType, tt.payload); got != tt.want { t.Errorf("ExtractPRURL() = %v, want %v", got, tt.want) } }) diff --git a/pkg/webhook/handler.go b/pkg/webhook/handler.go index 15415f1..56bda86 100644 --- a/pkg/webhook/handler.go +++ b/pkg/webhook/handler.go @@ -3,6 +3,7 @@ package webhook import ( + "context" "crypto/hmac" "crypto/sha256" "encoding/hex" @@ -49,9 +50,13 @@ func NewHandler(h *srv.Hub, secret string, allowedEvents []string) *Handler { } // ServeHTTP processes GitHub webhook events. +// +//nolint:maintidx // Webhook processing requires comprehensive validation and error handling func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + // Log incoming webhook request details - logger.Info("webhook request received", logger.Fields{ + logger.Info(ctx, "webhook request received", logger.Fields{ "method": r.Method, "url": r.URL.String(), "remote_addr": r.RemoteAddr, @@ -62,7 +67,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { }) if r.Method != http.MethodPost { - logger.Warn("webhook rejected: invalid method", logger.Fields{ + logger.Warn(ctx, "webhook rejected: invalid method", logger.Fields{ "method": r.Method, "remote_addr": r.RemoteAddr, "path": r.URL.Path, @@ -77,7 +82,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Check if event type is allowed if h.allowedEventsMap != nil && !h.allowedEventsMap[eventType] { - logger.Warn("webhook event type not allowed", logger.Fields{ + logger.Warn(ctx, "webhook event type not allowed", logger.Fields{ "event_type": eventType, "delivery_id": deliveryID, }) @@ -87,7 +92,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Check content length before reading if r.ContentLength > maxPayloadSize { - logger.Warn("webhook rejected: payload too large", logger.Fields{ + logger.Warn(ctx, "webhook rejected: payload too large", logger.Fields{ "content_length": r.ContentLength, "max_size": maxPayloadSize, "delivery_id": deliveryID, @@ -100,7 +105,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Read body body, err := io.ReadAll(io.LimitReader(r.Body, maxPayloadSize)) if err != nil { - logger.Error("error reading webhook body", err, logger.Fields{"delivery_id": deliveryID}) + logger.Error(ctx, "error reading webhook body", err, logger.Fields{"delivery_id": deliveryID}) http.Error(w, "bad request", http.StatusBadRequest) return } @@ -112,7 +117,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Verify signature if !VerifySignature(body, signature, h.secret) { - logger.Warn("webhook rejected: 401 Unauthorized - signature verification failed", logger.Fields{ + logger.Warn(ctx, "webhook rejected: 401 Unauthorized - signature verification failed", logger.Fields{ "delivery_id": deliveryID, "event_type": eventType, "remote_addr": r.RemoteAddr, @@ -126,7 +131,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Parse payload var payload map[string]any if err := json.Unmarshal(body, &payload); err != nil { - logger.Error("webhook rejected: 400 Bad Request - error parsing payload", err, logger.Fields{ + logger.Error(ctx, "webhook rejected: 400 Bad Request - error parsing payload", err, logger.Fields{ "delivery_id": deliveryID, "event_type": eventType, "remote_addr": r.RemoteAddr, @@ -140,13 +145,13 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if eventType == "check_run" || eventType == "check_suite" { payloadJSON, err := json.Marshal(payload) if err != nil { - logger.Warn("failed to marshal check event payload", logger.Fields{ + logger.Warn(ctx, "failed to marshal check event payload", logger.Fields{ "event_type": eventType, "delivery_id": deliveryID, "error": err.Error(), }) } else { - logger.Info("received check event - full payload for debugging", logger.Fields{ + logger.Info(ctx, "received check event - full payload for debugging", logger.Fields{ "event_type": eventType, "delivery_id": deliveryID, "payload": string(payloadJSON), @@ -155,20 +160,20 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // Extract PR URL - prURL := ExtractPRURL(eventType, payload) + prURL := ExtractPRURL(ctx, eventType, payload) if prURL == "" { // For non-check events, log payload and return early if eventType != "check_run" && eventType != "check_suite" { // Log full payload to understand the structure (for non-check events) payloadJSON, err := json.Marshal(payload) if err != nil { - logger.Warn("failed to marshal payload for logging", logger.Fields{ + logger.Warn(ctx, "failed to marshal payload for logging", logger.Fields{ "event_type": eventType, "delivery_id": deliveryID, "error": err.Error(), }) } else { - logger.Info("no PR URL found in event - full payload", logger.Fields{ + logger.Info(ctx, "no PR URL found in event - full payload", logger.Fields{ "event_type": eventType, "delivery_id": deliveryID, "payload": string(payloadJSON), @@ -192,7 +197,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // If we can't extract repo URL, drop the event if repoURL == "" { // Can't extract even repo URL - must drop the event - logger.Warn("⛔ DROPPING CHECK EVENT - no PR URL or repo URL", logger.Fields{ + logger.Warn(ctx, "⛔ DROPPING CHECK EVENT - no PR URL or repo URL", logger.Fields{ "event_type": eventType, "delivery_id": deliveryID, "commit_sha": commitSHA, @@ -203,7 +208,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // We can still broadcast using repo URL - org-based subscriptions will work - logger.Warn("⚠️ CHECK EVENT RACE CONDITION DETECTED", logger.Fields{ + logger.Warn(ctx, "⚠️ CHECK EVENT RACE CONDITION DETECTED", logger.Fields{ "event_type": eventType, "delivery_id": deliveryID, "commit_sha": commitSHA, @@ -232,11 +237,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Get client count before broadcasting (for debugging delivery issues) clientCount := h.hub.ClientCount() - h.hub.Broadcast(event, payload) + h.hub.Broadcast(ctx, event, payload) w.WriteHeader(http.StatusOK) if _, err := w.Write([]byte("OK")); err != nil { - logger.Error("failed to write response", err, logger.Fields{"delivery_id": deliveryID}) + logger.Error(ctx, "failed to write response", err, logger.Fields{"delivery_id": deliveryID}) } // Log successful webhook processing with client count for debugging @@ -257,7 +262,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { logFields["url_type"] = "pull_request" } - logger.Info("webhook processed successfully", logFields) + logger.Info(ctx, "webhook processed successfully", logFields) } // VerifySignature validates the GitHub webhook signature. @@ -279,7 +284,7 @@ func VerifySignature(payload []byte, signature, secret string) bool { } // ExtractPRURL extracts the pull request URL from various event types. -func ExtractPRURL(eventType string, payload map[string]any) string { +func ExtractPRURL(ctx context.Context, eventType string, payload map[string]any) string { switch eventType { case "pull_request", "pull_request_review", "pull_request_review_comment": if pr, ok := payload["pull_request"].(map[string]any); ok { @@ -299,12 +304,12 @@ func ExtractPRURL(eventType string, payload map[string]any) string { case "check_run", "check_suite": // Extract PR URLs from check events if available if checkRun, ok := payload["check_run"].(map[string]any); ok { - if url := extractPRFromCheckEvent(checkRun, payload, eventType); url != "" { + if url := extractPRFromCheckEvent(ctx, checkRun, payload, eventType); url != "" { return url } } if checkSuite, ok := payload["check_suite"].(map[string]any); ok { - if url := extractPRFromCheckEvent(checkSuite, payload, eventType); url != "" { + if url := extractPRFromCheckEvent(ctx, checkSuite, payload, eventType); url != "" { return url } } @@ -313,7 +318,7 @@ func ExtractPRURL(eventType string, payload map[string]any) string { for k := range payload { payloadKeys = append(payloadKeys, k) } - logger.Warn("no PR URL found in check event", logger.Fields{ + logger.Warn(ctx, "no PR URL found in check event", logger.Fields{ "event_type": eventType, "has_check_run": payload["check_run"] != nil, "has_check_suite": payload["check_suite"] != nil, @@ -326,10 +331,10 @@ func ExtractPRURL(eventType string, payload map[string]any) string { } // extractPRFromCheckEvent extracts PR URL from check_run or check_suite events. -func extractPRFromCheckEvent(checkEvent map[string]any, payload map[string]any, eventType string) string { +func extractPRFromCheckEvent(ctx context.Context, checkEvent map[string]any, payload map[string]any, eventType string) string { prs, ok := checkEvent["pull_requests"].([]any) if !ok || len(prs) == 0 { - logger.Info("check event has no pull_requests array", logger.Fields{ + logger.Info(ctx, "check event has no pull_requests array", logger.Fields{ "event_type": eventType, "has_pr_array": ok, "pr_array_length": len(prs), @@ -340,7 +345,7 @@ func extractPRFromCheckEvent(checkEvent map[string]any, payload map[string]any, pr, ok := prs[0].(map[string]any) if !ok { - logger.Warn("pull_requests[0] is not a map", logger.Fields{ + logger.Warn(ctx, "pull_requests[0] is not a map", logger.Fields{ "event_type": eventType, "pr_type": fmt.Sprintf("%T", prs[0]), }) @@ -349,7 +354,7 @@ func extractPRFromCheckEvent(checkEvent map[string]any, payload map[string]any, // Try html_url first if htmlURL, ok := pr["html_url"].(string); ok { - logger.Info("extracted PR URL from check event html_url", logger.Fields{ + logger.Info(ctx, "extracted PR URL from check event html_url", logger.Fields{ "event_type": eventType, "pr_url": htmlURL, }) @@ -359,7 +364,7 @@ func extractPRFromCheckEvent(checkEvent map[string]any, payload map[string]any, // Fallback: construct from number num, ok := pr["number"].(float64) if !ok { - logger.Warn("PR number not found in check event", logger.Fields{ + logger.Warn(ctx, "PR number not found in check event", logger.Fields{ "event_type": eventType, "pr_keys": getMapKeys(pr), }) @@ -368,7 +373,7 @@ func extractPRFromCheckEvent(checkEvent map[string]any, payload map[string]any, repo, ok := payload["repository"].(map[string]any) if !ok { - logger.Warn("repository not found in payload", logger.Fields{ + logger.Warn(ctx, "repository not found in payload", logger.Fields{ "event_type": eventType, }) return "" @@ -376,7 +381,7 @@ func extractPRFromCheckEvent(checkEvent map[string]any, payload map[string]any, repoURL, ok := repo["html_url"].(string) if !ok { - logger.Warn("repository html_url not found", logger.Fields{ + logger.Warn(ctx, "repository html_url not found", logger.Fields{ "event_type": eventType, "repo_keys": getMapKeys(repo), }) @@ -384,7 +389,7 @@ func extractPRFromCheckEvent(checkEvent map[string]any, payload map[string]any, } constructedURL := repoURL + "/pull/" + strconv.Itoa(int(num)) - logger.Info("constructed PR URL from check event", logger.Fields{ + logger.Info(ctx, "constructed PR URL from check event", logger.Fields{ "event_type": eventType, "pr_url": constructedURL, "pr_number": int(num), @@ -421,4 +426,3 @@ func extractCommitSHA(eventType string, payload map[string]any) string { } return "" } -