Skip to content

Commit

Permalink
Implement rate counter logs aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
julius-welink committed Jan 7, 2022
1 parent a58fd3e commit 5e5eb63
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 8 deletions.
19 changes: 17 additions & 2 deletions server/rate_counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
type rateCounter struct {
limit int64
count int64
blocked uint64
end time.Time
interval time.Duration
mu sync.Mutex
Expand All @@ -37,14 +38,28 @@ func (r *rateCounter) allow() bool {
now := time.Now()

r.mu.Lock()
defer r.mu.Unlock()

if now.After(r.end) {
r.count = 0
r.end = now.Add(r.interval)
} else {
r.count++
}
allow := r.count < r.limit
if !allow {
r.blocked++
}

r.mu.Unlock()

return allow
}

func (r *rateCounter) countBlocked() uint64 {
r.mu.Lock()
blocked := r.blocked
r.blocked = 0
r.mu.Unlock()

return r.count < r.limit
return blocked
}
19 changes: 15 additions & 4 deletions server/rate_counter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,25 @@ func TestRateCounter(t *testing.T) {
counter.interval = 100 * time.Millisecond

var i int
for i = 0; i <= 10; i++ {
for i = 0; i < 10; i++ {
if !counter.allow() {
break
t.Errorf("counter should allow (iteration %d)", i)
}
}
for i = 0; i < 5; i++ {
if counter.allow() {
t.Errorf("counter should not allow (iteration %d)", i)
}
}

blocked := counter.countBlocked()
if blocked != 5 {
t.Errorf("Expected blocked = 5, got %d", blocked)
}

if i != 10 {
t.Errorf("Expected i = 10, got %d", i)
blocked = counter.countBlocked()
if blocked != 0 {
t.Errorf("Expected blocked = 0, got %d", blocked)
}

time.Sleep(150 * time.Millisecond)
Expand Down
18 changes: 16 additions & 2 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,22 @@ func NewServer(opts *Options) (*Server, error) {

if opts.TLSRateLimit > 0 {
s.connRateCounter = newRateCounter(opts.tlsConfigOpts.RateLimit)
s.startGoRoutine(func() {
defer s.grWG.Done()
t := time.NewTicker(time.Second)
defer t.Stop()
for {
select {
case <-s.quitCh:
return
case <-t.C:
blocked := s.connRateCounter.countBlocked()
if blocked > 0 {
s.Warnf("Rejected %d connections due to TLS rate limiting", blocked)
}
}
}
})
}

// Trusted root operator keys.
Expand Down Expand Up @@ -2416,8 +2432,6 @@ func (s *Server) createClient(conn net.Conn) *client {
// Check for TLS
if !isClosed && tlsRequired {
if s.connRateCounter != nil && !s.connRateCounter.allow() {
c.Warnf("Rejecting connection due to TLS rate limiting")

c.mu.Unlock()
c.sendErr("Connection throttling is active. Please try again later.")
c.closeConnection(MaxConnectionsExceeded)
Expand Down

0 comments on commit 5e5eb63

Please sign in to comment.