From 703c068ce426532e3d5e1c181b998d480d268da3 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Wed, 2 Aug 2017 16:21:41 -0700 Subject: [PATCH 1/4] Add simple rate limiting to Client --- client.go | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- conn.go | 22 +++++++++++++++++---- 2 files changed, 76 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index 4e8a515..98ebdaf 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package irc import ( "io" + "time" ) // clientFilters are pre-processing which happens for certain message @@ -49,6 +50,13 @@ type ClientConfig struct { User string Name string + // SendLimit is how frequent messages can be sent. If this is zero, + // there will be no limit. + SendLimit time.Duration + + // SendBurst is the number of messages which can be sent in a burst. + SendBurst int + // Handler is used for message dispatching. Handler Handler } @@ -61,20 +69,69 @@ type Client struct { // Internal state currentNick string + limitTick *time.Ticker + limiter chan struct{} } // NewClient creates a client given an io stream and a client config. func NewClient(rwc io.ReadWriter, config ClientConfig) *Client { - return &Client{ + c := &Client{ Conn: NewConn(rwc), config: config, } + + // Replace the writer writeCallback with one of our own + c.Conn.Writer.writeCallback = c.writeCallback + + return c +} + +func (c *Client) writeCallback(w *Writer, line string) error { + if c.limiter != nil { + <-c.limiter + } + + _, err := w.writer.Write([]byte(line + "\r\n")) + return err +} + +func (c *Client) maybeStartLimiter() { + if c.config.SendLimit == 0 { + return + } + + // If SendBurst is 0, this will be unbuffered, so keep that in mind. + c.limiter = make(chan struct{}, c.config.SendBurst) + + c.limitTick = time.NewTicker(c.config.SendLimit) + + go func() { + for range c.limitTick.C { + select { + case c.limiter <- struct{}{}: + default: + } + } + close(c.limiter) + c.limiter = nil + }() +} + +func (c *Client) stopLimiter() { + if c.limiter == nil { + return + } + + c.limitTick.Stop() } // Run starts the main loop for this IRC connection. Note that it may break in // strange and unexpected ways if it is called again before the first connection // exits. func (c *Client) Run() error { + c.maybeStartLimiter() + defer c.stopLimiter() + c.currentNick = c.config.Nick if c.config.Pass != "" { diff --git a/conn.go b/conn.go index 7bf9846..99d05f6 100644 --- a/conn.go +++ b/conn.go @@ -11,6 +11,9 @@ import ( type Conn struct { *Reader *Writer + + // Internal fields + closer io.Closer } // NewConn creates a new Conn @@ -19,6 +22,12 @@ func NewConn(rw io.ReadWriter) *Conn { c := &Conn{ NewReader(rw), NewWriter(rw), + nil, + } + + // If there's a closer available, we want to keep it around + if closer, ok := rw.(io.Closer); ok { + c.closer = closer } return c @@ -31,12 +40,18 @@ type Writer struct { DebugCallback func(line string) // Internal fields - writer io.Writer + writer io.Writer + writeCallback func(w *Writer, line string) error +} + +func defaultWriteCallback(w *Writer, line string) error { + _, err := w.writer.Write([]byte(line + "\r\n")) + return err } // NewWriter creates an irc.Writer from an io.Writer. func NewWriter(w io.Writer) *Writer { - return &Writer{nil, w} + return &Writer{nil, w, defaultWriteCallback} } // Write is a simple function which will write the given line to the @@ -46,8 +61,7 @@ func (w *Writer) Write(line string) error { w.DebugCallback(line) } - _, err := w.writer.Write([]byte(line + "\r\n")) - return err + return w.writeCallback(w, line) } // Writef is a wrapper around the connection's Write method and From a597b4d31dc280ab76a88048bcb465f54e55cfa5 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Thu, 3 Aug 2017 13:10:14 -0700 Subject: [PATCH 2/4] Fix a few rate limiting bugs and add tests --- client.go | 24 ++++++++++++++++++------ client_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index 98ebdaf..0e45cb3 100644 --- a/client.go +++ b/client.go @@ -71,13 +71,15 @@ type Client struct { currentNick string limitTick *time.Ticker limiter chan struct{} + tickDone chan struct{} } // NewClient creates a client given an io stream and a client config. func NewClient(rwc io.ReadWriter, config ClientConfig) *Client { c := &Client{ - Conn: NewConn(rwc), - config: config, + Conn: NewConn(rwc), + config: config, + tickDone: make(chan struct{}), } // Replace the writer writeCallback with one of our own @@ -106,14 +108,23 @@ func (c *Client) maybeStartLimiter() { c.limitTick = time.NewTicker(c.config.SendLimit) go func() { - for range c.limitTick.C { + var done bool + for !done { select { - case c.limiter <- struct{}{}: - default: + case <-c.limitTick.C: + select { + case c.limiter <- struct{}{}: + default: + } + case <-c.tickDone: + done = true } } + + c.limitTick.Stop() close(c.limiter) c.limiter = nil + c.tickDone <- struct{}{} }() } @@ -122,7 +133,8 @@ func (c *Client) stopLimiter() { return } - c.limitTick.Stop() + c.tickDone <- struct{}{} + <-c.tickDone } // Run starts the main loop for this IRC connection. Note that it may break in diff --git a/client_test.go b/client_test.go index d1ce79f..23c457d 100644 --- a/client_test.go +++ b/client_test.go @@ -3,6 +3,7 @@ package irc import ( "io" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -94,6 +95,43 @@ func TestClient(t *testing.T) { assert.Equal(t, "test_nick_", c.CurrentNick()) } +func TestSendLimit(t *testing.T) { + t.Parallel() + + handler := &TestHandler{} + rwc := newTestReadWriteCloser() + config := ClientConfig{ + Nick: "test_nick", + Pass: "test_pass", + User: "test_user", + Name: "test_name", + + Handler: handler, + + SendLimit: time.Second / 4, + SendBurst: 2, + } + + rwc.server.WriteString("001 :hello_world\r\n") + c := NewClient(rwc, config) + + before := time.Now() + err := c.Run() + assert.Equal(t, io.EOF, err) + assert.WithinDuration(t, before, time.Now(), 2*time.Second) + testLines(t, rwc, []string{ + "PASS :test_pass", + "NICK :test_nick", + "USER test_user 0.0.0.0 0.0.0.0 :test_name", + }) + + rwc.server.WriteString("PING :hello world\r\n") + rwc.server.WriteString("PING :hello world\r\n") + rwc.server.WriteString("PING :hello world\r\n") + rwc.server.WriteString("PING :hello world\r\n") + rwc.server.WriteString("PING :hello world\r\n") +} + func TestClientHandler(t *testing.T) { t.Parallel() From a1493ce583092b70c33c950bccb5235ca7da2795 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Thu, 3 Aug 2017 15:22:05 -0700 Subject: [PATCH 3/4] Speed up and improve SendLimit tests --- client_test.go | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/client_test.go b/client_test.go index 23c457d..bf7be97 100644 --- a/client_test.go +++ b/client_test.go @@ -10,10 +10,14 @@ import ( type TestHandler struct { messages []*Message + delay time.Duration } func (th *TestHandler) Handle(c *Client, m *Message) { th.messages = append(th.messages, m) + if th.delay > 0 { + time.Sleep(th.delay) + } } func (th *TestHandler) Messages() []*Message { @@ -108,7 +112,7 @@ func TestSendLimit(t *testing.T) { Handler: handler, - SendLimit: time.Second / 4, + SendLimit: 10 * time.Millisecond, SendBurst: 2, } @@ -118,18 +122,29 @@ func TestSendLimit(t *testing.T) { before := time.Now() err := c.Run() assert.Equal(t, io.EOF, err) - assert.WithinDuration(t, before, time.Now(), 2*time.Second) + assert.WithinDuration(t, before, time.Now(), 50*time.Millisecond) testLines(t, rwc, []string{ "PASS :test_pass", "NICK :test_nick", "USER test_user 0.0.0.0 0.0.0.0 :test_name", }) - rwc.server.WriteString("PING :hello world\r\n") - rwc.server.WriteString("PING :hello world\r\n") - rwc.server.WriteString("PING :hello world\r\n") - rwc.server.WriteString("PING :hello world\r\n") - rwc.server.WriteString("PING :hello world\r\n") + // This last test isn't really a test. It's being used to make sure we + // hit the branch which handles dropping ticks if the buffered channel is + // full. + rwc.server.WriteString("001 :hello world\r\n") + handler.delay = 20 * time.Millisecond // Sleep for 20ms when we get the 001 message + c.config.SendLimit = 10 * time.Millisecond + c.config.SendBurst = 0 + before = time.Now() + err = c.Run() + assert.Equal(t, io.EOF, err) + assert.WithinDuration(t, before, time.Now(), 60*time.Millisecond) + testLines(t, rwc, []string{ + "PASS :test_pass", + "NICK :test_nick", + "USER test_user 0.0.0.0 0.0.0.0 :test_name", + }) } func TestClientHandler(t *testing.T) { From 512b48b01f2148b0e1f3d9484d6fca536fc69896 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Wed, 9 Aug 2017 13:56:56 -0700 Subject: [PATCH 4/4] Remove the unused closer --- conn.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/conn.go b/conn.go index 99d05f6..0c9c815 100644 --- a/conn.go +++ b/conn.go @@ -11,9 +11,6 @@ import ( type Conn struct { *Reader *Writer - - // Internal fields - closer io.Closer } // NewConn creates a new Conn @@ -22,12 +19,6 @@ func NewConn(rw io.ReadWriter) *Conn { c := &Conn{ NewReader(rw), NewWriter(rw), - nil, - } - - // If there's a closer available, we want to keep it around - if closer, ok := rw.(io.Closer); ok { - c.closer = closer } return c