Skip to content

Commit

Permalink
Added timeout to listener
Browse files Browse the repository at this point in the history
  • Loading branch information
oxtoacart committed Nov 25, 2014
1 parent c66c119 commit ac3ed04
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 8 deletions.
27 changes: 20 additions & 7 deletions serveme.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ import (

var (
log = golog.LoggerFor("serveme")

defaultDialTimeout = 30 * time.Second
defaultListenerTimeout = 30 * time.Second
)

// ServerId is an opaque identifier for a server used by the signaling channel.
Expand Down Expand Up @@ -69,7 +72,7 @@ func At(network, address string) (*Dialer, error) {
}
requests := make(chan *Request, 1000)
d := &Dialer{
Timeout: 30 * time.Second,
Timeout: defaultDialTimeout,
Requests: requests,
requests: requests,
network: network,
Expand Down Expand Up @@ -134,8 +137,7 @@ func (d *Dialer) run() {
log.Tracef("Unable to accept: %s", err)
return
}
fr := framed.NewReader(conn)
_, err = fr.Read(b)
_, err = framed.NewReader(conn).Read(b)
if err != nil {
log.Tracef("Unable to read conn id bytes: %s", err)
conn.Close()
Expand All @@ -161,6 +163,12 @@ func (d *Dialer) run() {

// Listener implements the net.Listener interface.
type Listener struct {
// Timeout controls how long the Listener is willing to take for dialing
// the client and writing the connection id. It defaults to 30 seconds.
Timeout time.Duration

// Requests is a channel on which to post requests to dial a dialer received
// from the signaling channel.
Requests chan<- *Request
requests chan *Request
}
Expand All @@ -173,6 +181,7 @@ type Listener struct {
func Listen() *Listener {
requests := make(chan *Request, 1000)
return &Listener{
Timeout: defaultListenerTimeout,
Requests: requests,
requests: requests,
}
Expand All @@ -182,14 +191,18 @@ func Listen() *Listener {
func (l *Listener) Accept() (net.Conn, error) {
req := <-l.requests
log.Tracef("Dialing %s %s", req.Network, req.Address)
conn, err := net.Dial(req.Network, req.Address)
start := time.Now()
conn, err := net.DialTimeout(req.Network, req.Address, l.Timeout)
if err != nil {
return nil, err
}
fw := framed.NewWriter(conn)
dialTime := time.Now().Sub(start)

_, _, err = withtimeout.Do(l.Timeout-dialTime, func() (interface{}, error) {
log.Trace("Writing connection id to identify connection")
return framed.NewWriter(conn).Write(req.ID.ToBytes())
})

log.Trace("Writing connection id to identify connection")
_, err = fw.Write(req.ID.ToBytes())
if err != nil {
conn.Close()
return nil, fmt.Errorf("Unable to write connection id: %s", err)
Expand Down
29 changes: 28 additions & 1 deletion serveme_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func Test(t *testing.T) {
assert.Equal(t, []byte("Message 2"), b2[:n2], "Conn1 should have read 'Message 2'")
}

func TestRapidTimeout(t *testing.T) {
func TestRapidDialTimeout(t *testing.T) {
numFilesAtStart := countTCPFiles()
dialer, err := At("tcp", "localhost:0")
if err != nil {
Expand All @@ -104,6 +104,33 @@ func TestRapidTimeout(t *testing.T) {
}
}

func TestRapidListenerTimeout(t *testing.T) {
numFilesAtStart := countTCPFiles()
dialer, err := At("tcp", "localhost:0")
if err != nil {
t.Fatalf("Unable to start dialer: %s", err)
}
l := Listen()
l.Timeout = 1 * time.Nanosecond
go func() {
dialer.Dial(nil)
}()

defer func() {
dialer.Close()
l.Close()
numFilesAtEnd := countTCPFiles()
assert.Equal(t, numFilesAtStart, numFilesAtEnd, "Number of TCP file descriptors should have remained constant between start and end of test")
}()

l.Requests <- <-dialer.Requests
_, err = l.Accept()
assert.Error(t, err, "Accept should have errored")
if err != nil {
assert.Contains(t, err.Error(), "timeout", "Error should have been timeout error")
}
}

func testListener(t *testing.T, msg string) *Listener {
l := Listen()
go func() {
Expand Down

0 comments on commit ac3ed04

Please sign in to comment.