From 2c802e9980427a1d47384b222dcbf4c9b4e84944 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 15 Mar 2019 14:58:55 +0000 Subject: [PATCH] net/http: add Server BaseContext & ConnContext fields to control early context Fixes golang/go#30694 Change-Id: I12a0a870e4aee6576e879d88a4868666ef448298 Reviewed-on: https://go-review.googlesource.com/c/go/+/167681 Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot Reviewed-by: JP Sugarbroad Reviewed-by: Brad Fitzpatrick --- src/net/http/serve_test.go | 37 +++++++++++++++++++++++++++++++++++++ src/net/http/server.go | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index ea6d7c2fda2a1..f10a4272abb44 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -6034,6 +6034,43 @@ func TestStripPortFromHost(t *testing.T) { } } +func TestServerContexts(t *testing.T) { + setParallel(t) + defer afterTest(t) + type baseKey struct{} + type connKey struct{} + ch := make(chan context.Context, 1) + ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + ch <- r.Context() + })) + ts.Config.BaseContext = func(ln net.Listener) context.Context { + if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") { + t.Errorf("unexpected onceClose listener type %T", ln) + } + return context.WithValue(context.Background(), baseKey{}, "base") + } + ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + if got, want := ctx.Value(baseKey{}), "base"; got != want { + t.Errorf("in ConnContext, base context key = %#v; want %q", got, want) + } + return context.WithValue(ctx, connKey{}, "conn") + } + ts.Start() + defer ts.Close() + res, err := ts.Client().Get(ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + ctx := <-ch + if got, want := ctx.Value(baseKey{}), "base"; got != want { + t.Errorf("base context key = %#v; want %q", got, want) + } + if got, want := ctx.Value(connKey{}), "conn"; got != want { + t.Errorf("conn context key = %#v; want %q", got, want) + } +} + func BenchmarkResponseStatusLine(b *testing.B) { b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { diff --git a/src/net/http/server.go b/src/net/http/server.go index 14f74285c12ad..bc6d93bce096a 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -2542,6 +2542,20 @@ type Server struct { // If nil, logging is done via the log package's standard logger. ErrorLog *log.Logger + // BaseContext optionally specifies a function that returns + // the base context for incoming requests on this server. + // The provided Listener is the specific Listener that's + // about to start accepting requests. + // If BaseContext is nil, the default is context.Background(). + // If non-nil, it must return a non-nil context. + BaseContext func(net.Listener) context.Context + + // ConnContext optionally specifies a function that modifies + // the context used for a newly connection c. The provided ctx + // is derived from the base context and has a ServerContextKey + // value. + ConnContext func(ctx context.Context, c net.Conn) context.Context + disableKeepAlives int32 // accessed atomically. inShutdown int32 // accessed atomically (non-zero means we're in Shutdown) nextProtoOnce sync.Once // guards setupHTTP2_* init @@ -2838,6 +2852,7 @@ func (srv *Server) Serve(l net.Listener) error { fn(srv, l) // call hook with unwrapped listener } + origListener := l l = &onceCloseListener{Listener: l} defer l.Close() @@ -2850,8 +2865,16 @@ func (srv *Server) Serve(l net.Listener) error { } defer srv.trackListener(&l, false) - var tempDelay time.Duration // how long to sleep on accept failure - baseCtx := context.Background() // base is always background, per Issue 16220 + var tempDelay time.Duration // how long to sleep on accept failure + + baseCtx := context.Background() + if srv.BaseContext != nil { + baseCtx = srv.BaseContext(origListener) + if baseCtx == nil { + panic("BaseContext returned a nil context") + } + } + ctx := context.WithValue(baseCtx, ServerContextKey, srv) for { rw, e := l.Accept() @@ -2876,6 +2899,12 @@ func (srv *Server) Serve(l net.Listener) error { } return e } + if cc := srv.ConnContext; cc != nil { + ctx = cc(ctx, rw) + if ctx == nil { + panic("ConnContext returned nil") + } + } tempDelay = 0 c := srv.newConn(rw) c.setState(c.rwc, StateNew) // before Serve can return