Skip to content

Commit

Permalink
Add Client.Timeout to allow limiting total exchange duration (#345)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbond authored and miekg committed Apr 19, 2016
1 parent a5cc44d commit c9d1302
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 5 deletions.
26 changes: 21 additions & 5 deletions client.go
Expand Up @@ -28,9 +28,10 @@ type Client struct {
Net string // if "tcp" or "tcp-tls" (DNS over TLS) a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
UDPSize uint16 // minimum receive buffer for UDP messages
TLSConfig *tls.Config // TLS connection configuration
DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds
Timeout time.Duration // a cumulative timeout for dial, write and read, defaults to 0 (disabled) - overrides DialTimeout, ReadTimeout and WriteTimeout when non-zero
DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds - overridden by Timeout when that value is non-zero
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds - overridden by Timeout when that value is non-zero
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified
SingleInflight bool // if true suppress multiple outstanding queries for the same Qname, Qtype and Qclass
group singleflight
Expand Down Expand Up @@ -129,6 +130,9 @@ func (c *Client) Exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
}

func (c *Client) dialTimeout() time.Duration {
if c.Timeout != 0 {
return c.Timeout
}
if c.DialTimeout != 0 {
return c.DialTimeout
}
Expand Down Expand Up @@ -170,6 +174,11 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
}
}

var deadline time.Time
if c.Timeout != 0 {
deadline = time.Now().Add(c.Timeout)
}

if tls {
co, err = DialTimeoutWithTLS(network, a, c.TLSConfig, c.dialTimeout())
} else {
Expand All @@ -192,12 +201,12 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
}

co.TsigSecret = c.TsigSecret
co.SetWriteDeadline(time.Now().Add(c.writeTimeout()))
co.SetWriteDeadline(deadlineOrTimeout(deadline, c.writeTimeout()))
if err = co.WriteMsg(m); err != nil {
return nil, 0, err
}

co.SetReadDeadline(time.Now().Add(c.readTimeout()))
co.SetReadDeadline(deadlineOrTimeout(deadline, c.readTimeout()))
r, err = co.ReadMsg()
if err == nil && r.Id != m.Id {
err = ErrId
Expand Down Expand Up @@ -434,3 +443,10 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout
}
return conn, nil
}

func deadlineOrTimeout(deadline time.Time, timeout time.Duration) time.Time {
if deadline.IsZero() {
return time.Now().Add(timeout)
}
return deadline
}
48 changes: 48 additions & 0 deletions client_test.go
Expand Up @@ -419,3 +419,51 @@ func TestTruncatedMsg(t *testing.T) {
t.Fail()
}
}

func TestTimeout(t *testing.T) {
// Set up a dummy UDP server that won't respond
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("unable to resolve local udp address: %v", err)
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
defer conn.Close()
addrstr := conn.LocalAddr().String()

// Message to send
m := new(Msg)
m.SetQuestion("miek.nl.", TypeTXT)

// Use a channel + timeout to ensure we don't get stuck if the
// Client Timeout is not working properly
done := make(chan struct{})

timeout := time.Millisecond
allowable := timeout + (10 * time.Millisecond)
abortAfter := timeout + (100 * time.Millisecond)

start := time.Now()

go func() {
c := &Client{Timeout: timeout}
_, _, err := c.Exchange(m, addrstr)
if err == nil {
t.Error("no timeout using Client")
}
done <- struct{}{}
}()

select {
case <-done:
case <-time.After(abortAfter):
}

length := time.Since(start)

if length > allowable {
t.Errorf("exchange took longer (%v) than specified Timeout (%v)", length, timeout)
}
}

0 comments on commit c9d1302

Please sign in to comment.