Skip to content

Commit

Permalink
Merge pull request #66 from go-irc/exit-on-write-error
Browse files Browse the repository at this point in the history
* Ensure all write errors cause the client to exit
* Replace queuedWriteError with writeErrorChan to fix a race
  • Loading branch information
belak committed Apr 11, 2018
2 parents 9106b7e + 4e3e991 commit 5bf07c6
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 30 deletions.
9 changes: 4 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ func (c *Client) writeCallback(w *Writer, line string) error {
}

_, err := w.writer.Write([]byte(line + "\r\n"))
if err != nil {
c.sendError(err)
}
return err
}

Expand Down Expand Up @@ -260,11 +263,7 @@ func (c *Client) maybeStartPingLoop(wg *sync.WaitGroup, exiting chan struct{}) {
func (c *Client) handlePing(timestamp int64, pongChan chan struct{}, wg *sync.WaitGroup, exiting chan struct{}) {
defer wg.Done()

err := c.Writef("PING :%d", timestamp)
if err != nil {
c.sendError(err)
return
}
c.Writef("PING :%d", timestamp)

timer := time.NewTimer(c.config.PingTimeout)
defer timer.Stop()
Expand Down
11 changes: 11 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,4 +403,15 @@ func TestPingLoop(t *testing.T) {
SendLine("PONG :hello 6\r\n"),
SendLine("PONG :hello 7\r\n"),
})

// Successful ping with write error
runClientTest(t, config, errors.New("test error"), nil, []TestAction{
ExpectLine("PASS :test_pass\r\n"),
ExpectLine("NICK :test_nick\r\n"),
ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"),
// We queue this up a line early because the next write will happen after the delay.
QueueWriteError(errors.New("test error")),
SendLine("001 :hello_world\r\n"),
Delay(25 * time.Millisecond),
})
}
77 changes: 52 additions & 25 deletions stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,36 @@ func Delay(delay time.Duration) TestAction {
}
}

func QueueReadError(err error) TestAction {
return func(t *testing.T, rw *testReadWriter) {
select {
case rw.readErrorChan <- err:
default:
assert.Fail(t, "Tried to queue a second read error")
}
}
}

func QueueWriteError(err error) TestAction {
return func(t *testing.T, rw *testReadWriter) {
select {
case rw.writeErrorChan <- err:
default:
assert.Fail(t, "Tried to queue a second write error")
}
}
}

type testReadWriter struct {
actions []TestAction
queuedWriteError error
writeChan chan string
queuedReadError error
readChan chan string
readEmptyChan chan struct{}
exiting chan struct{}
clientDone chan struct{}
serverBuffer bytes.Buffer
actions []TestAction
writeErrorChan chan error
writeChan chan string
readErrorChan chan error
readChan chan string
readEmptyChan chan struct{}
exiting chan struct{}
clientDone chan struct{}
serverBuffer bytes.Buffer
}

func (rw *testReadWriter) maybeBroadcastEmpty() {
Expand All @@ -109,10 +129,11 @@ func (rw *testReadWriter) maybeBroadcastEmpty() {
}

func (rw *testReadWriter) Read(buf []byte) (int, error) {
if rw.queuedReadError != nil {
err := rw.queuedReadError
rw.queuedReadError = nil
// Check for a read error first
select {
case err := <-rw.readErrorChan:
return 0, err
default:
}

// If there's data left in the buffer, we want to use that first.
Expand All @@ -125,10 +146,12 @@ func (rw *testReadWriter) Read(buf []byte) (int, error) {
return s, err
}

// Read from server. We're either waiting for this whole test to
// finish or for data to come in from the server buffer. We expect
// only one read to be happening at once.
// Read from server. We're waiting for this whole test to finish, data to
// come in from the server buffer, or for an error. We expect only one read
// to be happening at once.
select {
case err := <-rw.readErrorChan:
return 0, err
case data := <-rw.readChan:
rw.serverBuffer.WriteString(data)
s, err := rw.serverBuffer.Read(buf)
Expand All @@ -143,10 +166,10 @@ func (rw *testReadWriter) Read(buf []byte) (int, error) {
}

func (rw *testReadWriter) Write(buf []byte) (int, error) {
if rw.queuedWriteError != nil {
err := rw.queuedWriteError
rw.queuedWriteError = nil
select {
case err := <-rw.writeErrorChan:
return 0, err
default:
}

// Write to server. We can cheat with this because we know things
Expand All @@ -161,12 +184,14 @@ func (rw *testReadWriter) Write(buf []byte) (int, error) {

func newTestReadWriter(actions []TestAction) *testReadWriter {
return &testReadWriter{
actions: actions,
writeChan: make(chan string),
readChan: make(chan string),
readEmptyChan: make(chan struct{}, 1),
exiting: make(chan struct{}),
clientDone: make(chan struct{}),
actions: actions,
writeErrorChan: make(chan error, 1),
writeChan: make(chan string),
readErrorChan: make(chan error, 1),
readChan: make(chan string),
readEmptyChan: make(chan struct{}, 1),
exiting: make(chan struct{}),
clientDone: make(chan struct{}),
}
}

Expand Down Expand Up @@ -197,8 +222,10 @@ func runTest(t *testing.T, rw *testReadWriter, actions []TestAction) {

// TODO: Make sure there are no more incoming messages

// Ask everything to shut down and wait for the client to stop.
// Ask everything to shut down
close(rw.exiting)

// Wait for the client to stop
select {
case <-rw.clientDone:
case <-time.After(1 * time.Second):
Expand Down

0 comments on commit 5bf07c6

Please sign in to comment.