Skip to content

Commit

Permalink
net/http2: perform connection health check
Browse files Browse the repository at this point in the history
After the connection has been idle for a while, periodic pings are sent
over the connection to check its health. Unhealthy connection is closed
and removed from the connection pool.

Fixes golang/go#31643
  • Loading branch information
Chao Xu committed Mar 11, 2020
1 parent aa69164 commit 36607fe
Show file tree
Hide file tree
Showing 2 changed files with 221 additions and 5 deletions.
66 changes: 61 additions & 5 deletions http2/transport.go
Expand Up @@ -108,6 +108,19 @@ type Transport struct {
// waiting for their turn.
StrictMaxConcurrentStreams bool

// ReadIdleTimeout is the timeout after which a health check using ping
// frame will be carried out if no frame is received on the connection.
// Note that a ping response will is considered a received frame, so if
// there is no other traffic on the connection, the health check will
// be performed every ReadIdleTimeout interval.
// If zero, no health check is performed.
ReadIdleTimeout time.Duration

// PingTimeout is the timeout after which the connection will be closed
// if a response to Ping is not received.
// Defaults to 15s.
PingTimeout time.Duration

// t1, if non-nil, is the standard library Transport using
// this transport. Its settings are used (but not its
// RoundTrip method, etc).
Expand All @@ -131,6 +144,14 @@ func (t *Transport) disableCompression() bool {
return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
}

func (t *Transport) pingTimeout() time.Duration {
if t.PingTimeout == 0 {
return 15 * time.Second
}
return t.PingTimeout

}

// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2.
// It returns an error if t1 has already been HTTP/2-enabled.
func ConfigureTransport(t1 *http.Transport) error {
Expand Down Expand Up @@ -674,6 +695,20 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
return cc, nil
}

func (cc *ClientConn) healthCheck() {
pingTimeout := cc.t.pingTimeout()
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will
// trigger the healthCheck again if there is no frame received.
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
err := cc.Ping(ctx)
if err != nil {
cc.closeForLostPing()
cc.t.connPool().MarkDead(cc)
return
}
}

func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
cc.mu.Lock()
defer cc.mu.Unlock()
Expand Down Expand Up @@ -834,14 +869,12 @@ func (cc *ClientConn) sendGoAway() error {
return nil
}

// Close closes the client connection immediately.
//
// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
func (cc *ClientConn) Close() error {
// closes the client connection immediately. In-flight requests are interrupted.
// err is sent to streams.
func (cc *ClientConn) closeForError(err error) error {
cc.mu.Lock()
defer cc.cond.Broadcast()
defer cc.mu.Unlock()
err := errors.New("http2: client connection force closed via ClientConn.Close")
for id, cs := range cc.streams {
select {
case cs.resc <- resAndError{err: err}:
Expand All @@ -854,6 +887,20 @@ func (cc *ClientConn) Close() error {
return cc.tconn.Close()
}

// Close closes the client connection immediately.
//
// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
func (cc *ClientConn) Close() error {
err := errors.New("http2: client connection force closed via ClientConn.Close")
return cc.closeForError(err)
}

// closes the client connection immediately. In-flight requests are interrupted.
func (cc *ClientConn) closeForLostPing() error {
err := errors.New("http2: client connection lost")
return cc.closeForError(err)
}

const maxAllocFrameSize = 512 << 10

// frameBuffer returns a scratch buffer suitable for writing DATA frames.
Expand Down Expand Up @@ -1706,8 +1753,17 @@ func (rl *clientConnReadLoop) run() error {
rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse
gotReply := false // ever saw a HEADERS reply
gotSettings := false
readIdleTimeout := cc.t.ReadIdleTimeout
var t *time.Timer
if readIdleTimeout != 0 {
t = time.AfterFunc(readIdleTimeout, cc.healthCheck)
defer t.Stop()
}
for {
f, err := cc.fr.ReadFrame()
if t != nil {
t.Reset(readIdleTimeout)
}
if err != nil {
cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err)
}
Expand Down
160 changes: 160 additions & 0 deletions http2/transport_test.go
Expand Up @@ -3244,6 +3244,166 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
req.Header = http.Header{}
}

func TestTransportCloseAfterLostPing(t *testing.T) {
clientDone := make(chan struct{})
ct := newClientTester(t)
ct.tr.PingTimeout = 1 * time.Second
ct.tr.ReadIdleTimeout = 1 * time.Second
ct.client = func() error {
defer ct.cc.(*net.TCPConn).CloseWrite()
defer close(clientDone)
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
_, err := ct.tr.RoundTrip(req)
if err == nil || !strings.Contains(err.Error(), "client connection lost") {
return fmt.Errorf("expected to get error about \"connection lost\", got %v", err)
}
return nil
}
ct.server = func() error {
ct.greet()
<-clientDone
return nil
}
ct.run()
}

func TestTransportPingWhenReading(t *testing.T) {
testCases := []struct {
name string
readIdleTimeout time.Duration
serverResponseInterval time.Duration
expectedPingCount int
}{
{
name: "two pings in each serverResponseInterval",
readIdleTimeout: 400 * time.Millisecond,
serverResponseInterval: 1000 * time.Millisecond,
expectedPingCount: 4,
},
{
name: "one ping in each serverResponseInterval",
readIdleTimeout: 700 * time.Millisecond,
serverResponseInterval: 1000 * time.Millisecond,
expectedPingCount: 2,
},
{
name: "zero ping in each serverResponseInterval",
readIdleTimeout: 1000 * time.Millisecond,
serverResponseInterval: 500 * time.Millisecond,
expectedPingCount: 0,
},
{
name: "0 readIdleTimeout means no ping",
readIdleTimeout: 0 * time.Millisecond,
serverResponseInterval: 500 * time.Millisecond,
expectedPingCount: 0,
},
}

for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
testTransportPingWhenReading(t, tc.readIdleTimeout, tc.serverResponseInterval, tc.expectedPingCount)
})
}
}

func testTransportPingWhenReading(t *testing.T, readIdleTimeout, serverResponseInterval time.Duration, expectedPingCount int) {
var pingCount int
clientDone := make(chan struct{})
ct := newClientTester(t)
ct.tr.PingTimeout = 10 * time.Millisecond
ct.tr.ReadIdleTimeout = readIdleTimeout
// guards the ct.fr.Write
var wmu sync.Mutex

ct.client = func() error {
defer ct.cc.(*net.TCPConn).CloseWrite()
defer close(clientDone)
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
res, err := ct.tr.RoundTrip(req)
if err != nil {
return fmt.Errorf("RoundTrip: %v", err)
}
defer res.Body.Close()
if res.StatusCode != 200 {
return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200)
}
_, err = ioutil.ReadAll(res.Body)
return err
}

ct.server = func() error {
ct.greet()
var buf bytes.Buffer
enc := hpack.NewEncoder(&buf)
for {
f, err := ct.fr.ReadFrame()
if err != nil {
select {
case <-clientDone:
// If the client's done, it
// will have reported any
// errors on its side.
return nil
default:
return err
}
}
switch f := f.(type) {
case *WindowUpdateFrame, *SettingsFrame:
case *HeadersFrame:
if !f.HeadersEnded() {
return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
}
enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)})
ct.fr.WriteHeaders(HeadersFrameParam{
StreamID: f.StreamID,
EndHeaders: true,
EndStream: false,
BlockFragment: buf.Bytes(),
})

go func() {
for i := 0; i < 2; i++ {
wmu.Lock()
if err := ct.fr.WriteData(f.StreamID, false, []byte(fmt.Sprintf("hello, this is server data frame %d", i))); err != nil {
wmu.Unlock()
t.Error(err)
return
}
wmu.Unlock()
time.Sleep(serverResponseInterval)
}
wmu.Lock()
if err := ct.fr.WriteData(f.StreamID, true, []byte("hello, this is last server data frame")); err != nil {
wmu.Unlock()
t.Error(err)
return
}
wmu.Unlock()
}()
case *PingFrame:
pingCount++
wmu.Lock()
if err := ct.fr.WritePing(true, f.Data); err != nil {
wmu.Unlock()
return err
}
wmu.Unlock()
default:
return fmt.Errorf("Unexpected client frame %v", f)
}
}
}
ct.run()
if e, a := expectedPingCount, pingCount; e != a {
t.Errorf("expected receiving %d pings, got %d pings", e, a)

}
}

func TestTransportRetryAfterGOAWAY(t *testing.T) {
var dialer struct {
sync.Mutex
Expand Down

0 comments on commit 36607fe

Please sign in to comment.