Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 1 addition & 26 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ var (
leEmail = flag.String("le-email", "", "Contact email for Let's Encrypt notifications")
maxConnsPerIP = flag.Int("max-conns-per-ip", 10, "Maximum WebSocket connections per IP")
maxConnsTotal = flag.Int("max-conns-total", 1000, "Maximum total WebSocket connections")
rateLimit = flag.Int("rate-limit", 100, "Maximum requests per minute per IP")
allowedEvents = flag.String("allowed-events", func() string {
if value := os.Getenv("ALLOWED_WEBHOOK_EVENTS"); value != "" {
return value
Expand Down Expand Up @@ -92,8 +91,7 @@ func main() {
hub := srv.NewHub()
go hub.Run(ctx)

// Create security components
rateLimiter := security.NewRateLimiter(*rateLimit)
// Create connection limiter for WebSocket connections
connLimiter := security.NewConnectionLimiter(*maxConnsPerIP, *maxConnsTotal)

mux := http.NewServeMux()
Expand Down Expand Up @@ -133,16 +131,6 @@ func main() {
return
}

// Rate limiting
if !rateLimiter.Allow(ip) {
log.Printf("Webhook 429: rate limit exceeded ip=%s", ip)
w.WriteHeader(http.StatusTooManyRequests)
if _, err := w.Write([]byte("429 Too Many Requests: Rate limit exceeded\n")); err != nil {
log.Printf("failed to write 429 response: %v", err)
}
return
}

webhookHandler.ServeHTTP(w, r)
log.Printf("Webhook complete: ip=%s duration=%v", ip, time.Since(startTime))
})
Expand Down Expand Up @@ -185,16 +173,6 @@ func main() {
return
}

// Rate limiting check
if !rateLimiter.Allow(ip) {
log.Printf("WebSocket 429: rate limit exceeded ip=%s", ip)
w.WriteHeader(http.StatusTooManyRequests)
if _, err := w.Write([]byte("429 Too Many Requests: Rate limit exceeded\n")); err != nil {
log.Printf("failed to write 429 response: %v", err)
}
return
}

// Pre-validate authentication before WebSocket upgrade
authHeader := r.Header.Get("Authorization")
if !wsHandler.PreValidateAuth(r) {
Expand Down Expand Up @@ -287,9 +265,6 @@ func main() {
// Stop accepting new connections
hub.Stop()

// Stop the rate limiter cleanup routine
rateLimiter.Stop()

// Stop the connection limiter cleanup routine
connLimiter.Stop()

Expand Down
17 changes: 14 additions & 3 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,14 +468,25 @@ func (c *Client) connect(ctx context.Context) error {
// Start ping sender (sends to write channel, not directly to websocket)
pingCtx, cancelPing := context.WithCancel(ctx)
defer cancelPing()
go c.sendPings(pingCtx)
pingDone := make(chan struct{})
go func() {
c.sendPings(pingCtx)
close(pingDone)
}()

// Read events - when this returns, cancel everything
readErr := c.readEvents(ctx, ws)

// Stop write pump and ping sender
cancelWrite()
// Stop ping sender first - this ensures no more writes will be queued
cancelPing()
<-pingDone // Wait for ping sender to fully exit

// Stop write pump
cancelWrite()

// Close write channel to signal writePump to exit cleanly
// Safe to close now because ping sender has exited and won't write anymore
close(c.writeCh)

// Wait for write pump to finish
writeErr := <-writeDone
Expand Down
Loading