Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/next/6-stdlib/99-minor/net/http/httptest/31054.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
The HTTP client returned by [Server.Client] will now redirect requests for
`example.com` and any subdomains to the server being tested.
34 changes: 32 additions & 2 deletions src/net/http/httptest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package httptest

import (
"context"
"crypto/tls"
"crypto/x509"
"flag"
Expand Down Expand Up @@ -126,8 +127,24 @@ func (s *Server) Start() {
if s.URL != "" {
panic("Server already started")
}

if s.client == nil {
s.client = &http.Client{Transport: &http.Transport{}}
tr := &http.Transport{}
dialer := net.Dialer{}
// User code may set either of Dial or DialContext, with DialContext taking precedence.
// We set DialContext here to preserve any context values that are passed in,
// but fall back to Dial if the user has set it.
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
if tr.Dial != nil {
return tr.Dial(network, addr)
}
if addr == "example.com:80" || strings.HasSuffix(addr, ".example.com:80") {
addr = s.Listener.Addr().String()
}
return dialer.DialContext(ctx, network, addr)
}
s.client = &http.Client{Transport: tr}

}
s.URL = "http://" + s.Listener.Addr().String()
s.wrap()
Expand Down Expand Up @@ -173,12 +190,23 @@ func (s *Server) StartTLS() {
}
certpool := x509.NewCertPool()
certpool.AddCert(s.certificate)
s.client.Transport = &http.Transport{
tr := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: certpool,
},
ForceAttemptHTTP2: s.EnableHTTP2,
}
dialer := net.Dialer{}
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
if tr.Dial != nil {
return tr.Dial(network, addr)
}
if addr == "example.com:443" || strings.HasSuffix(addr, ".example.com:443") {
addr = s.Listener.Addr().String()
}
return dialer.DialContext(ctx, network, addr)
}
s.client.Transport = tr
s.Listener = tls.NewListener(s.Listener, s.TLS)
s.URL = "https://" + s.Listener.Addr().String()
s.wrap()
Expand Down Expand Up @@ -300,6 +328,8 @@ func (s *Server) Certificate() *x509.Certificate {
// It is configured to trust the server's TLS test certificate and will
// close its idle connections on [Server.Close].
// Use Server.URL as the base URL to send requests to the server.
// The returned client will also redirect any requests to "example.com"
// or its subdomains to the server.
func (s *Server) Client() *http.Client {
return s.client
}
Expand Down
37 changes: 37 additions & 0 deletions src/net/http/httptest/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,40 @@ func TestTLSServerWithHTTP2(t *testing.T) {
})
}
}

func TestClientExampleCom(t *testing.T) {
modes := []struct {
proto string
host string
}{
{"http", "example.com"},
{"http", "foo.example.com"},
{"https", "example.com"},
{"https", "foo.example.com"},
}

for _, tt := range modes {
t.Run(tt.proto+" "+tt.host, func(t *testing.T) {
cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("requested-hostname", r.Host)
}))
switch tt.proto {
case "https":
cst.EnableHTTP2 = true
cst.StartTLS()
default:
cst.Start()
}

defer cst.Close()

res, err := cst.Client().Get(tt.proto + "://" + tt.host)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
if got, want := res.Header.Get("requested-hostname"), tt.host; got != want {
t.Fatalf("Requested hostname mismatch\ngot: %q\nwant: %q", got, want)
}
})
}
}