diff --git a/client.go b/client.go index a8a835d..c5d504d 100644 --- a/client.go +++ b/client.go @@ -173,6 +173,9 @@ func (c *Client) writeCallback(w *Writer, line string) error { } _, err := w.writer.Write([]byte(line + "\r\n")) + if err != nil { + c.sendError(err) + } return err } @@ -260,11 +263,7 @@ func (c *Client) maybeStartPingLoop(wg *sync.WaitGroup, exiting chan struct{}) { func (c *Client) handlePing(timestamp int64, pongChan chan struct{}, wg *sync.WaitGroup, exiting chan struct{}) { defer wg.Done() - err := c.Writef("PING :%d", timestamp) - if err != nil { - c.sendError(err) - return - } + c.Writef("PING :%d", timestamp) timer := time.NewTimer(c.config.PingTimeout) defer timer.Stop() diff --git a/client_test.go b/client_test.go index 2ed3c96..68e6d7d 100644 --- a/client_test.go +++ b/client_test.go @@ -403,4 +403,15 @@ func TestPingLoop(t *testing.T) { SendLine("PONG :hello 6\r\n"), SendLine("PONG :hello 7\r\n"), }) + + // Successful ping with write error + runClientTest(t, config, errors.New("test error"), nil, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + // We queue this up a line early because the next write will happen after the delay. + QueueWriteError(errors.New("test error")), + SendLine("001 :hello_world\r\n"), + Delay(25 * time.Millisecond), + }) } diff --git a/stream_test.go b/stream_test.go index 7503372..49ecd25 100644 --- a/stream_test.go +++ b/stream_test.go @@ -87,11 +87,34 @@ func Delay(delay time.Duration) TestAction { } } +func QueueReadError(err error) TestAction { + return func(t *testing.T, rw *testReadWriter) { + select { + case rw.readErrorChan <- err: + default: + assert.Fail(t, "Tried to queue a second read error") + } + } +} + +func QueueWriteError(err error) TestAction { + return func(t *testing.T, rw *testReadWriter) { + assert.Nil(t, rw.queuedWriteError, "Tried to queue a second write error") + rw.queuedWriteError = err + } +} + type testReadWriter struct { - actions []TestAction + actions []TestAction + + // It's worth noting that there's queuedWriteError and readErrorChan. We + // don't actually need a channel for the write errors because it's more + // deterministic when that's called. However because reads happen in a + // readLoop goroutine, this needs to be possible to trigger in the middle of + // a read. queuedWriteError error writeChan chan string - queuedReadError error + readErrorChan chan error readChan chan string readEmptyChan chan struct{} exiting chan struct{} @@ -109,10 +132,11 @@ func (rw *testReadWriter) maybeBroadcastEmpty() { } func (rw *testReadWriter) Read(buf []byte) (int, error) { - if rw.queuedReadError != nil { - err := rw.queuedReadError - rw.queuedReadError = nil + // Check for a read error first + select { + case err := <-rw.readErrorChan: return 0, err + default: } // If there's data left in the buffer, we want to use that first. @@ -125,10 +149,12 @@ func (rw *testReadWriter) Read(buf []byte) (int, error) { return s, err } - // Read from server. We're either waiting for this whole test to - // finish or for data to come in from the server buffer. We expect - // only one read to be happening at once. + // Read from server. We're waiting for this whole test to finish, data to + // come in from the server buffer, or for an error. We expect only one read + // to be happening at once. select { + case err := <-rw.readErrorChan: + return 0, err case data := <-rw.readChan: rw.serverBuffer.WriteString(data) s, err := rw.serverBuffer.Read(buf) @@ -163,6 +189,7 @@ func newTestReadWriter(actions []TestAction) *testReadWriter { return &testReadWriter{ actions: actions, writeChan: make(chan string), + readErrorChan: make(chan error, 1), readChan: make(chan string), readEmptyChan: make(chan struct{}, 1), exiting: make(chan struct{}), @@ -197,8 +224,10 @@ func runTest(t *testing.T, rw *testReadWriter, actions []TestAction) { // TODO: Make sure there are no more incoming messages - // Ask everything to shut down and wait for the client to stop. + // Ask everything to shut down close(rw.exiting) + + // Wait for the client to stop select { case <-rw.clientDone: case <-time.After(1 * time.Second):