Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,10 @@ const (
minMaskHeaderLength = 20 // Minimum header length before we show full "[REDACTED]"
)

// getEnvOrDefault returns the value of the environment variable or the default if not set.
func getEnvOrDefault(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
// contextKey is a custom type for context keys to avoid collisions.
type contextKey string

const reservationTokenKey contextKey = "reservation_token"

var (
webhookSecret = flag.String("webhook-secret", os.Getenv("GITHUB_WEBHOOK_SECRET"), "GitHub webhook secret for signature verification")
Expand All @@ -50,8 +47,12 @@ var (
maxConnsPerIP = flag.Int("max-conns-per-ip", 10, "Maximum WebSocket connections per IP")
maxConnsTotal = flag.Int("max-conns-total", 1000, "Maximum total WebSocket connections")
rateLimit = flag.Int("rate-limit", 100, "Maximum requests per minute per IP")
allowedEvents = flag.String("allowed-events", getEnvOrDefault("ALLOWED_WEBHOOK_EVENTS", "*"),
"Comma-separated list of allowed webhook event types (use '*' for all, default: '*')")
allowedEvents = flag.String("allowed-events", func() string {
if value := os.Getenv("ALLOWED_WEBHOOK_EVENTS"); value != "" {
return value
}
return "*"
}(), "Comma-separated list of allowed webhook event types (use '*' for all, default: '*')")
debugHeaders = flag.Bool("debug-headers", false, "Log request headers for debugging (security warning: may log sensitive data)")
)

Expand Down Expand Up @@ -244,7 +245,7 @@ func main() {
}

// Set reservation token in request context so websocket handler can commit it
r = r.WithContext(context.WithValue(r.Context(), "reservation_token", reservationToken))
r = r.WithContext(context.WithValue(r.Context(), reservationTokenKey, reservationToken))

// Log successful auth and proceed to upgrade
log.Printf("WebSocket UPGRADE: ip=%s duration=%v", ip, time.Since(startTime))
Expand Down
8 changes: 0 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
github.com/codeGROOVE-dev/retry v1.2.0 h1:xYpYPX2PQZmdHwuiQAGGzsBm392xIMl4nfMEFApQnu8=
github.com/codeGROOVE-dev/retry v1.2.0/go.mod h1:8OgefgV1XP7lzX2PdKlCXILsYKuz6b4ZpHa/20iLi8E=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I=
golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
153 changes: 139 additions & 14 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/codeGROOVE-dev/retry"
"github.com/codeGROOVE-dev/sprinkler/pkg/github"
"golang.org/x/net/websocket"
)

Expand Down Expand Up @@ -45,8 +46,9 @@ type Event struct {
Timestamp time.Time `json:"timestamp"`
Raw map[string]any
Type string `json:"type"`
URL string `json:"url"`
URL string `json:"url"` // PR URL (or repo URL for check events with race condition)
DeliveryID string `json:"delivery_id,omitempty"`
CommitSHA string `json:"commit_sha,omitempty"` // Commit SHA for check events
}

// Config holds the configuration for the client.
Expand Down Expand Up @@ -85,9 +87,16 @@ type Client struct {
ws *websocket.Conn
stopCh chan struct{}
stoppedCh chan struct{}
writeCh chan any // Channel for serializing all writes
stopOnce sync.Once // Ensures Stop() is only executed once
writeCh chan any // Channel for serializing all writes
eventCount int
retries int

// Cache for commit SHA to PR number lookups (for check event race condition)
commitPRCache map[string][]int // key: "owner/repo:sha", value: PR numbers
commitCacheKeys []string // track insertion order for LRU eviction
cacheMu sync.RWMutex
maxCacheSize int
}

// New creates a new robust WebSocket client.
Expand Down Expand Up @@ -118,10 +127,13 @@ func New(config Config) (*Client, error) {
}

return &Client{
config: config,
stopCh: make(chan struct{}),
stoppedCh: make(chan struct{}),
logger: logger,
config: config,
stopCh: make(chan struct{}),
stoppedCh: make(chan struct{}),
logger: logger,
commitPRCache: make(map[string][]int),
commitCacheKeys: make([]string, 0, 512),
maxCacheSize: 512,
}, nil
}

Expand Down Expand Up @@ -220,16 +232,27 @@ func (c *Client) Start(ctx context.Context) error {
}

