Skip to content

Commit

Permalink
fix: dialtimeout is zero when not set reqTimeout (#1014)
Browse files Browse the repository at this point in the history
  • Loading branch information
Duslia committed Dec 4, 2023
1 parent 22aafdc commit 21aae1d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
2 changes: 1 addition & 1 deletion pkg/protocol/http1/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo
begin := req.Options().StartTime()

dialTimeout := rc.dialTimeout
if reqTimeout < dialTimeout || dialTimeout == 0 {
if (reqTimeout > 0 && reqTimeout < dialTimeout) || dialTimeout == 0 {
dialTimeout = reqTimeout
}
cc, inPool, err := c.acquireConn(dialTimeout)
Expand Down
49 changes: 35 additions & 14 deletions pkg/protocol/http1/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {

c := &HostClient{
ClientOptions: &ClientOptions{
Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) {
Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
return mock.SlowReadDialer(addr)
}),
MaxConns: 1,
Expand Down Expand Up @@ -212,16 +212,16 @@ func testContinueReadResponseBodyStream(t *testing.T, header, body string, maxBo
}
}

func newSlowConnDialer(dialer func(network, addr string) (network.Conn, error)) network.Dialer {
func newSlowConnDialer(dialer func(network, addr string, timeout time.Duration) (network.Conn, error)) network.Dialer {
return &mockDialer{customDialConn: dialer}
}

type mockDialer struct {
customDialConn func(network, addr string) (network.Conn, error)
customDialConn func(network, addr string, timeout time.Duration) (network.Conn, error)
}

func (m *mockDialer) DialConnection(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn network.Conn, err error) {
return m.customDialConn(network, address)
return m.customDialConn(network, address, timeout)
}

func (m *mockDialer) DialTimeout(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn net.Conn, err error) {
Expand All @@ -244,7 +244,7 @@ func (s *slowDialer) DialConnection(network, address string, timeout time.Durati
func TestReadTimeoutPriority(t *testing.T) {
c := &HostClient{
ClientOptions: &ClientOptions{
Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) {
Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
return mock.SlowReadDialer(addr)
}),
MaxConns: 1,
Expand Down Expand Up @@ -274,7 +274,7 @@ func TestReadTimeoutPriority(t *testing.T) {
func TestDoNonNilReqResp(t *testing.T) {
c := &HostClient{
ClientOptions: &ClientOptions{
Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) {
Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
return &writeErrConn{
Conn: mock.NewConn("HTTP/1.1 400 OK\nContent-Length: 6\n\n123456"),
},
Expand All @@ -295,7 +295,7 @@ func TestDoNonNilReqResp(t *testing.T) {
func TestDoNonNilReqResp1(t *testing.T) {
c := &HostClient{
ClientOptions: &ClientOptions{
Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) {
Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
return &writeErrConn{
Conn: mock.NewConn(""),
},
Expand All @@ -314,7 +314,7 @@ func TestDoNonNilReqResp1(t *testing.T) {
func TestWriteTimeoutPriority(t *testing.T) {
c := &HostClient{
ClientOptions: &ClientOptions{
Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) {
Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
return mock.SlowWriteDialer(addr)
}),
MaxConns: 1,
Expand Down Expand Up @@ -376,7 +376,7 @@ func TestStateObserve(t *testing.T) {
}{}
c := &HostClient{
ClientOptions: &ClientOptions{
Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) {
Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
return mock.SlowReadDialer(addr)
}),
StateObserve: func(hcs config.HostClientState) {
Expand Down Expand Up @@ -404,7 +404,7 @@ func TestStateObserve(t *testing.T) {
func TestCachedTLSConfig(t *testing.T) {
c := &HostClient{
ClientOptions: &ClientOptions{
Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) {
Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
return mock.SlowReadDialer(addr)
}),
TLSConfig: &tls.Config{
Expand All @@ -426,7 +426,7 @@ func TestRetry(t *testing.T) {
var times int32
c := &HostClient{
ClientOptions: &ClientOptions{
Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) {
Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
times++
if times < 3 {
return &retryConn{
Expand Down Expand Up @@ -486,7 +486,7 @@ func (w retryConn) SetWriteTimeout(t time.Duration) error {
func TestConnInPoolRetry(t *testing.T) {
c := &HostClient{
ClientOptions: &ClientOptions{
Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) {
Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
return mock.NewOneTimeConn("HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: foo/bar\r\n\r\n0123456789"), nil
}),
},
Expand Down Expand Up @@ -518,7 +518,7 @@ func TestConnInPoolRetry(t *testing.T) {
func TestConnNotRetry(t *testing.T) {
c := &HostClient{
ClientOptions: &ClientOptions{
Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) {
Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
return mock.NewBrokenConn(""), nil
}),
},
Expand Down Expand Up @@ -558,7 +558,7 @@ func TestStreamNoContent(t *testing.T) {

c := &HostClient{
ClientOptions: &ClientOptions{
Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) {
Dialer: newSlowConnDialer(func(network, addr string, timeout time.Duration) (network.Conn, error) {
return conn, nil
}),
},
Expand All @@ -576,3 +576,24 @@ func TestStreamNoContent(t *testing.T) {

assert.True(t, conn.isClose)
}

func TestDialTimeout(t *testing.T) {
c := &HostClient{
ClientOptions: &ClientOptions{
DialTimeout: time.Second * 10,
Dialer: &mockDialer{
customDialConn: func(network, addr string, timeout time.Duration) (network.Conn, error) {
assert.DeepEqual(t, time.Second*10, timeout)
return nil, errors.New("test error")
},
},
},
Addr: "foobar",
}

req := protocol.AcquireRequest()
req.SetRequestURI("http://foobar/baz")
resp := protocol.AcquireResponse()

c.Do(context.Background(), req, resp)
}

0 comments on commit 21aae1d

Please sign in to comment.