Skip to content

Commit

Permalink
net/http/httptest: change Server to use http.Server.ConnState for acc…
Browse files Browse the repository at this point in the history
…ounting

With this CL, httptest.Server now uses connection-level accounting of
outstanding requests instead of ServeHTTP-level accounting. This is
more robust and results in a non-racy shutdown.

This is much easier now that net/http.Server has the ConnState hook.

Fixes #12789
Fixes #12781

Change-Id: I098cf334a6494316acb66cd07df90766df41764b
  • Loading branch information
bradfitz committed Oct 12, 2015
1 parent 6f77278 commit a35d7b3
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 60 deletions.
149 changes: 89 additions & 60 deletions src/net/http/httptest/server.go
Expand Up @@ -14,6 +14,7 @@ import (
"net/http"
"os"
"sync"
"time"
)

// A Server is an HTTP server listening on a system-chosen port on the
Expand All @@ -34,24 +35,10 @@ type Server struct {
// wg counts the number of outstanding HTTP requests on this server.
// Close blocks until all requests are finished.
wg sync.WaitGroup
}

// historyListener keeps track of all connections that it's ever
// accepted.
type historyListener struct {
net.Listener
sync.Mutex // protects history
history []net.Conn
}

func (hs *historyListener) Accept() (c net.Conn, err error) {
c, err = hs.Listener.Accept()
if err == nil {
hs.Lock()
hs.history = append(hs.history, c)
hs.Unlock()
}
return
mu sync.Mutex // guards conns
closed bool
conns map[net.Conn]http.ConnState // except terminal states
}

func newLocalListener() net.Listener {
Expand Down Expand Up @@ -103,10 +90,9 @@ func (s *Server) Start() {
if s.URL != "" {
panic("Server already started")
}
s.Listener = &historyListener{Listener: s.Listener}
s.URL = "http://" + s.Listener.Addr().String()
s.wrapHandler()
go s.Config.Serve(s.Listener)
s.wrap()
s.goServe()
if *serve != "" {
fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
select {}
Expand Down Expand Up @@ -134,23 +120,10 @@ func (s *Server) StartTLS() {
if len(s.TLS.Certificates) == 0 {
s.TLS.Certificates = []tls.Certificate{cert}
}
tlsListener := tls.NewListener(s.Listener, s.TLS)

s.Listener = &historyListener{Listener: tlsListener}
s.Listener = tls.NewListener(s.Listener, s.TLS)
s.URL = "https://" + s.Listener.Addr().String()
s.wrapHandler()
go s.Config.Serve(s.Listener)
}

func (s *Server) wrapHandler() {
h := s.Config.Handler
if h == nil {
h = http.DefaultServeMux
}
s.Config.Handler = &waitGroupHandler{
s: s,
h: h,
}
s.wrap()
s.goServe()
}

// NewTLSServer starts and returns a new Server using TLS.
Expand All @@ -164,40 +137,96 @@ func NewTLSServer(handler http.Handler) *Server {
// Close shuts down the server and blocks until all outstanding
// requests on this server have completed.
func (s *Server) Close() {
s.Listener.Close()
s.wg.Wait()
s.CloseClientConnections()
if t, ok := http.DefaultTransport.(*http.Transport); ok {
t.CloseIdleConnections()
s.mu.Lock()
firstClose := !s.closed
if firstClose {
s.closed = true
s.Listener.Close()
s.Config.SetKeepAlivesEnabled(false)
for c, st := range s.conns {
if st == http.StateIdle {
c.Close()
}
}
}
s.mu.Unlock()

// Close any connection whose Accept was in-flight but not yet in
// http.Server.serve before we closed the Listener. In practice this
// only affects (is needed for) net/http's TestTimeoutHandlerRaceHeader.
// The new connections come from the Transport code's "socket late
// binding" code (see that phrase in transport.go for details).
// TODO(bradfitz): understand this more. but in practice this looks to
// be robust over many runs of the full net/http tests.
if firstClose {
var t *time.Timer // guarded by s.mu
const closeInterval = 50 * time.Millisecond
s.mu.Lock()
t = time.AfterFunc(closeInterval, func() {
s.mu.Lock()
defer s.mu.Unlock()
for c, st := range s.conns {
if st == http.StateNew {
c.Close()
}
}
t.Reset(closeInterval)
})
s.mu.Unlock() // now that t is assigned to.
defer t.Stop()
}

s.wg.Wait()
}

// CloseClientConnections closes any currently open HTTP connections
// CloseClientConnections closes any currently-open HTTP connections
// to the test Server.
func (s *Server) CloseClientConnections() {
hl, ok := s.Listener.(*historyListener)
if !ok {
return
s.mu.Lock()
defer s.mu.Unlock()
for c := range s.conns {
c.Close()
delete(s.conns, c)
}
hl.Lock()
for _, conn := range hl.history {
conn.Close()
}
hl.Unlock()
}

// waitGroupHandler wraps a handler, incrementing and decrementing a
// sync.WaitGroup on each request, to enable Server.Close to block
// until outstanding requests are finished.
type waitGroupHandler struct {
s *Server
h http.Handler // non-nil
func (s *Server) goServe() {
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.Config.Serve(s.Listener)
}()
}

func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.s.wg.Add(1)
defer h.s.wg.Done() // a defer, in case ServeHTTP below panics
h.h.ServeHTTP(w, r)
// wrap installs the connection state-tracking hook to know which
// connections are idle.
func (s *Server) wrap() {
oldHook := s.Config.ConnState
s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
s.mu.Lock()
defer s.mu.Unlock()
switch cs {
case http.StateNew:
s.wg.Add(1)
if s.conns == nil {
s.conns = make(map[net.Conn]http.ConnState)
}
s.conns[c] = cs
case http.StateActive:
s.conns[c] = cs
case http.StateIdle:
if s.closed {
c.Close() // will cause transition to Closed later
}
s.conns[c] = cs
case http.StateHijacked, http.StateClosed:
delete(s.conns, c)
s.wg.Done()
}
if oldHook != nil {
oldHook(c, cs)
}
}
}

// localhostCert is a PEM-encoded TLS cert with SAN IPs
Expand Down
28 changes: 28 additions & 0 deletions src/net/http/httptest/server_test.go
Expand Up @@ -27,3 +27,31 @@ func TestServer(t *testing.T) {
t.Errorf("got %q, want hello", string(got))
}
}

// Issue 12781
func TestGetAfterClose(t *testing.T) {
ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello"))
}))

res, err := http.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
got, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if string(got) != "hello" {
t.Fatalf("got %q, want hello", string(got))
}

ts.Close()

res, err = http.Get(ts.URL)
if err == nil {
body, _ := ioutil.ReadAll(res.Body)
t.Fatalf("Unexected response after close: %v, %v, %s", res.Status, res.Header, body)
}

}

0 comments on commit a35d7b3

Please sign in to comment.