// Stop gracefully stops the client.
// Safe to call multiple times - only the first call will take effect.
// Also safe to call before Start() or if Start() was never called.
func (c *Client) Stop() {
close(c.stopCh)
c.mu.Lock()
if c.ws != nil {
if closeErr := c.ws.Close(); closeErr != nil {
c.logger.Error("Error closing websocket on shutdown", "error", closeErr)
c.stopOnce.Do(func() {
close(c.stopCh)
c.mu.Lock()
if c.ws != nil {
if closeErr := c.ws.Close(); closeErr != nil {
c.logger.Error("Error closing websocket on shutdown", "error", closeErr)
}
}
}
c.mu.Unlock()
<-c.stoppedCh
c.mu.Unlock()

// Wait for Start() to finish, but with timeout in case Start() was never called
select {
case <-c.stoppedCh:
// Start() completed normally
case <-time.After(100 * time.Millisecond):
// Start() was never called or hasn't started yet - that's ok
}
})
}

// connect establishes a WebSocket connection and handles events.
Expand Down Expand Up @@ -608,18 +631,119 @@ func (c *Client) readEvents(ctx context.Context, ws *websocket.Conn) error {
event.DeliveryID = deliveryID
}

if commitSHA, ok := response["commit_sha"].(string); ok {
event.CommitSHA = commitSHA
}

c.mu.Lock()
c.eventCount++
eventNum := c.eventCount
c.mu.Unlock()

// Handle check events with repo-only URLs (GitHub race condition)
// Automatically expand into per-PR events using GitHub API with caching
if (event.Type == "check_run" || event.Type == "check_suite") && event.CommitSHA != "" && !strings.Contains(event.URL, "/pull/") {
// Extract owner/repo from URL
parts := strings.Split(event.URL, "/")
if len(parts) >= 5 && parts[2] == "github.com" {
owner := parts[3]
repo := parts[4]
key := owner + "/" + repo + ":" + event.CommitSHA

// Check cache first
c.cacheMu.RLock()
cached, ok := c.commitPRCache[key]
c.cacheMu.RUnlock()

var prs []int
if ok {
// Cache hit - return copy to prevent external modifications
prs = make([]int, len(cached))
copy(prs, cached)
c.logger.Info("Check event with repo URL - using cached PR lookup",
"commit_sha", event.CommitSHA,
"repo_url", event.URL,
"type", event.Type,
"pr_count", len(prs),
"cache_hit", true)
} else {
// Cache miss - look up via GitHub API
c.logger.Info("Check event with repo URL - looking up PRs via GitHub API",
"commit_sha", event.CommitSHA,
"repo_url", event.URL,
"type", event.Type,
"cache_hit", false)

gh := github.NewClient(c.config.Token)
var err error
prs, err = gh.FindPRsForCommit(ctx, owner, repo, event.CommitSHA)
if err != nil {
c.logger.Warn("Failed to look up PRs for commit",
"commit_sha", event.CommitSHA,
"owner", owner,
"repo", repo,
"error", err)
// Don't cache errors - try again next time
} else {
// Cache the result (even if empty)
c.cacheMu.Lock()
if _, exists := c.commitPRCache[key]; !exists {
c.commitCacheKeys = append(c.commitCacheKeys, key)
// Evict oldest 25% if cache is full
if len(c.commitCacheKeys) > c.maxCacheSize {
n := c.maxCacheSize / 4
for i := range n {
delete(c.commitPRCache, c.commitCacheKeys[i])
}
c.commitCacheKeys = c.commitCacheKeys[n:]
}
}
// Store copy to prevent external modifications
cached := make([]int, len(prs))
copy(cached, prs)
c.commitPRCache[key] = cached
c.cacheMu.Unlock()

c.logger.Info("Cached PR lookup result",
"commit_sha", event.CommitSHA,
"pr_count", len(prs))
}
}

// Emit events for each PR found
if len(prs) > 0 {
for _, n := range prs {
e := event // Copy the event
e.URL = fmt.Sprintf("https://github.com/%s/%s/pull/%d", owner, repo, n)

if c.config.OnEvent != nil {
c.logger.Info("Event received (expanded from commit)",
"timestamp", e.Timestamp.Format("15:04:05"),
"event_number", eventNum,
"type", e.Type,
"url", e.URL,
"commit_sha", e.CommitSHA,
"delivery_id", e.DeliveryID)
c.config.OnEvent(e)
}
}
continue // Skip the normal event handling since we expanded it
}
c.logger.Info("No PRs found for commit - may be push to main",
"commit_sha", event.CommitSHA,
"owner", owner,
"repo", repo)
}
}

// Log event
if c.config.Verbose {
c.logger.Info("Event received",
"event_number", eventNum,
"timestamp", event.Timestamp.Format("15:04:05"),
"type", event.Type,
"url", event.URL,
"commit_sha", event.CommitSHA,
"delivery_id", event.DeliveryID,
"raw", event.Raw)
} else {
Expand All @@ -629,6 +753,7 @@ func (c *Client) readEvents(ctx context.Context, ws *websocket.Conn) error {
"event_number", eventNum,
"type", event.Type,
"url", event.URL,
"commit_sha", event.CommitSHA,
"delivery_id", event.DeliveryID)
} else {
c.logger.Info("Event received",
Expand Down
79 changes: 79 additions & 0 deletions pkg/client/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package client

import (
"context"
"sync"
"testing"
"time"
)

// TestStopMultipleCalls verifies that calling Stop() multiple times is safe
// and doesn't panic with "close of closed channel".
func TestStopMultipleCalls(t *testing.T) {
// Create a client with minimal config
client, err := New(Config{
ServerURL: "ws://localhost:8080",
Token: "test-token",
Organization: "test-org",
NoReconnect: true, // Disable reconnect to make test faster
})
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}

// Start the client in a goroutine
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

go func() {
// Expected to fail to connect, but that's ok for this test
if err := client.Start(ctx); err != nil {
// Error is expected in tests - client can't connect to non-existent server
}
}()

// Give it a moment to initialize
time.Sleep(10 * time.Millisecond)

// Call Stop() multiple times concurrently
// This should NOT panic with "close of closed channel"
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
client.Stop() // Should be safe to call multiple times
}()
}

// Wait for all Stop() calls to complete
wg.Wait()

// If we get here without a panic, the test passes
}

// TestStopBeforeStart verifies that calling Stop() before Start() is safe.
func TestStopBeforeStart(t *testing.T) {
client, err := New(Config{
ServerURL: "ws://localhost:8080",
Token: "test-token",
Organization: "test-org",
NoReconnect: true,
})
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}

// Call Stop() before Start()
client.Stop()

// Now try to start - should exit cleanly
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

err = client.Start(ctx)
// We expect either context.DeadlineExceeded or "stop requested"
if err == nil {
t.Error("Expected Start() to fail after Stop(), but it succeeded")
}
}
Loading