Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 72 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package irc

import (
"io"
"time"
)

// clientFilters are pre-processing which happens for certain message
Expand Down Expand Up @@ -49,6 +50,13 @@ type ClientConfig struct {
User string
Name string

// SendLimit is how frequent messages can be sent. If this is zero,
// there will be no limit.
SendLimit time.Duration

// SendBurst is the number of messages which can be sent in a burst.
SendBurst int

// Handler is used for message dispatching.
Handler Handler
}
Expand All @@ -61,20 +69,81 @@ type Client struct {

// Internal state
currentNick string
limitTick *time.Ticker
limiter chan struct{}
tickDone chan struct{}
}

// NewClient creates a client given an io stream and a client config.
func NewClient(rwc io.ReadWriter, config ClientConfig) *Client {
return &Client{
Conn: NewConn(rwc),
config: config,
c := &Client{
Conn: NewConn(rwc),
config: config,
tickDone: make(chan struct{}),
}

// Replace the writer writeCallback with one of our own
c.Conn.Writer.writeCallback = c.writeCallback

return c
}

func (c *Client) writeCallback(w *Writer, line string) error {
if c.limiter != nil {
<-c.limiter
}

_, err := w.writer.Write([]byte(line + "\r\n"))
return err
}

func (c *Client) maybeStartLimiter() {
if c.config.SendLimit == 0 {
return
}

// If SendBurst is 0, this will be unbuffered, so keep that in mind.
c.limiter = make(chan struct{}, c.config.SendBurst)

c.limitTick = time.NewTicker(c.config.SendLimit)

go func() {
var done bool
for !done {
select {
case <-c.limitTick.C:
select {
case c.limiter <- struct{}{}:
default:
}
case <-c.tickDone:
done = true
}
}

c.limitTick.Stop()
close(c.limiter)
c.limiter = nil
c.tickDone <- struct{}{}
}()
}

func (c *Client) stopLimiter() {
if c.limiter == nil {
return
}

c.tickDone <- struct{}{}
<-c.tickDone
}

// Run starts the main loop for this IRC connection. Note that it may break in
// strange and unexpected ways if it is called again before the first connection
// exits.
func (c *Client) Run() error {
c.maybeStartLimiter()
defer c.stopLimiter()

c.currentNick = c.config.Nick

if c.config.Pass != "" {
Expand Down
53 changes: 53 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,21 @@ package irc
import (
"io"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

type TestHandler struct {
messages []*Message
delay time.Duration
}

func (th *TestHandler) Handle(c *Client, m *Message) {
th.messages = append(th.messages, m)
if th.delay > 0 {
time.Sleep(th.delay)
}
}

func (th *TestHandler) Messages() []*Message {
Expand Down Expand Up @@ -94,6 +99,54 @@ func TestClient(t *testing.T) {
assert.Equal(t, "test_nick_", c.CurrentNick())
}

func TestSendLimit(t *testing.T) {
t.Parallel()

handler := &TestHandler{}
rwc := newTestReadWriteCloser()
config := ClientConfig{
Nick: "test_nick",
Pass: "test_pass",
User: "test_user",
Name: "test_name",

Handler: handler,

SendLimit: 10 * time.Millisecond,
SendBurst: 2,
}

rwc.server.WriteString("001 :hello_world\r\n")
c := NewClient(rwc, config)

before := time.Now()
err := c.Run()
assert.Equal(t, io.EOF, err)
assert.WithinDuration(t, before, time.Now(), 50*time.Millisecond)
testLines(t, rwc, []string{
"PASS :test_pass",
"NICK :test_nick",
"USER test_user 0.0.0.0 0.0.0.0 :test_name",
})

// This last test isn't really a test. It's being used to make sure we
// hit the branch which handles dropping ticks if the buffered channel is
// full.
rwc.server.WriteString("001 :hello world\r\n")
handler.delay = 20 * time.Millisecond // Sleep for 20ms when we get the 001 message
c.config.SendLimit = 10 * time.Millisecond
c.config.SendBurst = 0
before = time.Now()
err = c.Run()
assert.Equal(t, io.EOF, err)
assert.WithinDuration(t, before, time.Now(), 60*time.Millisecond)
testLines(t, rwc, []string{
"PASS :test_pass",
"NICK :test_nick",
"USER test_user 0.0.0.0 0.0.0.0 :test_name",
})
}

func TestClientHandler(t *testing.T) {
t.Parallel()

Expand Down
13 changes: 9 additions & 4 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,18 @@ type Writer struct {
DebugCallback func(line string)

// Internal fields
writer io.Writer
writer io.Writer
writeCallback func(w *Writer, line string) error
}

func defaultWriteCallback(w *Writer, line string) error {
_, err := w.writer.Write([]byte(line + "\r\n"))
return err
}

// NewWriter creates an irc.Writer from an io.Writer.
func NewWriter(w io.Writer) *Writer {
return &Writer{nil, w}
return &Writer{nil, w, defaultWriteCallback}
}

// Write is a simple function which will write the given line to the
Expand All @@ -46,8 +52,7 @@ func (w *Writer) Write(line string) error {
w.DebugCallback(line)
}

_, err := w.writer.Write([]byte(line + "\r\n"))
return err
return w.writeCallback(w, line)
}

// Writef is a wrapper around the connection's Write method and
Expand Down