diff --git a/client.go b/client.go index 0e45cb3..2bb2d07 100644 --- a/client.go +++ b/client.go @@ -1,7 +1,10 @@ package irc import ( + "errors" + "fmt" "io" + "sync" "time" ) @@ -25,6 +28,14 @@ var clientFilters = map[string]func(*Client, *Message){ reply.Command = "PONG" c.WriteMessage(reply) }, + "PONG": func(c *Client, m *Message) { + if c.incomingPongChan != nil { + select { + case c.incomingPongChan <- m.Trailing(): + default: + } + } + }, "PRIVMSG": func(c *Client, m *Message) { // Clean up CTCP stuff so everyone doesn't have to parse it // manually. @@ -50,6 +61,10 @@ type ClientConfig struct { User string Name string + // Connection settings + PingFrequency time.Duration + PingTimeout time.Duration + // SendLimit is how frequent messages can be sent. If this is zero, // there will be no limit. SendLimit time.Duration @@ -68,18 +83,18 @@ type Client struct { config ClientConfig // Internal state - currentNick string - limitTick *time.Ticker - limiter chan struct{} - tickDone chan struct{} + currentNick string + limiter chan struct{} + incomingPongChan chan string + errChan chan error } // NewClient creates a client given an io stream and a client config. -func NewClient(rwc io.ReadWriter, config ClientConfig) *Client { +func NewClient(rw io.ReadWriter, config ClientConfig) *Client { c := &Client{ - Conn: NewConn(rwc), - config: config, - tickDone: make(chan struct{}), + Conn: NewConn(rw), + config: config, + errChan: make(chan error, 1), } // Replace the writer writeCallback with one of our own @@ -97,52 +112,122 @@ func (c *Client) writeCallback(w *Writer, line string) error { return err } -func (c *Client) maybeStartLimiter() { +func (c *Client) maybeStartLimiter(wg *sync.WaitGroup, exiting chan struct{}) { if c.config.SendLimit == 0 { return } + wg.Add(1) + // If SendBurst is 0, this will be unbuffered, so keep that in mind. c.limiter = make(chan struct{}, c.config.SendBurst) - - c.limitTick = time.NewTicker(c.config.SendLimit) + limitTick := time.NewTicker(c.config.SendLimit) go func() { + defer wg.Done() + var done bool for !done { select { - case <-c.limitTick.C: + case <-limitTick.C: select { case c.limiter <- struct{}{}: default: } - case <-c.tickDone: + case <-exiting: done = true } } - c.limitTick.Stop() + limitTick.Stop() close(c.limiter) c.limiter = nil - c.tickDone <- struct{}{} }() } -func (c *Client) stopLimiter() { - if c.limiter == nil { +func (c *Client) maybeStartPingLoop(wg *sync.WaitGroup, exiting chan struct{}) { + if c.config.PingFrequency <= 0 { return } - c.tickDone <- struct{}{} - <-c.tickDone + wg.Add(1) + + c.incomingPongChan = make(chan string, 5) + + go func() { + defer wg.Done() + + pingHandlers := make(map[string]chan struct{}) + ticker := time.NewTicker(c.config.PingFrequency) + + defer ticker.Stop() + + for { + select { + case <-ticker.C: + // Each time we get a tick, we send off a ping and start a + // goroutine to handle the pong. + timestamp := time.Now().Unix() + pongChan := make(chan struct{}) + pingHandlers[fmt.Sprintf("%d", timestamp)] = pongChan + wg.Add(1) + go c.handlePing(timestamp, pongChan, wg, exiting) + case data := <-c.incomingPongChan: + // Make sure the pong gets routed to the correct + // goroutine. + c := pingHandlers[data] + delete(pingHandlers, data) + + if c != nil { + c <- struct{}{} + } + case <-exiting: + return + } + } + }() +} + +func (c *Client) handlePing(timestamp int64, pongChan chan struct{}, wg *sync.WaitGroup, exiting chan struct{}) { + defer wg.Done() + + err := c.Writef("PING :%d", timestamp) + if err != nil { + c.sendError(err) + return + } + + timer := time.NewTimer(c.config.PingTimeout) + defer timer.Stop() + + select { + case <-timer.C: + c.sendError(errors.New("Ping Timeout")) + case <-pongChan: + return + case <-exiting: + return + } +} + +func (c *Client) sendError(err error) { + select { + case c.errChan <- err: + default: + } } // Run starts the main loop for this IRC connection. Note that it may break in // strange and unexpected ways if it is called again before the first connection // exits. func (c *Client) Run() error { - c.maybeStartLimiter() - defer c.stopLimiter() + // exiting is used by the main goroutine here to ensure any sub-goroutines + // get closed when exiting. + exiting := make(chan struct{}) + var wg sync.WaitGroup + + c.maybeStartLimiter(&wg, exiting) + c.maybeStartPingLoop(&wg, exiting) c.currentNick = c.config.Nick @@ -156,7 +241,8 @@ func (c *Client) Run() error { for { m, err := c.ReadMessage() if err != nil { - return err + c.sendError(err) + break } if f, ok := clientFilters[m.Command]; ok { @@ -167,6 +253,14 @@ func (c *Client) Run() error { c.config.Handler.Handle(c, m) } } + + // Wait for an error from any goroutine, then signal we're exiting and wait + // for the goroutines to exit. + err := <-c.errChan + close(exiting) + wg.Wait() + + return err } // CurrentNick returns what the nick of the client is known to be at this point diff --git a/client_test.go b/client_test.go index bf7be97..c53bb79 100644 --- a/client_test.go +++ b/client_test.go @@ -1,6 +1,8 @@ package irc import ( + "errors" + "fmt" "io" "testing" "time" @@ -29,72 +31,58 @@ func (th *TestHandler) Messages() []*Message { func TestClient(t *testing.T) { t.Parallel() - rwc := newTestReadWriteCloser() config := ClientConfig{ Nick: "test_nick", Pass: "test_pass", User: "test_user", Name: "test_name", } - c := NewClient(rwc, config) - err := c.Run() - assert.Equal(t, io.EOF, err) - - testLines(t, rwc, []string{ - "PASS :test_pass", - "NICK :test_nick", - "USER test_user 0.0.0.0 0.0.0.0 :test_name", + + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), }) - rwc.server.WriteString("PING :hello world\r\n") - err = c.Run() - assert.Equal(t, io.EOF, err) - testLines(t, rwc, []string{ - "PASS :test_pass", - "NICK :test_nick", - "USER test_user 0.0.0.0 0.0.0.0 :test_name", - "PONG :hello world", + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("PING :hello world\r\n"), + ExpectLine("PONG :hello world\r\n"), }) - rwc.server.WriteString(":test_nick NICK :new_test_nick\r\n") - err = c.Run() - assert.Equal(t, io.EOF, err) - testLines(t, rwc, []string{ - "PASS :test_pass", - "NICK :test_nick", - "USER test_user 0.0.0.0 0.0.0.0 :test_name", + c := runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine(":test_nick NICK :new_test_nick\r\n"), }) assert.Equal(t, "new_test_nick", c.CurrentNick()) - rwc.server.WriteString("001 :new_test_nick\r\n") - err = c.Run() - assert.Equal(t, io.EOF, err) - testLines(t, rwc, []string{ - "PASS :test_pass", - "NICK :test_nick", - "USER test_user 0.0.0.0 0.0.0.0 :test_name", + c = runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("001 :new_test_nick\r\n"), }) assert.Equal(t, "new_test_nick", c.CurrentNick()) - rwc.server.WriteString("433\r\n") - err = c.Run() - assert.Equal(t, io.EOF, err) - testLines(t, rwc, []string{ - "PASS :test_pass", - "NICK :test_nick", - "USER test_user 0.0.0.0 0.0.0.0 :test_name", - "NICK :test_nick_", + c = runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("433\r\n"), + ExpectLine("NICK :test_nick_\r\n"), }) assert.Equal(t, "test_nick_", c.CurrentNick()) - rwc.server.WriteString("437\r\n") - err = c.Run() - assert.Equal(t, io.EOF, err) - testLines(t, rwc, []string{ - "PASS :test_pass", - "NICK :test_nick", - "USER test_user 0.0.0.0 0.0.0.0 :test_name", - "NICK :test_nick_", + c = runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("437\r\n"), + ExpectLine("NICK :test_nick_\r\n"), }) assert.Equal(t, "test_nick_", c.CurrentNick()) } @@ -103,7 +91,7 @@ func TestSendLimit(t *testing.T) { t.Parallel() handler := &TestHandler{} - rwc := newTestReadWriteCloser() + config := ClientConfig{ Nick: "test_nick", Pass: "test_pass", @@ -116,42 +104,36 @@ func TestSendLimit(t *testing.T) { SendBurst: 2, } - rwc.server.WriteString("001 :hello_world\r\n") - c := NewClient(rwc, config) - before := time.Now() - err := c.Run() - assert.Equal(t, io.EOF, err) - assert.WithinDuration(t, before, time.Now(), 50*time.Millisecond) - testLines(t, rwc, []string{ - "PASS :test_pass", - "NICK :test_nick", - "USER test_user 0.0.0.0 0.0.0.0 :test_name", + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("001 :hello_world\r\n"), }) + assert.WithinDuration(t, before, time.Now(), 50*time.Millisecond) // This last test isn't really a test. It's being used to make sure we // hit the branch which handles dropping ticks if the buffered channel is // full. - rwc.server.WriteString("001 :hello world\r\n") handler.delay = 20 * time.Millisecond // Sleep for 20ms when we get the 001 message - c.config.SendLimit = 10 * time.Millisecond - c.config.SendBurst = 0 + config.SendLimit = 10 * time.Millisecond + config.SendBurst = 0 + before = time.Now() - err = c.Run() - assert.Equal(t, io.EOF, err) - assert.WithinDuration(t, before, time.Now(), 60*time.Millisecond) - testLines(t, rwc, []string{ - "PASS :test_pass", - "NICK :test_nick", - "USER test_user 0.0.0.0 0.0.0.0 :test_name", + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("001 :hello_world\r\n"), }) + assert.WithinDuration(t, before, time.Now(), 60*time.Millisecond) } func TestClientHandler(t *testing.T) { t.Parallel() handler := &TestHandler{} - rwc := newTestReadWriteCloser() config := ClientConfig{ Nick: "test_nick", Pass: "test_pass", @@ -161,17 +143,12 @@ func TestClientHandler(t *testing.T) { Handler: handler, } - rwc.server.WriteString("001 :hello_world\r\n") - c := NewClient(rwc, config) - err := c.Run() - assert.Equal(t, io.EOF, err) - - testLines(t, rwc, []string{ - "PASS :test_pass", - "NICK :test_nick", - "USER test_user 0.0.0.0 0.0.0.0 :test_name", + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("001 :hello_world\r\n"), }) - assert.EqualValues(t, []*Message{ { Tags: Tags{}, @@ -182,9 +159,12 @@ func TestClientHandler(t *testing.T) { }, handler.Messages()) // Ensure CTCP messages are parsed - rwc.server.WriteString(":world PRIVMSG :\x01VERSION\x01\r\n") - err = c.Run() - assert.Equal(t, io.EOF, err) + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine(":world PRIVMSG :\x01VERSION\x01\r\n"), + }) assert.EqualValues(t, []*Message{ { Tags: Tags{}, @@ -196,9 +176,12 @@ func TestClientHandler(t *testing.T) { // CTCP Regression test for PR#47 // Proper CTCP should start AND end in \x01 - rwc.server.WriteString(":world PRIVMSG :\x01VERSION\r\n") - err = c.Run() - assert.Equal(t, io.EOF, err) + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine(":world PRIVMSG :\x01VERSION\r\n"), + }) assert.EqualValues(t, []*Message{ { Tags: Tags{}, @@ -207,7 +190,12 @@ func TestClientHandler(t *testing.T) { Params: []string{"\x01VERSION"}, }, }, handler.Messages()) +} +func TestFromChannel(t *testing.T) { + t.Parallel() + + c := Client{currentNick: "test_nick"} m := MustParseMessage("PRIVMSG test_nick :hello world") assert.False(t, c.FromChannel(m)) @@ -217,3 +205,60 @@ func TestClientHandler(t *testing.T) { m = MustParseMessage("PING") assert.False(t, c.FromChannel(m)) } + +func TestPingLoop(t *testing.T) { + t.Parallel() + + config := ClientConfig{ + Nick: "test_nick", + Pass: "test_pass", + User: "test_user", + Name: "test_name", + + PingFrequency: 20 * time.Millisecond, + PingTimeout: 5 * time.Millisecond, + } + + var lastPing *Message + + // Successful ping + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("001 :hello_world\r\n"), + Delay(20 * time.Millisecond), + LineFunc(func(m *Message) { + lastPing = m + }), + SendFunc(func() string { + return fmt.Sprintf("PONG :%s\r\n", lastPing.Trailing()) + }), + Delay(10 * time.Millisecond), + }) + + // Ping timeout + runTest(t, config, errors.New("Ping Timeout"), []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("001 :hello_world\r\n"), + Delay(20 * time.Millisecond), + LineFunc(func(m *Message) { + lastPing = m + }), + Delay(20 * time.Millisecond), + }) + + // Exit in the middle of handling a ping + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("001 :hello_world\r\n"), + Delay(20 * time.Millisecond), + LineFunc(func(m *Message) { + lastPing = m + }), + }) +} diff --git a/conn.go b/conn.go index 0c9c815..7f9b68e 100644 --- a/conn.go +++ b/conn.go @@ -15,13 +15,10 @@ type Conn struct { // NewConn creates a new Conn func NewConn(rw io.ReadWriter) *Conn { - // Create the client - c := &Conn{ + return &Conn{ NewReader(rw), NewWriter(rw), } - - return c } // Writer is the outgoing side of a connection. @@ -79,7 +76,10 @@ type Reader struct { reader *bufio.Reader } -// NewReader creates an irc.Reader from an io.Reader. +// NewReader creates an irc.Reader from an io.Reader. Note that once a reader is +// passed into this function, you should no longer use it as it is being used +// inside a bufio.Reader so you cannot rely on only the amount of data for a +// Message being read when you call ReadMessage. func NewReader(r io.Reader) *Reader { return &Reader{ nil, diff --git a/conn_test.go b/conn_test.go index f215ecb..7de27eb 100644 --- a/conn_test.go +++ b/conn_test.go @@ -19,13 +19,6 @@ func (ew *errorWriter) Write([]byte) (int, error) { type readWriteCloser struct { io.Reader io.Writer - io.Closer -} - -type nilCloser struct{} - -func (nc *nilCloser) Close() error { - return nil } type testReadWriteCloser struct { @@ -84,7 +77,6 @@ func TestWriteMessageError(t *testing.T) { rw := readWriteCloser{ &bytes.Buffer{}, &errorWriter{}, - &nilCloser{}, } c := NewConn(rw) diff --git a/stream_test.go b/stream_test.go new file mode 100644 index 0000000..1f20a26 --- /dev/null +++ b/stream_test.go @@ -0,0 +1,178 @@ +package irc + +import ( + "bytes" + "errors" + "io" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// TestAction is used to execute an action during a stream test. If a +// non-nil error is returned the test will be failed. +type TestAction func(t *testing.T, c *Client, rw *testReadWriter) + +func SendLine(output string) TestAction { + return func(t *testing.T, c *Client, rw *testReadWriter) { + // First we send the message + select { + case rw.readChan <- output: + case <-rw.exiting: + assert.Fail(t, "Failed to send") + } + + // Now we wait for the buffer to be emptied + select { + case <-rw.readEmptyChan: + case <-rw.exiting: + assert.Fail(t, "Failed to send whole message") + } + } +} + +func SendFunc(cb func() string) TestAction { + return func(t *testing.T, c *Client, rw *testReadWriter) { + SendLine(cb())(t, c, rw) + } +} + +func LineFunc(cb func(m *Message)) TestAction { + return func(t *testing.T, c *Client, rw *testReadWriter) { + select { + case line := <-rw.writeChan: + cb(MustParseMessage(line)) + case <-time.After(1 * time.Second): + assert.Fail(t, "LineFunc timeout") + case <-rw.exiting: + } + } +} + +func ExpectLine(input string) TestAction { + return ExpectLineWithTimeout(input, 1*time.Second) +} + +func ExpectLineWithTimeout(input string, timeout time.Duration) TestAction { + return func(t *testing.T, c *Client, rw *testReadWriter) { + select { + case line := <-rw.writeChan: + assert.Equal(t, input, line) + case <-time.After(timeout): + assert.Fail(t, "ExpectLine timeout on %s", input) + case <-rw.exiting: + } + } +} + +func Delay(delay time.Duration) TestAction { + return func(t *testing.T, c *Client, rw *testReadWriter) { + select { + case <-time.After(delay): + case <-rw.exiting: + } + } +} + +type testReadWriter struct { + actions []TestAction + queuedWriteError error + writeChan chan string + queuedReadError error + readChan chan string + readEmptyChan chan struct{} + exiting chan struct{} + clientDone chan struct{} + serverBuffer bytes.Buffer +} + +func (rw *testReadWriter) Read(buf []byte) (int, error) { + if rw.queuedReadError != nil { + err := rw.queuedReadError + rw.queuedReadError = nil + return 0, err + } + + // If there's data left in the buffer, we want to use that first. + if rw.serverBuffer.Len() > 0 { + s, err := rw.serverBuffer.Read(buf) + if err == io.EOF { + err = nil + } + return s, err + } + + select { + case rw.readEmptyChan <- struct{}{}: + default: + } + + // Read from server. We're either waiting for this whole test to + // finish or for data to come in from the server buffer. We expect + // only one read to be happening at once. + select { + case data := <-rw.readChan: + rw.serverBuffer.WriteString(data) + s, err := rw.serverBuffer.Read(buf) + if err == io.EOF { + err = nil + } + return s, err + case <-rw.exiting: + return 0, io.EOF + } +} + +func (rw *testReadWriter) Write(buf []byte) (int, error) { + if rw.queuedWriteError != nil { + err := rw.queuedWriteError + rw.queuedWriteError = nil + return 0, err + } + + // Write to server. We can cheat with this because we know things + // will be written a line at a time. + select { + case rw.writeChan <- string(buf): + return len(buf), nil + case <-rw.exiting: + return 0, errors.New("Connection closed") + } +} + +func runTest(t *testing.T, cc ClientConfig, expectedErr error, actions []TestAction) *Client { + rw := &testReadWriter{ + actions: actions, + writeChan: make(chan string), + readChan: make(chan string), + readEmptyChan: make(chan struct{}, 1), + exiting: make(chan struct{}), + clientDone: make(chan struct{}), + } + + c := NewClient(rw, cc) + + go func() { + err := c.Run() + assert.Equal(t, expectedErr, err) + close(rw.clientDone) + }() + + // Perform each of the actions + for _, action := range rw.actions { + action(t, c, rw) + } + + // TODO: Make sure there are no more incoming messages + + // Ask everything to shut down and wait for the client to stop. + close(rw.exiting) + select { + case <-rw.clientDone: + case <-time.After(1 * time.Second): + assert.Fail(t, "Timeout in client shutdown") + } + + return c +}