From db5bdff3363ab3dfa075538ad5d647f0834bf55f Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Fri, 7 Apr 2017 16:29:01 -0700 Subject: [PATCH 01/20] Add support for a ping loop and connection timeouts --- client.go | 41 +++++++++++++++++++++++++++++++++++++++-- conn.go | 50 +++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 86 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index 4e8a515..e61bc8d 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,8 @@ package irc import ( "io" + "sync" + "time" ) // clientFilters are pre-processing which happens for certain message @@ -49,6 +51,9 @@ type ClientConfig struct { User string Name string + // Connection settings + PingFrequency time.Duration + // Handler is used for message dispatching. Handler Handler } @@ -75,6 +80,31 @@ func NewClient(rwc io.ReadWriter, config ClientConfig) *Client { // strange and unexpected ways if it is called again before the first connection // exits. func (c *Client) Run() error { + // exiting is used by the main goroutine here to ensure any sub-goroutines + // get closed when exiting. + exiting := make(chan struct{}) + var wg sync.WaitGroup + + // If PingFrequency isn't the zero value, we need to start a ping goroutine. + if c.config.PingFrequency > 0 { + wg.Add(1) + go func() { + defer wg.Done() + + t := time.NewTicker(c.config.PingFrequency) + defer t.Stop() + + for { + select { + case <-t.C: + c.Writef("PING :%d", time.Now().Unix()) + case <-exiting: + break + } + } + }() + } + c.currentNick = c.config.Nick if c.config.Pass != "" { @@ -84,10 +114,12 @@ func (c *Client) Run() error { c.Writef("NICK :%s", c.config.Nick) c.Writef("USER %s 0.0.0.0 0.0.0.0 :%s", c.config.User, c.config.Name) + var err error + var m *Message for { - m, err := c.ReadMessage() + m, err = c.ReadMessage() if err != nil { - return err + break } if f, ok := clientFilters[m.Command]; ok { @@ -98,6 +130,11 @@ func (c *Client) Run() error { c.config.Handler.Handle(c, m) } } + + close(exiting) + wg.Wait() + + return err } // CurrentNick returns what the nick of the client is known to be at this point diff --git a/conn.go b/conn.go index 7bf9846..fb86178 100644 --- a/conn.go +++ b/conn.go @@ -4,6 +4,8 @@ import ( "bufio" "fmt" "io" + "net" + "time" ) // Conn represents a simple IRC client. It embeds an irc.Reader and an @@ -31,12 +33,23 @@ type Writer struct { DebugCallback func(line string) // Internal fields - writer io.Writer + writer io.Writer + conn net.Conn + timeout time.Duration } // NewWriter creates an irc.Writer from an io.Writer. func NewWriter(w io.Writer) *Writer { - return &Writer{nil, w} + return &Writer{nil, w, nil, 0} +} + +// NewNetWriter creates an irc.Writer from a net.Conn and a write timeout. +// Note that the read timeout is not for stream activity but how long waiting +// for a message. These should be almost identical in most situations. +func NewNetWriter(conn net.Conn, timeout time.Duration) *Writer { + return &Writer{ + nil, conn, conn, timeout, + } } // Write is a simple function which will write the given line to the @@ -46,6 +59,13 @@ func (w *Writer) Write(line string) error { w.DebugCallback(line) } + if w.conn != nil && w.timeout > 0 { + err := w.conn.SetWriteDeadline(time.Now().Add(w.timeout)) + if err != nil { + return err + } + } + _, err := w.writer.Write([]byte(line + "\r\n")) return err } @@ -71,7 +91,9 @@ type Reader struct { DebugCallback func(string) // Internal fields - reader *bufio.Reader + reader *bufio.Reader + conn net.Conn + timeout time.Duration } // NewReader creates an irc.Reader from an io.Reader. @@ -79,11 +101,33 @@ func NewReader(r io.Reader) *Reader { return &Reader{ nil, bufio.NewReader(r), + nil, + 0, + } +} + +// NewNetReader creates an irc.Reader from a net.Conn and a read timeout. Note +// that the read timeout is not for stream activity but how long waiting for a +// message. These should be almost identical in most situations. +func NewNetReader(c net.Conn, timeout time.Duration) *Reader { + return &Reader{ + nil, + bufio.NewReader(c), + c, + timeout, } } // ReadMessage returns the next message from the stream or an error. func (r *Reader) ReadMessage() (*Message, error) { + // Set the read deadline if we have one + if r.conn != nil && r.timeout > 0 { + err := r.conn.SetReadDeadline(time.Now().Add(r.timeout)) + if err != nil { + return nil, err + } + } + line, err := r.reader.ReadString('\n') if err != nil { return nil, err From f0ee6fd365d3483a3d491e0a1c7ca5c83b608819 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Fri, 7 Apr 2017 16:43:54 -0700 Subject: [PATCH 02/20] Add NewNetConn and NewNetClient --- client.go | 10 ++++++++++ conn.go | 11 ++++++++--- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index e61bc8d..bf7b49d 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package irc import ( "io" + "net" "sync" "time" ) @@ -76,6 +77,15 @@ func NewClient(rwc io.ReadWriter, config ClientConfig) *Client { } } +// NewNetClient creates a client given net.Conn, optional timeouts, and +// a client config. +func NewNetClient(conn net.Conn, readTimeout, writeTimeout time.Duration, config ClientConfig) *Client { + return &Client{ + Conn: NewNetConn(conn, readTimeout, writeTimeout), + config: config, + } +} + // Run starts the main loop for this IRC connection. Note that it may break in // strange and unexpected ways if it is called again before the first connection // exits. diff --git a/conn.go b/conn.go index fb86178..f18348b 100644 --- a/conn.go +++ b/conn.go @@ -17,13 +17,18 @@ type Conn struct { // NewConn creates a new Conn func NewConn(rw io.ReadWriter) *Conn { - // Create the client - c := &Conn{ + return &Conn{ NewReader(rw), NewWriter(rw), } +} - return c +// NewNetConn creates a Conn with optional timeouts +func NewNetConn(conn net.Conn, readTimeout, writeTimeout time.Duration) *Conn { + return &Conn{ + NewNetReader(conn, readTimeout), + NewNetWriter(conn, writeTimeout), + } } // Writer is the outgoing side of a connection. From 34250f2f8a35fe39a668941209fbfaba72102427 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Tue, 11 Apr 2017 16:53:10 -0700 Subject: [PATCH 03/20] Remove NewNet* in favor of SetTimeout --- client.go | 17 ++++++---------- conn.go | 59 ++++++++++++++++++++++--------------------------------- 2 files changed, 29 insertions(+), 47 deletions(-) diff --git a/client.go b/client.go index bf7b49d..f5ff017 100644 --- a/client.go +++ b/client.go @@ -2,7 +2,6 @@ package irc import ( "io" - "net" "sync" "time" ) @@ -54,6 +53,8 @@ type ClientConfig struct { // Connection settings PingFrequency time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration // Handler is used for message dispatching. Handler Handler @@ -71,19 +72,13 @@ type Client struct { // NewClient creates a client given an io stream and a client config. func NewClient(rwc io.ReadWriter, config ClientConfig) *Client { - return &Client{ + c := &Client{ Conn: NewConn(rwc), config: config, } -} - -// NewNetClient creates a client given net.Conn, optional timeouts, and -// a client config. -func NewNetClient(conn net.Conn, readTimeout, writeTimeout time.Duration, config ClientConfig) *Client { - return &Client{ - Conn: NewNetConn(conn, readTimeout, writeTimeout), - config: config, - } + c.Reader.SetTimeout(config.ReadTimeout) + c.Writer.SetTimeout(config.WriteTimeout) + return c } // Run starts the main loop for this IRC connection. Note that it may break in diff --git a/conn.go b/conn.go index f18348b..fec62bf 100644 --- a/conn.go +++ b/conn.go @@ -23,14 +23,6 @@ func NewConn(rw io.ReadWriter) *Conn { } } -// NewNetConn creates a Conn with optional timeouts -func NewNetConn(conn net.Conn, readTimeout, writeTimeout time.Duration) *Conn { - return &Conn{ - NewNetReader(conn, readTimeout), - NewNetWriter(conn, writeTimeout), - } -} - // Writer is the outgoing side of a connection. type Writer struct { // DebugCallback is called for each outgoing message. The name of this may @@ -39,22 +31,19 @@ type Writer struct { // Internal fields writer io.Writer - conn net.Conn timeout time.Duration } // NewWriter creates an irc.Writer from an io.Writer. func NewWriter(w io.Writer) *Writer { - return &Writer{nil, w, nil, 0} + return &Writer{nil, w, 0} } -// NewNetWriter creates an irc.Writer from a net.Conn and a write timeout. -// Note that the read timeout is not for stream activity but how long waiting -// for a message. These should be almost identical in most situations. -func NewNetWriter(conn net.Conn, timeout time.Duration) *Writer { - return &Writer{ - nil, conn, conn, timeout, - } +// SetTimeout allows you to set the write timeout for the next call to Write. +// Note that it is undefined behavior to call this while a call to Write is +// happening. +func (w *Writer) SetTimeout(timeout time.Duration) { + w.timeout = timeout } // Write is a simple function which will write the given line to the @@ -64,8 +53,8 @@ func (w *Writer) Write(line string) error { w.DebugCallback(line) } - if w.conn != nil && w.timeout > 0 { - err := w.conn.SetWriteDeadline(time.Now().Add(w.timeout)) + if c, ok := w.writer.(net.Conn); ok && w.timeout > 0 { + err := c.SetWriteDeadline(time.Now().Add(w.timeout)) if err != nil { return err } @@ -96,38 +85,36 @@ type Reader struct { DebugCallback func(string) // Internal fields - reader *bufio.Reader - conn net.Conn - timeout time.Duration + rawReader io.Reader + reader *bufio.Reader + timeout time.Duration } -// NewReader creates an irc.Reader from an io.Reader. +// NewReader creates an irc.Reader from an io.Reader. Note that once a reader is +// passed into this function, you should no longer use it as it is being used +// inside a bufio.Reader so you cannot rely on only the amount of data for a +// Message being read when you call ReadMessage. func NewReader(r io.Reader) *Reader { return &Reader{ nil, + r, bufio.NewReader(r), - nil, 0, } } -// NewNetReader creates an irc.Reader from a net.Conn and a read timeout. Note -// that the read timeout is not for stream activity but how long waiting for a -// message. These should be almost identical in most situations. -func NewNetReader(c net.Conn, timeout time.Duration) *Reader { - return &Reader{ - nil, - bufio.NewReader(c), - c, - timeout, - } +// SetTimeout allows you to set the read timeout for the next call to +// ReadMessage. Note that it is undefined behavior to call this while +// a call to ReadMessage is happening. +func (r *Reader) SetTimeout(timeout time.Duration) { + r.timeout = timeout } // ReadMessage returns the next message from the stream or an error. func (r *Reader) ReadMessage() (*Message, error) { // Set the read deadline if we have one - if r.conn != nil && r.timeout > 0 { - err := r.conn.SetReadDeadline(time.Now().Add(r.timeout)) + if c, ok := r.rawReader.(net.Conn); ok && r.timeout > 0 { + err := c.SetReadDeadline(time.Now().Add(r.timeout)) if err != nil { return nil, err } From 89f05ca1525a164ba948b620951bb525b2ef471f Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Tue, 11 Apr 2017 17:26:38 -0700 Subject: [PATCH 04/20] Add additional info to SetTimeout --- conn.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index fec62bf..9830851 100644 --- a/conn.go +++ b/conn.go @@ -41,7 +41,8 @@ func NewWriter(w io.Writer) *Writer { // SetTimeout allows you to set the write timeout for the next call to Write. // Note that it is undefined behavior to call this while a call to Write is -// happening. +// happening. Additionally, this is only effective if a net.Conn was passed into +// NewWriter. func (w *Writer) SetTimeout(timeout time.Duration) { w.timeout = timeout } @@ -105,7 +106,8 @@ func NewReader(r io.Reader) *Reader { // SetTimeout allows you to set the read timeout for the next call to // ReadMessage. Note that it is undefined behavior to call this while -// a call to ReadMessage is happening. +// a call to ReadMessage is happening. Additionally, this is only +// effective if a net.Conn is passed into NewReader. func (r *Reader) SetTimeout(timeout time.Duration) { r.timeout = timeout } From 7af173feb3d28de31bf0dc3c00e401864b681363 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Wed, 12 Apr 2017 13:49:23 -0700 Subject: [PATCH 05/20] Improve the PING timeout code --- client.go | 130 ++++++++++++++++++++++++++++++++++++++++++------------ conn.go | 44 ++---------------- 2 files changed, 105 insertions(+), 69 deletions(-) diff --git a/client.go b/client.go index f5ff017..6388537 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,8 @@ package irc import ( + "errors" + "fmt" "io" "sync" "time" @@ -26,6 +28,22 @@ var clientFilters = map[string]func(*Client, *Message){ reply.Command = "PONG" c.WriteMessage(reply) }, + "PONG": func(c *Client, m *Message) { + if c.config.PingFrequency > 0 { + c.sentPingLock.Lock() + defer c.sentPingLock.Unlock() + + // If there haven't been any sent pings, so we can safely ignore + // this pong. + if len(c.sentPings) == 0 { + return + } + + if fmt.Sprintf("%d", c.sentPings[0].Unix()) == m.Trailing() { + c.sentPings = c.sentPings[1:] + } + } + }, "PRIVMSG": func(c *Client, m *Message) { // Clean up CTCP stuff so everyone doesn't have to parse it // manually. @@ -53,8 +71,7 @@ type ClientConfig struct { // Connection settings PingFrequency time.Duration - ReadTimeout time.Duration - WriteTimeout time.Duration + PingTimeout time.Duration // Handler is used for message dispatching. Handler Handler @@ -64,21 +81,89 @@ type ClientConfig struct { // much simpler. type Client struct { *Conn + rwc io.ReadWriteCloser config ClientConfig // Internal state - currentNick string + currentNick string + sentPingLock sync.Mutex + sentPings []time.Time } // NewClient creates a client given an io stream and a client config. -func NewClient(rwc io.ReadWriter, config ClientConfig) *Client { - c := &Client{ +func NewClient(rwc io.ReadWriteCloser, config ClientConfig) *Client { + return &Client{ Conn: NewConn(rwc), + rwc: rwc, config: config, } - c.Reader.SetTimeout(config.ReadTimeout) - c.Writer.SetTimeout(config.WriteTimeout) - return c +} + +func (c *Client) startPingLoop(wg *sync.WaitGroup, errChan chan error, exiting chan struct{}) { + // We're firing off two new goroutines here. + wg.Add(2) + + // PING ticker + go func() { + defer wg.Done() + + t := time.NewTicker(c.config.PingFrequency) + defer t.Stop() + + for { + select { + case <-t.C: + timestamp := time.Now() + + // We need to append before we write so we can guarantee + // this will be in the queue when the PONG gets here. + c.sentPingLock.Lock() + c.sentPings = append(c.sentPings, timestamp) + c.sentPingLock.Unlock() + + err := c.Writef("PING :%d", timestamp.Unix()) + if err != nil { + errChan <- err + c.rwc.Close() + return + } + case <-exiting: + return + } + } + }() + + // PONG checker + go func() { + defer wg.Done() + + var timer *time.Timer + var pingSent bool + + for { + c.sentPingLock.Lock() + pingSent = len(c.sentPings) > 0 + if pingSent { + timer = time.NewTimer(c.config.PingTimeout) + } else { + timer = time.NewTimer(c.config.PingFrequency) + } + c.sentPingLock.Unlock() + + select { + case <-timer.C: + if pingSent { + errChan <- errors.New("PING timeout") + c.rwc.Close() + return + } + case <-exiting: + return + } + + timer.Stop() + } + }() } // Run starts the main loop for this IRC connection. Note that it may break in @@ -88,26 +173,13 @@ func (c *Client) Run() error { // exiting is used by the main goroutine here to ensure any sub-goroutines // get closed when exiting. exiting := make(chan struct{}) + errChan := make(chan error, 3) var wg sync.WaitGroup - // If PingFrequency isn't the zero value, we need to start a ping goroutine. + // If PingFrequency isn't the zero value, we need to start a ping goroutine + // and a pong checker goroutine. if c.config.PingFrequency > 0 { - wg.Add(1) - go func() { - defer wg.Done() - - t := time.NewTicker(c.config.PingFrequency) - defer t.Stop() - - for { - select { - case <-t.C: - c.Writef("PING :%d", time.Now().Unix()) - case <-exiting: - break - } - } - }() + c.startPingLoop(&wg, errChan, exiting) } c.currentNick = c.config.Nick @@ -119,11 +191,10 @@ func (c *Client) Run() error { c.Writef("NICK :%s", c.config.Nick) c.Writef("USER %s 0.0.0.0 0.0.0.0 :%s", c.config.User, c.config.Name) - var err error - var m *Message for { - m, err = c.ReadMessage() + m, err := c.ReadMessage() if err != nil { + errChan <- err break } @@ -136,6 +207,9 @@ func (c *Client) Run() error { } } + // Wait for an error from any goroutine, then signal we're exiting and wait + // for the goroutines to exit. + err := <-errChan close(exiting) wg.Wait() diff --git a/conn.go b/conn.go index 9830851..aa3e920 100644 --- a/conn.go +++ b/conn.go @@ -4,8 +4,6 @@ import ( "bufio" "fmt" "io" - "net" - "time" ) // Conn represents a simple IRC client. It embeds an irc.Reader and an @@ -30,21 +28,12 @@ type Writer struct { DebugCallback func(line string) // Internal fields - writer io.Writer - timeout time.Duration + writer io.Writer } // NewWriter creates an irc.Writer from an io.Writer. func NewWriter(w io.Writer) *Writer { - return &Writer{nil, w, 0} -} - -// SetTimeout allows you to set the write timeout for the next call to Write. -// Note that it is undefined behavior to call this while a call to Write is -// happening. Additionally, this is only effective if a net.Conn was passed into -// NewWriter. -func (w *Writer) SetTimeout(timeout time.Duration) { - w.timeout = timeout + return &Writer{nil, w} } // Write is a simple function which will write the given line to the @@ -54,13 +43,6 @@ func (w *Writer) Write(line string) error { w.DebugCallback(line) } - if c, ok := w.writer.(net.Conn); ok && w.timeout > 0 { - err := c.SetWriteDeadline(time.Now().Add(w.timeout)) - if err != nil { - return err - } - } - _, err := w.writer.Write([]byte(line + "\r\n")) return err } @@ -86,9 +68,7 @@ type Reader struct { DebugCallback func(string) // Internal fields - rawReader io.Reader - reader *bufio.Reader - timeout time.Duration + reader *bufio.Reader } // NewReader creates an irc.Reader from an io.Reader. Note that once a reader is @@ -98,30 +78,12 @@ type Reader struct { func NewReader(r io.Reader) *Reader { return &Reader{ nil, - r, bufio.NewReader(r), - 0, } } -// SetTimeout allows you to set the read timeout for the next call to -// ReadMessage. Note that it is undefined behavior to call this while -// a call to ReadMessage is happening. Additionally, this is only -// effective if a net.Conn is passed into NewReader. -func (r *Reader) SetTimeout(timeout time.Duration) { - r.timeout = timeout -} - // ReadMessage returns the next message from the stream or an error. func (r *Reader) ReadMessage() (*Message, error) { - // Set the read deadline if we have one - if c, ok := r.rawReader.(net.Conn); ok && r.timeout > 0 { - err := c.SetReadDeadline(time.Now().Add(r.timeout)) - if err != nil { - return nil, err - } - } - line, err := r.reader.ReadString('\n') if err != nil { return nil, err From 4c560877044023931674b46ca5c41be1efe2a6ea Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Fri, 18 Aug 2017 11:54:15 -0700 Subject: [PATCH 06/20] Change limiter to use better exiting method --- client.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index 436f5b2..2d72616 100644 --- a/client.go +++ b/client.go @@ -125,17 +125,20 @@ func (c *Client) writeCallback(w *Writer, line string) error { return err } -func (c *Client) maybeStartLimiter() { +func (c *Client) maybeStartLimiter(wg *sync.WaitGroup, errChan chan error, exiting chan struct{}) { if c.config.SendLimit == 0 { return } + wg.Add(1) + // If SendBurst is 0, this will be unbuffered, so keep that in mind. c.limiter = make(chan struct{}, c.config.SendBurst) - c.limitTick = time.NewTicker(c.config.SendLimit) go func() { + defer wg.Done() + var done bool for !done { select { @@ -144,7 +147,7 @@ func (c *Client) maybeStartLimiter() { case c.limiter <- struct{}{}: default: } - case <-c.tickDone: + case <-exiting: done = true } } @@ -236,9 +239,6 @@ func (c *Client) startPingLoop(wg *sync.WaitGroup, errChan chan error, exiting c // strange and unexpected ways if it is called again before the first connection // exits. func (c *Client) Run() error { - c.maybeStartLimiter() - defer c.stopLimiter() - // exiting is used by the main goroutine here to ensure any sub-goroutines // get closed when exiting. exiting := make(chan struct{}) @@ -250,6 +250,7 @@ func (c *Client) Run() error { if c.config.PingFrequency > 0 { c.startPingLoop(&wg, errChan, exiting) } + c.maybeStartLimiter(&wg, errChan, exiting) c.currentNick = c.config.Nick From 7fd9b37e2ff6568f85dddfffe56a9467aba3c233 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Fri, 18 Aug 2017 11:54:42 -0700 Subject: [PATCH 07/20] Clean up ping loop implementation --- client.go | 119 +++++++++++++++++------------------------------------- 1 file changed, 38 insertions(+), 81 deletions(-) diff --git a/client.go b/client.go index 2d72616..adf43a4 100644 --- a/client.go +++ b/client.go @@ -29,19 +29,8 @@ var clientFilters = map[string]func(*Client, *Message){ c.WriteMessage(reply) }, "PONG": func(c *Client, m *Message) { - if c.config.PingFrequency > 0 { - c.sentPingLock.Lock() - defer c.sentPingLock.Unlock() - - // If there haven't been any sent pings, so we can safely ignore - // this pong. - if len(c.sentPings) == 0 { - return - } - - if fmt.Sprintf("%d", c.sentPings[0].Unix()) == m.Trailing() { - c.sentPings = c.sentPings[1:] - } + if c.incomingPongChan != nil { + c.incomingPongChan <- m.Trailing() } }, "PRIVMSG": func(c *Client, m *Message) { @@ -88,24 +77,20 @@ type ClientConfig struct { // much simpler. type Client struct { *Conn - rwc io.ReadWriteCloser config ClientConfig // Internal state - currentNick string - sentPingLock sync.Mutex - sentPings []time.Time - - limitTick *time.Ticker - limiter chan struct{} - tickDone chan struct{} + currentNick string + limitTick *time.Ticker + limiter chan struct{} + tickDone chan struct{} + incomingPongChan chan string } // NewClient creates a client given an io stream and a client config. -func NewClient(rwc io.ReadWriteCloser, config ClientConfig) *Client { +func NewClient(rw io.ReadWriter, config ClientConfig) *Client { c := &Client{ - Conn: NewConn(rwc), - rwc: rwc, + Conn: NewConn(rw), config: config, tickDone: make(chan struct{}), } @@ -155,82 +140,58 @@ func (c *Client) maybeStartLimiter(wg *sync.WaitGroup, errChan chan error, exiti c.limitTick.Stop() close(c.limiter) c.limiter = nil - c.tickDone <- struct{}{} }() } -func (c *Client) stopLimiter() { - if c.limiter == nil { +func (c *Client) maybeStartPingLoop(wg *sync.WaitGroup, errChan chan error, exiting chan struct{}) { + if c.config.PingFrequency <= 0 { return } - c.tickDone <- struct{}{} - <-c.tickDone -} - -func (c *Client) startPingLoop(wg *sync.WaitGroup, errChan chan error, exiting chan struct{}) { - // We're firing off two new goroutines here. - wg.Add(2) + wg.Add(1) - // PING ticker + // PONG checker go func() { defer wg.Done() - t := time.NewTicker(c.config.PingFrequency) - defer t.Stop() + var ( + sentPings []time.Time + pingTimeoutChan <-chan time.Time + ticker = time.NewTicker(c.config.PingFrequency) + ) + + defer ticker.Stop() for { + // Reset the pingTimeoutChan if we have any pings we're waiting for + // and it isn't currently set. + if len(sentPings) > 0 && pingTimeoutChan == nil { + pingTimeoutChan = time.After(time.Now().Sub(sentPings[0]) + c.config.PingTimeout) + } + select { - case <-t.C: + case <-ticker.C: timestamp := time.Now() - - // We need to append before we write so we can guarantee - // this will be in the queue when the PONG gets here. - c.sentPingLock.Lock() - c.sentPings = append(c.sentPings, timestamp) - c.sentPingLock.Unlock() - err := c.Writef("PING :%d", timestamp.Unix()) if err != nil { errChan <- err - c.rwc.Close() return } - case <-exiting: + sentPings = append(sentPings, timestamp) + case <-pingTimeoutChan: + errChan <- errors.New("PING timeout") return - } - } - }() - - // PONG checker - go func() { - defer wg.Done() - - var timer *time.Timer - var pingSent bool - - for { - c.sentPingLock.Lock() - pingSent = len(c.sentPings) > 0 - if pingSent { - timer = time.NewTimer(c.config.PingTimeout) - } else { - timer = time.NewTimer(c.config.PingFrequency) - } - c.sentPingLock.Unlock() - - select { - case <-timer.C: - if pingSent { - errChan <- errors.New("PING timeout") - c.rwc.Close() - return + case data := <-c.incomingPongChan: + if len(sentPings) == 0 || data != fmt.Sprintf("%d", sentPings[0].Unix()) { + continue } + + // Drop the first ping and clear the timeout chan + sentPings = sentPings[1:] + pingTimeoutChan = nil case <-exiting: return } - - timer.Stop() } }() } @@ -245,12 +206,8 @@ func (c *Client) Run() error { errChan := make(chan error, 3) var wg sync.WaitGroup - // If PingFrequency isn't the zero value, we need to start a ping goroutine - // and a pong checker goroutine. - if c.config.PingFrequency > 0 { - c.startPingLoop(&wg, errChan, exiting) - } c.maybeStartLimiter(&wg, errChan, exiting) + c.maybeStartPingLoop(&wg, errChan, exiting) c.currentNick = c.config.Nick From 2d6d26fc0087cccebbe0e69dd26b59f6231c10e0 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Fri, 18 Aug 2017 11:55:44 -0700 Subject: [PATCH 08/20] Small limiter cleanups --- client.go | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index adf43a4..5bddc30 100644 --- a/client.go +++ b/client.go @@ -81,18 +81,15 @@ type Client struct { // Internal state currentNick string - limitTick *time.Ticker limiter chan struct{} - tickDone chan struct{} incomingPongChan chan string } // NewClient creates a client given an io stream and a client config. func NewClient(rw io.ReadWriter, config ClientConfig) *Client { c := &Client{ - Conn: NewConn(rw), - config: config, - tickDone: make(chan struct{}), + Conn: NewConn(rw), + config: config, } // Replace the writer writeCallback with one of our own @@ -119,7 +116,7 @@ func (c *Client) maybeStartLimiter(wg *sync.WaitGroup, errChan chan error, exiti // If SendBurst is 0, this will be unbuffered, so keep that in mind. c.limiter = make(chan struct{}, c.config.SendBurst) - c.limitTick = time.NewTicker(c.config.SendLimit) + limitTick := time.NewTicker(c.config.SendLimit) go func() { defer wg.Done() @@ -127,7 +124,7 @@ func (c *Client) maybeStartLimiter(wg *sync.WaitGroup, errChan chan error, exiti var done bool for !done { select { - case <-c.limitTick.C: + case <-limitTick.C: select { case c.limiter <- struct{}{}: default: @@ -137,7 +134,7 @@ func (c *Client) maybeStartLimiter(wg *sync.WaitGroup, errChan chan error, exiti } } - c.limitTick.Stop() + limitTick.Stop() close(c.limiter) c.limiter = nil }() From 7568ec8fac1846f26a8b8bf6b1705dd00af5dc51 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Fri, 18 Aug 2017 12:03:07 -0700 Subject: [PATCH 09/20] Small pong channel fix --- client.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/client.go b/client.go index 5bddc30..3df0ca2 100644 --- a/client.go +++ b/client.go @@ -147,6 +147,8 @@ func (c *Client) maybeStartPingLoop(wg *sync.WaitGroup, errChan chan error, exit wg.Add(1) + c.incomingPongChan = make(chan string, 5) + // PONG checker go func() { defer wg.Done() From 62f7109623f44e523d0370bdb03423c8e161d00c Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Fri, 18 Aug 2017 12:07:46 -0700 Subject: [PATCH 10/20] Split pingLoop out into separate func --- client.go | 75 +++++++++++++++++++++++++++---------------------------- 1 file changed, 37 insertions(+), 38 deletions(-) diff --git a/client.go b/client.go index 3df0ca2..1c7128c 100644 --- a/client.go +++ b/client.go @@ -146,53 +146,52 @@ func (c *Client) maybeStartPingLoop(wg *sync.WaitGroup, errChan chan error, exit } wg.Add(1) - c.incomingPongChan = make(chan string, 5) + go c.pingLoop(wg, errChan, exiting) +} - // PONG checker - go func() { - defer wg.Done() +func (c *Client) pingLoop(wg *sync.WaitGroup, errChan chan error, exiting chan struct{}) { + defer wg.Done() - var ( - sentPings []time.Time - pingTimeoutChan <-chan time.Time - ticker = time.NewTicker(c.config.PingFrequency) - ) + var ( + sentPings []time.Time + pingTimeoutChan <-chan time.Time + ticker = time.NewTicker(c.config.PingFrequency) + ) - defer ticker.Stop() + defer ticker.Stop() - for { - // Reset the pingTimeoutChan if we have any pings we're waiting for - // and it isn't currently set. - if len(sentPings) > 0 && pingTimeoutChan == nil { - pingTimeoutChan = time.After(time.Now().Sub(sentPings[0]) + c.config.PingTimeout) - } - - select { - case <-ticker.C: - timestamp := time.Now() - err := c.Writef("PING :%d", timestamp.Unix()) - if err != nil { - errChan <- err - return - } - sentPings = append(sentPings, timestamp) - case <-pingTimeoutChan: - errChan <- errors.New("PING timeout") - return - case data := <-c.incomingPongChan: - if len(sentPings) == 0 || data != fmt.Sprintf("%d", sentPings[0].Unix()) { - continue - } + for { + // Reset the pingTimeoutChan if we have any pings we're waiting for + // and it isn't currently set. + if len(sentPings) > 0 && pingTimeoutChan == nil { + pingTimeoutChan = time.After(time.Now().Sub(sentPings[0]) + c.config.PingTimeout) + } - // Drop the first ping and clear the timeout chan - sentPings = sentPings[1:] - pingTimeoutChan = nil - case <-exiting: + select { + case <-ticker.C: + timestamp := time.Now() + err := c.Writef("PING :%d", timestamp.Unix()) + if err != nil { + errChan <- err return } + sentPings = append(sentPings, timestamp) + case <-pingTimeoutChan: + errChan <- errors.New("PING timeout") + return + case data := <-c.incomingPongChan: + if len(sentPings) == 0 || data != fmt.Sprintf("%d", sentPings[0].Unix()) { + continue + } + + // Drop the first ping and clear the timeout chan + sentPings = sentPings[1:] + pingTimeoutChan = nil + case <-exiting: + return } - }() + } } // Run starts the main loop for this IRC connection. Note that it may break in From 8aab619dd6911c18fc3ffe9b14ccd6efb78c83f9 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Fri, 18 Aug 2017 12:43:13 -0700 Subject: [PATCH 11/20] Bump go-cyclo up to 12 --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 494bd20..4d0114d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,7 +13,7 @@ before_install: - rm ./testcases/*.go script: - - gometalinter --fast ./... -D gas + - gometalinter --fast ./... -D gas --cyclo-over=12 - go test -race -v ./... - go test -covermode=count -coverprofile=profile.cov From 4a006f5d384a0c00b51dcd262891bbbd1896c11d Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Fri, 18 Aug 2017 12:50:13 -0700 Subject: [PATCH 12/20] Fix potential issue if more than 5 pings are waiting at once This isn't a perfect fix, but it will at least allow the bot to fail gracefully rather than hang. --- client.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 1c7128c..9323ff7 100644 --- a/client.go +++ b/client.go @@ -30,7 +30,10 @@ var clientFilters = map[string]func(*Client, *Message){ }, "PONG": func(c *Client, m *Message) { if c.incomingPongChan != nil { - c.incomingPongChan <- m.Trailing() + select { + case c.incomingPongChan <- m.Trailing(): + default: + } } }, "PRIVMSG": func(c *Client, m *Message) { From 77b53c1d9b2433dd8a6b7ac92e0f1a0d072d0c41 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Fri, 18 Aug 2017 17:28:17 -0700 Subject: [PATCH 13/20] Get pingLoop below the cyclo requirements --- .travis.yml | 2 +- client.go | 39 +++++++++++++++++++++++---------------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/.travis.yml b/.travis.yml index 4d0114d..494bd20 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,7 +13,7 @@ before_install: - rm ./testcases/*.go script: - - gometalinter --fast ./... -D gas --cyclo-over=12 + - gometalinter --fast ./... -D gas - go test -race -v ./... - go test -covermode=count -coverprofile=profile.cov diff --git a/client.go b/client.go index 9323ff7..bcb296a 100644 --- a/client.go +++ b/client.go @@ -153,22 +153,27 @@ func (c *Client) maybeStartPingLoop(wg *sync.WaitGroup, errChan chan error, exit go c.pingLoop(wg, errChan, exiting) } +type pingDeadline struct { + Data string + Deadline <-chan time.Time +} + func (c *Client) pingLoop(wg *sync.WaitGroup, errChan chan error, exiting chan struct{}) { defer wg.Done() var ( - sentPings []time.Time - pingTimeoutChan <-chan time.Time + sentPings = map[string]time.Time{} + pingDeadlines []pingDeadline + currentDeadline pingDeadline ticker = time.NewTicker(c.config.PingFrequency) ) defer ticker.Stop() for { - // Reset the pingTimeoutChan if we have any pings we're waiting for - // and it isn't currently set. - if len(sentPings) > 0 && pingTimeoutChan == nil { - pingTimeoutChan = time.After(time.Now().Sub(sentPings[0]) + c.config.PingTimeout) + if len(pingDeadlines) > 0 { + currentDeadline = pingDeadlines[0] + pingDeadlines = pingDeadlines[1:] } select { @@ -179,18 +184,20 @@ func (c *Client) pingLoop(wg *sync.WaitGroup, errChan chan error, exiting chan s errChan <- err return } - sentPings = append(sentPings, timestamp) - case <-pingTimeoutChan: - errChan <- errors.New("PING timeout") + deadline := pingDeadline{ + Data: fmt.Sprintf("%d", timestamp.Unix()), + Deadline: time.After(c.config.PingTimeout), + } + sentPings[deadline.Data] = timestamp + pingDeadlines = append(pingDeadlines, deadline) + case <-currentDeadline.Deadline: + if _, ok := sentPings[currentDeadline.Data]; ok { + errChan <- errors.New("PING timeout") + } + currentDeadline.Deadline = nil return case data := <-c.incomingPongChan: - if len(sentPings) == 0 || data != fmt.Sprintf("%d", sentPings[0].Unix()) { - continue - } - - // Drop the first ping and clear the timeout chan - sentPings = sentPings[1:] - pingTimeoutChan = nil + delete(sentPings, data) case <-exiting: return } From 1f0cac927280f67bc3613fb3608040eccf9ae404 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Mon, 21 Aug 2017 11:42:28 -0700 Subject: [PATCH 14/20] Rewrite pingLoop to fix a number of bugs --- client.go | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/client.go b/client.go index bcb296a..4353f4b 100644 --- a/client.go +++ b/client.go @@ -154,50 +154,56 @@ func (c *Client) maybeStartPingLoop(wg *sync.WaitGroup, errChan chan error, exit } type pingDeadline struct { - Data string Deadline <-chan time.Time + Data string } func (c *Client) pingLoop(wg *sync.WaitGroup, errChan chan error, exiting chan struct{}) { defer wg.Done() var ( - sentPings = map[string]time.Time{} - pingDeadlines []pingDeadline - currentDeadline pingDeadline + needsPong = map[string]bool{} + pingDeadlines []*pingDeadline + currentDeadline *pingDeadline ticker = time.NewTicker(c.config.PingFrequency) ) defer ticker.Stop() for { - if len(pingDeadlines) > 0 { + // Find the next ping we haven't received yet. + for len(pingDeadlines) > 0 && !needsPong[currentDeadline.Data] { currentDeadline = pingDeadlines[0] pingDeadlines = pingDeadlines[1:] } select { case <-ticker.C: + // Every time the ticker fires off we need to send a ping + // and store that we sent it. timestamp := time.Now() err := c.Writef("PING :%d", timestamp.Unix()) if err != nil { errChan <- err return } - deadline := pingDeadline{ - Data: fmt.Sprintf("%d", timestamp.Unix()), + deadline := &pingDeadline{ Deadline: time.After(c.config.PingTimeout), + Data: fmt.Sprintf("%d", timestamp.Unix()), } - sentPings[deadline.Data] = timestamp + needsPong[deadline.Data] = true pingDeadlines = append(pingDeadlines, deadline) case <-currentDeadline.Deadline: - if _, ok := sentPings[currentDeadline.Data]; ok { - errChan <- errors.New("PING timeout") + // When the deadline comes back, if we still haven't + // gotten a pong, we kill the connection. + if !needsPong[currentDeadline.Data] { + continue } - currentDeadline.Deadline = nil + + errChan <- errors.New("PING timeout") return case data := <-c.incomingPongChan: - delete(sentPings, data) + delete(needsPong, data) case <-exiting: return } From 96372b42948b8f4338d9a78055c56f3c585dcd85 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Mon, 21 Aug 2017 12:24:45 -0700 Subject: [PATCH 15/20] Improve pingloop and error handling --- client.go | 123 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 64 insertions(+), 59 deletions(-) diff --git a/client.go b/client.go index 4353f4b..909b336 100644 --- a/client.go +++ b/client.go @@ -1,7 +1,6 @@ package irc import ( - "errors" "fmt" "io" "sync" @@ -86,13 +85,15 @@ type Client struct { currentNick string limiter chan struct{} incomingPongChan chan string + errChan chan error } // NewClient creates a client given an io stream and a client config. func NewClient(rw io.ReadWriter, config ClientConfig) *Client { c := &Client{ - Conn: NewConn(rw), - config: config, + Conn: NewConn(rw), + config: config, + errChan: make(chan error, 1), } // Replace the writer writeCallback with one of our own @@ -110,7 +111,7 @@ func (c *Client) writeCallback(w *Writer, line string) error { return err } -func (c *Client) maybeStartLimiter(wg *sync.WaitGroup, errChan chan error, exiting chan struct{}) { +func (c *Client) maybeStartLimiter(wg *sync.WaitGroup, exiting chan struct{}) { if c.config.SendLimit == 0 { return } @@ -143,70 +144,75 @@ func (c *Client) maybeStartLimiter(wg *sync.WaitGroup, errChan chan error, exiti }() } -func (c *Client) maybeStartPingLoop(wg *sync.WaitGroup, errChan chan error, exiting chan struct{}) { +func (c *Client) maybeStartPingLoop(wg *sync.WaitGroup, exiting chan struct{}) { if c.config.PingFrequency <= 0 { return } wg.Add(1) + c.incomingPongChan = make(chan string, 5) - go c.pingLoop(wg, errChan, exiting) -} -type pingDeadline struct { - Deadline <-chan time.Time - Data string + go func() { + defer wg.Done() + + pingHandlers := make(map[string]chan struct{}) + ticker := time.NewTicker(c.config.PingFrequency) + + defer ticker.Stop() + + for { + select { + case <-ticker.C: + // Each time we get a tick, we send off a ping and start a + // goroutine to handle the pong. + timestamp := time.Now().Unix() + pongChan := make(chan struct{}) + pingHandlers[fmt.Sprintf("%d", timestamp)] = pongChan + wg.Add(1) + go c.handlePing(timestamp, pongChan, wg, exiting) + case data := <-c.incomingPongChan: + // Make sure the pong gets routed to the correct + // goroutine. + c := pingHandlers[data] + delete(pingHandlers, data) + + if c != nil { + c <- struct{}{} + } + case <-exiting: + return + } + } + }() } -func (c *Client) pingLoop(wg *sync.WaitGroup, errChan chan error, exiting chan struct{}) { +func (c *Client) handlePing(timestamp int64, pongChan chan struct{}, wg *sync.WaitGroup, exiting chan struct{}) { defer wg.Done() - var ( - needsPong = map[string]bool{} - pingDeadlines []*pingDeadline - currentDeadline *pingDeadline - ticker = time.NewTicker(c.config.PingFrequency) - ) + err := c.Writef("PING :%d", timestamp) + if err != nil { + c.sendError(err) + return + } - defer ticker.Stop() + timer := time.NewTimer(c.config.PingTimeout) + defer timer.Stop() - for { - // Find the next ping we haven't received yet. - for len(pingDeadlines) > 0 && !needsPong[currentDeadline.Data] { - currentDeadline = pingDeadlines[0] - pingDeadlines = pingDeadlines[1:] - } - - select { - case <-ticker.C: - // Every time the ticker fires off we need to send a ping - // and store that we sent it. - timestamp := time.Now() - err := c.Writef("PING :%d", timestamp.Unix()) - if err != nil { - errChan <- err - return - } - deadline := &pingDeadline{ - Deadline: time.After(c.config.PingTimeout), - Data: fmt.Sprintf("%d", timestamp.Unix()), - } - needsPong[deadline.Data] = true - pingDeadlines = append(pingDeadlines, deadline) - case <-currentDeadline.Deadline: - // When the deadline comes back, if we still haven't - // gotten a pong, we kill the connection. - if !needsPong[currentDeadline.Data] { - continue - } + select { + case <-timer.C: + c.sendError(err) + case <-pongChan: + return + case <-exiting: + return + } +} - errChan <- errors.New("PING timeout") - return - case data := <-c.incomingPongChan: - delete(needsPong, data) - case <-exiting: - return - } +func (c *Client) sendError(err error) { + select { + case c.errChan <- err: + default: } } @@ -217,11 +223,10 @@ func (c *Client) Run() error { // exiting is used by the main goroutine here to ensure any sub-goroutines // get closed when exiting. exiting := make(chan struct{}) - errChan := make(chan error, 3) var wg sync.WaitGroup - c.maybeStartLimiter(&wg, errChan, exiting) - c.maybeStartPingLoop(&wg, errChan, exiting) + c.maybeStartLimiter(&wg, exiting) + c.maybeStartPingLoop(&wg, exiting) c.currentNick = c.config.Nick @@ -235,7 +240,7 @@ func (c *Client) Run() error { for { m, err := c.ReadMessage() if err != nil { - errChan <- err + c.sendError(err) break } @@ -250,7 +255,7 @@ func (c *Client) Run() error { // Wait for an error from any goroutine, then signal we're exiting and wait // for the goroutines to exit. - err := <-errChan + err := <-c.errChan close(exiting) wg.Wait() From 413e6f5fa8a4cd9781b18ef7125063602dcdace9 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Tue, 22 Aug 2017 14:03:30 -0700 Subject: [PATCH 16/20] Rewrite client testing framework and add pingLoop tests --- client.go | 3 +- client_test.go | 215 ++++++++++++++++++++++++++++++------------------- conn_test.go | 8 -- 3 files changed, 132 insertions(+), 94 deletions(-) diff --git a/client.go b/client.go index 909b336..2bb2d07 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package irc import ( + "errors" "fmt" "io" "sync" @@ -201,7 +202,7 @@ func (c *Client) handlePing(timestamp int64, pongChan chan struct{}, wg *sync.Wa select { case <-timer.C: - c.sendError(err) + c.sendError(errors.New("Ping Timeout")) case <-pongChan: return case <-exiting: diff --git a/client_test.go b/client_test.go index bf7be97..c53bb79 100644 --- a/client_test.go +++ b/client_test.go @@ -1,6 +1,8 @@ package irc import ( + "errors" + "fmt" "io" "testing" "time" @@ -29,72 +31,58 @@ func (th *TestHandler) Messages() []*Message { func TestClient(t *testing.T) { t.Parallel() - rwc := newTestReadWriteCloser() config := ClientConfig{ Nick: "test_nick", Pass: "test_pass", User: "test_user", Name: "test_name", } - c := NewClient(rwc, config) - err := c.Run() - assert.Equal(t, io.EOF, err) - - testLines(t, rwc, []string{ - "PASS :test_pass", - "NICK :test_nick", - "USER test_user 0.0.0.0 0.0.0.0 :test_name", + + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), }) - rwc.server.WriteString("PING :hello world\r\n") - err = c.Run() - assert.Equal(t, io.EOF, err) - testLines(t, rwc, []string{ - "PASS :test_pass", - "NICK :test_nick", - "USER test_user 0.0.0.0 0.0.0.0 :test_name", - "PONG :hello world", + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("PING :hello world\r\n"), + ExpectLine("PONG :hello world\r\n"), }) - rwc.server.WriteString(":test_nick NICK :new_test_nick\r\n") - err = c.Run() - assert.Equal(t, io.EOF, err) - testLines(t, rwc, []string{ - "PASS :test_pass", - "NICK :test_nick", - "USER test_user 0.0.0.0 0.0.0.0 :test_name", + c := runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine(":test_nick NICK :new_test_nick\r\n"), }) assert.Equal(t, "new_test_nick", c.CurrentNick()) - rwc.server.WriteString("001 :new_test_nick\r\n") - err = c.Run() - assert.Equal(t, io.EOF, err) - testLines(t, rwc, []string{ - "PASS :test_pass", - "NICK :test_nick", - "USER test_user 0.0.0.0 0.0.0.0 :test_name", + c = runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("001 :new_test_nick\r\n"), }) assert.Equal(t, "new_test_nick", c.CurrentNick()) - rwc.server.WriteString("433\r\n") - err = c.Run() - assert.Equal(t, io.EOF, err) - testLines(t, rwc, []string{ - "PASS :test_pass", - "NICK :test_nick", - "USER test_user 0.0.0.0 0.0.0.0 :test_name", - "NICK :test_nick_", + c = runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("433\r\n"), + ExpectLine("NICK :test_nick_\r\n"), }) assert.Equal(t, "test_nick_", c.CurrentNick()) - rwc.server.WriteString("437\r\n") - err = c.Run() - assert.Equal(t, io.EOF, err) - testLines(t, rwc, []string{ - "PASS :test_pass", - "NICK :test_nick", - "USER test_user 0.0.0.0 0.0.0.0 :test_name", - "NICK :test_nick_", + c = runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("437\r\n"), + ExpectLine("NICK :test_nick_\r\n"), }) assert.Equal(t, "test_nick_", c.CurrentNick()) } @@ -103,7 +91,7 @@ func TestSendLimit(t *testing.T) { t.Parallel() handler := &TestHandler{} - rwc := newTestReadWriteCloser() + config := ClientConfig{ Nick: "test_nick", Pass: "test_pass", @@ -116,42 +104,36 @@ func TestSendLimit(t *testing.T) { SendBurst: 2, } - rwc.server.WriteString("001 :hello_world\r\n") - c := NewClient(rwc, config) - before := time.Now() - err := c.Run() - assert.Equal(t, io.EOF, err) - assert.WithinDuration(t, before, time.Now(), 50*time.Millisecond) - testLines(t, rwc, []string{ - "PASS :test_pass", - "NICK :test_nick", - "USER test_user 0.0.0.0 0.0.0.0 :test_name", + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("001 :hello_world\r\n"), }) + assert.WithinDuration(t, before, time.Now(), 50*time.Millisecond) // This last test isn't really a test. It's being used to make sure we // hit the branch which handles dropping ticks if the buffered channel is // full. - rwc.server.WriteString("001 :hello world\r\n") handler.delay = 20 * time.Millisecond // Sleep for 20ms when we get the 001 message - c.config.SendLimit = 10 * time.Millisecond - c.config.SendBurst = 0 + config.SendLimit = 10 * time.Millisecond + config.SendBurst = 0 + before = time.Now() - err = c.Run() - assert.Equal(t, io.EOF, err) - assert.WithinDuration(t, before, time.Now(), 60*time.Millisecond) - testLines(t, rwc, []string{ - "PASS :test_pass", - "NICK :test_nick", - "USER test_user 0.0.0.0 0.0.0.0 :test_name", + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("001 :hello_world\r\n"), }) + assert.WithinDuration(t, before, time.Now(), 60*time.Millisecond) } func TestClientHandler(t *testing.T) { t.Parallel() handler := &TestHandler{} - rwc := newTestReadWriteCloser() config := ClientConfig{ Nick: "test_nick", Pass: "test_pass", @@ -161,17 +143,12 @@ func TestClientHandler(t *testing.T) { Handler: handler, } - rwc.server.WriteString("001 :hello_world\r\n") - c := NewClient(rwc, config) - err := c.Run() - assert.Equal(t, io.EOF, err) - - testLines(t, rwc, []string{ - "PASS :test_pass", - "NICK :test_nick", - "USER test_user 0.0.0.0 0.0.0.0 :test_name", + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("001 :hello_world\r\n"), }) - assert.EqualValues(t, []*Message{ { Tags: Tags{}, @@ -182,9 +159,12 @@ func TestClientHandler(t *testing.T) { }, handler.Messages()) // Ensure CTCP messages are parsed - rwc.server.WriteString(":world PRIVMSG :\x01VERSION\x01\r\n") - err = c.Run() - assert.Equal(t, io.EOF, err) + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine(":world PRIVMSG :\x01VERSION\x01\r\n"), + }) assert.EqualValues(t, []*Message{ { Tags: Tags{}, @@ -196,9 +176,12 @@ func TestClientHandler(t *testing.T) { // CTCP Regression test for PR#47 // Proper CTCP should start AND end in \x01 - rwc.server.WriteString(":world PRIVMSG :\x01VERSION\r\n") - err = c.Run() - assert.Equal(t, io.EOF, err) + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine(":world PRIVMSG :\x01VERSION\r\n"), + }) assert.EqualValues(t, []*Message{ { Tags: Tags{}, @@ -207,7 +190,12 @@ func TestClientHandler(t *testing.T) { Params: []string{"\x01VERSION"}, }, }, handler.Messages()) +} +func TestFromChannel(t *testing.T) { + t.Parallel() + + c := Client{currentNick: "test_nick"} m := MustParseMessage("PRIVMSG test_nick :hello world") assert.False(t, c.FromChannel(m)) @@ -217,3 +205,60 @@ func TestClientHandler(t *testing.T) { m = MustParseMessage("PING") assert.False(t, c.FromChannel(m)) } + +func TestPingLoop(t *testing.T) { + t.Parallel() + + config := ClientConfig{ + Nick: "test_nick", + Pass: "test_pass", + User: "test_user", + Name: "test_name", + + PingFrequency: 20 * time.Millisecond, + PingTimeout: 5 * time.Millisecond, + } + + var lastPing *Message + + // Successful ping + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("001 :hello_world\r\n"), + Delay(20 * time.Millisecond), + LineFunc(func(m *Message) { + lastPing = m + }), + SendFunc(func() string { + return fmt.Sprintf("PONG :%s\r\n", lastPing.Trailing()) + }), + Delay(10 * time.Millisecond), + }) + + // Ping timeout + runTest(t, config, errors.New("Ping Timeout"), []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("001 :hello_world\r\n"), + Delay(20 * time.Millisecond), + LineFunc(func(m *Message) { + lastPing = m + }), + Delay(20 * time.Millisecond), + }) + + // Exit in the middle of handling a ping + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("001 :hello_world\r\n"), + Delay(20 * time.Millisecond), + LineFunc(func(m *Message) { + lastPing = m + }), + }) +} diff --git a/conn_test.go b/conn_test.go index f215ecb..7de27eb 100644 --- a/conn_test.go +++ b/conn_test.go @@ -19,13 +19,6 @@ func (ew *errorWriter) Write([]byte) (int, error) { type readWriteCloser struct { io.Reader io.Writer - io.Closer -} - -type nilCloser struct{} - -func (nc *nilCloser) Close() error { - return nil } type testReadWriteCloser struct { @@ -84,7 +77,6 @@ func TestWriteMessageError(t *testing.T) { rw := readWriteCloser{ &bytes.Buffer{}, &errorWriter{}, - &nilCloser{}, } c := NewConn(rw) From aabda25d5b74740652f12f2c1cdd8945effac5a7 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Tue, 22 Aug 2017 14:04:42 -0700 Subject: [PATCH 17/20] Add missing stream_test.go --- stream_test.go | 178 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 stream_test.go diff --git a/stream_test.go b/stream_test.go new file mode 100644 index 0000000..1f20a26 --- /dev/null +++ b/stream_test.go @@ -0,0 +1,178 @@ +package irc + +import ( + "bytes" + "errors" + "io" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// TestAction is used to execute an action during a stream test. If a +// non-nil error is returned the test will be failed. +type TestAction func(t *testing.T, c *Client, rw *testReadWriter) + +func SendLine(output string) TestAction { + return func(t *testing.T, c *Client, rw *testReadWriter) { + // First we send the message + select { + case rw.readChan <- output: + case <-rw.exiting: + assert.Fail(t, "Failed to send") + } + + // Now we wait for the buffer to be emptied + select { + case <-rw.readEmptyChan: + case <-rw.exiting: + assert.Fail(t, "Failed to send whole message") + } + } +} + +func SendFunc(cb func() string) TestAction { + return func(t *testing.T, c *Client, rw *testReadWriter) { + SendLine(cb())(t, c, rw) + } +} + +func LineFunc(cb func(m *Message)) TestAction { + return func(t *testing.T, c *Client, rw *testReadWriter) { + select { + case line := <-rw.writeChan: + cb(MustParseMessage(line)) + case <-time.After(1 * time.Second): + assert.Fail(t, "LineFunc timeout") + case <-rw.exiting: + } + } +} + +func ExpectLine(input string) TestAction { + return ExpectLineWithTimeout(input, 1*time.Second) +} + +func ExpectLineWithTimeout(input string, timeout time.Duration) TestAction { + return func(t *testing.T, c *Client, rw *testReadWriter) { + select { + case line := <-rw.writeChan: + assert.Equal(t, input, line) + case <-time.After(timeout): + assert.Fail(t, "ExpectLine timeout on %s", input) + case <-rw.exiting: + } + } +} + +func Delay(delay time.Duration) TestAction { + return func(t *testing.T, c *Client, rw *testReadWriter) { + select { + case <-time.After(delay): + case <-rw.exiting: + } + } +} + +type testReadWriter struct { + actions []TestAction + queuedWriteError error + writeChan chan string + queuedReadError error + readChan chan string + readEmptyChan chan struct{} + exiting chan struct{} + clientDone chan struct{} + serverBuffer bytes.Buffer +} + +func (rw *testReadWriter) Read(buf []byte) (int, error) { + if rw.queuedReadError != nil { + err := rw.queuedReadError + rw.queuedReadError = nil + return 0, err + } + + // If there's data left in the buffer, we want to use that first. + if rw.serverBuffer.Len() > 0 { + s, err := rw.serverBuffer.Read(buf) + if err == io.EOF { + err = nil + } + return s, err + } + + select { + case rw.readEmptyChan <- struct{}{}: + default: + } + + // Read from server. We're either waiting for this whole test to + // finish or for data to come in from the server buffer. We expect + // only one read to be happening at once. + select { + case data := <-rw.readChan: + rw.serverBuffer.WriteString(data) + s, err := rw.serverBuffer.Read(buf) + if err == io.EOF { + err = nil + } + return s, err + case <-rw.exiting: + return 0, io.EOF + } +} + +func (rw *testReadWriter) Write(buf []byte) (int, error) { + if rw.queuedWriteError != nil { + err := rw.queuedWriteError + rw.queuedWriteError = nil + return 0, err + } + + // Write to server. We can cheat with this because we know things + // will be written a line at a time. + select { + case rw.writeChan <- string(buf): + return len(buf), nil + case <-rw.exiting: + return 0, errors.New("Connection closed") + } +} + +func runTest(t *testing.T, cc ClientConfig, expectedErr error, actions []TestAction) *Client { + rw := &testReadWriter{ + actions: actions, + writeChan: make(chan string), + readChan: make(chan string), + readEmptyChan: make(chan struct{}, 1), + exiting: make(chan struct{}), + clientDone: make(chan struct{}), + } + + c := NewClient(rw, cc) + + go func() { + err := c.Run() + assert.Equal(t, expectedErr, err) + close(rw.clientDone) + }() + + // Perform each of the actions + for _, action := range rw.actions { + action(t, c, rw) + } + + // TODO: Make sure there are no more incoming messages + + // Ask everything to shut down and wait for the client to stop. + close(rw.exiting) + select { + case <-rw.clientDone: + case <-time.After(1 * time.Second): + assert.Fail(t, "Timeout in client shutdown") + } + + return c +} From ef996029bc89fee038e07f2da8b2c11003ef1e2f Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Tue, 22 Aug 2017 14:08:59 -0700 Subject: [PATCH 18/20] Aim for that last bit of code coverage --- client_test.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/client_test.go b/client_test.go index c53bb79..dee1b85 100644 --- a/client_test.go +++ b/client_test.go @@ -261,4 +261,21 @@ func TestPingLoop(t *testing.T) { lastPing = m }), }) + + // This one is just for coverage, so we know we're hitting the + // branch that drops extra pings. + runTest(t, config, io.EOF, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("001 :hello_world\r\n"), + + // It's a buffered channel of 5, so we want to send 6 of them + SendLine("PONG :hello 1\r\n"), + SendLine("PONG :hello 2\r\n"), + SendLine("PONG :hello 3\r\n"), + SendLine("PONG :hello 4\r\n"), + SendLine("PONG :hello 5\r\n"), + SendLine("PONG :hello 6\r\n"), + }) } From d553264fc16ed2042c361eec0ba322d7ae2d9351 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Tue, 22 Aug 2017 14:32:49 -0700 Subject: [PATCH 19/20] Fix small issue with the pongChan --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index 2bb2d07..691be2a 100644 --- a/client.go +++ b/client.go @@ -168,7 +168,7 @@ func (c *Client) maybeStartPingLoop(wg *sync.WaitGroup, exiting chan struct{}) { // Each time we get a tick, we send off a ping and start a // goroutine to handle the pong. timestamp := time.Now().Unix() - pongChan := make(chan struct{}) + pongChan := make(chan struct{}, 1) pingHandlers[fmt.Sprintf("%d", timestamp)] = pongChan wg.Add(1) go c.handlePing(timestamp, pongChan, wg, exiting) From fcc8c3748fd50d3efab04f20e4d2e270293be777 Mon Sep 17 00:00:00 2001 From: Kaleb Elwert Date: Tue, 22 Aug 2017 17:26:43 -0700 Subject: [PATCH 20/20] Small test cleanup --- client_test.go | 30 +++++++++++++++--------------- conn_test.go | 5 ----- stream_test.go | 18 +++++++++--------- 3 files changed, 24 insertions(+), 29 deletions(-) diff --git a/client_test.go b/client_test.go index dee1b85..05d4ca8 100644 --- a/client_test.go +++ b/client_test.go @@ -38,13 +38,13 @@ func TestClient(t *testing.T) { Name: "test_name", } - runTest(t, config, io.EOF, []TestAction{ + runClientTest(t, config, io.EOF, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), }) - runTest(t, config, io.EOF, []TestAction{ + runClientTest(t, config, io.EOF, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -52,7 +52,7 @@ func TestClient(t *testing.T) { ExpectLine("PONG :hello world\r\n"), }) - c := runTest(t, config, io.EOF, []TestAction{ + c := runClientTest(t, config, io.EOF, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -60,7 +60,7 @@ func TestClient(t *testing.T) { }) assert.Equal(t, "new_test_nick", c.CurrentNick()) - c = runTest(t, config, io.EOF, []TestAction{ + c = runClientTest(t, config, io.EOF, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -68,7 +68,7 @@ func TestClient(t *testing.T) { }) assert.Equal(t, "new_test_nick", c.CurrentNick()) - c = runTest(t, config, io.EOF, []TestAction{ + c = runClientTest(t, config, io.EOF, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -77,7 +77,7 @@ func TestClient(t *testing.T) { }) assert.Equal(t, "test_nick_", c.CurrentNick()) - c = runTest(t, config, io.EOF, []TestAction{ + c = runClientTest(t, config, io.EOF, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -105,7 +105,7 @@ func TestSendLimit(t *testing.T) { } before := time.Now() - runTest(t, config, io.EOF, []TestAction{ + runClientTest(t, config, io.EOF, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -121,7 +121,7 @@ func TestSendLimit(t *testing.T) { config.SendBurst = 0 before = time.Now() - runTest(t, config, io.EOF, []TestAction{ + runClientTest(t, config, io.EOF, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -143,7 +143,7 @@ func TestClientHandler(t *testing.T) { Handler: handler, } - runTest(t, config, io.EOF, []TestAction{ + runClientTest(t, config, io.EOF, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -159,7 +159,7 @@ func TestClientHandler(t *testing.T) { }, handler.Messages()) // Ensure CTCP messages are parsed - runTest(t, config, io.EOF, []TestAction{ + runClientTest(t, config, io.EOF, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -176,7 +176,7 @@ func TestClientHandler(t *testing.T) { // CTCP Regression test for PR#47 // Proper CTCP should start AND end in \x01 - runTest(t, config, io.EOF, []TestAction{ + runClientTest(t, config, io.EOF, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -222,7 +222,7 @@ func TestPingLoop(t *testing.T) { var lastPing *Message // Successful ping - runTest(t, config, io.EOF, []TestAction{ + runClientTest(t, config, io.EOF, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -238,7 +238,7 @@ func TestPingLoop(t *testing.T) { }) // Ping timeout - runTest(t, config, errors.New("Ping Timeout"), []TestAction{ + runClientTest(t, config, errors.New("Ping Timeout"), []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -251,7 +251,7 @@ func TestPingLoop(t *testing.T) { }) // Exit in the middle of handling a ping - runTest(t, config, io.EOF, []TestAction{ + runClientTest(t, config, io.EOF, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -264,7 +264,7 @@ func TestPingLoop(t *testing.T) { // This one is just for coverage, so we know we're hitting the // branch that drops extra pings. - runTest(t, config, io.EOF, []TestAction{ + runClientTest(t, config, io.EOF, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), diff --git a/conn_test.go b/conn_test.go index 7de27eb..5b9b407 100644 --- a/conn_test.go +++ b/conn_test.go @@ -41,11 +41,6 @@ func (t *testReadWriteCloser) Write(p []byte) (int, error) { return t.client.Write(p) } -// Ensure we can close the thing -func (t *testReadWriteCloser) Close() error { - return nil -} - func testReadMessage(t *testing.T, c *Conn) *Message { m, err := c.ReadMessage() assert.NoError(t, err) diff --git a/stream_test.go b/stream_test.go index 1f20a26..a04bd5a 100644 --- a/stream_test.go +++ b/stream_test.go @@ -12,10 +12,10 @@ import ( // TestAction is used to execute an action during a stream test. If a // non-nil error is returned the test will be failed. -type TestAction func(t *testing.T, c *Client, rw *testReadWriter) +type TestAction func(t *testing.T, rw *testReadWriter) func SendLine(output string) TestAction { - return func(t *testing.T, c *Client, rw *testReadWriter) { + return func(t *testing.T, rw *testReadWriter) { // First we send the message select { case rw.readChan <- output: @@ -33,13 +33,13 @@ func SendLine(output string) TestAction { } func SendFunc(cb func() string) TestAction { - return func(t *testing.T, c *Client, rw *testReadWriter) { - SendLine(cb())(t, c, rw) + return func(t *testing.T, rw *testReadWriter) { + SendLine(cb())(t, rw) } } func LineFunc(cb func(m *Message)) TestAction { - return func(t *testing.T, c *Client, rw *testReadWriter) { + return func(t *testing.T, rw *testReadWriter) { select { case line := <-rw.writeChan: cb(MustParseMessage(line)) @@ -55,7 +55,7 @@ func ExpectLine(input string) TestAction { } func ExpectLineWithTimeout(input string, timeout time.Duration) TestAction { - return func(t *testing.T, c *Client, rw *testReadWriter) { + return func(t *testing.T, rw *testReadWriter) { select { case line := <-rw.writeChan: assert.Equal(t, input, line) @@ -67,7 +67,7 @@ func ExpectLineWithTimeout(input string, timeout time.Duration) TestAction { } func Delay(delay time.Duration) TestAction { - return func(t *testing.T, c *Client, rw *testReadWriter) { + return func(t *testing.T, rw *testReadWriter) { select { case <-time.After(delay): case <-rw.exiting: @@ -141,7 +141,7 @@ func (rw *testReadWriter) Write(buf []byte) (int, error) { } } -func runTest(t *testing.T, cc ClientConfig, expectedErr error, actions []TestAction) *Client { +func runClientTest(t *testing.T, cc ClientConfig, expectedErr error, actions []TestAction) *Client { rw := &testReadWriter{ actions: actions, writeChan: make(chan string), @@ -161,7 +161,7 @@ func runTest(t *testing.T, cc ClientConfig, expectedErr error, actions []TestAct // Perform each of the actions for _, action := range rw.actions { - action(t, c, rw) + action(t, rw) } // TODO: Make sure there are no more incoming messages