Skip to content

Commit

Permalink
Add simple rate limiting to Client (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
belak committed Aug 18, 2017
1 parent f23d783 commit acbcebe
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 7 deletions.
75 changes: 72 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package irc

import (
"io"
"time"
)

// clientFilters are pre-processing which happens for certain message
Expand Down Expand Up @@ -49,6 +50,13 @@ type ClientConfig struct {
User string
Name string

// SendLimit is how frequent messages can be sent. If this is zero,
// there will be no limit.
SendLimit time.Duration

// SendBurst is the number of messages which can be sent in a burst.
SendBurst int

// Handler is used for message dispatching.
Handler Handler
}
Expand All @@ -61,20 +69,81 @@ type Client struct {

// Internal state
currentNick string
limitTick *time.Ticker
limiter chan struct{}
tickDone chan struct{}
}

// NewClient creates a client given an io stream and a client config.
func NewClient(rwc io.ReadWriter, config ClientConfig) *Client {
return &Client{
Conn: NewConn(rwc),
config: config,
c := &Client{
Conn: NewConn(rwc),
config: config,
tickDone: make(chan struct{}),
}

// Replace the writer writeCallback with one of our own
c.Conn.Writer.writeCallback = c.writeCallback

return c
}

func (c *Client) writeCallback(w *Writer, line string) error {
if c.limiter != nil {
<-c.limiter
}

_, err := w.writer.Write([]byte(line + "\r\n"))
return err
}

func (c *Client) maybeStartLimiter() {
if c.config.SendLimit == 0 {
return
}

// If SendBurst is 0, this will be unbuffered, so keep that in mind.
c.limiter = make(chan struct{}, c.config.SendBurst)

c.limitTick = time.NewTicker(c.config.SendLimit)

go func() {
var done bool
for !done {
select {
case <-c.limitTick.C:
select {
case c.limiter <- struct{}{}:
default:
}
case <-c.tickDone:
done = true
}
}

c.limitTick.Stop()
close(c.limiter)
c.limiter = nil
c.tickDone <- struct{}{}
}()
}

func (c *Client) stopLimiter() {
if c.limiter == nil {
return
}

c.tickDone <- struct{}{}
<-c.tickDone
}

// Run starts the main loop for this IRC connection. Note that it may break in
// strange and unexpected ways if it is called again before the first connection
// exits.
func (c *Client) Run() error {
c.maybeStartLimiter()
defer c.stopLimiter()

c.currentNick = c.config.Nick

if c.config.Pass != "" {
Expand Down
53 changes: 53 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,21 @@ package irc
import (
"io"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

type TestHandler struct {
messages []*Message
delay time.Duration
}

func (th *TestHandler) Handle(c *Client, m *Message) {
th.messages = append(th.messages, m)
if th.delay > 0 {
time.Sleep(th.delay)
}
}

func (th *TestHandler) Messages() []*Message {
Expand Down Expand Up @@ -94,6 +99,54 @@ func TestClient(t *testing.T) {
assert.Equal(t, "test_nick_", c.CurrentNick())
}

func TestSendLimit(t *testing.T) {
t.Parallel()

handler := &TestHandler{}
rwc := newTestReadWriteCloser()
config := ClientConfig{
Nick: "test_nick",
Pass: "test_pass",
User: "test_user",
Name: "test_name",

Handler: handler,

SendLimit: 10 * time.Millisecond,
SendBurst: 2,
}

rwc.server.WriteString("001 :hello_world\r\n")
c := NewClient(rwc, config)

before := time.Now()
err := c.Run()
assert.Equal(t, io.EOF, err)
assert.WithinDuration(t, before, time.Now(), 50*time.Millisecond)
testLines(t, rwc, []string{
"PASS :test_pass",
"NICK :test_nick",
"USER test_user 0.0.0.0 0.0.0.0 :test_name",
})

// This last test isn't really a test. It's being used to make sure we
// hit the branch which handles dropping ticks if the buffered channel is
// full.
rwc.server.WriteString("001 :hello world\r\n")
handler.delay = 20 * time.Millisecond // Sleep for 20ms when we get the 001 message
c.config.SendLimit = 10 * time.Millisecond
c.config.SendBurst = 0
before = time.Now()
err = c.Run()
assert.Equal(t, io.EOF, err)
assert.WithinDuration(t, before, time.Now(), 60*time.Millisecond)
testLines(t, rwc, []string{
"PASS :test_pass",
"NICK :test_nick",
"USER test_user 0.0.0.0 0.0.0.0 :test_name",
})
}

func TestClientHandler(t *testing.T) {
t.Parallel()

Expand Down
13 changes: 9 additions & 4 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,18 @@ type Writer struct {
DebugCallback func(line string)

// Internal fields
writer io.Writer
writer io.Writer
writeCallback func(w *Writer, line string) error
}

func defaultWriteCallback(w *Writer, line string) error {
_, err := w.writer.Write([]byte(line + "\r\n"))
return err
}

// NewWriter creates an irc.Writer from an io.Writer.
func NewWriter(w io.Writer) *Writer {
return &Writer{nil, w}
return &Writer{nil, w, defaultWriteCallback}
}

// Write is a simple function which will write the given line to the
Expand All @@ -46,8 +52,7 @@ func (w *Writer) Write(line string) error {
w.DebugCallback(line)
}

_, err := w.writer.Write([]byte(line + "\r\n"))
return err
return w.writeCallback(w, line)
}

// Writef is a wrapper around the connection's Write method and
Expand Down

0 comments on commit acbcebe

Please sign in to comment.