Skip to content

Commit

Permalink
Code review changes to simplify concurrency stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
oxtoacart committed Dec 3, 2014
1 parent 732c7fa commit d4cc22e
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 90 deletions.
59 changes: 35 additions & 24 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@ var (
maxReconnectDelay = 5 * time.Second
reconnectDelayInterval = 100 * time.Millisecond

notConnectedError = fmt.Errorf("Client not yet connected")
closedError = fmt.Errorf("Client closed")
closedError = fmt.Errorf("Client closed")
)

// Client is a client of a waddell server
type Client struct {
type ClientConfig struct {
// Dial is a function that dials the waddell server
Dial DialFunc

Expand All @@ -40,39 +38,44 @@ type Client struct {
// PeerId is assigned to this client (i.e. on each successful connection to
// the waddell server).
OnId func(id PeerId)
}

// Client is a client of a waddell server
type Client struct {
*ClientConfig

connInfoChs chan chan *connInfo
connErrCh chan error
topicsOut map[TopicId]*topic
topicsOutMutex sync.Mutex
topicsIn map[TopicId]chan *MessageIn
topicsInMutex sync.Mutex
connected int32
currentId PeerId
currentIdMutex sync.RWMutex
closed int32
}

// DialFunc is a function for dialing a waddell server.
type DialFunc func() (net.Conn, error)

// Connect starts the waddell client and establishes an initial connection to
// the waddell server, returning the initial PeerId.
// NewClient creates a waddell client, including establishing an initial
// connection to the waddell server, returning the client and the initial
// PeerId.
//
// Note - if the client automatically reconnects, its peer ID will change. You
// can obtain the new id through providing an OnId callback to the client.
//
// Note - whether or not auto reconnecting is enabled, this method doesn't
// return until a connection has been established or we've failed trying.
func (c *Client) Connect() (PeerId, error) {
alreadyConnected := !atomic.CompareAndSwapInt32(&c.connected, 0, 1)
if alreadyConnected {
return PeerId{}, fmt.Errorf("Client already connecting or connected")
func NewClient(cfg *ClientConfig) (*Client, error) {
c := &Client{
ClientConfig: cfg,
}

var err error
if c.ServerCert != "" {
c.Dial, err = secured(c.Dial, c.ServerCert)
if err != nil {
return PeerId{}, err
return nil, err
}
}

Expand All @@ -83,15 +86,26 @@ func (c *Client) Connect() (PeerId, error) {
go c.stayConnected()
go c.processInbound()
info := c.getConnInfo()
return info.id, info.err
return c, info.err
}

// CurrentId returns the current id (from most recent connection to waddell).
// To be notified about changes to the id, use the OnId handler.
func (c *Client) CurrentId() PeerId {
c.currentIdMutex.RLock()
defer c.currentIdMutex.RUnlock()
return c.currentId
}

func (c *Client) setCurrentId(id PeerId) {
c.currentIdMutex.Lock()
c.currentId = id
c.currentIdMutex.Unlock()
}

// SendKeepAlive sends a keep alive message to the server to keep the underlying
// connection open.
func (c *Client) SendKeepAlive() error {
if !c.hasConnected() {
return notConnectedError
}
if c.isClosed() {
return closedError
}
Expand All @@ -113,12 +127,13 @@ func (c *Client) SendKeepAlive() error {
// channels after they're closed will result in a panic. So, don't call Close()
// until you're actually 100% finished using this client.
func (c *Client) Close() error {
if !c.hasConnected() {
return notConnectedError
if c == nil {
return nil
}

justClosed := atomic.CompareAndSwapInt32(&c.closed, 0, 1)
if !justClosed {
return closedError
return nil
}

var err error
Expand Down Expand Up @@ -162,10 +177,6 @@ func secured(dial DialFunc, cert string) (DialFunc, error) {
}, nil
}

func (c *Client) hasConnected() bool {
return c.connected == 1
}

func (c *Client) isClosed() bool {
return c.closed == 1
}
17 changes: 7 additions & 10 deletions clientmgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,42 +27,39 @@ type ClientMgr struct {
OnId func(addr string, id PeerId)

clients map[string]*Client
ids map[string]PeerId
clientsMutex sync.Mutex
}

// ClientTo obtains the one (and only) client to the given addr, creating a new
// one if necessary. This method is safe to call from multiple goroutines.
func (m *ClientMgr) ClientTo(addr string) (*Client, PeerId, error) {
func (m *ClientMgr) ClientTo(addr string) (*Client, error) {
m.clientsMutex.Lock()
defer m.clientsMutex.Unlock()
if m.clients == nil {
m.clients = make(map[string]*Client)
m.ids = make(map[string]PeerId)
}
client := m.clients[addr]
var err error
if client == nil {
client = &Client{
cfg := &ClientConfig{
Dial: func() (net.Conn, error) {
return m.Dial(addr)
},
ServerCert: m.ServerCert,
ReconnectAttempts: m.ReconnectAttempts,
}
if m.OnId != nil {
client.OnId = func(id PeerId) {
cfg.OnId = func(id PeerId) {
m.OnId(addr, id)
}
}
id, err := client.Connect()
client, err = NewClient(cfg)
if err != nil {
return nil, id, err
return nil, err
}
m.clients[addr] = client
m.ids[addr] = id
}
id := m.ids[addr]
return client, id, nil
return client, nil
}

// Close closes this ClientMgr and all managed clients.
Expand Down
1 change: 1 addition & 0 deletions connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ func (c *Client) connectOnce() (*connInfo, error) {
if c.OnId != nil {
go c.OnId(info.id)
}
c.setCurrentId(info.id)
return info, nil
}

Expand Down
6 changes: 0 additions & 6 deletions topic.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ import (
// Out returns the (one and only) channel for writing to the topic identified by
// the given id.
func (c *Client) Out(id TopicId) chan<- *MessageOut {
if !c.hasConnected() {
panic(notConnectedError.Error())
}
if c.isClosed() {
panic("Attempted to obtain out topic on closed client")
}
Expand All @@ -32,9 +29,6 @@ func (c *Client) Out(id TopicId) chan<- *MessageOut {
// In returns the (one and only) channel for receiving from the topic identified
// by the given id.
func (c *Client) In(id TopicId) <-chan *MessageIn {
if !c.hasConnected() {
panic(notConnectedError.Error())
}
if c.isClosed() {
panic("Attempted to obtain in topic on closed client")
}
Expand Down
75 changes: 25 additions & 50 deletions waddell_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,51 +66,35 @@ func TestTopicIdTruncation(t *testing.T) {
}

func TestBadDialerWithNoReconnect(t *testing.T) {
client := &Client{
cfg := &ClientConfig{
ReconnectAttempts: 0,
Dial: func() (net.Conn, error) {
return nil, fmt.Errorf("I won't dial, no way!")
},
}
defer client.Close()
_, err := client.Connect()
client, err := NewClient(cfg)
assert.Error(t, err, "Connecting with no reconnect attempts should have failed")
if err == nil {
client.Close()
}
}

func TestBadDialerWithMultipleReconnect(t *testing.T) {
client := &Client{
cfg := &ClientConfig{
ReconnectAttempts: 2,
Dial: func() (net.Conn, error) {
return nil, fmt.Errorf("I won't dial, no way!")
},
}
defer client.Close()
start := time.Now()
_, err := client.Connect()
client, err := NewClient(cfg)
defer client.Close()
delta := time.Now().Sub(start)
assert.Error(t, err, "Connecting with 2 reconnect attempts should have failed")
expectedDelta := reconnectDelayInterval * 3
assert.True(t, delta >= expectedDelta, fmt.Sprintf("Reconnecting didn't wait long enough. Should have waited %s, only waited %s", expectedDelta, delta))
}

func TestCloseFailing(t *testing.T) {
client := &Client{
ReconnectAttempts: 100,
Dial: func() (net.Conn, error) {
return nil, fmt.Errorf("I won't dial, no way!")
},
}
go client.Connect()
time.Sleep(100 * time.Millisecond)
client.Close()
}

func TestCloseUnconnected(t *testing.T) {
client := &Client{}
err := client.Close()
assert.Error(t, err, "Closing unconnected client should fail")
}

func TestPeersPlainText(t *testing.T) {
doTestPeers(t, false)
}
Expand All @@ -119,15 +103,10 @@ func TestPeersTLS(t *testing.T) {
doTestPeers(t, true)
}

type clientWithId struct {
id PeerId
client *Client
}

func doTestPeers(t *testing.T, useTLS bool) {
socketsAtStart := countTCPFiles()
closeActions := make([]func(), 0)
peers := make([]*clientWithId, 0, NumPeers)
peers := make([]*Client, 0, NumPeers)
defer func() {
for _, action := range closeActions {
action()
Expand All @@ -136,14 +115,11 @@ func doTestPeers(t *testing.T, useTLS bool) {
assert.Equal(t, socketsAtStart, socketsAtEnd, "All file descriptors should have been closed")

// Make sure we can't do stuff with closed client
client := peers[0].client
client := peers[0]

err := client.SendKeepAlive()
assert.Error(t, err, "Attempting to SendKeepAlive on closed client should fail")

err = client.Close()
assert.Error(t, err, "Attempting to close already closed client should fail")

// Make sure we can't obtain in or out topics after closing clients
defer func() {
defer func() {
Expand Down Expand Up @@ -200,25 +176,23 @@ func doTestPeers(t *testing.T, useTLS bool) {
}

idCallbackTriggered := int32(0)
connect := func() *clientWithId {
client := &Client{
connect := func() *Client {
cfg := &ClientConfig{
Dial: dial,
ServerCert: cert,
ReconnectAttempts: 1,
OnId: func(id PeerId) {
atomic.AddInt32(&idCallbackTriggered, 1)
},
}
id, err := client.Connect()
client, err := NewClient(cfg)
if err != nil {
log.Fatalf("Unable to connect client: %s", err)
}
_, err = client.Connect()
assert.Error(t, err, "Extra connect call should resulst in error")
return &clientWithId{id, client}
return client
}

peersCh := make(chan *clientWithId, NumPeers)
peersCh := make(chan *Client, NumPeers)
// Connect clients
for i := 0; i < NumPeers; i++ {
go func() {
Expand All @@ -232,7 +206,7 @@ func doTestPeers(t *testing.T, useTLS bool) {
peer := <-peersCh
peers = append(peers, peer)
closeActions = append(closeActions, func() {
err := peer.client.Close()
err := peer.Close()
assert.NoError(t, err, "Closing client shouldn't result in error")
})
}
Expand Down Expand Up @@ -265,11 +239,12 @@ func doTestPeers(t *testing.T, useTLS bool) {
assert.Nil(t, errs, "Closing clientMgr shouldn't result in errors")
})

badPeer, badPeerId, err := clientMgr.ClientTo(serverAddr)
badPeer, err := clientMgr.ClientTo(serverAddr)
if err != nil {
log.Fatalf("Unable to connect bad peer: %s", err)
}
cbMutex.Lock()
badPeerId := badPeer.CurrentId()
assert.Equal(t, serverAddr, cbAddr, "IdCallback should have recorded the server's addr")
assert.Equal(t, badPeerId, cbId, "IdCallback should have recorded the correct id")
cbMutex.Unlock()
Expand All @@ -294,13 +269,13 @@ func doTestPeers(t *testing.T, useTLS bool) {
// Write to each reader
for j := 0; j < NumPeers; j += 2 {
recip := peers[j]
peer.client.Out(TestTopic) <- Message(recip.id, []byte(Hello[:2]), []byte(Hello[2:]))
peer.Out(TestTopic) <- Message(recip.CurrentId(), []byte(Hello[:2]), []byte(Hello[2:]))
if err != nil {
log.Fatalf("Unable to write hello: %s", err)
} else {
resp := <-peer.client.In(TestTopic)
assert.Equal(t, fmt.Sprintf(HelloYourself, peer.id), string(resp.Body), "Response should match expected.")
assert.Equal(t, recip.id, resp.From, "Peer on response should match expected")
resp := <-peer.In(TestTopic)
assert.Equal(t, fmt.Sprintf(HelloYourself, peer.CurrentId()), string(resp.Body), "Response should match expected.")
assert.Equal(t, recip.CurrentId(), resp.From, "Peer on response should match expected")
}
}
}()
Expand All @@ -311,13 +286,13 @@ func doTestPeers(t *testing.T, useTLS bool) {

// Read from all readers
for j := 1; j < NumPeers; j += 2 {
err := peer.client.SendKeepAlive()
err := peer.SendKeepAlive()
if err != nil {
log.Fatalf("Unable to send KeepAlive: %s", err)
}
msg := <-peer.client.In(TestTopic)
msg := <-peer.In(TestTopic)
assert.Equal(t, Hello, string(msg.Body), "Hello message should match expected")
peer.client.Out(TestTopic) <- Message(msg.From, []byte(fmt.Sprintf(HelloYourself, msg.From)))
peer.Out(TestTopic) <- Message(msg.From, []byte(fmt.Sprintf(HelloYourself, msg.From)))
if err != nil {
log.Fatalf("Unable to write response to HELLO message: %s", err)
}
Expand Down

0 comments on commit d4cc22e

Please sign in to comment.