Permalink
Browse files

Add dialer to abstract reconnect logic

  • Loading branch information...
1 parent de670de commit 2c5dc360153442a388aa08cbc8862d9d9fb2d640 @pietern pietern committed Jun 2, 2012
Showing with 144 additions and 2 deletions.
  1. +7 −1 client.go
  2. +1 −1 client_test.go
  3. +82 −0 dial.go
  4. +54 −0 dial_test.go
View
@@ -316,9 +316,15 @@ func (t *Client) runConnection(n net.Conn) error {
return e
}
-func (t *Client) Run(n net.Conn, h Handshaker) error {
+func (t *Client) Run(d Dialer, h Handshaker) error {
+ var n net.Conn
var e error
+ n, e = d.Dial()
+ if e != nil {
+ return e
+ }
+
n, e = h.Handshake(n)
if e != nil {
return e
View
@@ -32,7 +32,7 @@ func (tc *testClient) Setup(t *testing.T) {
tc.Add(1)
go func() {
- tc.ec <- tc.c.Run(tc.nc, EmptyHandshake)
+ tc.ec <- tc.c.Run(DumbDialer{tc.nc}, EmptyHandshake)
tc.Done()
}()
}
View
@@ -0,0 +1,82 @@
+package nats
+
+import (
+ "net"
+ "time"
+)
+
+type Dialer interface {
+ Dial() (net.Conn, error)
+}
+
+type DumbDialer struct {
+ Conn net.Conn
+}
+
+func (d DumbDialer) Dial() (net.Conn, error) {
+ return d.Conn, nil
+}
+
+type RetryingDialer struct {
+ // The dialer
+ f func(addr string) (net.Conn, error)
+
+ // The sleeper
+ s func(i uint)
+
+ // Address to connect to
+ Addr string
+
+ // Maximum number of connection attempts
+ MaxAttempts uint
+}
+
+func (d RetryingDialer) Dial() (net.Conn, error) {
+ var i uint
+ var n net.Conn
+ var e error
+
+ for ; ; i++ {
+ if d.MaxAttempts > 0 && i >= d.MaxAttempts {
+ break
+ }
+
+ n, e = d.f(d.Addr)
+ if n != nil {
+ return n, nil
+ }
+
+ d.s(i)
+ }
+
+ if e == nil {
+ panic("expected an error")
+ }
+
+ return nil, e
+}
+
+func DefaultDialer(addr string) Dialer {
+ var d RetryingDialer
+
+ d.f = func(addr string) (net.Conn, error) {
+ return net.Dial("tcp", addr)
+ }
+
+ d.s = func(i uint) {
+ var exp uint = i + 3
+ if exp > 12 {
+ exp = 12
+ }
+
+ // Sleep between 8ms and 4096ms
+ time.Sleep((1 << exp) * time.Millisecond)
+ }
+
+ d.Addr = addr
+
+ // Retry forever
+ d.MaxAttempts = 0
+
+ return d
+}
View
@@ -0,0 +1,54 @@
+package nats
+
+import (
+ "testing"
+ "fmt"
+ "net"
+)
+
+var ErrWhatever = fmt.Errorf("whatever")
+
+func TestDialMaxAttempts(t *testing.T) {
+ d := DefaultDialer("address").(RetryingDialer)
+
+ var i uint = 0
+
+ // Fail every time
+ d.f = func(addr string) (net.Conn, error) {
+ i++
+ return nil, ErrWhatever
+ }
+
+ // Don't sleep
+ d.s = func(i uint) {
+ }
+
+ d.MaxAttempts = 2
+
+ _, e := d.Dial()
+ if e != ErrWhatever {
+ t.Errorf("Expected: %#v, got: %#v", ErrWhatever, e)
+ return
+ }
+
+ if i != d.MaxAttempts {
+ t.Errorf("Expected: %#v, got: %#v", d.MaxAttempts, i)
+ return
+ }
+}
+
+func TestDialSuccess(t *testing.T) {
+ d := DefaultDialer("address").(RetryingDialer)
+
+ // Succeed
+ d.f = func(addr string) (net.Conn, error) {
+ n, _ := net.Pipe()
+ return n, nil
+ }
+
+ _, e := d.Dial()
+ if e != nil {
+ t.Errorf("Error: %#v", e)
+ return
+ }
+}

0 comments on commit 2c5dc36

Please sign in to comment.