diff --git a/client.go b/client.go index 6694bce..423992a 100644 --- a/client.go +++ b/client.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "io" + "strings" "sync" "time" ) @@ -52,6 +53,49 @@ var clientFilters = map[string]func(*Client, *Message){ c.currentNick = m.Params[0] } }, + "CAP": func(c *Client, m *Message) { + if c.remainingCapResponses <= 0 || len(m.Params) <= 2 { + return + } + + switch m.Params[1] { + case "LS": + for _, key := range strings.Split(m.Trailing(), " ") { + cap := c.caps[key] + cap.Available = true + c.caps[key] = cap + } + c.remainingCapResponses-- + case "ACK": + for _, key := range strings.Split(m.Trailing(), " ") { + cap := c.caps[key] + cap.Enabled = true + c.caps[key] = cap + } + c.remainingCapResponses-- + case "NAK": + // If we got a NAK and this REQ was required, we need to bail + // with an error. + for _, key := range strings.Split(m.Trailing(), " ") { + if c.caps[key].Required { + c.sendError(fmt.Errorf("CAP %s requested but was rejected", key)) + return + } + } + c.remainingCapResponses-- + } + + if c.remainingCapResponses <= 0 { + for key, cap := range c.caps { + if cap.Required && !cap.Enabled { + c.sendError(fmt.Errorf("CAP %s requested but not accepted", key)) + return + } + } + + c.Write("CAP END") + } + }, } // ClientConfig is a structure used to configure a Client. @@ -77,6 +121,20 @@ type ClientConfig struct { Handler Handler } +type cap struct { + // Requested means that this cap was requested by the user + Requested bool + + // Required will be true if this cap is non-optional + Required bool + + // Enabled means that this cap was accepted by the server + Enabled bool + + // Available means that the server supports this cap + Available bool +} + // Client is a wrapper around Conn which is designed to make common operations // much simpler. type Client struct { @@ -84,11 +142,13 @@ type Client struct { config ClientConfig // Internal state - currentNick string - limiter chan struct{} - incomingPongChan chan string - errChan chan error - connected bool + currentNick string + limiter chan struct{} + incomingPongChan chan string + errChan chan error + caps map[string]cap + remainingCapResponses int + connected bool } // NewClient creates a client given an io stream and a client config. @@ -97,6 +157,7 @@ func NewClient(rw io.ReadWriter, config ClientConfig) *Client { Conn: NewConn(rw), config: config, errChan: make(chan error, 1), + caps: make(map[string]cap), } // Replace the writer writeCallback with one of our own @@ -177,6 +238,7 @@ func (c *Client) maybeStartPingLoop(wg *sync.WaitGroup, exiting chan struct{}) { case data := <-c.incomingPongChan: // Make sure the pong gets routed to the correct // goroutine. + c := pingHandlers[data] delete(pingHandlers, data) @@ -219,6 +281,49 @@ func (c *Client) sendError(err error) { } } +// CapRequest allows you to request IRCv3 capabilities from the server during +// the handshake. The behavior is undefined if this is called before the +// handshake completes so it is recommended that this be called before Run. If +// the CAP is marked as required, the client will exit if that CAP could not be +// negotiated during the handshake. +func (c *Client) CapRequest(capName string, required bool) { + cap := c.caps[capName] + cap.Requested = true + cap.Required = cap.Required || required + c.caps[capName] = cap +} + +// CapEnabled allows you to check if a CAP is enabled for this connection. Note +// that it will not be populated until after the CAP handshake is done, so it is +// recommended to wait to check this until after a message like 001. +func (c *Client) CapEnabled(capName string) bool { + return c.caps[capName].Enabled +} + +// CapAvailable allows you to check if a CAP is available on this server. Note +// that it will not be populated until after the CAP handshake is done, so it is +// recommended to wait to check this until after a message like 001. +func (c *Client) CapAvailable(capName string) bool { + return c.caps[capName].Available +} + +func (c *Client) maybeStartCapHandshake() error { + if len(c.caps) <= 0 { + return nil + } + + c.Write("CAP LS") + c.remainingCapResponses = 1 // We count the CAP LS response as a normal response + for key, cap := range c.caps { + if cap.Requested { + c.Writef("CAP REQ :%s", key) + c.remainingCapResponses++ + } + } + + return nil +} + // Run starts the main loop for this IRC connection. Note that it may break in // strange and unexpected ways if it is called again before the first connection // exits. @@ -237,6 +342,10 @@ func (c *Client) Run() error { c.Writef("PASS :%s", c.config.Pass) } + c.maybeStartCapHandshake() + + // This feels wrong because it results in CAP LS, CAP REQ, NICK, USER, CAP + // END, but it works and lets us keep the code a bit simpler. c.Writef("NICK :%s", c.config.Nick) c.Writef("USER %s 0.0.0.0 0.0.0.0 :%s", c.config.User, c.config.Name) diff --git a/client_test.go b/client_test.go index 2f60ec4..8e442c2 100644 --- a/client_test.go +++ b/client_test.go @@ -28,6 +28,133 @@ func (th *TestHandler) Messages() []*Message { return ret } +func TestCapReq(t *testing.T) { + t.Parallel() + + config := ClientConfig{ + Nick: "test_nick", + Pass: "test_pass", + User: "test_user", + Name: "test_name", + } + + c := runClientTest(t, config, io.EOF, func(c *Client) { + assert.False(t, c.CapAvailable("random-thing")) + assert.False(t, c.CapAvailable("multi-prefix")) + c.CapRequest("multi-prefix", true) + }, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("CAP LS\r\n"), + ExpectLine("CAP REQ :multi-prefix\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("CAP * LS :multi-prefix\r\n"), + SendLine("CAP * ACK :multi-prefix\r\n"), + ExpectLine("CAP END\r\n"), + }) + assert.False(t, c.CapEnabled("random-thing")) + assert.True(t, c.CapEnabled("multi-prefix")) + assert.False(t, c.CapAvailable("random-thing")) + assert.True(t, c.CapAvailable("multi-prefix")) + + // Malformed CAP responses are ignored + c = runClientTest(t, config, io.EOF, func(c *Client) { + assert.False(t, c.CapAvailable("random-thing")) + assert.False(t, c.CapAvailable("multi-prefix")) + c.CapRequest("multi-prefix", true) + }, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("CAP LS\r\n"), + ExpectLine("CAP REQ :multi-prefix\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("CAP * LS :multi-prefix\r\n"), + //SendLine("CAP * ACK\r\n"), // Malformed CAP response + SendLine("CAP * ACK :multi-prefix\r\n"), + ExpectLine("CAP END\r\n"), + }) + assert.False(t, c.CapEnabled("random-thing")) + assert.True(t, c.CapEnabled("multi-prefix")) + assert.False(t, c.CapAvailable("random-thing")) + assert.True(t, c.CapAvailable("multi-prefix")) + + // Additional CAP messages after the start are ignored. + c = runClientTest(t, config, io.EOF, func(c *Client) { + assert.False(t, c.CapAvailable("random-thing")) + assert.False(t, c.CapAvailable("multi-prefix")) + c.CapRequest("multi-prefix", true) + }, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("CAP LS\r\n"), + ExpectLine("CAP REQ :multi-prefix\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("CAP * LS :multi-prefix\r\n"), + SendLine("CAP * ACK :multi-prefix\r\n"), + ExpectLine("CAP END\r\n"), + SendLine("CAP * NAK :multi-prefix\r\n"), + }) + assert.False(t, c.CapEnabled("random-thing")) + assert.True(t, c.CapEnabled("multi-prefix")) + assert.False(t, c.CapAvailable("random-thing")) + assert.True(t, c.CapAvailable("multi-prefix")) + + c = runClientTest(t, config, io.EOF, func(c *Client) { + assert.False(t, c.CapAvailable("random-thing")) + assert.False(t, c.CapAvailable("multi-prefix")) + c.CapRequest("multi-prefix", false) + }, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("CAP LS\r\n"), + ExpectLine("CAP REQ :multi-prefix\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("CAP * LS :multi-prefix\r\n"), + SendLine("CAP * NAK :multi-prefix\r\n"), + ExpectLine("CAP END\r\n"), + }) + assert.False(t, c.CapEnabled("random-thing")) + assert.False(t, c.CapEnabled("multi-prefix")) + assert.False(t, c.CapAvailable("random-thing")) + assert.True(t, c.CapAvailable("multi-prefix")) + + c = runClientTest(t, config, errors.New("CAP multi-prefix requested but was rejected"), func(c *Client) { + assert.False(t, c.CapAvailable("random-thing")) + assert.False(t, c.CapAvailable("multi-prefix")) + c.CapRequest("multi-prefix", true) + }, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("CAP LS\r\n"), + ExpectLine("CAP REQ :multi-prefix\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("CAP * LS :multi-prefix\r\n"), + SendLine("CAP * NAK :multi-prefix\r\n"), + }) + assert.False(t, c.CapEnabled("random-thing")) + assert.False(t, c.CapEnabled("multi-prefix")) + assert.False(t, c.CapAvailable("random-thing")) + assert.True(t, c.CapAvailable("multi-prefix")) + + c = runClientTest(t, config, errors.New("CAP multi-prefix requested but not accepted"), func(c *Client) { + assert.False(t, c.CapAvailable("random-thing")) + assert.False(t, c.CapAvailable("multi-prefix")) + c.CapRequest("multi-prefix", true) + }, []TestAction{ + ExpectLine("PASS :test_pass\r\n"), + ExpectLine("CAP LS\r\n"), + ExpectLine("CAP REQ :multi-prefix\r\n"), + ExpectLine("NICK :test_nick\r\n"), + ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), + SendLine("CAP * LS :multi-prefix\r\n"), + SendLine("CAP * ACK :\r\n"), + }) + assert.False(t, c.CapEnabled("random-thing")) + assert.False(t, c.CapEnabled("multi-prefix")) + assert.False(t, c.CapAvailable("random-thing")) + assert.True(t, c.CapAvailable("multi-prefix")) +} + func TestClient(t *testing.T) { t.Parallel() @@ -38,13 +165,13 @@ func TestClient(t *testing.T) { Name: "test_name", } - runClientTest(t, config, io.EOF, []TestAction{ + runClientTest(t, config, io.EOF, nil, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), }) - runClientTest(t, config, io.EOF, []TestAction{ + runClientTest(t, config, io.EOF, nil, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -52,7 +179,7 @@ func TestClient(t *testing.T) { ExpectLine("PONG :hello world\r\n"), }) - c := runClientTest(t, config, io.EOF, []TestAction{ + c := runClientTest(t, config, io.EOF, nil, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -60,7 +187,7 @@ func TestClient(t *testing.T) { }) assert.Equal(t, "new_test_nick", c.CurrentNick()) - c = runClientTest(t, config, io.EOF, []TestAction{ + c = runClientTest(t, config, io.EOF, nil, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -68,7 +195,7 @@ func TestClient(t *testing.T) { }) assert.Equal(t, "new_test_nick", c.CurrentNick()) - c = runClientTest(t, config, io.EOF, []TestAction{ + c = runClientTest(t, config, io.EOF, nil, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -77,7 +204,7 @@ func TestClient(t *testing.T) { }) assert.Equal(t, "test_nick_", c.CurrentNick()) - c = runClientTest(t, config, io.EOF, []TestAction{ + c = runClientTest(t, config, io.EOF, nil, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -86,7 +213,7 @@ func TestClient(t *testing.T) { }) assert.Equal(t, "test_nick_", c.CurrentNick()) - c = runClientTest(t, config, io.EOF, []TestAction{ + c = runClientTest(t, config, io.EOF, nil, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -97,7 +224,7 @@ func TestClient(t *testing.T) { }) assert.Equal(t, "test_nick_", c.CurrentNick()) - c = runClientTest(t, config, io.EOF, []TestAction{ + c = runClientTest(t, config, io.EOF, nil, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -127,7 +254,7 @@ func TestSendLimit(t *testing.T) { } before := time.Now() - runClientTest(t, config, io.EOF, []TestAction{ + runClientTest(t, config, io.EOF, nil, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -143,7 +270,7 @@ func TestSendLimit(t *testing.T) { config.SendBurst = 0 before = time.Now() - runClientTest(t, config, io.EOF, []TestAction{ + runClientTest(t, config, io.EOF, nil, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -165,7 +292,7 @@ func TestClientHandler(t *testing.T) { Handler: handler, } - runClientTest(t, config, io.EOF, []TestAction{ + runClientTest(t, config, io.EOF, nil, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -211,7 +338,7 @@ func TestPingLoop(t *testing.T) { var lastPing *Message // Successful ping - runClientTest(t, config, io.EOF, []TestAction{ + runClientTest(t, config, io.EOF, nil, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -227,7 +354,7 @@ func TestPingLoop(t *testing.T) { }) // Ping timeout - runClientTest(t, config, errors.New("Ping Timeout"), []TestAction{ + runClientTest(t, config, errors.New("Ping Timeout"), nil, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -240,7 +367,7 @@ func TestPingLoop(t *testing.T) { }) // Exit in the middle of handling a ping - runClientTest(t, config, io.EOF, []TestAction{ + runClientTest(t, config, io.EOF, nil, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), @@ -253,7 +380,7 @@ func TestPingLoop(t *testing.T) { // This one is just for coverage, so we know we're hitting the // branch that drops extra pings. - runClientTest(t, config, io.EOF, []TestAction{ + runClientTest(t, config, io.EOF, nil, []TestAction{ ExpectLine("PASS :test_pass\r\n"), ExpectLine("NICK :test_nick\r\n"), ExpectLine("USER test_user 0.0.0.0 0.0.0.0 :test_name\r\n"), diff --git a/stream_test.go b/stream_test.go index 71b9a0c..247df41 100644 --- a/stream_test.go +++ b/stream_test.go @@ -15,17 +15,29 @@ import ( type TestAction func(t *testing.T, rw *testReadWriter) func SendLine(output string) TestAction { + return SendLineWithTimeout(output, 1*time.Second) +} + +func SendLineWithTimeout(output string, timeout time.Duration) TestAction { return func(t *testing.T, rw *testReadWriter) { + waitChan := time.After(timeout) + // First we send the message select { case rw.readChan <- output: + case <-waitChan: + assert.Fail(t, "SendLine timeout on %s", output) + return case <-rw.exiting: assert.Fail(t, "Failed to send") + return } // Now we wait for the buffer to be emptied select { case <-rw.readEmptyChan: + case <-waitChan: + assert.Fail(t, "SendLine timeout on %s", output) case <-rw.exiting: assert.Fail(t, "Failed to send whole message") } @@ -152,10 +164,14 @@ func newTestReadWriter(actions []TestAction) *testReadWriter { } } -func runClientTest(t *testing.T, cc ClientConfig, expectedErr error, actions []TestAction) *Client { +func runClientTest(t *testing.T, cc ClientConfig, expectedErr error, setup func(c *Client), actions []TestAction) *Client { rw := newTestReadWriter(actions) c := NewClient(rw, cc) + if setup != nil { + setup(c) + } + go func() { err := c.Run() assert.Equal(t, expectedErr, err)