diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go index 00824e754cb36..fbed45070c880 100644 --- a/src/net/http/export_test.go +++ b/src/net/http/export_test.go @@ -87,6 +87,12 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) { return } +func (t *Transport) IdleConnKeyCountForTesting() int { + t.idleMu.Lock() + defer t.idleMu.Unlock() + return len(t.idleConn) +} + func (t *Transport) IdleConnStrsForTesting() []string { var ret []string t.idleMu.Lock() diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go index da7c02578cd5a..f8398adb9239b 100644 --- a/src/net/http/h2_bundle.go +++ b/src/net/http/h2_bundle.go @@ -2982,10 +2982,6 @@ func (s *http2Server) maxConcurrentStreams() uint32 { return http2defaultMaxStreams } -// List of funcs for ConfigureServer to run. Both h1 and h2 are guaranteed -// to be non-nil. -var http2configServerFuncs []func(h1 *Server, h2 *http2Server) error - // ConfigureServer adds HTTP/2 support to a net/http Server. // // The configuration conf may be nil. @@ -3512,6 +3508,11 @@ func (sc *http2serverConn) serve() { sc.idleTimerCh = sc.idleTimer.C } + var gracefulShutdownCh <-chan struct{} + if sc.hs != nil { + gracefulShutdownCh = http2h1ServerShutdownChan(sc.hs) + } + go sc.readFrames() settingsTimer := time.NewTimer(http2firstSettingsTimeout) @@ -3539,6 +3540,9 @@ func (sc *http2serverConn) serve() { case <-settingsTimer.C: sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr()) return + case <-gracefulShutdownCh: + gracefulShutdownCh = nil + sc.goAwayIn(http2ErrCodeNo, 0) case <-sc.shutdownTimerCh: sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) return @@ -3548,6 +3552,10 @@ func (sc *http2serverConn) serve() { case fn := <-sc.testHookCh: fn(loopNum) } + + if sc.inGoAway && sc.curClientStreams == 0 && !sc.needToSendGoAway && !sc.writingFrame { + return + } } } @@ -3803,7 +3811,7 @@ func (sc *http2serverConn) scheduleFrameWrite() { sc.startFrameWrite(http2FrameWriteRequest{write: http2writeSettingsAck{}}) continue } - if !sc.inGoAway { + if !sc.inGoAway || sc.goAwayCode == http2ErrCodeNo { if wr, ok := sc.writeSched.Pop(); ok { sc.startFrameWrite(wr) continue @@ -3821,14 +3829,23 @@ func (sc *http2serverConn) scheduleFrameWrite() { func (sc *http2serverConn) goAway(code http2ErrCode) { sc.serveG.check() - if sc.inGoAway { - return - } + var forceCloseIn time.Duration if code != http2ErrCodeNo { - sc.shutDownIn(250 * time.Millisecond) + forceCloseIn = 250 * time.Millisecond } else { - sc.shutDownIn(1 * time.Second) + forceCloseIn = 1 * time.Second + } + sc.goAwayIn(code, forceCloseIn) +} + +func (sc *http2serverConn) goAwayIn(code http2ErrCode, forceCloseIn time.Duration) { + sc.serveG.check() + if sc.inGoAway { + return + } + if forceCloseIn != 0 { + sc.shutDownIn(forceCloseIn) } sc.inGoAway = true sc.needToSendGoAway = true @@ -5264,6 +5281,31 @@ var http2badTrailer = map[string]bool{ "Www-Authenticate": true, } +// h1ServerShutdownChan returns a channel that will be closed when the +// provided *http.Server wants to shut down. +// +// This is a somewhat hacky way to get at http1 innards. It works +// when the http2 code is bundled into the net/http package in the +// standard library. The alternatives ended up making the cmd/go tool +// depend on http Servers. This is the lightest option for now. +// This is tested via the TestServeShutdown* tests in net/http. +func http2h1ServerShutdownChan(hs *Server) <-chan struct{} { + if fn := http2testh1ServerShutdownChan; fn != nil { + return fn(hs) + } + var x interface{} = hs + type I interface { + getDoneChan() <-chan struct{} + } + if hs, ok := x.(I); ok { + return hs.getDoneChan() + } + return nil +} + +// optional test hook for h1ServerShutdownChan. +var http2testh1ServerShutdownChan func(hs *Server) <-chan struct{} + const ( // transportDefaultConnFlow is how many connection-level flow control // tokens we give the server at start-up, past the default 64k. diff --git a/src/net/http/http_test.go b/src/net/http/http_test.go index c6c38ffcaedae..aaae67cf2918e 100644 --- a/src/net/http/http_test.go +++ b/src/net/http/http_test.go @@ -12,8 +12,13 @@ import ( "os/exec" "reflect" "testing" + "time" ) +func init() { + shutdownPollInterval = 5 * time.Millisecond +} + func TestForeachHeaderElement(t *testing.T) { tests := []struct { in string diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 7834478352cd6..f855c358221da 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -4832,3 +4832,96 @@ func TestServerIdleTimeout(t *testing.T) { t.Fatal("copy byte succeeded; want err") } } + +func get(t *testing.T, c *Client, url string) string { + res, err := c.Get(url) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + return string(slurp) +} + +// Tests that calls to Server.SetKeepAlivesEnabled(false) closes any +// currently-open connections. +func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { + if runtime.GOOS == "nacl" { + t.Skip("skipping on nacl; see golang.org/issue/17695") + } + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + io.WriteString(w, r.RemoteAddr) + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + get := func() string { return get(t, c, ts.URL) } + + a1, a2 := get(), get() + if a1 != a2 { + t.Fatal("expected first two requests on same connection") + } + var idle0 int + if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool { + idle0 = tr.IdleConnKeyCountForTesting() + return idle0 == 1 + }) { + t.Fatalf("idle count before SetKeepAlivesEnabled called = %v; want 1", idle0) + } + + ts.Config.SetKeepAlivesEnabled(false) + + var idle1 int + if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool { + idle1 = tr.IdleConnKeyCountForTesting() + return idle1 == 0 + }) { + t.Fatalf("idle count after SetKeepAlivesEnabled called = %v; want 0", idle1) + } + + a3 := get() + if a3 == a2 { + t.Fatal("expected third request on new connection") + } +} + +func TestServerShutdown_h1(t *testing.T) { testServerShutdown(t, h1Mode) } +func TestServerShutdown_h2(t *testing.T) { testServerShutdown(t, h2Mode) } + +func testServerShutdown(t *testing.T, h2 bool) { + defer afterTest(t) + var doShutdown func() // set later + var shutdownRes = make(chan error, 1) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + go doShutdown() + // Shutdown is graceful, so it should not interrupt + // this in-flight response. Add a tiny sleep here to + // increase the odds of a failure if shutdown has + // bugs. + time.Sleep(20 * time.Millisecond) + io.WriteString(w, r.RemoteAddr) + })) + defer cst.close() + + doShutdown = func() { + shutdownRes <- cst.ts.Config.Shutdown(context.Background()) + } + get(t, cst.c, cst.ts.URL) // calls t.Fail on failure + + if err := <-shutdownRes; err != nil { + t.Fatalf("Shutdown: %v", err) + } + + res, err := cst.c.Get(cst.ts.URL) + if err == nil { + res.Body.Close() + t.Fatal("second request should fail. server should be shut down") + } +} diff --git a/src/net/http/server.go b/src/net/http/server.go index 51a66c37d53f0..eae065f6730f4 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -248,6 +248,8 @@ type conn struct { curReq atomic.Value // of *response (which has a Request in it) + curState atomic.Value // of ConnectionState + // mu guards hijackedv mu sync.Mutex @@ -1586,11 +1588,30 @@ func validNPN(proto string) bool { } func (c *conn) setState(nc net.Conn, state ConnState) { - if hook := c.server.ConnState; hook != nil { + srv := c.server + switch state { + case StateNew: + srv.trackConn(c, true) + case StateHijacked, StateClosed: + srv.trackConn(c, false) + } + c.curState.Store(connStateInterface[state]) + if hook := srv.ConnState; hook != nil { hook(nc, state) } } +// connStateInterface is an array of the interface{} versions of +// ConnState values, so we can use them in atomic.Values later without +// paying the cost of shoving their integers in an interface{}. +var connStateInterface = [...]interface{}{ + StateNew: StateNew, + StateActive: StateActive, + StateIdle: StateIdle, + StateHijacked: StateHijacked, + StateClosed: StateClosed, +} + // badRequestError is a literal string (used by in the server in HTML, // unescaped) to tell the user why their request was bad. It should // be plain text without user info or other embedded errors. @@ -2247,8 +2268,120 @@ type Server struct { ErrorLog *log.Logger disableKeepAlives int32 // accessed atomically. + inShutdown int32 // accessed atomically (non-zero means we're in Shutdown) nextProtoOnce sync.Once // guards setupHTTP2_* init nextProtoErr error // result of http2.ConfigureServer if used + + mu sync.Mutex + listeners map[net.Listener]struct{} + activeConn map[*conn]struct{} + doneChan chan struct{} +} + +func (s *Server) getDoneChan() <-chan struct{} { + s.mu.Lock() + defer s.mu.Unlock() + return s.getDoneChanLocked() +} + +func (s *Server) getDoneChanLocked() chan struct{} { + if s.doneChan == nil { + s.doneChan = make(chan struct{}) + } + return s.doneChan +} + +func (s *Server) closeDoneChanLocked() { + ch := s.getDoneChanLocked() + select { + case <-ch: + // Already closed. Don't close again. + default: + // Safe to close here. We're the only closer, guarded + // by s.mu. + close(ch) + } +} + +// Close immediately closes all active net.Listeners and connections, +// regardless of their state. For a graceful shutdown, use Shutdown. +func (s *Server) Close() error { + s.mu.Lock() + defer s.mu.Lock() + s.closeDoneChanLocked() + err := s.closeListenersLocked() + for c := range s.activeConn { + c.rwc.Close() + delete(s.activeConn, c) + } + return err +} + +// shutdownPollInterval is how often we poll for quiescence +// during Server.Shutdown. This is lower during tests, to +// speed up tests. +// Ideally we could find a solution that doesn't involve polling, +// but which also doesn't have a high runtime cost (and doesn't +// involve any contentious mutexes), but that is left as an +// exercise for the reader. +var shutdownPollInterval = 500 * time.Millisecond + +// Shutdown gracefully shuts down the server without interrupting any +// active connections. Shutdown works by first closing all open +// listeners, then closing all idle connections, and then waiting +// indefinitely for connections to return to idle and then shut down. +// If the provided context expires before the shutdown is complete, +// then the context's error is returned. +func (s *Server) Shutdown(ctx context.Context) error { + atomic.AddInt32(&s.inShutdown, 1) + defer atomic.AddInt32(&s.inShutdown, -1) + + s.mu.Lock() + lnerr := s.closeListenersLocked() + s.closeDoneChanLocked() + s.mu.Unlock() + + ticker := time.NewTicker(shutdownPollInterval) + defer ticker.Stop() + for { + if s.closeIdleConns() { + return lnerr + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } +} + +// closeIdleConns closes all idle connections and reports whether the +// server is quiescent. +func (s *Server) closeIdleConns() bool { + s.mu.Lock() + defer s.mu.Unlock() + quiescent := true + for c := range s.activeConn { + st, ok := c.curState.Load().(ConnState) + if !ok || st != StateIdle { + quiescent = false + continue + } + c.rwc.Close() + delete(s.activeConn, c) + } + return quiescent +} + +func (s *Server) closeListenersLocked() error { + var err error + for ln := range s.listeners { + if cerr := ln.Close(); cerr != nil && err == nil { + err = cerr + } + delete(s.listeners, ln) + } + return err } // A ConnState represents the state of a client connection to a server. @@ -2361,6 +2494,8 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool { return strSliceContains(srv.TLSConfig.NextProtos, http2NextProtoTLS) } +var ErrServerClosed = errors.New("http: Server closed") + // Serve accepts incoming connections on the Listener l, creating a // new service goroutine for each. The service goroutines read requests and // then call srv.Handler to reply to them. @@ -2370,7 +2505,8 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool { // srv.TLSConfig is non-nil and doesn't include the string "h2" in // Config.NextProtos, HTTP/2 support is not enabled. // -// Serve always returns a non-nil error. +// Serve always returns a non-nil error. After Shutdown or Close, the +// returned error is ErrServerClosed. func (srv *Server) Serve(l net.Listener) error { defer l.Close() if fn := testHookServerServe; fn != nil { @@ -2382,12 +2518,20 @@ func (srv *Server) Serve(l net.Listener) error { return err } + srv.trackListener(l, true) + defer srv.trackListener(l, false) + baseCtx := context.Background() // base is always background, per Issue 16220 ctx := context.WithValue(baseCtx, ServerContextKey, srv) ctx = context.WithValue(ctx, LocalAddrContextKey, l.Addr()) for { rw, e := l.Accept() if e != nil { + select { + case <-srv.getDoneChan(): + return ErrServerClosed + default: + } if ne, ok := e.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { tempDelay = 5 * time.Millisecond @@ -2410,6 +2554,37 @@ func (srv *Server) Serve(l net.Listener) error { } } +func (s *Server) trackListener(ln net.Listener, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.listeners == nil { + s.listeners = make(map[net.Listener]struct{}) + } + if add { + // If the *Server is being reused after a previous + // Close or Shutdown, reset its doneChan: + if len(s.listeners) == 0 && len(s.activeConn) == 0 { + s.doneChan = nil + } + s.listeners[ln] = struct{}{} + } else { + delete(s.listeners, ln) + } +} + +func (s *Server) trackConn(c *conn, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.activeConn == nil { + s.activeConn = make(map[*conn]struct{}) + } + if add { + s.activeConn[c] = struct{}{} + } else { + delete(s.activeConn, c) + } +} + func (s *Server) idleTimeout() time.Duration { if s.IdleTimeout != 0 { return s.IdleTimeout @@ -2425,7 +2600,11 @@ func (s *Server) readHeaderTimeout() time.Duration { } func (s *Server) doKeepAlives() bool { - return atomic.LoadInt32(&s.disableKeepAlives) == 0 + return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown() +} + +func (s *Server) shuttingDown() bool { + return atomic.LoadInt32(&s.inShutdown) != 0 } // SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled. @@ -2435,9 +2614,21 @@ func (s *Server) doKeepAlives() bool { func (srv *Server) SetKeepAlivesEnabled(v bool) { if v { atomic.StoreInt32(&srv.disableKeepAlives, 0) - } else { - atomic.StoreInt32(&srv.disableKeepAlives, 1) + return } + atomic.StoreInt32(&srv.disableKeepAlives, 1) + + // Close idle HTTP/1 conns: + srv.closeIdleConns() + + // Close HTTP/2 conns, as soon as they become idle, but reset + // the chan so future conns (if the listener is still active) + // still work and don't get a GOAWAY immediately, before their + // first request: + srv.mu.Lock() + defer srv.mu.Unlock() + srv.closeDoneChanLocked() // closes http2 conns + srv.doneChan = nil } func (s *Server) logf(format string, args ...interface{}) {