diff --git a/cmd/server/main.go b/cmd/server/main.go index 220cfb2..a5f8227 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -46,7 +46,6 @@ var ( leEmail = flag.String("le-email", "", "Contact email for Let's Encrypt notifications") maxConnsPerIP = flag.Int("max-conns-per-ip", 10, "Maximum WebSocket connections per IP") maxConnsTotal = flag.Int("max-conns-total", 1000, "Maximum total WebSocket connections") - rateLimit = flag.Int("rate-limit", 100, "Maximum requests per minute per IP") allowedEvents = flag.String("allowed-events", func() string { if value := os.Getenv("ALLOWED_WEBHOOK_EVENTS"); value != "" { return value @@ -92,8 +91,7 @@ func main() { hub := srv.NewHub() go hub.Run(ctx) - // Create security components - rateLimiter := security.NewRateLimiter(*rateLimit) + // Create connection limiter for WebSocket connections connLimiter := security.NewConnectionLimiter(*maxConnsPerIP, *maxConnsTotal) mux := http.NewServeMux() @@ -133,16 +131,6 @@ func main() { return } - // Rate limiting - if !rateLimiter.Allow(ip) { - log.Printf("Webhook 429: rate limit exceeded ip=%s", ip) - w.WriteHeader(http.StatusTooManyRequests) - if _, err := w.Write([]byte("429 Too Many Requests: Rate limit exceeded\n")); err != nil { - log.Printf("failed to write 429 response: %v", err) - } - return - } - webhookHandler.ServeHTTP(w, r) log.Printf("Webhook complete: ip=%s duration=%v", ip, time.Since(startTime)) }) @@ -185,16 +173,6 @@ func main() { return } - // Rate limiting check - if !rateLimiter.Allow(ip) { - log.Printf("WebSocket 429: rate limit exceeded ip=%s", ip) - w.WriteHeader(http.StatusTooManyRequests) - if _, err := w.Write([]byte("429 Too Many Requests: Rate limit exceeded\n")); err != nil { - log.Printf("failed to write 429 response: %v", err) - } - return - } - // Pre-validate authentication before WebSocket upgrade authHeader := r.Header.Get("Authorization") if !wsHandler.PreValidateAuth(r) { @@ -287,9 +265,6 @@ func main() { // Stop accepting new connections hub.Stop() - // Stop the rate limiter cleanup routine - rateLimiter.Stop() - // Stop the connection limiter cleanup routine connLimiter.Stop() diff --git a/pkg/client/client.go b/pkg/client/client.go index 42f6274..7ee381d 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -468,14 +468,25 @@ func (c *Client) connect(ctx context.Context) error { // Start ping sender (sends to write channel, not directly to websocket) pingCtx, cancelPing := context.WithCancel(ctx) defer cancelPing() - go c.sendPings(pingCtx) + pingDone := make(chan struct{}) + go func() { + c.sendPings(pingCtx) + close(pingDone) + }() // Read events - when this returns, cancel everything readErr := c.readEvents(ctx, ws) - // Stop write pump and ping sender - cancelWrite() + // Stop ping sender first - this ensures no more writes will be queued cancelPing() + <-pingDone // Wait for ping sender to fully exit + + // Stop write pump + cancelWrite() + + // Close write channel to signal writePump to exit cleanly + // Safe to close now because ping sender has exited and won't write anymore + close(c.writeCh) // Wait for write pump to finish writeErr := <-writeDone diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index 0542475..df7cd92 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -2,9 +2,14 @@ package client import ( "context" + "io" + "net/http/httptest" + "strings" "sync" "testing" "time" + + "golang.org/x/net/websocket" ) // TestStopMultipleCalls verifies that calling Stop() multiple times is safe @@ -188,3 +193,568 @@ func TestCommitPRCachePopulation(t *testing.T) { } }) } + +// mockWebSocketServer creates a test WebSocket server with configurable behavior. +type mockWebSocketServer struct { + server *httptest.Server + url string + onConnection func(*websocket.Conn) + acceptAuth bool + sendEvents []map[string]any + sendPings bool + closeDelay time.Duration + rejectWithCode int +} + +func newMockServer(t *testing.T, acceptAuth bool) *mockWebSocketServer { + t.Helper() + m := &mockWebSocketServer{ + acceptAuth: acceptAuth, + } + + handler := websocket.Handler(func(ws *websocket.Conn) { + if m.onConnection != nil { + m.onConnection(ws) + return + } + + // Default behavior: read subscription, confirm, send events, handle pings + var sub map[string]any + if err := websocket.JSON.Receive(ws, &sub); err != nil { + t.Logf("Failed to read subscription: %v", err) + return + } + + // Send subscription confirmation + confirmation := map[string]any{ + "type": "subscription_confirmed", + "organization": sub["organization"], + } + if err := websocket.JSON.Send(ws, confirmation); err != nil { + t.Logf("Failed to send confirmation: %v", err) + return + } + + // Send events if configured + for _, event := range m.sendEvents { + if err := websocket.JSON.Send(ws, event); err != nil { + t.Logf("Failed to send event: %v", err) + return + } + } + + // Handle pings/pongs + for { + var msg map[string]any + if err := websocket.JSON.Receive(ws, &msg); err != nil { + if err == io.EOF { + return + } + t.Logf("Read error: %v", err) + return + } + + if msgType, ok := msg["type"].(string); ok { + if msgType == "ping" { + pong := map[string]any{"type": "pong"} + if seq, ok := msg["seq"]; ok { + pong["seq"] = seq + } + if err := websocket.JSON.Send(ws, pong); err != nil { + return + } + } + } + } + }) + + m.server = httptest.NewServer(handler) + m.url = "ws" + strings.TrimPrefix(m.server.URL, "http") + return m +} + +func (m *mockWebSocketServer) Close() { + m.server.Close() +} + +// TestClientConnectAndReceiveEvents tests the full connection lifecycle. +func TestClientConnectAndReceiveEvents(t *testing.T) { + // Create mock server that sends test events + srv := newMockServer(t, true) + defer srv.Close() + + srv.sendEvents = []map[string]any{ + { + "type": "pull_request", + "url": "https://github.com/test/repo/pull/1", + "timestamp": time.Now().Format(time.RFC3339), + }, + { + "type": "check_run", + "url": "https://github.com/test/repo/pull/1", + "timestamp": time.Now().Format(time.RFC3339), + }, + } + + // Create client + var receivedEvents []Event + var mu sync.Mutex + + client, err := New(Config{ + ServerURL: srv.url, + Token: "test-token", + Organization: "test-org", + NoReconnect: true, + OnEvent: func(e Event) { + mu.Lock() + receivedEvents = append(receivedEvents, e) + mu.Unlock() + }, + }) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + // Start client with timeout + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- client.Start(ctx) + }() + + // Wait a bit for events to be received + time.Sleep(500 * time.Millisecond) + client.Stop() + + // Check received events + mu.Lock() + eventCount := len(receivedEvents) + mu.Unlock() + + if eventCount != 2 { + t.Errorf("Expected 2 events, got %d", eventCount) + } +} + +// TestClientPingPong tests that pings are sent and pongs are received. +func TestClientPingPong(t *testing.T) { + pingReceived := make(chan bool, 10) + + srv := newMockServer(t, true) + defer srv.Close() + + // Custom connection handler that tracks pings + srv.onConnection = func(ws *websocket.Conn) { + // Read subscription + var sub map[string]any + if err := websocket.JSON.Receive(ws, &sub); err != nil { + return + } + + // Send confirmation + confirmation := map[string]any{"type": "subscription_confirmed"} + if err := websocket.JSON.Send(ws, confirmation); err != nil { + return + } + + // Listen for pings from client + for { + var msg map[string]any + if err := websocket.JSON.Receive(ws, &msg); err != nil { + return + } + + if msgType, ok := msg["type"].(string); ok && msgType == "ping" { + pingReceived <- true + + // Send pong response + pong := map[string]any{"type": "pong"} + if err := websocket.JSON.Send(ws, pong); err != nil { + return + } + } + } + } + + client, err := New(Config{ + ServerURL: srv.url, + Token: "test-token", + Organization: "test-org", + PingInterval: 100 * time.Millisecond, // Fast pings for testing + NoReconnect: true, + }) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + go func() { + _ = client.Start(ctx) //nolint:errcheck // Expected to timeout + }() + + // Wait for at least 2 pings + select { + case <-pingReceived: + // First ping received + case <-time.After(1 * time.Second): + t.Fatal("No ping received within 1 second") + } + + select { + case <-pingReceived: + // Second ping received - success! + case <-time.After(1 * time.Second): + t.Fatal("Second ping not received within 1 second") + } + + client.Stop() +} + +// TestClientReconnection tests that the client reconnects on disconnect. +func TestClientReconnection(t *testing.T) { + connectionCount := 0 + var mu sync.Mutex + + srv := newMockServer(t, true) + defer srv.Close() + + srv.onConnection = func(ws *websocket.Conn) { + mu.Lock() + connectionCount++ + count := connectionCount + mu.Unlock() + + // Read subscription + var sub map[string]any + if err := websocket.JSON.Receive(ws, &sub); err != nil { + return + } + + // Send confirmation + confirmation := map[string]any{"type": "subscription_confirmed"} + if err := websocket.JSON.Send(ws, confirmation); err != nil { + return + } + + // First connection: close immediately to trigger reconnection + if count == 1 { + ws.Close() + return + } + + // Second connection: stay alive + for { + var msg map[string]any + if err := websocket.JSON.Receive(ws, &msg); err != nil { + return + } + } + } + + client, err := New(Config{ + ServerURL: srv.url, + Token: "test-token", + Organization: "test-org", + MaxBackoff: 100 * time.Millisecond, // Fast reconnection for testing + }) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + go func() { + _ = client.Start(ctx) //nolint:errcheck // Expected to timeout + }() + + // Wait for reconnection + time.Sleep(1 * time.Second) + + mu.Lock() + count := connectionCount + mu.Unlock() + + if count < 2 { + t.Errorf("Expected at least 2 connections (reconnection), got %d", count) + } + + client.Stop() +} + +// TestClientAuthenticationError tests that auth errors don't trigger reconnection. +func TestClientAuthenticationError(t *testing.T) { + srv := newMockServer(t, false) + defer srv.Close() + + srv.onConnection = func(ws *websocket.Conn) { + // Read subscription + var sub map[string]any + if err := websocket.JSON.Receive(ws, &sub); err != nil { + return + } + + // Send auth error + errMsg := map[string]any{ + "type": "error", + "error": "access_denied", + "message": "Not authorized", + } + if err := websocket.JSON.Send(ws, errMsg); err != nil { + return + } + } + + client, err := New(Config{ + ServerURL: srv.url, + Token: "bad-token", + Organization: "test-org", + MaxRetries: 3, + }) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err = client.Start(ctx) + if err == nil { + t.Fatal("Expected authentication error, got nil") + } + + if !strings.Contains(err.Error(), "Authentication") && !strings.Contains(err.Error(), "authorization") { + t.Errorf("Expected authentication error, got: %v", err) + } +} + +// TestClientServerPings tests that the client responds to server pings. +func TestClientServerPings(t *testing.T) { + pongReceived := make(chan bool, 10) + + srv := newMockServer(t, true) + defer srv.Close() + + srv.onConnection = func(ws *websocket.Conn) { + // Read subscription + var sub map[string]any + if err := websocket.JSON.Receive(ws, &sub); err != nil { + return + } + + // Send confirmation + confirmation := map[string]any{"type": "subscription_confirmed"} + if err := websocket.JSON.Send(ws, confirmation); err != nil { + return + } + + // Send pings to client + go func() { + for i := 0; i < 3; i++ { + ping := map[string]any{"type": "ping", "seq": i} + if err := websocket.JSON.Send(ws, ping); err != nil { + return + } + time.Sleep(100 * time.Millisecond) + } + }() + + // Listen for pongs + for { + var msg map[string]any + if err := websocket.JSON.Receive(ws, &msg); err != nil { + return + } + + if msgType, ok := msg["type"].(string); ok && msgType == "pong" { + pongReceived <- true + } + } + } + + client, err := New(Config{ + ServerURL: srv.url, + Token: "test-token", + Organization: "test-org", + NoReconnect: true, + }) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + go func() { + _ = client.Start(ctx) //nolint:errcheck // Expected to timeout + }() + + // Wait for pongs + pongsReceived := 0 + timeout := time.After(1 * time.Second) + + for pongsReceived < 2 { + select { + case <-pongReceived: + pongsReceived++ + case <-timeout: + t.Fatalf("Only received %d pongs, expected at least 2", pongsReceived) + } + } + + client.Stop() +} + +// TestClientEventWithCommitSHA tests event handling with commit SHA. +func TestClientEventWithCommitSHA(t *testing.T) { + srv := newMockServer(t, true) + defer srv.Close() + + srv.sendEvents = []map[string]any{ + { + "type": "pull_request", + "url": "https://github.com/test/repo/pull/123", + "commit_sha": "abc123", + "timestamp": time.Now().Format(time.RFC3339), + }, + } + + var receivedEvent Event + eventReceived := make(chan bool, 1) + + client, err := New(Config{ + ServerURL: srv.url, + Token: "test-token", + Organization: "test-org", + NoReconnect: true, + OnEvent: func(e Event) { + receivedEvent = e + eventReceived <- true + }, + }) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + go func() { + _ = client.Start(ctx) //nolint:errcheck // Expected to timeout + }() + + // Wait for event + select { + case <-eventReceived: + // Success + case <-time.After(1 * time.Second): + t.Fatal("Event not received") + } + + if receivedEvent.CommitSHA != "abc123" { + t.Errorf("Expected commit SHA 'abc123', got %q", receivedEvent.CommitSHA) + } + if receivedEvent.Type != "pull_request" { + t.Errorf("Expected type 'pull_request', got %q", receivedEvent.Type) + } + + client.Stop() +} + +// TestClientWriteChannelBlocking tests that write channel doesn't block indefinitely. +func TestClientWriteChannelBlocking(t *testing.T) { + srv := newMockServer(t, true) + defer srv.Close() + + srv.onConnection = func(ws *websocket.Conn) { + // Read subscription + var sub map[string]any + if err := websocket.JSON.Receive(ws, &sub); err != nil { + return + } + + // Send confirmation + confirmation := map[string]any{"type": "subscription_confirmed"} + if err := websocket.JSON.Send(ws, confirmation); err != nil { + return + } + + // Don't read anything else - this will cause write buffer to potentially fill + time.Sleep(5 * time.Second) + } + + client, err := New(Config{ + ServerURL: srv.url, + Token: "test-token", + Organization: "test-org", + PingInterval: 10 * time.Millisecond, // Very fast pings to fill buffer + NoReconnect: true, + }) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + err = client.Start(ctx) + // Should timeout gracefully, not deadlock + if err != context.DeadlineExceeded { + t.Logf("Expected deadline exceeded, got: %v", err) + } + + client.Stop() +} + +// TestClientCachePopulationFromPullRequestEvent tests the cache population logic. +func TestClientCachePopulationFromPullRequestEvent(t *testing.T) { + srv := newMockServer(t, true) + defer srv.Close() + + // Send a pull_request event with commit SHA + srv.sendEvents = []map[string]any{ + { + "type": "pull_request", + "url": "https://github.com/owner/repo/pull/456", + "commit_sha": "def789", + "timestamp": time.Now().Format(time.RFC3339), + }, + } + + client, err := New(Config{ + ServerURL: srv.url, + Token: "test-token", + Organization: "test-org", + NoReconnect: true, + }) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + go func() { + _ = client.Start(ctx) //nolint:errcheck // Expected to timeout + }() + + // Wait for event processing + time.Sleep(500 * time.Millisecond) + client.Stop() + + // Check that cache was populated + client.cacheMu.RLock() + cached, exists := client.commitPRCache["owner/repo:def789"] + client.cacheMu.RUnlock() + + if !exists { + t.Error("Expected cache to be populated from pull_request event") + } + if len(cached) != 1 || cached[0] != 456 { + t.Errorf("Expected cached PR [456], got %v", cached) + } +} diff --git a/pkg/security/race_test.go b/pkg/security/race_test.go index 7ba7c0e..80bbece 100644 --- a/pkg/security/race_test.go +++ b/pkg/security/race_test.go @@ -213,33 +213,6 @@ func TestConnectionLimiterConcurrentAccess(t *testing.T) { } } -// TestRateLimiterConcurrentAccess verifies rate limiter thread safety. -func TestRateLimiterConcurrentAccess(t *testing.T) { - limiter := NewRateLimiter(100) - defer limiter.Stop() - - var wg sync.WaitGroup - - // Test concurrent Allow from multiple IPs - for i := range 10 { - wg.Add(1) - go func(id int) { - defer wg.Done() - ip := "192.168.1." + string(rune('1'+id)) - - // Rapid allow checks - for range 100 { - _ = limiter.Allow(ip) - } - }(i) - } - - wg.Wait() - - // No assertions needed - if we didn't panic, we're good - t.Log("✓ Rate limiter handled concurrent access without panics") -} - // TestConnectionLimiterReservationCancellation tests the cancellation path. func TestConnectionLimiterReservationCancellation(t *testing.T) { limiter := NewConnectionLimiter(10, 1000) diff --git a/pkg/security/ratelimiter.go b/pkg/security/ratelimiter.go deleted file mode 100644 index da646f7..0000000 --- a/pkg/security/ratelimiter.go +++ /dev/null @@ -1,127 +0,0 @@ -package security - -import ( - "sync" - "time" -) - -const ( - maxBuckets = 10000 // Limit to 10k unique IPs to prevent memory exhaustion -) - -// RateLimiter implements a simple token bucket rate limiter. -type RateLimiter struct { - buckets map[string]*bucket - stopCh chan struct{} - cleanupWG sync.WaitGroup - maxTokens int - maxBuckets int - mu sync.Mutex -} - -type bucket struct { - resetTime time.Time - count int -} - -// NewRateLimiter creates a new rate limiter. -func NewRateLimiter(maxTokens int) *RateLimiter { - rl := &RateLimiter{ - buckets: make(map[string]*bucket), - maxTokens: maxTokens, - maxBuckets: maxBuckets, - stopCh: make(chan struct{}), - } - - // Start cleanup goroutine - rl.cleanupWG.Add(1) - go rl.cleanupRoutine() - - return rl -} - -// Allow checks if a request from the given IP is allowed. -func (rl *RateLimiter) Allow(ip string) bool { - rl.mu.Lock() - defer rl.mu.Unlock() - - now := time.Now() - b, exists := rl.buckets[ip] - - // Create new bucket or reset if expired - if !exists || now.After(b.resetTime) { - // Check if we've reached the max buckets limit - if !exists && len(rl.buckets) >= rl.maxBuckets { - // Find and remove the oldest bucket to make room - rl.evictOldest() - } - - rl.buckets[ip] = &bucket{ - count: 1, - resetTime: now.Add(time.Minute), - } - return true - } - - // Check limit - if b.count >= rl.maxTokens { - return false - } - - b.count++ - return true -} - -// cleanupRoutine periodically removes expired buckets to prevent memory leaks. -func (rl *RateLimiter) cleanupRoutine() { - defer rl.cleanupWG.Done() - - ticker := time.NewTicker(5 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - rl.cleanup() - case <-rl.stopCh: - return - } - } -} - -// cleanup removes expired buckets. -func (rl *RateLimiter) cleanup() { - rl.mu.Lock() - defer rl.mu.Unlock() - - now := time.Now() - for ip, b := range rl.buckets { - if now.After(b.resetTime) { - delete(rl.buckets, ip) - } - } -} - -// evictOldest removes the oldest bucket (called with lock held). -func (rl *RateLimiter) evictOldest() { - var oldestIP string - var oldestTime time.Time - - // Find the oldest bucket - for ip, b := range rl.buckets { - if oldestIP == "" || b.resetTime.Before(oldestTime) { - oldestIP = ip - oldestTime = b.resetTime - } - } - - if oldestIP != "" { - delete(rl.buckets, oldestIP) - } -} - -// Stop gracefully stops the rate limiter. -func (rl *RateLimiter) Stop() { - close(rl.stopCh) - rl.cleanupWG.Wait() -} diff --git a/pkg/security/security_test.go b/pkg/security/security_test.go index 37e609b..c0d3f45 100644 --- a/pkg/security/security_test.go +++ b/pkg/security/security_test.go @@ -6,27 +6,6 @@ import ( "testing" ) -func TestRateLimiter(t *testing.T) { - rl := NewRateLimiter(5) - - ip := "192.168.1.1" - - // Should allow first 5 requests - for i := range 5 { - if !rl.Allow(ip) { - t.Errorf("request %d should be allowed", i+1) - } - } - - // 6th request should be denied - if rl.Allow(ip) { - t.Error("6th request should be denied") - } - - // Note: Our simplified rate limiter resets every minute, - // not every second, so we can't test the reset behavior easily -} - func TestConnectionLimiter(t *testing.T) { cl := NewConnectionLimiter(2, 5) diff --git a/pkg/srv/websocket_test.go b/pkg/srv/websocket_test.go new file mode 100644 index 0000000..db9fd9b --- /dev/null +++ b/pkg/srv/websocket_test.go @@ -0,0 +1,355 @@ +package srv + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "golang.org/x/net/websocket" + + "github.com/codeGROOVE-dev/sprinkler/pkg/security" +) + +// TestValidateTokenFormat tests token format validation. +func TestValidateTokenFormat(t *testing.T) { + tests := []struct { + name string + token string + want bool + }{ + { + name: "valid ghp token", + token: "ghp_" + strings.Repeat("a", 36), + want: true, + }, + { + name: "valid gho token", + token: "gho_" + strings.Repeat("b", 36), + want: true, + }, + { + name: "valid ghs token", + token: "ghs_" + strings.Repeat("c", 36), + want: true, + }, + { + name: "valid 40-char classic token", + token: strings.Repeat("d", 40), + want: true, + }, + { + name: "valid fine-grained PAT", + token: "github_pat_" + strings.Repeat("e", 36), + want: true, + }, + { + name: "too short", + token: "ghp_short", + want: false, + }, + { + name: "empty token", + token: "", + want: false, + }, + { + name: "invalid characters", + token: "ghp_" + strings.Repeat("!", 36), + want: false, + }, + { + name: "wrong prefix length", + token: "ghp_" + strings.Repeat("a", 35), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := validateTokenFormat(tt.token) + if got != tt.want { + t.Errorf("validateTokenFormat(%q) = %v, want %v", tt.token, got, tt.want) + } + }) + } +} + +// TestPreValidateAuth tests the PreValidateAuth method. +func TestPreValidateAuth(t *testing.T) { + ctx := context.Background() + hub := NewHub() + go hub.Run(ctx) + defer hub.Stop() + + connLimiter := security.NewConnectionLimiter(10, 50) + defer connLimiter.Stop() + + handler := NewWebSocketHandler(hub, connLimiter, nil) + + tests := []struct { + name string + authHeader string + want bool + }{ + { + name: "valid token", + authHeader: "Bearer ghp_" + strings.Repeat("a", 36), + want: true, + }, + { + name: "missing authorization header", + authHeader: "", + want: false, + }, + { + name: "missing bearer prefix", + authHeader: "ghp_" + strings.Repeat("a", 36), + want: false, + }, + { + name: "invalid token format", + authHeader: "Bearer invalid", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/ws", nil) + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) + } + + got := handler.PreValidateAuth(req) + if got != tt.want { + t.Errorf("PreValidateAuth() = %v, want %v", got, tt.want) + } + }) + } +} + +// TestPreValidateAuthTestMode tests that test mode skips validation. +func TestPreValidateAuthTestMode(t *testing.T) { + ctx := context.Background() + hub := NewHub() + go hub.Run(ctx) + defer hub.Stop() + + connLimiter := security.NewConnectionLimiter(10, 50) + defer connLimiter.Stop() + + handler := NewWebSocketHandlerForTest(hub, connLimiter, nil) + + // Even with no auth header, test mode should return true + req := httptest.NewRequest(http.MethodGet, "/ws", nil) + got := handler.PreValidateAuth(req) + if !got { + t.Error("PreValidateAuth() in test mode should return true") + } +} + +// TestWebSocketHandlerWithMockConnection tests the full WebSocket handler lifecycle. +func TestWebSocketHandlerWithMockConnection(t *testing.T) { + ctx := context.Background() + hub := NewHub() + go hub.Run(ctx) + defer hub.Stop() + + connLimiter := security.NewConnectionLimiter(10, 50) + defer connLimiter.Stop() + + // Use test mode to skip GitHub auth + handler := NewWebSocketHandlerForTest(hub, connLimiter, []string{"pull_request", "check_run"}) + + // Create test server + server := httptest.NewServer(websocket.Handler(handler.Handle)) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + // Connect client + ws, err := websocket.Dial(wsURL, "", "http://localhost/") + if err != nil { + t.Fatalf("Failed to dial WebSocket: %v", err) + } + defer ws.Close() + + // Send subscription request + sub := map[string]any{ + "organization": "test-org", + "event_types": []string{"pull_request"}, + } + + if err := websocket.JSON.Send(ws, sub); err != nil { + t.Fatalf("Failed to send subscription: %v", err) + } + + // Read subscription confirmation + var response map[string]any + if err := websocket.JSON.Receive(ws, &response); err != nil { + t.Fatalf("Failed to receive confirmation: %v", err) + } + + responseType, ok := response["type"].(string) + if !ok || responseType != "subscription_confirmed" { + t.Errorf("Expected subscription_confirmed, got %v", response) + } +} + +// TestWebSocketHandlerEventFiltering tests that only allowed events are accepted. +func TestWebSocketHandlerEventFiltering(t *testing.T) { + ctx := context.Background() + hub := NewHub() + go hub.Run(ctx) + defer hub.Stop() + + connLimiter := security.NewConnectionLimiter(10, 50) + defer connLimiter.Stop() + + // Only allow pull_request events + handler := NewWebSocketHandlerForTest(hub, connLimiter, []string{"pull_request"}) + + // Verify the allowedEventsMap was built correctly + if !handler.allowedEventsMap["pull_request"] { + t.Error("Expected pull_request to be in allowedEventsMap") + } + if handler.allowedEventsMap["check_run"] { + t.Error("Expected check_run to NOT be in allowedEventsMap") + } +} + +// TestWSCloser tests the wsCloser to prevent double-close. +func TestWSCloser(t *testing.T) { + // Create a mock WebSocket connection + server := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { + // Keep connection open + time.Sleep(100 * time.Millisecond) + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + ws, err := websocket.Dial(wsURL, "", "http://localhost/") + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + + wc := &wsCloser{ws: ws} + + // Close once + err1 := wc.Close() + if err1 != nil && !strings.Contains(err1.Error(), "use of closed") { + t.Errorf("First close error: %v", err1) + } + + // Verify closed status + if !wc.IsClosed() { + t.Error("Expected IsClosed() to return true after Close()") + } + + // Close again - should be safe (no panic) + err2 := wc.Close() + if err2 != nil && !strings.Contains(err2.Error(), "use of closed") { + t.Errorf("Second close error: %v", err2) + } + + // Multiple concurrent closes should be safe + for i := 0; i < 10; i++ { + go func() { + _ = wc.Close() // Should not panic + }() + } + + time.Sleep(10 * time.Millisecond) +} + +// TestExtractGitHubTokenTestMode tests token extraction in test mode. +func TestExtractGitHubTokenTestMode(t *testing.T) { + ctx := context.Background() + hub := NewHub() + go hub.Run(ctx) + defer hub.Stop() + + connLimiter := security.NewConnectionLimiter(10, 50) + defer connLimiter.Stop() + + handler := NewWebSocketHandlerForTest(hub, connLimiter, nil) + + // Create test server + server := httptest.NewServer(websocket.Handler(handler.Handle)) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + // Connect without any auth header (test mode should allow) + ws, err := websocket.Dial(wsURL, "", "http://localhost/") + if err != nil { + t.Fatalf("Failed to dial in test mode: %v", err) + } + defer ws.Close() + + // Send subscription - should work in test mode + sub := map[string]any{ + "organization": "test-org", + } + + if err := websocket.JSON.Send(ws, sub); err != nil { + t.Fatalf("Failed to send subscription in test mode: %v", err) + } + + // Should get confirmation + var response map[string]any + if err := websocket.JSON.Receive(ws, &response); err != nil { + t.Fatalf("Failed to receive confirmation in test mode: %v", err) + } + + if response["type"] != "subscription_confirmed" { + t.Errorf("Expected confirmation in test mode, got %v", response) + } +} + +// TestNewWebSocketHandler tests handler creation with and without allowed events. +func TestNewWebSocketHandler(t *testing.T) { + ctx := context.Background() + hub := NewHub() + go hub.Run(ctx) + defer hub.Stop() + + connLimiter := security.NewConnectionLimiter(10, 50) + defer connLimiter.Stop() + + t.Run("with allowed events", func(t *testing.T) { + handler := NewWebSocketHandler(hub, connLimiter, []string{"pull_request", "check_run"}) + if handler == nil { + t.Fatal("Expected non-nil handler") + } + if len(handler.allowedEvents) != 2 { + t.Errorf("Expected 2 allowed events, got %d", len(handler.allowedEvents)) + } + if len(handler.allowedEventsMap) != 2 { + t.Errorf("Expected 2 entries in allowedEventsMap, got %d", len(handler.allowedEventsMap)) + } + }) + + t.Run("without allowed events", func(t *testing.T) { + handler := NewWebSocketHandler(hub, connLimiter, nil) + if handler == nil { + t.Fatal("Expected non-nil handler") + } + if handler.allowedEventsMap != nil { + t.Error("Expected nil allowedEventsMap when no events specified") + } + }) + + t.Run("test mode", func(t *testing.T) { + handler := NewWebSocketHandlerForTest(hub, connLimiter, []string{"pull_request"}) + if handler == nil { + t.Fatal("Expected non-nil handler") + } + if !handler.testMode { + t.Error("Expected testMode to be true") + } + }) +} diff --git a/pkg/webhook/handler.go b/pkg/webhook/handler.go index 56bda86..35d3232 100644 --- a/pkg/webhook/handler.go +++ b/pkg/webhook/handler.go @@ -229,8 +229,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { DeliveryID: deliveryID, } - // For check events, include commit SHA to allow PR lookup when URL is repo-only (race condition) - if eventType == "check_run" || eventType == "check_suite" { + // Extract commit SHA for cache population and PR lookup + // - pull_request: allows client-side cache population for commit->PR mapping + // - check events: enables PR lookup when URL is repo-only (GitHub race condition) + if eventType == "check_run" || eventType == "check_suite" || eventType == "pull_request" { event.CommitSHA = extractCommitSHA(eventType, payload) } @@ -406,7 +408,7 @@ func getMapKeys(m map[string]any) []string { return keys } -// extractCommitSHA extracts the commit SHA from check_run or check_suite events. +// extractCommitSHA extracts the commit SHA from pull_request, check_run, or check_suite events. func extractCommitSHA(eventType string, payload map[string]any) string { switch eventType { case "check_run": @@ -421,8 +423,16 @@ func extractCommitSHA(eventType string, payload map[string]any) string { return headSHA } } + case "pull_request": + if pr, ok := payload["pull_request"].(map[string]any); ok { + if head, ok := pr["head"].(map[string]any); ok { + if sha, ok := head["sha"].(string); ok { + return sha + } + } + } default: - // Not a check event, no SHA to extract + // Not a supported event type for SHA extraction } return "" } diff --git a/pkg/webhook/handler_test.go b/pkg/webhook/handler_test.go index b2228b3..1cc6e73 100644 --- a/pkg/webhook/handler_test.go +++ b/pkg/webhook/handler_test.go @@ -111,3 +111,307 @@ func TestWebhookHandler(t *testing.T) { t.Errorf("expected status %d for check_suite, got %d", http.StatusOK, w.Code) } } + +// TestWebhookHandlerEventFiltering tests event type filtering. +func TestWebhookHandlerEventFiltering(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h := srv.NewHub() + go h.Run(ctx) + defer h.Stop() + + secret := "testsecret" + // Only allow pull_request events + handler := NewHandler(h, secret, []string{"pull_request"}) + + // Test allowed event + payload := map[string]any{ + "action": "opened", + "pull_request": map[string]any{ + "html_url": "https://gitsrv.com/user/repo/pull/1", + }, + } + + body, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(body)) + req.Header.Set("X-GitHub-Event", "pull_request") //nolint:canonicalheader // GitHub webhook header + + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write(body) + signature := "sha256=" + hex.EncodeToString(mac.Sum(nil)) + req.Header.Set("X-Hub-Signature-256", signature) + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Errorf("allowed event: expected status %d, got %d", http.StatusOK, w.Code) + } + + // Test disallowed event (check_run) + req = httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(body)) + req.Header.Set("X-GitHub-Event", "check_run") //nolint:canonicalheader // GitHub webhook header + req.Header.Set("X-Hub-Signature-256", signature) + + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Errorf("disallowed event: expected status %d (silent accept), got %d", http.StatusOK, w.Code) + } +} + +// TestWebhookHandlerPayloadTooLarge tests max payload size enforcement. +func TestWebhookHandlerPayloadTooLarge(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h := srv.NewHub() + go h.Run(ctx) + defer h.Stop() + + secret := "testsecret" + handler := NewHandler(h, secret, nil) + + // Create payload larger than maxPayloadSize (1MB) + largePayload := make([]byte, maxPayloadSize+1) + for i := range largePayload { + largePayload[i] = 'a' + } + + req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(largePayload)) + req.Header.Set("X-GitHub-Event", "pull_request") //nolint:canonicalheader // GitHub webhook header + req.ContentLength = int64(len(largePayload)) + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusRequestEntityTooLarge { + t.Errorf("expected status %d, got %d", http.StatusRequestEntityTooLarge, w.Code) + } +} + +// TestWebhookHandlerMissingSignature tests missing signature handling. +func TestWebhookHandlerMissingSignature(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h := srv.NewHub() + go h.Run(ctx) + defer h.Stop() + + secret := "testsecret" + handler := NewHandler(h, secret, nil) + + payload := map[string]any{"action": "opened"} + body, _ := json.Marshal(payload) + + req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(body)) + req.Header.Set("X-GitHub-Event", "pull_request") //nolint:canonicalheader // GitHub webhook header + // No signature header + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Errorf("expected status %d, got %d", http.StatusUnauthorized, w.Code) + } +} + +// TestWebhookHandlerInvalidJSON tests invalid JSON payload handling. +func TestWebhookHandlerInvalidJSON(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h := srv.NewHub() + go h.Run(ctx) + defer h.Stop() + + secret := "testsecret" + handler := NewHandler(h, secret, nil) + + invalidJSON := []byte("{invalid json") + + req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(invalidJSON)) + req.Header.Set("X-GitHub-Event", "pull_request") //nolint:canonicalheader // GitHub webhook header + + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write(invalidJSON) + signature := "sha256=" + hex.EncodeToString(mac.Sum(nil)) + req.Header.Set("X-Hub-Signature-256", signature) + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } +} + +// TestWebhookHandlerCheckRunWithCommit tests check_run event with commit SHA. +func TestWebhookHandlerCheckRunWithCommit(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h := srv.NewHub() + go h.Run(ctx) + defer h.Stop() + + secret := "testsecret" + handler := NewHandler(h, secret, nil) + + // check_run with head_sha + payload := map[string]any{ + "action": "completed", + "check_run": map[string]any{ + "head_sha": "abc123def456", + "pull_requests": []any{ + map[string]any{ + "number": float64(42), + }, + }, + }, + "repository": map[string]any{ + "html_url": "https://github.com/owner/repo", + }, + } + + body, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(body)) + req.Header.Set("X-GitHub-Event", "check_run") //nolint:canonicalheader // GitHub webhook header + + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write(body) + signature := "sha256=" + hex.EncodeToString(mac.Sum(nil)) + req.Header.Set("X-Hub-Signature-256", signature) + + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +// TestExtractCommitSHA tests commit SHA extraction. +func TestExtractCommitSHA(t *testing.T) { + tests := []struct { + name string + eventType string + payload map[string]any + expected string + }{ + { + name: "check_run with head_sha", + eventType: "check_run", + payload: map[string]any{ + "check_run": map[string]any{ + "head_sha": "abc123", + }, + }, + expected: "abc123", + }, + { + name: "check_suite with head_sha", + eventType: "check_suite", + payload: map[string]any{ + "check_suite": map[string]any{ + "head_sha": "def456", + }, + }, + expected: "def456", + }, + { + name: "no SHA", + eventType: "check_run", + payload: map[string]any{}, + expected: "", + }, + { + name: "check_run with invalid type", + eventType: "check_run", + payload: map[string]any{ + "check_run": map[string]any{ + "head_sha": 12345, // not a string + }, + }, + expected: "", + }, + { + name: "wrong event type", + eventType: "issues", + payload: map[string]any{ + "check_run": map[string]any{ + "head_sha": "shouldnotextract", + }, + }, + expected: "", + }, + { + name: "pull_request with head.sha", + eventType: "pull_request", + payload: map[string]any{ + "pull_request": map[string]any{ + "head": map[string]any{ + "sha": "pr_commit_123", + }, + }, + }, + expected: "pr_commit_123", + }, + { + name: "pull_request without head", + eventType: "pull_request", + payload: map[string]any{ + "pull_request": map[string]any{ + "number": 42, + }, + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractCommitSHA(tt.eventType, tt.payload) + if result != tt.expected { + t.Errorf("extractCommitSHA() = %q, want %q", result, tt.expected) + } + }) + } +} + +// TestGetMapKeys tests the getMapKeys utility function. +func TestGetMapKeys(t *testing.T) { + tests := []struct { + name string + input map[string]any + expected int // Just check length since order is undefined + }{ + { + name: "empty map", + input: map[string]any{}, + expected: 0, + }, + { + name: "single key", + input: map[string]any{ + "key1": "value1", + }, + expected: 1, + }, + { + name: "multiple keys", + input: map[string]any{ + "key1": "value1", + "key2": "value2", + "key3": "value3", + }, + expected: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getMapKeys(tt.input) + if len(result) != tt.expected { + t.Errorf("getMapKeys() returned %d keys, want %d", len(result), tt.expected) + } + }) + } +}