From cdfcd6de2386f7dec2627ddfa3fd07d6d1641801 Mon Sep 17 00:00:00 2001 From: roc Date: Wed, 6 Mar 2024 16:58:14 +0800 Subject: [PATCH] remove tests from std net/http --- http_test.go | 64 - response_test.go | 73 - transfer_test.go | 364 --- transport_internal_test.go | 188 -- transport_test.go | 5996 ------------------------------------ 5 files changed, 6685 deletions(-) delete mode 100644 http_test.go delete mode 100644 response_test.go delete mode 100644 transfer_test.go delete mode 100644 transport_internal_test.go delete mode 100644 transport_test.go diff --git a/http_test.go b/http_test.go deleted file mode 100644 index 7dca3a45..00000000 --- a/http_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package req - -import ( - "reflect" - "testing" -) - -func TestForeachHeaderElement(t *testing.T) { - tests := []struct { - in string - want []string - }{ - {"Foo", []string{"Foo"}}, - {" Foo", []string{"Foo"}}, - {"Foo ", []string{"Foo"}}, - {" Foo ", []string{"Foo"}}, - - {"foo", []string{"foo"}}, - {"anY-cAsE", []string{"anY-cAsE"}}, - - {"", nil}, - {",,,, , ,, ,,, ,", nil}, - - {" Foo,Bar, Baz,lower,,Quux ", []string{"Foo", "Bar", "Baz", "lower", "Quux"}}, - } - for _, tt := range tests { - var got []string - foreachHeaderElement(tt.in, func(v string) { - got = append(got, v) - }) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("foreachHeaderElement(%q) = %q; want %q", tt.in, got, tt.want) - } - } -} - -func TestCleanHost(t *testing.T) { - tests := []struct { - in, want string - }{ - {"www.google.com", "www.google.com"}, - {"www.google.com foo", "www.google.com"}, - {"www.google.com/foo", "www.google.com"}, - {" first character is a space", ""}, - {"[1::6]:8080", "[1::6]:8080"}, - - // Punycode: - {"гофер.рф/foo", "xn--c1ae0ajs.xn--p1ai"}, - {"bücher.de", "xn--bcher-kva.de"}, - {"bücher.de:8080", "xn--bcher-kva.de:8080"}, - // Verify we convert to lowercase before punycode: - {"BÜCHER.de", "xn--bcher-kva.de"}, - {"BÜCHER.de:8080", "xn--bcher-kva.de:8080"}, - // Verify we normalize to NFC before punycode: - {"gophér.nfc", "xn--gophr-esa.nfc"}, // NFC input; no work needed - {"goph\u0065\u0301r.nfd", "xn--gophr-esa.nfd"}, // NFD input - } - for _, tt := range tests { - got := cleanHost(tt.in) - if tt.want != got { - t.Errorf("cleanHost(%q) = %q, want %q", tt.in, got, tt.want) - } - } -} diff --git a/response_test.go b/response_test.go deleted file mode 100644 index 7ec82376..00000000 --- a/response_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package req - -import ( - "github.com/imroc/req/v3/internal/header" - "github.com/imroc/req/v3/internal/tests" - "net/http" - "testing" -) - -type User struct { - Name string `json:"name" xml:"name"` -} - -type Message struct { - Message string `json:"message"` -} - -func TestUnmarshalJson(t *testing.T) { - var user User - resp, err := tc().R().Get("/json") - assertSuccess(t, resp, err) - err = resp.UnmarshalJson(&user) - tests.AssertNoError(t, err) - tests.AssertEqual(t, "roc", user.Name) -} - -func TestUnmarshalXml(t *testing.T) { - var user User - resp, err := tc().R().Get("/xml") - assertSuccess(t, resp, err) - err = resp.UnmarshalXml(&user) - tests.AssertNoError(t, err) - tests.AssertEqual(t, "roc", user.Name) -} - -func TestUnmarshal(t *testing.T) { - var user User - resp, err := tc().R().Get("/xml") - assertSuccess(t, resp, err) - err = resp.Unmarshal(&user) - tests.AssertNoError(t, err) - tests.AssertEqual(t, "roc", user.Name) -} - -func TestResponseResult(t *testing.T) { - resp, _ := tc().R().SetResult(&User{}).Get("/json") - user, ok := resp.Result().(*User) - if !ok { - t.Fatal("Response.Result() should return *User") - } - tests.AssertEqual(t, "roc", user.Name) - - tests.AssertEqual(t, true, resp.TotalTime() > 0) - tests.AssertEqual(t, false, resp.ReceivedAt().IsZero()) -} - -func TestResponseError(t *testing.T) { - resp, _ := tc().R().SetError(&Message{}).Get("/json?error=yes") - msg, ok := resp.Error().(*Message) - if !ok { - t.Fatal("Response.Error() should return *Message") - } - tests.AssertEqual(t, "not allowed", msg.Message) -} - -func TestResponseWrap(t *testing.T) { - resp, err := tc().R().Get("/json") - assertSuccess(t, resp, err) - tests.AssertEqual(t, true, resp.GetStatusCode() == http.StatusOK) - tests.AssertEqual(t, true, resp.GetStatus() == "200 OK") - tests.AssertEqual(t, true, resp.GetHeader(header.ContentType) == header.JsonContentType) - tests.AssertEqual(t, true, len(resp.GetHeaderValues(header.ContentType)) == 1) -} diff --git a/transfer_test.go b/transfer_test.go deleted file mode 100644 index 0721aeed..00000000 --- a/transfer_test.go +++ /dev/null @@ -1,364 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package req - -import ( - "bufio" - "bytes" - "crypto/rand" - "fmt" - "io" - "net/http" - "os" - "reflect" - "strings" - "testing" -) - -func TestBodyReadBadTrailer(t *testing.T) { - b := &body{ - src: strings.NewReader("foobar"), - hdr: true, // force reading the trailer - r: bufio.NewReader(strings.NewReader("")), - } - buf := make([]byte, 7) - n, err := b.Read(buf[:3]) - got := string(buf[:n]) - if got != "foo" || err != nil { - t.Fatalf(`first Read = %d (%q), %v; want 3 ("foo")`, n, got, err) - } - - n, err = b.Read(buf[:]) - got = string(buf[:n]) - if got != "bar" || err != nil { - t.Fatalf(`second Read = %d (%q), %v; want 3 ("bar")`, n, got, err) - } - - n, err = b.Read(buf[:]) - got = string(buf[:n]) - if err == nil { - t.Errorf("final Read was successful (%q), expected error from trailer read", got) - } -} - -func TestFinalChunkedBodyReadEOF(t *testing.T) { - res, err := http.ReadResponse(bufio.NewReader(strings.NewReader( - "HTTP/1.1 200 OK\r\n"+ - "Transfer-Encoding: chunked\r\n"+ - "\r\n"+ - "0a\r\n"+ - "Body here\n\r\n"+ - "09\r\n"+ - "continued\r\n"+ - "0\r\n"+ - "\r\n")), nil) - if err != nil { - t.Fatal(err) - } - want := "Body here\ncontinued" - buf := make([]byte, len(want)) - n, err := res.Body.Read(buf) - if n != len(want) || err != io.EOF { - t.Logf("body = %#v", res.Body) - t.Errorf("Read = %v, %v; want %d, EOF", n, err, len(want)) - } - if string(buf) != want { - t.Errorf("buf = %q; want %q", buf, want) - } -} - -func TestDetectInMemoryReaders(t *testing.T) { - pr, _ := io.Pipe() - tests := []struct { - r io.Reader - want bool - }{ - {pr, false}, - - {bytes.NewReader(nil), true}, - {bytes.NewBuffer(nil), true}, - {strings.NewReader(""), true}, - - {io.NopCloser(pr), false}, - - {io.NopCloser(bytes.NewReader(nil)), true}, - {io.NopCloser(bytes.NewBuffer(nil)), true}, - {io.NopCloser(strings.NewReader("")), true}, - } - for i, tt := range tests { - got := isKnownInMemoryReader(tt.r) - if got != tt.want { - t.Errorf("%d: got = %v; want %v", i, got, tt.want) - } - } -} - -type mockTransferWriter struct { - CalledReader io.Reader - WriteCalled bool -} - -var _ io.ReaderFrom = (*mockTransferWriter)(nil) - -func (w *mockTransferWriter) ReadFrom(r io.Reader) (int64, error) { - w.CalledReader = r - return io.Copy(io.Discard, r) -} - -func (w *mockTransferWriter) Write(p []byte) (int, error) { - w.WriteCalled = true - return io.Discard.Write(p) -} - -func TestTransferWriterWriteBodyReaderTypes(t *testing.T) { - fileType := reflect.TypeOf(&os.File{}) - bufferType := reflect.TypeOf(&bytes.Buffer{}) - - nBytes := int64(1 << 10) - newFileFunc := func() (r io.Reader, done func(), err error) { - f, err := os.CreateTemp("", "net-http-newfilefunc") - if err != nil { - return nil, nil, err - } - - // Write some bytes to the file to enable reading. - if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil { - return nil, nil, fmt.Errorf("failed to write data to file: %v", err) - } - if _, err := f.Seek(0, 0); err != nil { - return nil, nil, fmt.Errorf("failed to seek to front: %v", err) - } - - done = func() { - f.Close() - os.Remove(f.Name()) - } - - return f, done, nil - } - - newBufferFunc := func() (io.Reader, func(), error) { - return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil - } - - cases := []struct { - name string - bodyFunc func() (io.Reader, func(), error) - method string - contentLength int64 - transferEncoding []string - limitedReader bool - expectedReader reflect.Type - expectedWrite bool - }{ - { - name: "file, non-chunked, size set", - bodyFunc: newFileFunc, - method: "PUT", - contentLength: nBytes, - limitedReader: true, - expectedReader: fileType, - }, - { - name: "file, non-chunked, size set, nopCloser wrapped", - method: "PUT", - bodyFunc: func() (io.Reader, func(), error) { - r, cleanup, err := newFileFunc() - return io.NopCloser(r), cleanup, err - }, - contentLength: nBytes, - limitedReader: true, - expectedReader: fileType, - }, - { - name: "file, non-chunked, negative size", - method: "PUT", - bodyFunc: newFileFunc, - contentLength: -1, - expectedReader: fileType, - }, - { - name: "file, non-chunked, CONNECT, negative size", - method: "CONNECT", - bodyFunc: newFileFunc, - contentLength: -1, - expectedReader: fileType, - }, - { - name: "file, chunked", - method: "PUT", - bodyFunc: newFileFunc, - transferEncoding: []string{"chunked"}, - expectedWrite: true, - }, - { - name: "buffer, non-chunked, size set", - bodyFunc: newBufferFunc, - method: "PUT", - contentLength: nBytes, - limitedReader: true, - expectedReader: bufferType, - }, - { - name: "buffer, non-chunked, size set, nopCloser wrapped", - method: "PUT", - bodyFunc: func() (io.Reader, func(), error) { - r, cleanup, err := newBufferFunc() - return io.NopCloser(r), cleanup, err - }, - contentLength: nBytes, - limitedReader: true, - expectedReader: bufferType, - }, - { - name: "buffer, non-chunked, negative size", - method: "PUT", - bodyFunc: newBufferFunc, - contentLength: -1, - expectedWrite: true, - }, - { - name: "buffer, non-chunked, CONNECT, negative size", - method: "CONNECT", - bodyFunc: newBufferFunc, - contentLength: -1, - expectedWrite: true, - }, - { - name: "buffer, chunked", - method: "PUT", - bodyFunc: newBufferFunc, - transferEncoding: []string{"chunked"}, - expectedWrite: true, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - body, cleanup, err := tc.bodyFunc() - if err != nil { - t.Fatal(err) - } - defer cleanup() - - mw := &mockTransferWriter{} - tw := &transferWriter{ - Body: body, - ContentLength: tc.contentLength, - TransferEncoding: tc.transferEncoding, - } - - if err := tw.writeBody(mw, nil); err != nil { - t.Fatal(err) - } - - if tc.expectedReader != nil { - if mw.CalledReader == nil { - t.Fatal("did not call ReadFrom") - } - - var actualReader reflect.Type - lr, ok := mw.CalledReader.(*io.LimitedReader) - if ok && tc.limitedReader { - actualReader = reflect.TypeOf(lr.R) - } else { - actualReader = reflect.TypeOf(mw.CalledReader) - } - - if tc.expectedReader != actualReader { - t.Fatalf("got reader %s want %s", actualReader, tc.expectedReader) - } - } - - if tc.expectedWrite && !mw.WriteCalled { - t.Fatal("did not invoke Write") - } - }) - } -} - -func TestParseTransferEncoding(t *testing.T) { - tests := []struct { - hdr http.Header - wantErr error - }{ - { - hdr: http.Header{"Transfer-Encoding": {"fugazi"}}, - wantErr: &unsupportedTEError{`unsupported transfer encoding: "fugazi"`}, - }, - { - hdr: http.Header{"Transfer-Encoding": {"chunked, chunked", "identity", "chunked"}}, - wantErr: &unsupportedTEError{`too many transfer encodings: ["chunked, chunked" "identity" "chunked"]`}, - }, - { - hdr: http.Header{"Transfer-Encoding": {""}}, - wantErr: &unsupportedTEError{`unsupported transfer encoding: ""`}, - }, - { - hdr: http.Header{"Transfer-Encoding": {"chunked, identity"}}, - wantErr: &unsupportedTEError{`unsupported transfer encoding: "chunked, identity"`}, - }, - { - hdr: http.Header{"Transfer-Encoding": {"chunked", "identity"}}, - wantErr: &unsupportedTEError{`too many transfer encodings: ["chunked" "identity"]`}, - }, - { - hdr: http.Header{"Transfer-Encoding": {"\x0bchunked"}}, - wantErr: &unsupportedTEError{`unsupported transfer encoding: "\vchunked"`}, - }, - { - hdr: http.Header{"Transfer-Encoding": {"chunked"}}, - wantErr: nil, - }, - } - - for i, tt := range tests { - tr := &transferReader{ - Header: tt.hdr, - ProtoMajor: 1, - ProtoMinor: 1, - } - gotErr := tr.parseTransferEncoding() - if !reflect.DeepEqual(gotErr, tt.wantErr) { - t.Errorf("%d.\ngot error:\n%v\nwant error:\n%v\n\n", i, gotErr, tt.wantErr) - } - } -} - -// issue 39017 - disallow Content-Length values such as "+3" -func TestParseContentLength(t *testing.T) { - tests := []struct { - cl string - wantErr error - }{ - { - cl: "3", - wantErr: nil, - }, - { - cl: "+3", - wantErr: badStringError("bad Content-Length", "+3"), - }, - { - cl: "-3", - wantErr: badStringError("bad Content-Length", "-3"), - }, - { - // max int64, for safe conversion before returning - cl: "9223372036854775807", - wantErr: nil, - }, - { - cl: "9223372036854775808", - wantErr: badStringError("bad Content-Length", "9223372036854775808"), - }, - } - - for _, tt := range tests { - if _, gotErr := parseContentLength(tt.cl); !reflect.DeepEqual(gotErr, tt.wantErr) { - t.Errorf("%q:\n\tgot=%v\n\twant=%v", tt.cl, gotErr, tt.wantErr) - } - } -} diff --git a/transport_internal_test.go b/transport_internal_test.go deleted file mode 100644 index 91bea4cd..00000000 --- a/transport_internal_test.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2016 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// White-box tests for transport.go (in package http instead of http_test). - -package req - -import ( - "context" - "errors" - "github.com/imroc/req/v3/internal/http2" - "github.com/imroc/req/v3/internal/tests" - "net" - "net/http" - "strings" - "testing" -) - -func withT(r *http.Request, t *testing.T) *http.Request { - return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf)) -} - -// Issue 15446: incorrect wrapping of errors when server closes an idle connection. -func TestTransportPersistConnReadLoopEOF(t *testing.T) { - ln := tests.NewLocalListener(t) - defer ln.Close() - - connc := make(chan net.Conn, 1) - go func() { - defer close(connc) - c, err := ln.Accept() - if err != nil { - t.Error(err) - return - } - connc <- c - }() - - tr := new(Transport) - req, _ := http.NewRequest("GET", "http://"+ln.Addr().String(), nil) - req = withT(req, t) - treq := &transportRequest{Request: req} - cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()} - pc, err := tr.getConn(treq, cm) - if err != nil { - t.Fatal(err) - } - defer pc.close(errors.New("test over")) - - conn := <-connc - if conn == nil { - // Already called t.Error in the accept goroutine. - return - } - conn.Close() // simulate the server hanging up on the client - - _, err = pc.roundTrip(treq) - if !isNothingWrittenError(err) && !isTransportReadFromServerError(err) && err != errServerClosedIdle { - t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle, transportReadFromServerError, or nothingWrittenError", err, err) - } - - <-pc.closech - err = pc.closed - if !isTransportReadFromServerError(err) && err != errServerClosedIdle { - t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err) - } -} - -func isNothingWrittenError(err error) bool { - _, ok := err.(nothingWrittenError) - return ok -} - -func isTransportReadFromServerError(err error) bool { - _, ok := err.(transportReadFromServerError) - return ok -} - -func dummyRequest(method string) *http.Request { - req, err := http.NewRequest(method, "http://fake.tld/", nil) - if err != nil { - panic(err) - } - return req -} -func dummyRequestWithBody(method string) *http.Request { - req, err := http.NewRequest(method, "http://fake.tld/", strings.NewReader("foo")) - if err != nil { - panic(err) - } - return req -} - -func dummyRequestWithBodyNoGetBody(method string) *http.Request { - req := dummyRequestWithBody(method) - req.GetBody = nil - return req -} - -// issue22091Error acts like a golang.org/x/net/http2.ErrNoCachedConn. -type issue22091Error struct{} - -func (issue22091Error) IsHTTP2NoCachedConnError() {} -func (issue22091Error) Error() string { return "issue22091Error" } - -func TestTransportShouldRetryRequest(t *testing.T) { - tests := []struct { - pc *persistConn - req *http.Request - - err error - want bool - }{ - 0: { - pc: &persistConn{reused: false}, - req: dummyRequest("POST"), - err: nothingWrittenError{}, - want: false, - }, - 1: { - pc: &persistConn{reused: true}, - req: dummyRequest("POST"), - err: nothingWrittenError{}, - want: true, - }, - 2: { - pc: &persistConn{reused: true}, - req: dummyRequest("POST"), - err: http2.ErrNoCachedConn, - want: true, - }, - 3: { - pc: nil, - req: nil, - err: issue22091Error{}, // like an external http2ErrNoCachedConn - want: true, - }, - 4: { - pc: &persistConn{reused: true}, - req: dummyRequest("POST"), - err: errMissingHost, - want: false, - }, - 5: { - pc: &persistConn{reused: true}, - req: dummyRequest("POST"), - err: transportReadFromServerError{}, - want: false, - }, - 6: { - pc: &persistConn{reused: true}, - req: dummyRequest("GET"), - err: transportReadFromServerError{}, - want: true, - }, - 7: { - pc: &persistConn{reused: true}, - req: dummyRequest("GET"), - err: errServerClosedIdle, - want: true, - }, - 8: { - pc: &persistConn{reused: true}, - req: dummyRequestWithBody("POST"), - err: nothingWrittenError{}, - want: true, - }, - 9: { - pc: &persistConn{reused: true}, - req: dummyRequestWithBodyNoGetBody("POST"), - err: nothingWrittenError{}, - want: false, - }, - } - for i, tt := range tests { - got := tt.pc.shouldRetryRequest(tt.req, tt.err) - if got != tt.want { - t.Errorf("%d. shouldRetryRequest = %v; want %v", i, got, tt.want) - } - } -} - -type roundTripFunc func(r *http.Request) (*http.Response, error) - -func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { - return f(r) -} diff --git a/transport_test.go b/transport_test.go deleted file mode 100644 index 1858540b..00000000 --- a/transport_test.go +++ /dev/null @@ -1,5996 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Tests for transport.go. -// -// More tests are in clientserver_test.go (for things testing both client & server for both -// HTTP/1 and HTTP/2). This - -package req - -import ( - "bufio" - "bytes" - "compress/gzip" - "context" - "crypto/rand" - "crypto/tls" - "crypto/x509" - "encoding/binary" - "errors" - "fmt" - "github.com/imroc/req/v3/internal/common" - "github.com/imroc/req/v3/internal/tests" - "github.com/imroc/req/v3/internal/transport" - "go/token" - "golang.org/x/net/http/httpproxy" - nethttp2 "golang.org/x/net/http2" - "io" - "log" - mrand "math/rand" - "net" - "net/http" - "net/http/httptest" - "net/http/httptrace" - "net/http/httputil" - "net/textproto" - "net/url" - "os" - "reflect" - "runtime" - "sort" - "strconv" - "strings" - "sync" - "sync/atomic" - "testing" - "testing/iotest" - "time" - - "golang.org/x/net/http/httpguts" -) - -func (t *Transport) NumPendingRequestsForTesting() int { - t.reqMu.Lock() - defer t.reqMu.Unlock() - return len(t.reqCanceler) -} - -func (t *Transport) IdleConnKeysForTesting() (keys []string) { - keys = make([]string, 0) - t.idleMu.Lock() - defer t.idleMu.Unlock() - for key := range t.idleConn { - keys = append(keys, key.String()) - } - sort.Strings(keys) - return -} - -func (t *Transport) IdleConnKeyCountForTesting() int { - t.idleMu.Lock() - defer t.idleMu.Unlock() - return len(t.idleConn) -} - -func (t *Transport) IdleConnStrsForTesting() []string { - var ret []string - t.idleMu.Lock() - defer t.idleMu.Unlock() - for _, conns := range t.idleConn { - for _, pc := range conns { - ret = append(ret, pc.conn.LocalAddr().String()+"/"+pc.conn.RemoteAddr().String()) - } - } - sort.Strings(ret) - return ret -} - -func (t *Transport) IdleConnCountForTesting(scheme, addr string) int { - t.idleMu.Lock() - defer t.idleMu.Unlock() - key := connectMethodKey{"", scheme, addr, false} - cacheKey := key.String() - for k, conns := range t.idleConn { - if k.String() == cacheKey { - return len(conns) - } - } - return 0 -} - -func (t *Transport) IdleConnWaitMapSizeForTesting() int { - t.idleMu.Lock() - defer t.idleMu.Unlock() - return len(t.idleConnWait) -} - -func (t *Transport) IsIdleForTesting() bool { - t.idleMu.Lock() - defer t.idleMu.Unlock() - return t.closeIdle -} - -func (t *Transport) QueueForIdleConnForTesting() { - t.queueForIdleConn(nil) -} - -// PutIdleTestConn reports whether it was able to insert a fresh -// persistConn for scheme, addr into the idle connection pool. -func (t *Transport) PutIdleTestConn(scheme, addr string) bool { - c, _ := net.Pipe() - key := connectMethodKey{"", scheme, addr, false} - - if t.MaxConnsPerHost > 0 { - // Transport is tracking conns-per-host. - // Increment connection count to account - // for new persistConn created below. - t.connsPerHostMu.Lock() - if t.connsPerHost == nil { - t.connsPerHost = make(map[connectMethodKey]int) - } - t.connsPerHost[key]++ - t.connsPerHostMu.Unlock() - } - - return t.tryPutIdleConn(&persistConn{ - t: t, - conn: c, // dummy - closech: make(chan struct{}), // so it can be closed - cacheKey: key, - }) == nil -} - -// PutIdleTestConnH2 reports whether it was able to insert a fresh -// HTTP/2 persistConn for scheme, addr into the idle connection pool. -func (t *Transport) PutIdleTestConnH2(scheme, addr string, alt http.RoundTripper) bool { - key := connectMethodKey{"", scheme, addr, false} - - if t.MaxConnsPerHost > 0 { - // Transport is tracking conns-per-host. - // Increment connection count to account - // for new persistConn created below. - t.connsPerHostMu.Lock() - if t.connsPerHost == nil { - t.connsPerHost = make(map[connectMethodKey]int) - } - t.connsPerHost[key]++ - t.connsPerHostMu.Unlock() - } - - return t.tryPutIdleConn(&persistConn{ - t: t, - alt: alt, - cacheKey: key, - }) == nil -} - -// TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close -// and then verify that the final 2 responses get errors back. - -// hostPortHandler writes back the client's "host:port". -var hostPortHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.FormValue("close") == "true" { - w.Header().Set("Connection", "close") - } - w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close)) - w.Write([]byte(r.RemoteAddr)) -}) - -// testCloseConn is a net.Conn tracked by a testConnSet. -type testCloseConn struct { - net.Conn - set *testConnSet -} - -func (c *testCloseConn) Close() error { - c.set.remove(c) - return c.Conn.Close() -} - -// testConnSet tracks a set of TCP connections and whether they've -// been closed. -type testConnSet struct { - t *testing.T - mu sync.Mutex // guards closed and list - closed map[net.Conn]bool - list []net.Conn // in order created -} - -func (tcs *testConnSet) insert(c net.Conn) { - tcs.mu.Lock() - defer tcs.mu.Unlock() - tcs.closed[c] = false - tcs.list = append(tcs.list, c) -} - -func (tcs *testConnSet) remove(c net.Conn) { - tcs.mu.Lock() - defer tcs.mu.Unlock() - tcs.closed[c] = true -} - -// some tests use this to manage raw tcp connections for later inspection -func makeTestDial(t *testing.T) (*testConnSet, func(ctx context.Context, n, addr string) (net.Conn, error)) { - connSet := &testConnSet{ - t: t, - closed: make(map[net.Conn]bool), - } - dial := func(_ context.Context, n, addr string) (net.Conn, error) { - c, err := net.Dial(n, addr) - if err != nil { - return nil, err - } - tc := &testCloseConn{c, connSet} - connSet.insert(tc) - return tc, nil - } - return connSet, dial -} - -func (tcs *testConnSet) check(t *testing.T) { - tcs.mu.Lock() - defer tcs.mu.Unlock() - for i := 4; i >= 0; i-- { - for i, c := range tcs.list { - if tcs.closed[c] { - continue - } - if i != 0 { - tcs.mu.Unlock() - time.Sleep(50 * time.Millisecond) - tcs.mu.Lock() - continue - } - t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list)) - } - } -} - -func TestReuseRequest(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("{}")) - })) - defer ts.Close() - - c := tc().httpClient - req, _ := http.NewRequest("GET", ts.URL, nil) - res, err := c.Do(req) - if err != nil { - t.Fatal(err) - } - err = res.Body.Close() - if err != nil { - t.Fatal(err) - } - - res, err = c.Do(req) - if err != nil { - t.Fatal(err) - } - err = res.Body.Close() - if err != nil { - t.Fatal(err) - } -} - -// Two subsequent requests and verify their response is the same. -// The response from the server is our own IP:port -func TestTransportKeepAlives(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() - - c := tc().httpClient - for _, disableKeepAlive := range []bool{false, true} { - c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive - fetch := func(n int) string { - res, err := c.Get(ts.URL) - if err != nil { - t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err) - } - body, err := io.ReadAll(res.Body) - if err != nil { - t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err) - } - return string(body) - } - - body1 := fetch(1) - body2 := fetch(2) - - bodiesDiffer := body1 != body2 - if bodiesDiffer != disableKeepAlive { - t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", - disableKeepAlive, bodiesDiffer, body1, body2) - } - } -} - -func interestingGoroutines() (gs []string) { - buf := make([]byte, 2<<20) - buf = buf[:runtime.Stack(buf, true)] - for _, g := range strings.Split(string(buf), "\n\n") { - sl := strings.SplitN(g, "\n", 2) - if len(sl) != 2 { - continue - } - stack := strings.TrimSpace(sl[1]) - if stack == "" || - strings.Contains(stack, "testing.(*M).before.func1") || - strings.Contains(stack, "os/signal.signal_recv") || - strings.Contains(stack, "created by net.startServer") || - strings.Contains(stack, "created by testing.RunTests") || - strings.Contains(stack, "closeWriteAndWait") || - strings.Contains(stack, "testing.Main(") || - // These only show up with GOTRACEBACK=2; Issue 5005 (comment 28) - strings.Contains(stack, "runtime.goexit") || - strings.Contains(stack, "created by runtime.gc") || - strings.Contains(stack, "net/http_test.interestingGoroutines") || - strings.Contains(stack, "runtime.MHeap_Scavenger") { - continue - } - gs = append(gs, stack) - } - sort.Strings(gs) - return -} - -func afterTest(t testing.TB) { - http.DefaultTransport.(*http.Transport).CloseIdleConnections() - if testing.Short() { - return - } - // var bad string - // badSubstring := map[string]string{ - // ").readLoop(": "a Transport", - // ").writeLoop(": "a Transport", - // "created by net/http/httptest.(*Server).Start": "an httptest.Server", - // "timeoutHandler": "a TimeoutHandler", - // "net.(*netFD).connect(": "a timing out dial", - // ").noteClientGone(": "a closenotifier sender", - // } - // var stacks string - // for i := 0; i < 10; i++ { - // bad = "" - // stacks = strings.Join(interestingGoroutines(), "\n\n") - // for substr, what := range badSubstring { - // if strings.Contains(stacks, substr) { - // bad = what - // } - // } - // if bad == "" { - // return - // } - // // Bad stuff found, but goroutines might just still be - // // shutting down, so give it some time. - // time.Sleep(250 * time.Millisecond) - // } - // t.Errorf("Test appears to have leaked %s:\n%s", bad, stacks) -} - -func TestTransportConnectionCloseOnResponse(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() - - connSet, testDial := makeTestDial(t) - - c := tc().httpClient - tr := c.Transport.(*Transport) - tr.DialContext = testDial - - for _, connectionClose := range []bool{false, true} { - fetch := func(n int) string { - req := new(http.Request) - var err error - req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose)) - if err != nil { - t.Fatalf("URL parse error: %v", err) - } - req.Method = "GET" - req.Proto = "HTTP/1.1" - req.ProtoMajor = 1 - req.ProtoMinor = 1 - - res, err := c.Do(req) - if err != nil { - t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) - } - defer res.Body.Close() - body, err := io.ReadAll(res.Body) - if err != nil { - t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) - } - return string(body) - } - - body1 := fetch(1) - body2 := fetch(2) - bodiesDiffer := body1 != body2 - if bodiesDiffer != connectionClose { - t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", - connectionClose, bodiesDiffer, body1, body2) - } - - tr.CloseIdleConnections() - } - - connSet.check(t) -} - -// TestTransportConnectionCloseOnRequest tests that the Transport's doesn't reuse -// an underlying TCP connection after making an http.Request with Request.Close set. -// -// It tests the behavior by making an HTTP request to a server which -// describes the source source connection it got (remote port number + -// address of its net.Conn) -func TestTransportConnectionCloseOnRequest(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() - - connSet, testDial := makeTestDial(t) - - c := tc().httpClient - tr := c.Transport.(*Transport) - tr.DialContext = testDial - for _, reqClose := range []bool{false, true} { - fetch := func(n int) string { - req := new(http.Request) - var err error - req.URL, err = url.Parse(ts.URL) - if err != nil { - t.Fatalf("URL parse error: %v", err) - } - req.Method = "GET" - req.Proto = "HTTP/1.1" - req.ProtoMajor = 1 - req.ProtoMinor = 1 - req.Close = reqClose - - res, err := c.Do(req) - if err != nil { - t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err) - } - if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want { - t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v", - reqClose, got, !reqClose) - } - body, err := io.ReadAll(res.Body) - if err != nil { - t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err) - } - return string(body) - } - - body1 := fetch(1) - body2 := fetch(2) - - got := 1 - if body1 != body2 { - got++ - } - want := 1 - if reqClose { - want = 2 - } - if got != want { - t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q", - reqClose, got, want, body1, body2) - } - - tr.CloseIdleConnections() - } - - connSet.check(t) -} - -// if the Transport's DisableKeepAlives is set, all requests should -// send Connection: close. -// HTTP/1-only (Connection: close doesn't exist in h2) -func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() - - c := tc().httpClient - c.Transport.(*Transport).DisableKeepAlives = true - - res, err := c.Get(ts.URL) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - if res.Header.Get("X-Saw-Close") != "true" { - t.Errorf("handler didn't see Connection: close ") - } -} - -// Test that Transport only sends one "Connection: close", regardless of -// how "close" was indicated. -func TestTransportRespectRequestWantsClose(t *testing.T) { - tests := []struct { - disableKeepAlives bool - close bool - }{ - {disableKeepAlives: false, close: false}, - {disableKeepAlives: false, close: true}, - {disableKeepAlives: true, close: false}, - {disableKeepAlives: true, close: true}, - } - - for _, testCase := range tests { - t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", testCase.disableKeepAlives, testCase.close), - func(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() - - c := tc().httpClient - c.Transport.(*Transport).DisableKeepAlives = testCase.disableKeepAlives - req, err := http.NewRequest("GET", ts.URL, nil) - if err != nil { - t.Fatal(err) - } - count := 0 - trace := &httptrace.ClientTrace{ - WroteHeaderField: func(key string, field []string) { - if key != "Connection" { - return - } - if httpguts.HeaderValuesContainsToken(field, "close") { - count += 1 - } - }, - } - req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) - req.Close = testCase.close - res, err := c.Do(req) - if err != nil { - t.Fatal(err) - } - defer res.Body.Close() - if want := testCase.disableKeepAlives || testCase.close; count > 1 || (count == 1) != want { - t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count) - } - }) - } - -} - -func TestTransportIdleCacheKeys(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() - c := tc().httpClient - tr := c.Transport.(*Transport) - - if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { - t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) - } - - resp, err := c.Get(ts.URL) - if err != nil { - t.Error(err) - } - io.ReadAll(resp.Body) - - keys := tr.IdleConnKeysForTesting() - if e, g := 1, len(keys); e != g { - t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g) - } - - if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e { - t.Errorf("Expected idle cache key %q; got %q", e, keys[0]) - } - - tr.CloseIdleConnections() - if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { - t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) - } -} - -// Tests that the HTTP transport re-uses connections when a client -// reads to the end of a response Body without closing it. -func TestTransportReadToEndReusesConn(t *testing.T) { - defer afterTest(t) - const msg = "foobar" - - var addrSeen map[string]int - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - addrSeen[r.RemoteAddr]++ - if r.URL.Path == "/chunked/" { - w.WriteHeader(200) - w.(http.Flusher).Flush() - } else { - w.Header().Set("Content-Length", strconv.Itoa(len(msg))) - w.WriteHeader(200) - } - w.Write([]byte(msg)) - })) - defer ts.Close() - - buf := make([]byte, len(msg)) - - for pi, path := range []string{"/content-length/", "/chunked/"} { - wantLen := []int{len(msg), -1}[pi] - addrSeen = make(map[string]int) - for i := 0; i < 3; i++ { - res, err := http.Get(ts.URL + path) - if err != nil { - t.Errorf("Get %s: %v", path, err) - continue - } - // We want to close this body eventually (before the - // defer afterTest at top runs), but not before the - // len(addrSeen) check at the bottom of this test, - // since Closing this early in the loop would risk - // making connections be re-used for the wrong reason. - defer res.Body.Close() - - if res.ContentLength != int64(wantLen) { - t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen) - } - n, err := res.Body.Read(buf) - if n != len(msg) || err != io.EOF { - t.Errorf("%s Read = %v, %v; want %d, EOF", path, n, err, len(msg)) - } - } - if len(addrSeen) != 1 { - t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen)) - } - } -} - -func TestTransportMaxPerHostIdleConns(t *testing.T) { - defer afterTest(t) - stop := make(chan struct{}) // stop marks the exit of main Test goroutine - defer close(stop) - - resch := make(chan string) - gotReq := make(chan bool) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotReq <- true - var msg string - select { - case <-stop: - return - case msg = <-resch: - } - _, err := w.Write([]byte(msg)) - if err != nil { - t.Errorf("Write: %v", err) - return - } - })) - defer ts.Close() - - c := tc().httpClient - tr := c.Transport.(*Transport) - maxIdleConnsPerHost := 2 - tr.MaxIdleConnsPerHost = maxIdleConnsPerHost - - // Start 3 outstanding requests and wait for the server to get them. - // Their responses will hang until we write to resch, though. - donech := make(chan bool) - doReq := func() { - defer func() { - select { - case <-stop: - return - case donech <- t.Failed(): - } - }() - resp, err := c.Get(ts.URL) - if err != nil { - t.Error(err) - return - } - if _, err := io.ReadAll(resp.Body); err != nil { - t.Errorf("ReadAll: %v", err) - return - } - } - go doReq() - <-gotReq - go doReq() - <-gotReq - go doReq() - <-gotReq - - if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { - t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g) - } - - resch <- "res1" - <-donech - keys := tr.IdleConnKeysForTesting() - if e, g := 1, len(keys); e != g { - t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g) - } - addr := ts.Listener.Addr().String() - cacheKey := "|http|" + addr - if keys[0] != cacheKey { - t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0]) - } - if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g { - t.Errorf("after first response, expected %d idle conns; got %d", e, g) - } - - resch <- "res2" - <-donech - if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w { - t.Errorf("after second response, idle conns = %d; want %d", g, w) - } - - resch <- "res3" - <-donech - if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w { - t.Errorf("after third response, idle conns = %d; want %d", g, w) - } -} - -func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := w.Write([]byte("foo")) - if err != nil { - t.Fatalf("Write: %v", err) - } - })) - defer ts.Close() - c := tc().httpClient - tr := c.Transport.(*Transport) - dialStarted := make(chan struct{}) - stallDial := make(chan struct{}) - tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) { - dialStarted <- struct{}{} - <-stallDial - return net.Dial(network, addr) - } - - tr.DisableKeepAlives = true - tr.MaxConnsPerHost = 1 - - preDial := make(chan struct{}) - reqComplete := make(chan struct{}) - doReq := func(reqId string) { - req, _ := http.NewRequest("GET", ts.URL, nil) - trace := &httptrace.ClientTrace{ - GetConn: func(hostPort string) { - preDial <- struct{}{} - }, - } - req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) - resp, err := tr.RoundTrip(req) - if err != nil { - t.Errorf("unexpected error for request %s: %v", reqId, err) - } - _, err = io.ReadAll(resp.Body) - if err != nil { - t.Errorf("unexpected error for request %s: %v", reqId, err) - } - reqComplete <- struct{}{} - } - // get req1 to dial-in-progress - go doReq("req1") - <-preDial - <-dialStarted - - // get req2 to waiting on conns per host to go down below max - go doReq("req2") - <-preDial - select { - case <-dialStarted: - t.Error("req2 dial started while req1 dial in progress") - return - default: - } - - // let req1 complete - stallDial <- struct{}{} - <-reqComplete - - // let req2 complete - <-dialStarted - stallDial <- struct{}{} - <-reqComplete -} - -func TestTransportMaxConnsPerHost(t *testing.T) { - defer afterTest(t) - - h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := w.Write([]byte("foo")) - if err != nil { - t.Fatalf("Write: %v", err) - } - }) - - testMaxConns := func(scheme string, ts *httptest.Server) { - defer ts.Close() - - c := tc().httpClient - tr := c.Transport.(*Transport) - tr.MaxConnsPerHost = 1 - - mu := sync.Mutex{} - var conns []net.Conn - var dialCnt, gotConnCnt, tlsHandshakeCnt int32 - tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) { - atomic.AddInt32(&dialCnt, 1) - c, err := net.Dial(network, addr) - mu.Lock() - defer mu.Unlock() - conns = append(conns, c) - return c, err - } - - doReq := func() { - trace := &httptrace.ClientTrace{ - GotConn: func(connInfo httptrace.GotConnInfo) { - if !connInfo.Reused { - atomic.AddInt32(&gotConnCnt, 1) - } - }, - TLSHandshakeStart: func() { - atomic.AddInt32(&tlsHandshakeCnt, 1) - }, - } - req, _ := http.NewRequest("GET", ts.URL, nil) - req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) - - resp, err := c.Do(req) - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - _, err = io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read body failed: %v", err) - } - } - - wg := sync.WaitGroup{} - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - doReq() - }() - } - wg.Wait() - - expected := int32(tr.MaxConnsPerHost) - if dialCnt != expected { - t.Errorf("round 1: too many dials (%s): %d != %d", scheme, dialCnt, expected) - } - if gotConnCnt != expected { - t.Errorf("round 1: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected) - } - if ts.TLS != nil && tlsHandshakeCnt != expected { - t.Errorf("round 1: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected) - } - - if t.Failed() { - t.FailNow() - } - - mu.Lock() - for _, c := range conns { - c.Close() - } - conns = nil - mu.Unlock() - tr.CloseIdleConnections() - - doReq() - expected++ - if dialCnt != expected { - t.Errorf("round 2: too many dials (%s): %d", scheme, dialCnt) - } - if gotConnCnt != expected { - t.Errorf("round 2: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected) - } - if ts.TLS != nil && tlsHandshakeCnt != expected { - t.Errorf("round 2: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected) - } - } - - testMaxConns("http", httptest.NewServer(h)) - testMaxConns("https", httptest.NewTLSServer(h)) - - ts := httptest.NewUnstartedServer(h) - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - testMaxConns("http2", ts) -} - -func TestTransportRemovesDeadIdleConnections(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, r.RemoteAddr) - })) - defer ts.Close() - - c := tc().httpClient - tr := c.Transport.(*Transport) - - doReq := func(name string) string { - // Do a POST instead of a GET to prevent the Transport's - // idempotent request retry logic from kicking in... - res, err := c.Post(ts.URL, "", nil) - if err != nil { - t.Fatalf("%s: %v", name, err) - } - if res.StatusCode != 200 { - t.Fatalf("%s: %v", name, res.Status) - } - defer res.Body.Close() - slurp, err := io.ReadAll(res.Body) - if err != nil { - t.Fatalf("%s: %v", name, err) - } - return string(slurp) - } - - first := doReq("first") - keys1 := tr.IdleConnKeysForTesting() - - ts.CloseClientConnections() - - var keys2 []string - if !tests.WaitCondition(3*time.Second, 50*time.Millisecond, func() bool { - keys2 = tr.IdleConnKeysForTesting() - return len(keys2) == 0 - }) { - t.Fatalf("Transport didn't notice idle connection's death.\nbefore: %q\n after: %q\n", keys1, keys2) - } - - second := doReq("second") - if first == second { - t.Errorf("expected a different connection between requests. got %q both times", first) - } -} - -// ExportCloseTransportConnsAbruptly closes all idle connections from -// tr in an abrupt way, just reaching into the underlying Conns and -// closing them, without telling the Transport or its persistConns -// that it's doing so. This is to simulate the server closing connections -// on the Transport. -func ExportCloseTransportConnsAbruptly(tr *Transport) { - tr.idleMu.Lock() - for _, pcs := range tr.idleConn { - for _, pc := range pcs { - pc.conn.Close() - } - } - tr.idleMu.Unlock() -} - -// Test that the Transport notices when a server hangs up on its -// unexpectedly (a keep-alive connection is closed). -func TestTransportServerClosingUnexpectedly(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() - c := tc().httpClient - - fetch := func(n, retries int) string { - condFatalf := func(format string, arg ...interface{}) { - if retries <= 0 { - t.Fatalf(format, arg...) - } - t.Logf("retrying shortly after expected error: "+format, arg...) - time.Sleep(time.Second / time.Duration(retries)) - } - for retries >= 0 { - retries-- - res, err := c.Get(ts.URL) - if err != nil { - condFatalf("error in req #%d, GET: %v", n, err) - continue - } - body, err := io.ReadAll(res.Body) - if err != nil { - condFatalf("error in req #%d, ReadAll: %v", n, err) - continue - } - res.Body.Close() - return string(body) - } - panic("unreachable") - } - - body1 := fetch(1, 0) - body2 := fetch(2, 0) - - // Close all the idle connections in a way that's similar to - // the server hanging up on us. We don't use - // httptest.Server.CloseClientConnections because it's - // best-effort and stops blocking after 5 seconds. On a loaded - // machine running many tests concurrently it's possible for - // that method to be async and cause the body3 fetch below to - // run on an old connection. This function is synchronous. - ExportCloseTransportConnsAbruptly(c.Transport.(*Transport)) - - body3 := fetch(3, 5) - - if body1 != body2 { - t.Errorf("expected body1 and body2 to be equal") - } - if body2 == body3 { - t.Errorf("expected body2 and body3 to be different") - } -} - -// Test for https://golang.org/issue/2616 (appropriate issue number) -// This fails pretty reliably with GOMAXPROCS=100 or something high. -func TestStressSurpriseServerCloses(t *testing.T) { - defer afterTest(t) - if testing.Short() { - t.Skip("skipping test in short mode") - } - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Length", "5") - w.Header().Set("Content-Type", "text/plain") - w.Write([]byte("Hello")) - w.(http.Flusher).Flush() - conn, buf, _ := w.(http.Hijacker).Hijack() - buf.Flush() - conn.Close() - })) - defer ts.Close() - c := tc().httpClient - - // Do a bunch of traffic from different goroutines. Send to activityc - // after each request completes, regardless of whether it failed. - // If these are too high, OS X exhausts its ephemeral ports - // and hangs waiting for them to transition TCP states. That's - // not what we want to test. TODO(bradfitz): use an io.Pipe - // dialer for this test instead? - const ( - numClients = 20 - reqsPerClient = 25 - ) - activityc := make(chan bool) - for i := 0; i < numClients; i++ { - go func() { - for i := 0; i < reqsPerClient; i++ { - res, err := c.Get(ts.URL) - if err == nil { - // We expect errors since the server is - // hanging up on us after telling us to - // send more requests, so we don't - // actually care what the error is. - // But we want to close the body in cases - // where we won the race. - res.Body.Close() - } - if !<-activityc { // Receives false when close(activityc) is executed - return - } - } - }() - } - - // Make sure all the request come back, one way or another. - for i := 0; i < numClients*reqsPerClient; i++ { - select { - case activityc <- true: - case <-time.After(5 * time.Second): - close(activityc) - t.Fatalf("presumed deadlock; no HTTP client activity seen in awhile") - } - } -} - -// TestTransportHeadResponses verifies that we deal with Content-Lengths -// with no bodies properly -func TestTransportHeadResponses(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != "HEAD" { - panic("expected HEAD; got " + r.Method) - } - w.Header().Set("Content-Length", "123") - w.WriteHeader(200) - })) - defer ts.Close() - c := tc().httpClient - - for i := 0; i < 2; i++ { - res, err := c.Head(ts.URL) - if err != nil { - t.Errorf("error on loop %d: %v", i, err) - continue - } - if e, g := "123", res.Header.Get("Content-Length"); e != g { - t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) - } - if e, g := int64(123), res.ContentLength; e != g { - t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) - } - if all, err := io.ReadAll(res.Body); err != nil { - t.Errorf("loop %d: Body ReadAll: %v", i, err) - } else if len(all) != 0 { - t.Errorf("Bogus body %q", all) - } - } -} - -// All test hooks must be non-nil so they can be called directly, -// but the tests use nil to mean hook disabled. -func unnilTestHook(f *func()) { - if *f == nil { - *f = nop - } -} - -func SetReadLoopBeforeNextReadHook(f func()) { - testHookMu.Lock() - defer testHookMu.Unlock() - unnilTestHook(&f) - testHookReadLoopBeforeNextRead = f -} - -// TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding -// on responses to HEAD requests. -func TestTransportHeadChunkedResponse(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != "HEAD" { - panic("expected HEAD; got " + r.Method) - } - w.Header().Set("Transfer-Encoding", "chunked") // client should ignore - w.Header().Set("x-client-ipport", r.RemoteAddr) - w.WriteHeader(200) - })) - defer ts.Close() - c := tc().httpClient - - // Ensure that we wait for the readLoop to complete before - // calling Head again - didRead := make(chan bool) - SetReadLoopBeforeNextReadHook(func() { didRead <- true }) - defer SetReadLoopBeforeNextReadHook(nil) - - res1, err := c.Head(ts.URL) - <-didRead - - if err != nil { - t.Fatalf("request 1 error: %v", err) - } - - res2, err := c.Head(ts.URL) - <-didRead - - if err != nil { - t.Fatalf("request 2 error: %v", err) - } - if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 { - t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2) - } -} - -var roundTripTests = []struct { - accept string - expectAccept string - compressed bool -}{ - // Requests with no accept-encoding header use transparent compression - {"", "gzip", false}, - // Requests with other accept-encoding should pass through unmodified - {"foo", "foo", false}, - // Requests with accept-encoding == gzip should be passed through - {"gzip", "gzip", true}, -} - -// Test that the modification made to the Request by the http.RoundTripper is cleaned up -func TestRoundTripGzip(t *testing.T) { - setParallel(t) - defer afterTest(t) - const responseBody = "test response body" - ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - accept := req.Header.Get("Accept-Encoding") - if expect := req.FormValue("expect_accept"); accept != expect { - t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q", - req.FormValue("testnum"), accept, expect) - } - if accept == "gzip" { - rw.Header().Set("Content-Encoding", "gzip") - gz := gzip.NewWriter(rw) - gz.Write([]byte(responseBody)) - gz.Close() - } else { - rw.Header().Set("Content-Encoding", accept) - rw.Write([]byte(responseBody)) - } - })) - defer ts.Close() - tr := tc().GetTransport() - - for i, test := range roundTripTests { - // Test basic request (no accept-encoding) - req, _ := http.NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil) - if test.accept != "" { - req.Header.Set("Accept-Encoding", test.accept) - } - res, err := tr.RoundTrip(req) - if err != nil { - t.Errorf("%d. RoundTrip: %v", i, err) - continue - } - var body []byte - if test.compressed { - var r *gzip.Reader - r, err = gzip.NewReader(res.Body) - if err != nil { - t.Errorf("%d. gzip NewReader: %v", i, err) - continue - } - body, err = io.ReadAll(r) - res.Body.Close() - } else { - body, err = io.ReadAll(res.Body) - } - if err != nil { - t.Errorf("%d. Error: %q", i, err) - continue - } - if g, e := string(body), responseBody; g != e { - t.Errorf("%d. body = %q; want %q", i, g, e) - } - if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e { - t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e) - } - if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e { - t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e) - } - } - -} - -func TestTransportGzip(t *testing.T) { - setParallel(t) - defer afterTest(t) - const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" - const nRandBytes = 1024 * 1024 - ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - if req.Method == "HEAD" { - if g := req.Header.Get("Accept-Encoding"); g != "" { - t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g) - } - return - } - if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e { - t.Errorf("Accept-Encoding = %q, want %q", g, e) - } - rw.Header().Set("Content-Encoding", "gzip") - - var w io.Writer = rw - var buf bytes.Buffer - if req.FormValue("chunked") == "0" { - w = &buf - defer io.Copy(rw, &buf) - defer func() { - rw.Header().Set("Content-Length", strconv.Itoa(buf.Len())) - }() - } - gz := gzip.NewWriter(w) - gz.Write([]byte(testString)) - if req.FormValue("body") == "large" { - io.CopyN(gz, rand.Reader, nRandBytes) - } - gz.Close() - })) - defer ts.Close() - c := tc().httpClient - - for _, chunked := range []string{"1", "0"} { - // First fetch something large, but only read some of it. - res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked) - if err != nil { - t.Fatalf("large get: %v", err) - } - buf := make([]byte, len(testString)) - n, err := io.ReadFull(res.Body, buf) - if err != nil { - t.Fatalf("partial read of large response: size=%d, %v", n, err) - } - if e, g := testString, string(buf); e != g { - t.Errorf("partial read got %q, expected %q", g, e) - } - res.Body.Close() - // Read on the body, even though it's closed - n, err = res.Body.Read(buf) - if n != 0 || err == nil { - t.Errorf("expected error post-closed large Read; got = %d, %v", n, err) - } - - // Then something small. - res, err = c.Get(ts.URL + "/?chunked=" + chunked) - if err != nil { - t.Fatal(err) - } - body, err := io.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - if g, e := string(body), testString; g != e { - t.Fatalf("body = %q; want %q", g, e) - } - if g, e := res.Header.Get("Content-Encoding"), ""; g != e { - t.Fatalf("Content-Encoding = %q; want %q", g, e) - } - - // Read on the body after it's been fully read: - n, err = res.Body.Read(buf) - if n != 0 || err == nil { - t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err) - } - res.Body.Close() - n, err = res.Body.Read(buf) - if n != 0 || err == nil { - t.Errorf("expected Read error after Close; got %d, %v", n, err) - } - } - - // And a HEAD request too, because they're always weird. - res, err := c.Head(ts.URL) - if err != nil { - t.Fatalf("Head: %v", err) - } - if res.StatusCode != 200 { - t.Errorf("Head status=%d; want=200", res.StatusCode) - } -} - -// setParallel marks t as a parallel test if we're in short mode -// (all.bash), but as a serial test otherwise. Using t.Parallel isn't -// compatible with the afterTest func in non-short mode. -func setParallel(t *testing.T) { - if testing.Short() { - t.Parallel() - } -} - -// If a request has Expect:100-continue header, the request blocks sending body until the first response. -// Premature consumption of the request body should not be occurred. -func TestTransportExpect100Continue(t *testing.T) { - setParallel(t) - defer afterTest(t) - - ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - switch req.URL.Path { - case "/100": - // This endpoint implicitly responds 100 Continue and reads body. - if _, err := io.Copy(io.Discard, req.Body); err != nil { - t.Error("Failed to read Body", err) - } - rw.WriteHeader(http.StatusOK) - case "/200": - // Go 1.5 adds Connection: close header if the client expect - // continue but not entire request body is consumed. - rw.WriteHeader(http.StatusOK) - case "/500": - rw.WriteHeader(http.StatusInternalServerError) - case "/keepalive": - // This hijacked endpoint responds error without Connection:close. - _, bufrw, err := rw.(http.Hijacker).Hijack() - if err != nil { - log.Fatal(err) - } - bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n") - bufrw.WriteString("Content-Length: 0\r\n\r\n") - bufrw.Flush() - case "/timeout": - // This endpoint tries to read body without 100 (Continue) response. - // After ExpectContinueTimeout, the reading will be started. - conn, bufrw, err := rw.(http.Hijacker).Hijack() - if err != nil { - log.Fatal(err) - } - if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil { - t.Error("Failed to read Body", err) - } - bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n") - bufrw.Flush() - conn.Close() - } - - })) - defer ts.Close() - - tests := []struct { - path string - body []byte - sent int - status int - }{ - {path: "/100", body: []byte("hello"), sent: 5, status: 200}, // Got 100 followed by 200, entire body is sent. - {path: "/200", body: []byte("hello"), sent: 0, status: 200}, // Got 200 without 100. body isn't sent. - {path: "/500", body: []byte("hello"), sent: 0, status: 500}, // Got 500 without 100. body isn't sent. - {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Although without Connection:close, body isn't sent. - {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent. - } - - c := tc().httpClient - for i, v := range tests { - tr := T() - tr.ExpectContinueTimeout = 2 * time.Second - defer tr.CloseIdleConnections() - c.Transport = tr - body := bytes.NewReader(v.body) - req, err := http.NewRequest("PUT", ts.URL+v.path, body) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Expect", "100-continue") - req.ContentLength = int64(len(v.body)) - - resp, err := c.Do(req) - if err != nil { - t.Fatal(err) - } - resp.Body.Close() - - sent := len(v.body) - body.Len() - if v.status != resp.StatusCode { - t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path) - } - if v.sent != sent { - t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path) - } - } -} - -func TestSOCKS5Proxy(t *testing.T) { - defer afterTest(t) - ch := make(chan string, 1) - l := tests.NewLocalListener(t) - defer l.Close() - defer close(ch) - proxy := func(t *testing.T) { - s, err := l.Accept() - if err != nil { - t.Errorf("socks5 proxy Accept(): %v", err) - return - } - defer s.Close() - var buf [22]byte - if _, err := io.ReadFull(s, buf[:3]); err != nil { - t.Errorf("socks5 proxy initial read: %v", err) - return - } - if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) { - t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want) - return - } - if _, err := s.Write([]byte{5, 0}); err != nil { - t.Errorf("socks5 proxy initial write: %v", err) - return - } - if _, err := io.ReadFull(s, buf[:4]); err != nil { - t.Errorf("socks5 proxy second read: %v", err) - return - } - if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) { - t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want) - return - } - var ipLen int - switch buf[3] { - case 1: - ipLen = net.IPv4len - case 4: - ipLen = net.IPv6len - default: - t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4]) - return - } - if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil { - t.Errorf("socks5 proxy address read: %v", err) - return - } - ip := net.IP(buf[4 : ipLen+4]) - port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6]) - copy(buf[:3], []byte{5, 0, 0}) - if _, err := s.Write(buf[:ipLen+6]); err != nil { - t.Errorf("socks5 proxy connect write: %v", err) - return - } - ch <- fmt.Sprintf("proxy for %s:%d", ip, port) - - // Implement proxying. - targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port))) - targetConn, err := net.Dial("tcp", targetHost) - if err != nil { - t.Errorf("net.Dial failed") - return - } - go io.Copy(targetConn, s) - io.Copy(s, targetConn) // Wait for the client to close the socket. - targetConn.Close() - } - - pu, err := url.Parse("socks5://" + l.Addr().String()) - if err != nil { - t.Fatal(err) - } - - sentinelHeader := "X-Sentinel" - sentinelValue := "12345" - h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set(sentinelHeader, sentinelValue) - }) - for _, useTLS := range []bool{false, true} { - t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) { - var ts *httptest.Server - if useTLS { - ts = httptest.NewTLSServer(h) - } else { - ts = httptest.NewServer(h) - } - go proxy(t) - c := tc().httpClient - c.Transport.(*Transport).Proxy = http.ProxyURL(pu) - r, err := c.Head(ts.URL) - if err != nil { - t.Fatal(err) - } - if r.Header.Get(sentinelHeader) != sentinelValue { - t.Errorf("Failed to retrieve sentinel value") - } - var got string - select { - case got = <-ch: - case <-time.After(5 * time.Second): - t.Fatal("timeout connecting to socks5 proxy") - } - ts.Close() - tsu, err := url.Parse(ts.URL) - if err != nil { - t.Fatal(err) - } - want := "proxy for " + tsu.Host - if got != want { - t.Errorf("got %q, want %q", got, want) - } - }) - } -} - -func TestTransportProxy(t *testing.T) { - defer afterTest(t) - testCases := []struct{ httpsSite, httpsProxy bool }{ - {false, false}, - {false, true}, - {true, false}, - {true, true}, - } - for _, testCase := range testCases { - httpsSite := testCase.httpsSite - httpsProxy := testCase.httpsProxy - t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) { - siteCh := make(chan *http.Request, 1) - h1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - siteCh <- r - }) - proxyCh := make(chan *http.Request, 1) - h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - proxyCh <- r - // Implement an entire CONNECT proxy - if r.Method == "CONNECT" { - hijacker, ok := w.(http.Hijacker) - if !ok { - t.Errorf("hijack not allowed") - return - } - clientConn, _, err := hijacker.Hijack() - if err != nil { - t.Errorf("hijacking failed") - return - } - res := &http.Response{ - StatusCode: http.StatusOK, - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - } - - targetConn, err := net.Dial("tcp", r.URL.Host) - if err != nil { - t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err) - return - } - - if err := res.Write(clientConn); err != nil { - t.Errorf("Writing 200 OK failed: %v", err) - return - } - - go io.Copy(targetConn, clientConn) - go func() { - io.Copy(clientConn, targetConn) - targetConn.Close() - }() - } - }) - var ts *httptest.Server - if httpsSite { - ts = httptest.NewTLSServer(h1) - } else { - ts = httptest.NewServer(h1) - } - var proxy *httptest.Server - if httpsProxy { - proxy = httptest.NewTLSServer(h2) - } else { - proxy = httptest.NewServer(h2) - } - - pu, err := url.Parse(proxy.URL) - if err != nil { - t.Fatal(err) - } - - // If neither server is HTTPS or both are, then c may be derived from either. - // If only one server is HTTPS, c must be derived from that server in order - // to ensure that it is configured to use the fake root CA from testcert.go. - c := tc().httpClient - - c.Transport.(*Transport).Proxy = http.ProxyURL(pu) - if _, err := c.Head(ts.URL); err != nil { - t.Error(err) - } - var got *http.Request - select { - case got = <-proxyCh: - case <-time.After(5 * time.Second): - t.Fatal("timeout connecting to http proxy") - } - c.Transport.(*Transport).CloseIdleConnections() - ts.Close() - proxy.Close() - if httpsSite { - // First message should be a CONNECT, asking for a socket to the real server, - if got.Method != "CONNECT" { - t.Errorf("Wrong method for secure proxying: %q", got.Method) - } - gotHost := got.URL.Host - pu, err := url.Parse(ts.URL) - if err != nil { - t.Fatal("Invalid site URL") - } - if wantHost := pu.Host; gotHost != wantHost { - t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost) - } - - // The next message on the channel should be from the site's server. - next := <-siteCh - if next.Method != "HEAD" { - t.Errorf("Wrong method at destination: %s", next.Method) - } - if nextURL := next.URL.String(); nextURL != "/" { - t.Errorf("Wrong URL at destination: %s", nextURL) - } - } else { - if got.Method != "HEAD" { - t.Errorf("Wrong method for destination: %q", got.Method) - } - gotURL := got.URL.String() - wantURL := ts.URL + "/" - if gotURL != wantURL { - t.Errorf("Got URL %q, want %q", gotURL, wantURL) - } - } - }) - } -} - -// Issue 28012: verify that the Transport closes its TCP connection to http proxies -// when they're slow to reply to HTTPS CONNECT responses. -func TestTransportProxyHTTPSConnectLeak(t *testing.T) { - setParallel(t) - defer afterTest(t) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ln := tests.NewLocalListener(t) - defer ln.Close() - listenerDone := make(chan struct{}) - go func() { - defer close(listenerDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Accept: %v", err) - return - } - defer c.Close() - // Read the CONNECT request - br := bufio.NewReader(c) - cr, err := http.ReadRequest(br) - if err != nil { - t.Errorf("proxy server failed to read CONNECT request") - return - } - if cr.Method != "CONNECT" { - t.Errorf("unexpected method %q", cr.Method) - return - } - - // Now hang and never write a response; instead, cancel the request and wait - // for the client to close. - // (Prior to Issue 28012 being fixed, we never closed.) - cancel() - var buf [1]byte - _, err = br.Read(buf[:]) - if err != io.EOF { - t.Errorf("proxy server Read err = %v; want EOF", err) - } - return - }() - - tr := T().SetProxy(func(*http.Request) (*url.URL, error) { - return url.Parse("http://" + ln.Addr().String()) - }) - c := &http.Client{ - Transport: tr, - } - req, err := http.NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil) - if err != nil { - t.Fatal(err) - } - _, err = c.Do(req) - if err == nil { - t.Errorf("unexpected Get success") - } - - // Wait unconditionally for the listener goroutine to exit: this should never - // hang, so if it does we want a full goroutine dump — and that's exactly what - // the testing package will give us when the test run times out. - <-listenerDone -} - -// Issue 16997: test transport dial preserves typed errors -func TestTransportDialPreservesNetOpProxyError(t *testing.T) { - defer afterTest(t) - - var errDial = errors.New("some dial error") - - tr := T().SetProxy(func(*http.Request) (*url.URL, error) { - return url.Parse("http://proxy.fake.tld/") - }).SetDial(func(context.Context, string, string) (net.Conn, error) { - return nil, errDial - }) - defer tr.CloseIdleConnections() - - c := &http.Client{Transport: tr} - req, _ := http.NewRequest("GET", "http://fake.tld", nil) - res, err := c.Do(req) - if err == nil { - res.Body.Close() - t.Fatal("wanted a non-nil error") - } - - uerr, ok := err.(*url.Error) - if !ok { - t.Fatalf("got %T, want *url.Error", err) - } - oe, ok := uerr.Err.(*net.OpError) - if !ok { - t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err) - } - want := &net.OpError{ - Op: "proxyconnect", - Net: "tcp", - Err: errDial, // original error, unwrapped. - } - if !reflect.DeepEqual(oe, want) { - t.Errorf("Got error %#v; want %#v", oe, want) - } -} - -// Issue 36431: calls to RoundTrip should not mutate t.ProxyConnectHeader. -// -// (A bug caused dialConn to instead write the per-request Proxy-Authorization -// header through to the shared Header instance, introducing a data race.) -func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) { - setParallel(t) - defer afterTest(t) - - proxy := httptest.NewTLSServer(http.NotFoundHandler()) - defer proxy.Close() - c := tc().httpClient - - tr := c.Transport.(*Transport) - tr.Proxy = func(*http.Request) (*url.URL, error) { - u, _ := url.Parse(proxy.URL) - u.User = url.UserPassword("aladdin", "opensesame") - return u, nil - } - h := tr.ProxyConnectHeader - if h == nil { - h = make(http.Header) - } - tr.ProxyConnectHeader = h.Clone() - - req, err := http.NewRequest("GET", "https://golang.fake.tld/", nil) - if err != nil { - t.Fatal(err) - } - _, err = c.Do(req) - if err == nil { - t.Errorf("unexpected Get success") - } - - if !reflect.DeepEqual(tr.ProxyConnectHeader, h) { - t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h) - } -} - -// TestTransportGzipRecursive sends a gzip quine and checks that the -// client gets the same value back. This is more cute than anything, -// but checks that we don't recurse forever, and checks that -// Content-Encoding is removed. -func TestTransportGzipRecursive(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Encoding", "gzip") - w.Write(rgz) - })) - defer ts.Close() - - c := tc().httpClient - res, err := c.Get(ts.URL) - if err != nil { - t.Fatal(err) - } - body, err := io.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(body, rgz) { - t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x", - body, rgz) - } - if g, e := res.Header.Get("Content-Encoding"), ""; g != e { - t.Fatalf("Content-Encoding = %q; want %q", g, e) - } -} - -// golang.org/issue/7750: request fails when server replies with -// a short gzip body -func TestTransportGzipShort(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Encoding", "gzip") - w.Write([]byte{0x1f, 0x8b}) - })) - defer ts.Close() - - c := tc().httpClient - res, err := c.Get(ts.URL) - if err != nil { - t.Fatal(err) - } - defer res.Body.Close() - _, err = io.ReadAll(res.Body) - if err == nil { - t.Fatal("Expect an error from reading a body.") - } - if err != io.ErrUnexpectedEOF { - t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err) - } -} - -// Wait until number of goroutines is no greater than nmax, or time out. -func waitNumGoroutine(nmax int) int { - nfinal := runtime.NumGoroutine() - for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- { - time.Sleep(50 * time.Millisecond) - runtime.GC() - nfinal = runtime.NumGoroutine() - } - return nfinal -} - -// tests that persistent goroutine connections shut down when no longer desired. -func TestTransportPersistConnLeak(t *testing.T) { - // Not parallel: counts goroutines - defer afterTest(t) - - const numReq = 25 - gotReqCh := make(chan bool, numReq) - unblockCh := make(chan bool, numReq) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotReqCh <- true - <-unblockCh - w.Header().Set("Content-Length", "0") - w.WriteHeader(204) - })) - defer ts.Close() - c := tc().httpClient - tr := c.Transport.(*Transport) - - n0 := runtime.NumGoroutine() - - didReqCh := make(chan bool, numReq) - failed := make(chan bool, numReq) - for i := 0; i < numReq; i++ { - go func() { - res, err := c.Get(ts.URL) - didReqCh <- true - if err != nil { - t.Logf("client fetch error: %v", err) - failed <- true - return - } - res.Body.Close() - }() - } - - // Wait for all goroutines to be stuck in the Handler. - for i := 0; i < numReq; i++ { - select { - case <-gotReqCh: - // ok - case <-failed: - // Not great but not what we are testing: - // sometimes an overloaded system will fail to make all the connections. - } - } - - nhigh := runtime.NumGoroutine() - - // Tell all handlers to unblock and reply. - close(unblockCh) - - // Wait for all HTTP clients to be done. - for i := 0; i < numReq; i++ { - <-didReqCh - } - - tr.CloseIdleConnections() - nfinal := waitNumGoroutine(n0 + 5) - - growth := nfinal - n0 - - // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. - // Previously we were leaking one per numReq. - if int(growth) > 5 { - t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) - t.Error("too many new goroutines") - } -} - -// golang.org/issue/4531: Transport leaks goroutines when -// request.ContentLength is explicitly short -func TestTransportPersistConnLeakShortBody(t *testing.T) { - // Not parallel: measures goroutines. - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - })) - defer ts.Close() - c := tc().httpClient - tr := c.Transport.(*Transport) - - n0 := runtime.NumGoroutine() - body := []byte("Hello") - for i := 0; i < 20; i++ { - req, err := http.NewRequest("POST", ts.URL, bytes.NewReader(body)) - if err != nil { - t.Fatal(err) - } - req.ContentLength = int64(len(body) - 2) // explicitly short - _, err = c.Do(req) - if err == nil { - t.Fatal("Expect an error from writing too long of a body.") - } - } - nhigh := runtime.NumGoroutine() - tr.CloseIdleConnections() - nfinal := waitNumGoroutine(n0 + 5) - - growth := nfinal - n0 - - // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. - // Previously we were leaking one per numReq. - t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) - if int(growth) > 5 { - t.Error("too many new goroutines") - } -} - -// A countedConn is a net.Conn that decrements an atomic counter when finalized. -type countedConn struct { - net.Conn -} - -// A countingDialer dials connections and counts the number that remain reachable. -type countingDialer struct { - dialer net.Dialer - mu sync.Mutex - total, live int64 -} - -func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - conn, err := d.dialer.DialContext(ctx, network, address) - if err != nil { - return nil, err - } - - counted := new(countedConn) - counted.Conn = conn - - d.mu.Lock() - defer d.mu.Unlock() - d.total++ - d.live++ - - runtime.SetFinalizer(counted, d.decrement) - return counted, nil -} - -func (d *countingDialer) decrement(*countedConn) { - d.mu.Lock() - defer d.mu.Unlock() - d.live-- -} - -func (d *countingDialer) Read() (total, live int64) { - d.mu.Lock() - defer d.mu.Unlock() - return d.total, d.live -} - -func TestTransportPersistConnLeakNeverIdle(t *testing.T) { - defer afterTest(t) - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Close every connection so that it cannot be kept alive. - conn, _, err := w.(http.Hijacker).Hijack() - if err != nil { - t.Errorf("Hijack failed unexpectedly: %v", err) - return - } - conn.Close() - })) - defer ts.Close() - - var d countingDialer - c := tc().httpClient - c.Transport.(*Transport).DialContext = d.DialContext - - body := []byte("Hello") - for i := 0; ; i++ { - total, live := d.Read() - if live < total { - break - } - if i >= 1<<12 { - t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i) - } - - req, err := http.NewRequest("POST", ts.URL, bytes.NewReader(body)) - if err != nil { - t.Fatal(err) - } - _, err = c.Do(req) - if err == nil { - t.Fatal("expected broken connection") - } - - runtime.GC() - } -} - -type countedContext struct { - context.Context -} - -type contextCounter struct { - mu sync.Mutex - live int64 -} - -func (cc *contextCounter) Track(ctx context.Context) context.Context { - counted := new(countedContext) - counted.Context = ctx - cc.mu.Lock() - defer cc.mu.Unlock() - cc.live++ - runtime.SetFinalizer(counted, cc.decrement) - return counted -} - -func (cc *contextCounter) decrement(*countedContext) { - cc.mu.Lock() - defer cc.mu.Unlock() - cc.live-- -} - -func (cc *contextCounter) Read() (live int64) { - cc.mu.Lock() - defer cc.mu.Unlock() - return cc.live -} - -// This used to crash; https://golang.org/issue/3266 -func TestTransportIdleConnCrash(t *testing.T) { - defer afterTest(t) - var tr *Transport - - unblockCh := make(chan bool, 1) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - <-unblockCh - tr.CloseIdleConnections() - })) - defer ts.Close() - c := tc().httpClient - tr = c.Transport.(*Transport) - - didreq := make(chan bool) - go func() { - res, err := c.Get(ts.URL) - if err != nil { - t.Error(err) - } else { - res.Body.Close() // returns idle conn - } - didreq <- true - }() - unblockCh <- true - <-didreq -} - -// Test that the transport doesn't close the TCP connection early, -// before the response body has been read. This was a regression -// which sadly lacked a triggering test. The large response body made -// the old race easier to trigger. -func TestIssue3644(t *testing.T) { - defer afterTest(t) - const numFoos = 5000 - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Connection", "close") - for i := 0; i < numFoos; i++ { - w.Write([]byte("foo ")) - } - })) - defer ts.Close() - c := tc().httpClient - res, err := c.Get(ts.URL) - if err != nil { - t.Fatal(err) - } - defer res.Body.Close() - bs, err := io.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - if len(bs) != numFoos*len("foo ") { - t.Errorf("unexpected response length") - } -} - -// Test that a client receives a server's reply, even if the server doesn't read -// the entire request body. -func TestIssue3595(t *testing.T) { - setParallel(t) - defer afterTest(t) - const deniedMsg = "sorry, denied." - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, deniedMsg, http.StatusUnauthorized) - })) - defer ts.Close() - c := tc().httpClient - res, err := c.Post(ts.URL, "application/octet-stream", tests.NeverEnding('a')) - if err != nil { - t.Errorf("Post: %v", err) - return - } - got, err := io.ReadAll(res.Body) - if err != nil { - t.Fatalf("Body ReadAll: %v", err) - } - if !strings.Contains(string(got), deniedMsg) { - t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg) - } -} - -// From https://golang.org/issue/4454 , -// "client fails to handle requests with no body and chunked encoding" -func TestChunkedNoContent(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNoContent) - })) - defer ts.Close() - - c := tc().httpClient - for _, closeBody := range []bool{true, false} { - const n = 4 - for i := 1; i <= n; i++ { - res, err := c.Get(ts.URL) - if err != nil { - t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err) - } else { - if closeBody { - res.Body.Close() - } - } - } - } -} - -// SetPendingDialHooks sets the hooks that run before and after handling -// pending dials. -func SetPendingDialHooks(before, after func()) { - unnilTestHook(&before) - unnilTestHook(&after) - testHookPrePendingDial, testHookPostPendingDial = before, after -} - -func TestTransportConcurrency(t *testing.T) { - // Not parallel: uses global test hooks. - defer afterTest(t) - maxProcs, numReqs := 16, 500 - if testing.Short() { - maxProcs, numReqs = 4, 50 - } - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "%v", r.FormValue("echo")) - })) - defer ts.Close() - - var wg sync.WaitGroup - wg.Add(numReqs) - - // Due to the Transport's "socket late binding" (see - // idleConnCh in transport.go), the numReqs HTTP requests - // below can finish with a dial still outstanding. To keep - // the leak checker happy, keep track of pending dials and - // wait for them to finish (and be closed or returned to the - // idle pool) before we close idle connections. - SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) - defer SetPendingDialHooks(nil, nil) - - c := tc().httpClient - reqs := make(chan string) - defer close(reqs) - - for i := 0; i < maxProcs*2; i++ { - go func() { - for req := range reqs { - res, err := c.Get(ts.URL + "/?echo=" + req) - if err != nil { - if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") { - // https://go.dev/issue/52168: this test was observed to fail with - // ECONNRESET errors in Dial on various netbsd builders. - t.Logf("error on req %s: %v", req, err) - t.Logf("(see https://go.dev/issue/52168)") - } else { - t.Errorf("error on req %s: %v", req, err) - } - wg.Done() - continue - } - all, err := io.ReadAll(res.Body) - if err != nil { - t.Errorf("read error on req %s: %v", req, err) - } else if string(all) != req { - t.Errorf("body of req %s = %q; want %q", req, all, req) - } - res.Body.Close() - wg.Done() - } - }() - } - for i := 0; i < numReqs; i++ { - reqs <- fmt.Sprintf("request-%d", i) - } - wg.Wait() -} - -// loggingConn is used for debugging. -type loggingConn struct { - name string - net.Conn -} - -var ( - uniqNameMu sync.Mutex - uniqNameNext = make(map[string]int) -) - -func newLoggingConn(baseName string, c net.Conn) net.Conn { - uniqNameMu.Lock() - defer uniqNameMu.Unlock() - uniqNameNext[baseName]++ - return &loggingConn{ - name: fmt.Sprintf("%s-%d", baseName, uniqNameNext[baseName]), - Conn: c, - } -} - -func (c *loggingConn) Write(p []byte) (n int, err error) { - log.Printf("%s.Write(%d) = ....", c.name, len(p)) - n, err = c.Conn.Write(p) - log.Printf("%s.Write(%d) = %d, %v", c.name, len(p), n, err) - return -} - -func (c *loggingConn) Read(p []byte) (n int, err error) { - log.Printf("%s.Read(%d) = ....", c.name, len(p)) - n, err = c.Conn.Read(p) - log.Printf("%s.Read(%d) = %d, %v", c.name, len(p), n, err) - return -} - -func (c *loggingConn) Close() (err error) { - log.Printf("%s.Close() = ...", c.name) - err = c.Conn.Close() - log.Printf("%s.Close() = %v", c.name, err) - return -} - -func TestIssue4191_InfiniteGetTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) - const debug = false - mux := http.NewServeMux() - mux.HandleFunc("/get", func(w http.ResponseWriter, r *http.Request) { - io.Copy(w, tests.NeverEnding('a')) - }) - ts := httptest.NewServer(mux) - defer ts.Close() - timeout := 100 * time.Millisecond - - c := tc().httpClient - c.Transport.(*Transport).DialContext = func(_ context.Context, n, addr string) (net.Conn, error) { - conn, err := net.Dial(n, addr) - if err != nil { - return nil, err - } - conn.SetDeadline(time.Now().Add(timeout)) - if debug { - conn = newLoggingConn("client", conn) - } - return conn, nil - } - - getFailed := false - nRuns := 5 - if testing.Short() { - nRuns = 1 - } - for i := 0; i < nRuns; i++ { - if debug { - println("run", i+1, "of", nRuns) - } - sres, err := c.Get(ts.URL + "/get") - if err != nil { - if !getFailed { - // Make the timeout longer, once. - getFailed = true - t.Logf("increasing timeout") - i-- - timeout *= 10 - continue - } - t.Errorf("Error issuing GET: %v", err) - break - } - _, err = io.Copy(io.Discard, sres.Body) - if err == nil { - t.Errorf("Unexpected successful copy") - break - } - } - if debug { - println("tests complete; waiting for handlers to finish") - } -} - -func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) - const debug = false - mux := http.NewServeMux() - mux.HandleFunc("/get", func(w http.ResponseWriter, r *http.Request) { - io.Copy(w, tests.NeverEnding('a')) - }) - mux.HandleFunc("/put", func(w http.ResponseWriter, r *http.Request) { - defer r.Body.Close() - io.Copy(io.Discard, r.Body) - }) - ts := httptest.NewServer(mux) - timeout := 100 * time.Millisecond - - c := tc().httpClient - c.Transport.(*Transport).DialContext = func(_ context.Context, n, addr string) (net.Conn, error) { - conn, err := net.Dial(n, addr) - if err != nil { - return nil, err - } - conn.SetDeadline(time.Now().Add(timeout)) - if debug { - conn = newLoggingConn("client", conn) - } - return conn, nil - } - - getFailed := false - nRuns := 5 - if testing.Short() { - nRuns = 1 - } - for i := 0; i < nRuns; i++ { - if debug { - println("run", i+1, "of", nRuns) - } - sres, err := c.Get(ts.URL + "/get") - if err != nil { - if !getFailed { - // Make the timeout longer, once. - getFailed = true - t.Logf("increasing timeout") - i-- - timeout *= 10 - continue - } - t.Errorf("Error issuing GET: %v", err) - break - } - req, _ := http.NewRequest("PUT", ts.URL+"/put", sres.Body) - _, err = c.Do(req) - if err == nil { - sres.Body.Close() - t.Errorf("Unexpected successful PUT") - break - } - sres.Body.Close() - } - if debug { - println("tests complete; waiting for handlers to finish") - } - ts.Close() -} - -func reqWithT(r *http.Request, t *testing.T) *http.Request { - return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf)) -} - -func TestTransportResponseHeaderTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) - if testing.Short() { - t.Skip("skipping timeout test in -short mode") - } - inHandler := make(chan bool, 1) - mux := http.NewServeMux() - mux.HandleFunc("/fast", func(w http.ResponseWriter, r *http.Request) { - inHandler <- true - }) - mux.HandleFunc("/slow", func(w http.ResponseWriter, r *http.Request) { - inHandler <- true - time.Sleep(2 * time.Second) - }) - ts := httptest.NewServer(mux) - defer ts.Close() - - c := tc().httpClient - c.Transport.(*Transport).ResponseHeaderTimeout = 500 * time.Millisecond - - tests := []struct { - path string - want int - wantErr string - }{ - {path: "/fast", want: 200}, - {path: "/slow", wantErr: "timeout awaiting response headers"}, - {path: "/fast", want: 200}, - } - for i, tt := range tests { - req, _ := http.NewRequest("GET", ts.URL+tt.path, nil) - req = reqWithT(req, t) - res, err := c.Do(req) - select { - case <-inHandler: - case <-time.After(5 * time.Second): - t.Errorf("never entered handler for test index %d, %s", i, tt.path) - continue - } - if err != nil { - uerr, ok := err.(*url.Error) - if !ok { - t.Errorf("error is not an url.Error; got: %#v", err) - continue - } - nerr, ok := uerr.Err.(net.Error) - if !ok { - t.Errorf("error does not satisfy net.Error interface; got: %#v", err) - continue - } - if !nerr.Timeout() { - t.Errorf("want timeout error; got: %q", nerr) - continue - } - if strings.Contains(err.Error(), tt.wantErr) { - continue - } - t.Errorf("%d. unexpected error: %v", i, err) - continue - } - if tt.wantErr != "" { - t.Errorf("%d. no error. expected error: %v", i, tt.wantErr) - continue - } - if res.StatusCode != tt.want { - t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want) - } - } -} - -func testTransportCancelRequestInDo(t *testing.T, body io.Reader) { - setParallel(t) - defer afterTest(t) - if testing.Short() { - t.Skip("skipping test in -short mode") - } - unblockc := make(chan bool) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - <-unblockc - })) - defer ts.Close() - defer close(unblockc) - - c := tc().httpClient - tr := c.Transport.(*Transport) - - donec := make(chan bool) - req, _ := http.NewRequest("GET", ts.URL, body) - go func() { - defer close(donec) - c.Do(req) - }() - start := time.Now() - timeout := 10 * time.Second - for time.Since(start) < timeout { - time.Sleep(100 * time.Millisecond) - tr.CancelRequest(req) - select { - case <-donec: - return - default: - } - } - t.Errorf("Do of canceled request has not returned after %v", timeout) -} - -func TestCancelRequestWithChannel(t *testing.T) { - setParallel(t) - defer afterTest(t) - if testing.Short() { - t.Skip("skipping test in -short mode") - } - unblockc := make(chan bool) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "Hello") - w.(http.Flusher).Flush() // send headers and some body - <-unblockc - })) - defer ts.Close() - defer close(unblockc) - - c := tc().httpClient - tr := c.Transport.(*Transport) - - req, _ := http.NewRequest("GET", ts.URL, nil) - ch := make(chan struct{}) - req.Cancel = ch - - res, err := c.Do(req) - if err != nil { - t.Fatal(err) - } - go func() { - time.Sleep(1 * time.Second) - close(ch) - }() - t0 := time.Now() - body, err := io.ReadAll(res.Body) - d := time.Since(t0) - - if err != common.ErrRequestCanceled { - t.Errorf("Body.Read error = %v; want errRequestCanceled", err) - } - if string(body) != "Hello" { - t.Errorf("Body = %q; want Hello", body) - } - if d < 500*time.Millisecond { - t.Errorf("expected ~1 second delay; got %v", d) - } - // Verify no outstanding requests after readLoop/writeLoop - // goroutines shut down. - for tries := 5; tries > 0; tries-- { - n := tr.NumPendingRequestsForTesting() - if n == 0 { - break - } - time.Sleep(100 * time.Millisecond) - if tries == 1 { - t.Errorf("pending requests = %d; want 0", n) - } - } -} - -func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) { - testCancelRequestWithChannelBeforeDo(t, false) -} -func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) { - testCancelRequestWithChannelBeforeDo(t, true) -} -func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { - setParallel(t) - defer afterTest(t) - unblockc := make(chan bool) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - <-unblockc - })) - defer ts.Close() - defer close(unblockc) - - c := tc().httpClient - - req, _ := http.NewRequest("GET", ts.URL, nil) - if withCtx { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - req = req.WithContext(ctx) - } else { - ch := make(chan struct{}) - req.Cancel = ch - close(ch) - } - - _, err := c.Do(req) - if ue, ok := err.(*url.Error); ok { - err = ue.Err - } - if withCtx { - if err != context.Canceled { - t.Errorf("Do error = %v; want %v", err, context.Canceled) - } - } else { - if err == nil || !strings.Contains(err.Error(), "canceled") { - t.Errorf("Do error = %v; want cancellation", err) - } - } -} - -// Issue 11020. The returned error message should be errRequestCanceled -func TestTransportCancelBeforeResponseHeaders(t *testing.T) { - defer afterTest(t) - - serverConnCh := make(chan net.Conn, 1) - tr := T().SetDial(func(_ context.Context, network, addr string) (net.Conn, error) { - cc, sc := net.Pipe() - serverConnCh <- sc - return cc, nil - }) - defer tr.CloseIdleConnections() - errc := make(chan error, 1) - req, _ := http.NewRequest("GET", "http://example.com/", nil) - go func() { - _, err := tr.RoundTrip(req) - errc <- err - }() - - sc := <-serverConnCh - verb := make([]byte, 3) - if _, err := io.ReadFull(sc, verb); err != nil { - t.Errorf("Error reading HTTP verb from server: %v", err) - } - if string(verb) != "GET" { - t.Errorf("server received %q; want GET", verb) - } - defer sc.Close() - - tr.CancelRequest(req) - - err := <-errc - if err == nil { - t.Fatalf("unexpected success from RoundTrip") - } - if err != common.ErrRequestCanceled { - t.Errorf("RoundTrip error = %v; want errRequestCanceled", err) - } -} - -// golang.org/issue/3672 -- Client can't close HTTP stream -// Calling Close on a Response.Body used to just read until EOF. -// Now it actually closes the TCP connection. -func TestTransportCloseResponseBody(t *testing.T) { - defer afterTest(t) - writeErr := make(chan error, 1) - msg := []byte("young\n") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - for { - _, err := w.Write(msg) - if err != nil { - writeErr <- err - return - } - w.(http.Flusher).Flush() - } - })) - defer ts.Close() - - c := tc().httpClient - tr := c.Transport.(*Transport) - - req, _ := http.NewRequest("GET", ts.URL, nil) - defer tr.CancelRequest(req) - - res, err := c.Do(req) - if err != nil { - t.Fatal(err) - } - - const repeats = 3 - buf := make([]byte, len(msg)*repeats) - want := bytes.Repeat(msg, repeats) - - _, err = io.ReadFull(res.Body, buf) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(buf, want) { - t.Fatalf("read %q; want %q", buf, want) - } - didClose := make(chan error, 1) - go func() { - didClose <- res.Body.Close() - }() - select { - case err := <-didClose: - if err != nil { - t.Errorf("Close = %v", err) - } - case <-time.After(10 * time.Second): - t.Fatal("too long waiting for close") - } - select { - case err := <-writeErr: - if err == nil { - t.Errorf("expected non-nil write error") - } - case <-time.After(10 * time.Second): - t.Fatal("too long waiting for write error") - } -} - -func TestTransportNoHost(t *testing.T) { - defer afterTest(t) - tr := T() - _, err := tr.RoundTrip(&http.Request{ - Header: make(http.Header), - URL: &url.URL{ - Scheme: "http", - }, - }) - want := "http: no Host in request URL" - if got := fmt.Sprint(err); got != want { - t.Errorf("error = %v; want %q", err, want) - } -} - -// Issue 13311 -func TestTransportEmptyMethod(t *testing.T) { - req, _ := http.NewRequest("GET", "http://foo.com/", nil) - req.Method = "" // docs say "For client requests an empty string means GET" - got, err := httputil.DumpRequestOut(req, false) // DumpRequestOut uses Transport - if err != nil { - t.Fatal(err) - } - if !strings.Contains(string(got), "GET ") { - t.Fatalf("expected substring 'GET '; got: %s", got) - } -} - -func TestTransportSocketLateBinding(t *testing.T) { - setParallel(t) - defer afterTest(t) - - mux := http.NewServeMux() - fooGate := make(chan bool, 1) - mux.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("foo-ipport", r.RemoteAddr) - w.(http.Flusher).Flush() - <-fooGate - }) - mux.HandleFunc("/bar", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("bar-ipport", r.RemoteAddr) - }) - ts := httptest.NewServer(mux) - defer ts.Close() - - dialGate := make(chan bool, 1) - c := tc().httpClient - c.Transport.(*Transport).DialContext = func(_ context.Context, n, addr string) (net.Conn, error) { - if <-dialGate { - return net.Dial(n, addr) - } - return nil, errors.New("manually closed") - } - - dialGate <- true // only allow one dial - fooRes, err := c.Get(ts.URL + "/foo") - if err != nil { - t.Fatal(err) - } - fooAddr := fooRes.Header.Get("foo-ipport") - if fooAddr == "" { - t.Fatal("No addr on /foo request") - } - time.AfterFunc(200*time.Millisecond, func() { - // let the foo response finish so we can use its - // connection for /bar - fooGate <- true - io.Copy(io.Discard, fooRes.Body) - fooRes.Body.Close() - }) - - barRes, err := c.Get(ts.URL + "/bar") - if err != nil { - t.Fatal(err) - } - barAddr := barRes.Header.Get("bar-ipport") - if barAddr != fooAddr { - t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr) - } - barRes.Body.Close() - dialGate <- false -} - -type dummyAddr string -type oneConnListener struct { - conn net.Conn -} - -func (l *oneConnListener) Accept() (c net.Conn, err error) { - c = l.conn - if c == nil { - err = io.EOF - return - } - err = nil - l.conn = nil - return -} - -func (l *oneConnListener) Close() error { - return nil -} - -func (l *oneConnListener) Addr() net.Addr { - return dummyAddr("test-address") -} - -func (a dummyAddr) Network() string { - return string(a) -} - -func (a dummyAddr) String() string { - return string(a) -} - -type noopConn struct{} - -func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") } -func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") } -func (noopConn) SetDeadline(t time.Time) error { return nil } -func (noopConn) SetReadDeadline(t time.Time) error { return nil } -func (noopConn) SetWriteDeadline(t time.Time) error { return nil } - -type rwTestConn struct { - io.Reader - io.Writer - noopConn - - closeFunc func() error // called if non-nil - closec chan bool // else, if non-nil, send value to it on close -} - -func (c *rwTestConn) Close() error { - if c.closeFunc != nil { - return c.closeFunc() - } - select { - case c.closec <- true: - default: - } - return nil -} - -// Issue 2184 -func TestTransportReading100Continue(t *testing.T) { - defer afterTest(t) - - const numReqs = 5 - reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) } - reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) } - - send100Response := func(w *io.PipeWriter, r *io.PipeReader) { - defer w.Close() - defer r.Close() - br := bufio.NewReader(r) - n := 0 - for { - n++ - req, err := http.ReadRequest(br) - if err == io.EOF { - return - } - if err != nil { - t.Error(err) - return - } - slurp, err := io.ReadAll(req.Body) - if err != nil { - t.Errorf("Server request body slurp: %v", err) - return - } - id := req.Header.Get("Request-Id") - resCode := req.Header.Get("X-Want-Response-Code") - if resCode == "" { - resCode = "100 Continue" - if string(slurp) != reqBody(n) { - t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n)) - } - } - body := fmt.Sprintf("Response number %d", n) - v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s -Date: Thu, 28 Feb 2013 17:55:41 GMT - -HTTP/1.1 200 OK -Content-Type: text/html -Echo-Request-Id: %s -Content-Length: %d - -%s`, resCode, id, len(body), body), "\n", "\r\n", -1)) - w.Write(v) - if id == reqID(numReqs) { - return - } - } - - } - - tr := T().SetDial(func(_ context.Context, n, addr string) (net.Conn, error) { - sr, sw := io.Pipe() // server read/write - cr, cw := io.Pipe() // client read/write - conn := &rwTestConn{ - Reader: cr, - Writer: sw, - closeFunc: func() error { - sw.Close() - cw.Close() - return nil - }, - } - go send100Response(cw, sr) - return conn, nil - }) - defer tr.CloseIdleConnections() - c := &http.Client{Transport: tr} - - testResponse := func(req *http.Request, name string, wantCode int) { - t.Helper() - res, err := c.Do(req) - if err != nil { - t.Fatalf("%s: Do: %v", name, err) - } - if res.StatusCode != wantCode { - t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode) - } - if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack { - t.Errorf("%s: response id %q != request id %q", name, idBack, id) - } - _, err = io.ReadAll(res.Body) - if err != nil { - t.Fatalf("%s: Slurp error: %v", name, err) - } - } - - // Few 100 responses, making sure we're not off-by-one. - for i := 1; i <= numReqs; i++ { - req, _ := http.NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i))) - req.Header.Set("Request-Id", reqID(i)) - testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200) - } -} - -type clientServerTest struct { - t *testing.T - h2 bool - h http.Handler - ts *httptest.Server - tr *Transport - c *http.Client -} - -func (t *clientServerTest) close() { - t.tr.CloseIdleConnections() - t.ts.Close() -} - -func (t *clientServerTest) getURL(u string) string { - res, err := t.c.Get(u) - if err != nil { - t.t.Fatal(err) - } - defer res.Body.Close() - slurp, err := io.ReadAll(res.Body) - if err != nil { - t.t.Fatal(err) - } - return string(slurp) -} - -func (t *clientServerTest) scheme() string { - if t.h2 { - return "https" - } - return "http" -} - -const ( - h1Mode = false - h2Mode = true -) - -var quietLog = log.New(io.Discard, "", 0) - -var optQuietLog = func(ts *httptest.Server) { - ts.Config.ErrorLog = quietLog -} - -func newClientServerTest(t *testing.T, h2 bool, h http.Handler, opts ...interface{}) *clientServerTest { - cst := &clientServerTest{ - t: t, - h2: h2, - h: h, - tr: T(), - } - cst.c = &http.Client{Transport: cst.tr} - cst.ts = httptest.NewUnstartedServer(h) - - for _, opt := range opts { - switch opt := opt.(type) { - case func(*Transport): - opt(cst.tr) - case func(*httptest.Server): - opt(cst.ts) - default: - t.Fatalf("unhandled option type %T", opt) - } - } - - if !h2 { - cst.ts.Start() - return cst - } - nethttp2.ConfigureServer(cst.ts.Config, nil) - cst.ts.TLS = cst.ts.Config.TLSConfig - cst.ts.StartTLS() - - cst.tr.TLSClientConfig.InsecureSkipVerify = true - return cst -} - -// Issue 17739: the HTTP client must ignore any unknown 1xx -// informational responses before the actual response. -func TestTransportIgnore1xxResponses(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, buf, _ := w.(http.Hijacker).Hijack() - buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello")) - buf.Flush() - conn.Close() - })) - defer cst.close() - cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway - - var got bytes.Buffer - - req, _ := http.NewRequest("GET", cst.ts.URL, nil) - req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ - Got1xxResponse: func(code int, header textproto.MIMEHeader) error { - fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header) - return nil - }, - })) - res, err := cst.c.Do(req) - if err != nil { - t.Fatal(err) - } - defer res.Body.Close() - - res.Write(&got) - want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello" - if got.String() != want { - t.Errorf(" got: %q\nwant: %q\n", got.Bytes(), want) - } -} - -func TestTransportLimits1xxResponses(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, buf, _ := w.(http.Hijacker).Hijack() - for i := 0; i < 10; i++ { - buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n")) - } - buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) - buf.Flush() - conn.Close() - })) - defer cst.close() - cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway - - res, err := cst.c.Get(cst.ts.URL) - if res != nil { - defer res.Body.Close() - } - got := fmt.Sprint(err) - wantSub := "too many 1xx informational responses" - if !strings.Contains(got, wantSub) { - t.Errorf("Get error = %v; want substring %q", err, wantSub) - } -} - -// Issue 26161: the HTTP client must treat 101 responses -// as the final response. -func TestTransportTreat101Terminal(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, buf, _ := w.(http.Hijacker).Hijack() - buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n")) - buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) - buf.Flush() - conn.Close() - })) - defer cst.close() - res, err := cst.c.Get(cst.ts.URL) - if err != nil { - t.Fatal(err) - } - defer res.Body.Close() - if res.StatusCode != http.StatusSwitchingProtocols { - t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode) - } -} - -type proxyFromEnvTest struct { - req string // URL to fetch; blank means "http://example.com" - - env string // HTTP_PROXY - httpsenv string // HTTPS_PROXY - noenv string // NO_PROXY - reqmeth string // REQUEST_METHOD - - want string - wanterr error -} - -func (t proxyFromEnvTest) String() string { - var buf bytes.Buffer - space := func() { - if buf.Len() > 0 { - buf.WriteByte(' ') - } - } - if t.env != "" { - fmt.Fprintf(&buf, "http_proxy=%q", t.env) - } - if t.httpsenv != "" { - space() - fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv) - } - if t.noenv != "" { - space() - fmt.Fprintf(&buf, "no_proxy=%q", t.noenv) - } - if t.reqmeth != "" { - space() - fmt.Fprintf(&buf, "request_method=%q", t.reqmeth) - } - req := "http://example.com" - if t.req != "" { - req = t.req - } - space() - fmt.Fprintf(&buf, "req=%q", req) - return strings.TrimSpace(buf.String()) -} - -var proxyFromEnvTests = []proxyFromEnvTest{ - {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"}, - {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"}, - {env: "cache.corp.example.com", want: "http://cache.corp.example.com"}, - {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"}, - {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"}, - {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"}, - {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"}, - - // Don't use secure for http - {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"}, - // Use secure for https. - {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"}, - {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"}, - - // Issue 16405: don't use HTTP_PROXY in a CGI environment, - // where HTTP_PROXY can be attacker-controlled. - {env: "http://10.1.2.3:8080", reqmeth: "POST", - want: "", - wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")}, - - {want: ""}, - - {noenv: "example.com", req: "http://example.com/", env: "proxy", want: ""}, - {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, - {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, - {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: ""}, - {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, -} - -func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *http.Request) (*url.URL, error)) { - t.Helper() - reqURL := tt.req - if reqURL == "" { - reqURL = "http://example.com" - } - req, _ := http.NewRequest("GET", reqURL, nil) - url, err := proxyForRequest(req) - if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e { - t.Errorf("%v: got error = %q, want %q", tt, g, e) - return - } - if got := fmt.Sprintf("%s", url); got != tt.want { - t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want) - } -} - -func ResetProxyEnv() { - for _, v := range []string{"HTTP_PROXY", "http_proxy", "NO_PROXY", "no_proxy", "REQUEST_METHOD"} { - os.Unsetenv(v) - } -} - -func TestProxyFromEnvironment(t *testing.T) { - ResetProxyEnv() - defer ResetProxyEnv() - for _, tt := range proxyFromEnvTests { - testProxyForRequest(t, tt, func(req *http.Request) (*url.URL, error) { - os.Setenv("HTTP_PROXY", tt.env) - os.Setenv("HTTPS_PROXY", tt.httpsenv) - os.Setenv("NO_PROXY", tt.noenv) - os.Setenv("REQUEST_METHOD", tt.reqmeth) - return httpproxy.FromEnvironment().ProxyFunc()(req.URL) - }) - } -} - -func TestProxyFromEnvironmentLowerCase(t *testing.T) { - ResetProxyEnv() - defer ResetProxyEnv() - for _, tt := range proxyFromEnvTests { - testProxyForRequest(t, tt, func(req *http.Request) (*url.URL, error) { - os.Setenv("http_proxy", tt.env) - os.Setenv("https_proxy", tt.httpsenv) - os.Setenv("no_proxy", tt.noenv) - os.Setenv("REQUEST_METHOD", tt.reqmeth) - return httpproxy.FromEnvironment().ProxyFunc()(req.URL) - }) - } -} - -func TestIdleConnChannelLeak(t *testing.T) { - // Not parallel: uses global test hooks. - var mu sync.Mutex - var n int - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - mu.Lock() - n++ - mu.Unlock() - })) - defer ts.Close() - - const nReqs = 5 - didRead := make(chan bool, nReqs) - SetReadLoopBeforeNextReadHook(func() { didRead <- true }) - defer SetReadLoopBeforeNextReadHook(nil) - - c := tc().httpClient - tr := c.Transport.(*Transport) - tr.DialContext = func(_ context.Context, netw, addr string) (net.Conn, error) { - return net.Dial(netw, ts.Listener.Addr().String()) - } - - // First, without keep-alives. - for _, disableKeep := range []bool{true, false} { - tr.DisableKeepAlives = disableKeep - for i := 0; i < nReqs; i++ { - _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i)) - if err != nil { - t.Fatal(err) - } - // Note: no res.Body.Close is needed here, since the - // response Content-Length is zero. Perhaps the test - // should be more explicit and use a HEAD, but tests - // elsewhere guarantee that zero byte responses generate - // a "Content-Length: 0" instead of chunking. - } - - // At this point, each of the 5 Transport.readLoop goroutines - // are scheduling noting that there are no response bodies (see - // earlier comment), and are then calling putIdleConn, which - // decrements this count. Usually that happens quickly, which is - // why this test has seemed to work for ages. But it's still - // racey: we have wait for them to finish first. See Issue 10427 - for i := 0; i < nReqs; i++ { - <-didRead - } - - if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 { - t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got) - } - } -} - -// Verify the status quo: that the Client.Post function coerces its -// body into a ReadCloser if it's a Closer, and that the Transport -// then closes it. -func TestTransportClosesRequestBody(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - io.Copy(io.Discard, r.Body) - })) - defer ts.Close() - - c := tc().httpClient - - closes := 0 - - res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - if closes != 1 { - t.Errorf("closes = %d; want 1", closes) - } -} - -func TestTransportTLSHandshakeTimeout(t *testing.T) { - defer afterTest(t) - if testing.Short() { - t.Skip("skipping in short mode") - } - ln := tests.NewLocalListener(t) - defer ln.Close() - testdonec := make(chan struct{}) - defer close(testdonec) - - go func() { - c, err := ln.Accept() - if err != nil { - t.Error(err) - return - } - <-testdonec - c.Close() - }() - - getdonec := make(chan struct{}) - go func() { - defer close(getdonec) - tr := T().SetDial(func(_ context.Context, _, _ string) (net.Conn, error) { - return net.Dial("tcp", ln.Addr().String()) - }).SetTLSHandshakeTimeout(250 * time.Millisecond) - cl := &http.Client{Transport: tr} - _, err := cl.Get("https://dummy.tld/") - if err == nil { - t.Error("expected error") - return - } - ue, ok := err.(*url.Error) - if !ok { - t.Errorf("expected url.Error; got %#v", err) - return - } - ne, ok := ue.Err.(net.Error) - if !ok { - t.Errorf("expected net.Error; got %#v", err) - return - } - if !ne.Timeout() { - t.Errorf("expected timeout error; got %v", err) - } - if !strings.Contains(err.Error(), "handshake timeout") { - t.Errorf("expected 'handshake timeout' in error; got %v", err) - } - }() - select { - case <-getdonec: - case <-time.After(5 * time.Second): - t.Error("test timeout; TLS handshake hung?") - } -} - -// Trying to repro golang.org/issue/3514 -func TestTLSServerClosesConnection(t *testing.T) { - defer afterTest(t) - - closedc := make(chan bool, 1) - ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, "/keep-alive-then-die") { - conn, _, _ := w.(http.Hijacker).Hijack() - conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) - conn.Close() - closedc <- true - return - } - fmt.Fprintf(w, "hello") - })) - defer ts.Close() - - c := tc().httpClient - tr := c.Transport.(*Transport) - - var nSuccess = 0 - var errs []error - const trials = 20 - for i := 0; i < trials; i++ { - tr.CloseIdleConnections() - res, err := c.Get(ts.URL + "/keep-alive-then-die") - if err != nil { - t.Fatal(err) - } - <-closedc - slurp, err := io.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - if string(slurp) != "foo" { - t.Errorf("Got %q, want foo", slurp) - } - - // Now try again and see if we successfully - // pick a new connection. - res, err = c.Get(ts.URL + "/") - if err != nil { - errs = append(errs, err) - continue - } - slurp, err = io.ReadAll(res.Body) - if err != nil { - errs = append(errs, err) - continue - } - nSuccess++ - } - if nSuccess > 0 { - t.Logf("successes = %d of %d", nSuccess, trials) - } else { - t.Errorf("All runs failed:") - } - for _, err := range errs { - t.Logf(" err: %v", err) - } -} - -// byteFromChanReader is an io.Reader that reads a single byte at a -// time from the channel. When the channel is closed, the reader -// returns io.EOF. -type byteFromChanReader chan byte - -func (c byteFromChanReader) Read(p []byte) (n int, err error) { - if len(p) == 0 { - return - } - b, ok := <-c - if !ok { - return 0, io.EOF - } - p[0] = b - return 1, nil -} - -// Verifies that the Transport doesn't reuse a connection in the case -// where the server replies before the request has been fully -// written. We still honor that reply (see TestIssue3595), but don't -// send future requests on the connection because it's then in a -// questionable state. -// golang.org/issue/7569 -func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) - var sconn struct { - sync.Mutex - c net.Conn - } - var getOkay bool - closeConn := func() { - sconn.Lock() - defer sconn.Unlock() - if sconn.c != nil { - sconn.c.Close() - sconn.c = nil - if !getOkay { - t.Logf("Closed server connection") - } - } - } - defer closeConn() - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == "GET" { - io.WriteString(w, "bar") - return - } - conn, _, _ := w.(http.Hijacker).Hijack() - sconn.Lock() - sconn.c = conn - sconn.Unlock() - conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive - go io.Copy(io.Discard, conn) - })) - defer ts.Close() - c := tc().httpClient - - const bodySize = 256 << 10 - finalBit := make(byteFromChanReader, 1) - req, _ := http.NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(tests.NeverEnding('x'), bodySize-1), finalBit)) - req.ContentLength = bodySize - res, err := c.Do(req) - if err := wantBody(res, err, "foo"); err != nil { - t.Errorf("POST response: %v", err) - } - donec := make(chan bool) - go func() { - defer close(donec) - res, err = c.Get(ts.URL) - if err := wantBody(res, err, "bar"); err != nil { - t.Errorf("GET response: %v", err) - return - } - getOkay = true // suppress test noise - }() - time.AfterFunc(5*time.Second, closeConn) - select { - case <-donec: - finalBit <- 'x' // unblock the writeloop of the first Post - close(finalBit) - case <-time.After(7 * time.Second): - t.Fatal("timeout waiting for GET request to finish") - } -} - -// Tests that we don't leak Transport persistConn.readLoop goroutines -// when a server hangs up immediately after saying it would keep-alive. -func TestTransportIssue10457(t *testing.T) { - defer afterTest(t) // used to fail in goroutine leak check - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Send a response with no body, keep-alive - // (implicit), and then lie and immediately close the - // connection. This forces the Transport's readLoop to - // immediately Peek an io.EOF and get to the point - // that used to hang. - conn, _, _ := w.(http.Hijacker).Hijack() - conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive - conn.Close() - })) - defer ts.Close() - c := tc().httpClient - - res, err := c.Get(ts.URL) - if err != nil { - t.Fatalf("Get: %v", err) - } - defer res.Body.Close() - - // Just a sanity check that we at least get the response. The real - // test here is that the "defer afterTest" above doesn't find any - // leaked goroutines. - if got, want := res.Header.Get("Foo"), "Bar"; got != want { - t.Errorf("Foo header = %q; want %q", got, want) - } -} - -type closerFunc func() error - -func (f closerFunc) Close() error { return f() } - -type writerFuncConn struct { - net.Conn - write func(p []byte) (n int, err error) -} - -func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) } - -func hookSetter(dst *func()) func(func()) { - return func(fn func()) { - unnilTestHook(&fn) - *dst = fn - } -} - -var ( - SetEnterRoundTripHook = hookSetter(&testHookEnterRoundTrip) - SetRoundTripRetried = hookSetter(&testHookRoundTripRetried) -) - -// Issue 6981 -func TestTransportClosesBodyOnError(t *testing.T) { - setParallel(t) - defer afterTest(t) - readBody := make(chan error, 1) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := io.ReadAll(r.Body) - readBody <- err - })) - defer ts.Close() - c := tc().httpClient - fakeErr := errors.New("fake error") - didClose := make(chan bool, 1) - req, _ := http.NewRequest("POST", ts.URL, struct { - io.Reader - io.Closer - }{ - io.MultiReader(io.LimitReader(tests.NeverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)), - closerFunc(func() error { - select { - case didClose <- true: - default: - } - return nil - }), - }) - res, err := c.Do(req) - if res != nil { - defer res.Body.Close() - } - if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) { - t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error()) - } - select { - case err := <-readBody: - if err == nil { - t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'") - } - case <-time.After(5 * time.Second): - t.Error("timeout waiting for server handler to complete") - } - select { - case <-didClose: - default: - t.Errorf("didn't see Body.Close") - } -} - -func TestTransportDialTLS(t *testing.T) { - setParallel(t) - defer afterTest(t) - var mu sync.Mutex // guards following - var gotReq, didDial bool - - ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - mu.Lock() - gotReq = true - mu.Unlock() - })) - defer ts.Close() - c := tc().httpClient - c.Transport.(*Transport).DialTLSContext = func(_ context.Context, netw, addr string) (net.Conn, error) { - mu.Lock() - didDial = true - mu.Unlock() - c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig) - if err != nil { - return nil, err - } - return c, c.Handshake() - } - - res, err := c.Get(ts.URL) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - mu.Lock() - if !gotReq { - t.Error("didn't get request") - } - if !didDial { - t.Error("didn't use dial hook") - } -} - -// Test for issue 8755 -// Ensure that if a proxy returns an error, it is exposed by RoundTrip -func TestRoundTripReturnsProxyError(t *testing.T) { - badProxy := func(*http.Request) (*url.URL, error) { - return nil, errors.New("errorMessage") - } - - tr := T().SetProxy(badProxy) - - req, _ := http.NewRequest("GET", "http://example.com", nil) - - _, err := tr.RoundTrip(req) - - if err == nil { - t.Error("Expected proxy error to be returned by RoundTrip") - } -} - -// tests that putting an idle conn after a call to CloseIdleConns does return it -func TestTransportCloseIdleConnsThenReturn(t *testing.T) { - tr := T() - wantIdle := func(when string, n int) bool { - got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn - if got == n { - return true - } - t.Errorf("%s: idle conns = %d; want %d", when, got, n) - return false - } - wantIdle("start", 0) - if !tr.PutIdleTestConn("http", "example.com") { - t.Fatal("put failed") - } - if !tr.PutIdleTestConn("http", "example.com") { - t.Fatal("second put failed") - } - wantIdle("after put", 2) - tr.CloseIdleConnections() - if !tr.IsIdleForTesting() { - t.Error("should be idle after CloseIdleConnections") - } - wantIdle("after close idle", 0) - if tr.PutIdleTestConn("http", "example.com") { - t.Fatal("put didn't fail") - } - wantIdle("after second put", 0) - - tr.QueueForIdleConnForTesting() // should toggle the transport out of idle mode - if tr.IsIdleForTesting() { - t.Error("shouldn't be idle after QueueForIdleConnForTesting") - } - if !tr.PutIdleTestConn("http", "example.com") { - t.Fatal("after re-activation") - } - wantIdle("after final put", 1) -} - -// Test for issue 34282 -// Ensure that getConn doesn't call the GotConn trace hook on a HTTP/2 idle conn -func TestTransportTraceGotConnH2IdleConns(t *testing.T) { - tr := T() - wantIdle := func(when string, n int) bool { - got := tr.IdleConnCountForTesting("https", "example.com:443") // key used by PutIdleTestConnH2 - if got == n { - return true - } - t.Errorf("%s: idle conns = %d; want %d", when, got, n) - return false - } - wantIdle("start", 0) - alt := funcRoundTripper(func() {}) - if !tr.PutIdleTestConnH2("https", "example.com:443", alt) { - t.Fatal("put failed") - } - wantIdle("after put", 1) - ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ - GotConn: func(httptrace.GotConnInfo) { - // tr.getConn should leave it for the HTTP/2 alt to call GotConn. - t.Error("GotConn called") - }, - }) - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "https://example.com", nil) - _, err := tr.RoundTrip(req) - if err != errFakeRoundTrip { - t.Errorf("got error: %v; want %q", err, errFakeRoundTrip) - } - wantIdle("after round trip", 1) -} - -func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) { - if testing.Short() { - t.Skip("skipping in short mode") - } - - trFunc := func(tr *Transport) { - tr.MaxConnsPerHost = 1 - tr.MaxIdleConnsPerHost = 1 - tr.IdleConnTimeout = 10 * time.Millisecond - } - cst := newClientServerTest(t, h2Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), trFunc) - defer cst.close() - - if _, err := cst.c.Get(cst.ts.URL); err != nil { - t.Fatalf("got error: %s", err) - } - - time.Sleep(100 * time.Millisecond) - got := make(chan error) - go func() { - if _, err := cst.c.Get(cst.ts.URL); err != nil { - got <- err - } - close(got) - }() - - timeout := time.NewTimer(5 * time.Second) - defer timeout.Stop() - select { - case err := <-got: - if err != nil { - t.Fatalf("got error: %s", err) - } - case <-timeout.C: - t.Fatal("request never completed") - } -} - -// This tests that a client requesting a content range won't also -// implicitly ask for gzip support. If they want that, they need to do it -// on their own. -// golang.org/issue/8923 -func TestTransportRangeAndGzip(t *testing.T) { - defer afterTest(t) - reqc := make(chan *http.Request, 1) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - reqc <- r - })) - defer ts.Close() - c := tc().httpClient - - req, _ := http.NewRequest("GET", ts.URL, nil) - req.Header.Set("Range", "bytes=7-11") - res, err := c.Do(req) - if err != nil { - t.Fatal(err) - } - - select { - case r := <-reqc: - if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { - t.Error("Transport advertised gzip support in the Accept header") - } - if r.Header.Get("Range") == "" { - t.Error("no Range in request") - } - case <-time.After(10 * time.Second): - t.Fatal("timeout") - } - res.Body.Close() -} - -// Test for issue 10474 -func TestTransportResponseCancelRace(t *testing.T) { - defer afterTest(t) - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // important that this response has a body. - var b [1024]byte - w.Write(b[:]) - })) - defer ts.Close() - tr := tc().GetTransport() - - req, err := http.NewRequest("GET", ts.URL, nil) - if err != nil { - t.Fatal(err) - } - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - // If we do an early close, Transport just throws the connection away and - // doesn't reuse it. In order to trigger the bug, it has to reuse the connection - // so read the body - if _, err := io.Copy(io.Discard, res.Body); err != nil { - t.Fatal(err) - } - - req2, err := http.NewRequest("GET", ts.URL, nil) - if err != nil { - t.Fatal(err) - } - tr.CancelRequest(req) - res, err = tr.RoundTrip(req2) - if err != nil { - t.Fatal(err) - } - res.Body.Close() -} - -// Test for issue 19248: Content-Encoding's value is case insensitive. -func TestTransportContentEncodingCaseInsensitive(t *testing.T) { - setParallel(t) - defer afterTest(t) - for _, ce := range []string{"gzip", "GZIP"} { - ce := ce - t.Run(ce, func(t *testing.T) { - const encodedString = "Hello Gopher" - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Encoding", ce) - gz := gzip.NewWriter(w) - gz.Write([]byte(encodedString)) - gz.Close() - })) - defer ts.Close() - - res, err := ts.Client().Get(ts.URL) - if err != nil { - t.Fatal(err) - } - - body, err := io.ReadAll(res.Body) - res.Body.Close() - if err != nil { - t.Fatal(err) - } - - if string(body) != encodedString { - t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body)) - } - }) - } -} - -func TestTransportDialCancelRace(t *testing.T) { - defer afterTest(t) - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer ts.Close() - tr := tc().GetTransport() - - req, err := http.NewRequest("GET", ts.URL, nil) - if err != nil { - t.Fatal(err) - } - SetEnterRoundTripHook(func() { - tr.CancelRequest(req) - }) - defer SetEnterRoundTripHook(nil) - res, err := tr.RoundTrip(req) - if err != common.ErrRequestCanceled { - t.Errorf("expected canceled request error; got %v", err) - if err == nil { - res.Body.Close() - } - } -} - -// logWritesConn is a net.Conn that logs each Write call to writes -// and then proxies to w. -// It proxies Read calls to a reader it receives from rch. -type logWritesConn struct { - net.Conn // nil. crash on use. - - w io.Writer - - rch <-chan io.Reader - r io.Reader // nil until received by rch - - mu sync.Mutex - writes []string -} - -func (c *logWritesConn) Write(p []byte) (n int, err error) { - c.mu.Lock() - defer c.mu.Unlock() - c.writes = append(c.writes, string(p)) - return c.w.Write(p) -} - -func (c *logWritesConn) Read(p []byte) (n int, err error) { - if c.r == nil { - c.r = <-c.rch - } - return c.r.Read(p) -} - -func (c *logWritesConn) Close() error { return nil } - -// Issue 6574 -func TestTransportFlushesBodyChunks(t *testing.T) { - defer afterTest(t) - resBody := make(chan io.Reader, 1) - connr, connw := io.Pipe() // connection pipe pair - lw := &logWritesConn{ - rch: resBody, - w: connw, - } - tr := T().SetDial(func(_ context.Context, network, addr string) (net.Conn, error) { - return lw, nil - }) - bodyr, bodyw := io.Pipe() // body pipe pair - go func() { - defer bodyw.Close() - for i := 0; i < 3; i++ { - fmt.Fprintf(bodyw, "num%d\n", i) - } - }() - resc := make(chan *http.Response) - go func() { - req, _ := http.NewRequest("POST", "http://localhost:8080", bodyr) - req.Header.Set("User-Agent", "x") // known value for test - res, err := tr.RoundTrip(req) - if err != nil { - t.Errorf("RoundTrip: %v", err) - close(resc) - return - } - resc <- res - - }() - // Fully consume the request before checking the Write log vs. want. - req, err := http.ReadRequest(bufio.NewReader(connr)) - if err != nil { - t.Fatal(err) - } - io.Copy(io.Discard, req.Body) - - // Unblock the transport's roundTrip goroutine. - resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n") - res, ok := <-resc - if !ok { - return - } - defer res.Body.Close() - - want := []string{ - "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n", - "5\r\nnum0\n\r\n", - "5\r\nnum1\n\r\n", - "5\r\nnum2\n\r\n", - "0\r\n\r\n", - } - if !reflect.DeepEqual(lw.writes, want) { - t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want) - } -} - -// Issue 22088: flush Transport request headers if we're not sure the body won't block on read. -func TestTransportFlushesRequestHeader(t *testing.T) { - defer afterTest(t) - gotReq := make(chan struct{}) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - close(gotReq) - })) - defer cst.close() - - pr, pw := io.Pipe() - req, err := http.NewRequest("POST", cst.ts.URL, pr) - if err != nil { - t.Fatal(err) - } - gotRes := make(chan struct{}) - go func() { - defer close(gotRes) - res, err := cst.tr.RoundTrip(req) - if err != nil { - t.Error(err) - return - } - res.Body.Close() - }() - - select { - case <-gotReq: - pw.Close() - case <-time.After(5 * time.Second): - t.Fatal("timeout waiting for handler to get request") - } - <-gotRes -} - -// Issue 11745. -func TestTransportPrefersResponseOverWriteError(t *testing.T) { - if testing.Short() { - t.Skip("skipping in short mode") - } - defer afterTest(t) - const contentLengthLimit = 1024 * 1024 // 1MB - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.ContentLength >= contentLengthLimit { - w.WriteHeader(http.StatusBadRequest) - r.Body.Close() - return - } - w.WriteHeader(http.StatusOK) - })) - defer ts.Close() - c := tc().httpClient - - fail := 0 - count := 100 - bigBody := strings.Repeat("a", contentLengthLimit*2) - for i := 0; i < count; i++ { - req, err := http.NewRequest("PUT", ts.URL, strings.NewReader(bigBody)) - if err != nil { - t.Fatal(err) - } - resp, err := c.Do(req) - if err != nil { - fail++ - t.Logf("%d = %#v", i, err) - if ue, ok := err.(*url.Error); ok { - t.Logf("urlErr = %#v", ue.Err) - if ne, ok := ue.Err.(*net.OpError); ok { - t.Logf("netOpError = %#v", ne.Err) - } - } - } else { - resp.Body.Close() - if resp.StatusCode != 400 { - t.Errorf("Expected status code 400, got %v", resp.Status) - } - } - } - if fail > 0 { - t.Errorf("Failed %v out of %v\n", fail, count) - } -} - -// Issue 13633: there was a race where we returned bodyless responses -// to callers before recycling the persistent connection, which meant -// a client doing two subsequent requests could end up on different -// connections. It's somewhat harmless but enough tests assume it's -// not true in order to test other things that it's worth fixing. -// Plus it's nice to be consistent and not have timing-dependent -// behavior. -func TestTransportReuseConnEmptyResponseBody(t *testing.T) { - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Addr", r.RemoteAddr) - // Empty response body. - })) - defer cst.close() - n := 100 - if testing.Short() { - n = 10 - } - var firstAddr string - for i := 0; i < n; i++ { - res, err := cst.c.Get(cst.ts.URL) - if err != nil { - log.Fatal(err) - } - addr := res.Header.Get("X-Addr") - if i == 0 { - firstAddr = addr - } else if addr != firstAddr { - t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr) - } - res.Body.Close() - } -} - -func TestTransportReuseConnectionGzipChunked(t *testing.T) { - testTransportReuseConnectionGzip(t, true) -} - -func TestTransportReuseConnectionGzipContentLength(t *testing.T) { - testTransportReuseConnectionGzip(t, false) -} - -// Make sure we re-use underlying TCP connection for gzipped responses too. -func testTransportReuseConnectionGzip(t *testing.T, chunked bool) { - setParallel(t) - defer afterTest(t) - addr := make(chan string, 2) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - addr <- r.RemoteAddr - w.Header().Set("Content-Encoding", "gzip") - if chunked { - w.(http.Flusher).Flush() - } - w.Write(rgz) // arbitrary gzip response - })) - defer ts.Close() - c := tc().httpClient - - for i := 0; i < 2; i++ { - res, err := c.Get(ts.URL) - if err != nil { - t.Fatal(err) - } - buf := make([]byte, len(rgz)) - if n, err := io.ReadFull(res.Body, buf); err != nil { - t.Errorf("%d. ReadFull = %v, %v", i, n, err) - } - // Note: no res.Body.Close call. It should work without it, - // since the flate.Reader's internal buffering will hit EOF - // and that should be sufficient. - } - a1, a2 := <-addr, <-addr - if a1 != a2 { - t.Fatalf("didn't reuse connection") - } -} - -func TestTransportResponseHeaderLength(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/long" { - w.Header().Set("Long", strings.Repeat("a", 1<<20)) - } - })) - defer ts.Close() - c := tc().httpClient - c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10 - - if res, err := c.Get(ts.URL); err != nil { - t.Fatal(err) - } else { - res.Body.Close() - } - - res, err := c.Get(ts.URL + "/long") - if err == nil { - defer res.Body.Close() - var n int64 - for k, vv := range res.Header { - for _, v := range vv { - n += int64(len(k)) + int64(len(v)) - } - } - t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n) - } - if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) { - t.Errorf("got error: %v; want %q", err, want) - } -} - -type lookupIPAltResolverKey struct{} - -func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { - defer afterTest(t) - const resBody = "some body" - gotWroteReqEvent := make(chan struct{}, 500) - cst := newClientServerTest(t, h2, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == "GET" { - // Do nothing for the second request. - return - } - if _, err := io.ReadAll(r.Body); err != nil { - t.Error(err) - } - if !noHooks { - select { - case <-gotWroteReqEvent: - case <-time.After(5 * time.Second): - t.Error("timeout waiting for WroteRequest event") - } - } - io.WriteString(w, resBody) - })) - defer cst.close() - - cst.tr.ExpectContinueTimeout = 1 * time.Second - - var mu sync.Mutex // guards buf - var buf bytes.Buffer - logf := func(format string, args ...interface{}) { - mu.Lock() - defer mu.Unlock() - fmt.Fprintf(&buf, format, args...) - buf.WriteByte('\n') - } - - addrStr := cst.ts.Listener.Addr().String() - ip, port, err := net.SplitHostPort(addrStr) - if err != nil { - t.Fatal(err) - } - - // Install a fake DNS server. - ctx := context.WithValue(context.Background(), lookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) { - if host != "dns-is-faked.golang" { - t.Errorf("unexpected DNS host lookup for %q/%q", network, host) - return nil, nil - } - return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil - }) - - body := "some body" - req, _ := http.NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body)) - req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"} - trace := &httptrace.ClientTrace{ - GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) }, - GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) }, - GotFirstResponseByte: func() { logf("first response byte") }, - PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) }, - DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) }, - DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) }, - ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) }, - ConnectDone: func(network, addr string, err error) { - if err != nil { - t.Errorf("ConnectDone: %v", err) - } - logf("ConnectDone: connected to %s %s = %v", network, addr, err) - }, - WroteHeaderField: func(key string, value []string) { - logf("WroteHeaderField: %s: %v", key, value) - }, - WroteHeaders: func() { - logf("WroteHeaders") - }, - Wait100Continue: func() { logf("Wait100Continue") }, - Got100Continue: func() { logf("Got100Continue") }, - WroteRequest: func(e httptrace.WroteRequestInfo) { - logf("WroteRequest: %+v", e) - gotWroteReqEvent <- struct{}{} - }, - } - if h2 { - trace.TLSHandshakeStart = func() { logf("tls handshake start") } - trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) { - logf("tls handshake done. ConnectionState = %v \n err = %v", s, err) - } - } - if noHooks { - // zero out all func pointers, trying to get some path to crash - *trace = httptrace.ClientTrace{} - } - req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) - - req.Header.Set("Expect", "100-continue") - res, err := cst.c.Do(req) - if err != nil { - t.Fatal(err) - } - logf("got roundtrip.response") - slurp, err := io.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - logf("consumed body") - if string(slurp) != resBody || res.StatusCode != 200 { - t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody) - } - res.Body.Close() - - if noHooks { - // Done at this point. Just testing a full HTTP - // requests can happen with a trace pointing to a zero - // ClientTrace, full of nil func pointers. - return - } - - mu.Lock() - got := buf.String() - mu.Unlock() - - wantOnce := func(sub string) { - if strings.Count(got, sub) != 1 { - t.Errorf("expected substring %q exactly once in output.", sub) - } - } - wantOnceOrMore := func(sub string) { - if strings.Count(got, sub) == 0 { - t.Errorf("expected substring %q at least once in output.", sub) - } - } - wantOnce("Getting conn for dns-is-faked.golang:" + port) - wantOnce("DNS start: {Host:dns-is-faked.golang}") - wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err: Coalesced:false}") - wantOnce("got conn: {") - wantOnceOrMore("Connecting to tcp " + addrStr) - wantOnceOrMore("connected to tcp " + addrStr + " = ") - wantOnce("Reused:false WasIdle:false IdleTime:0s") - wantOnce("first response byte") - if h2 { - wantOnce("tls handshake start") - wantOnce("tls handshake done") - } else { - wantOnce("PutIdleConn = ") - wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]") - // TODO(meirf): issue 19761. Make these agnostic to h1/h2. (These are not h1 specific, but the - // WroteHeaderField hook is not yet implemented in h2.) - wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port)) - wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body))) - wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]") - wantOnce("WroteHeaderField: Accept-Encoding: [gzip]") - } - wantOnce("WroteHeaders") - wantOnce("Wait100Continue") - wantOnce("Got100Continue") - wantOnce("WroteRequest: {Err:}") - if strings.Contains(got, " to udp ") { - t.Errorf("should not see UDP (DNS) connections") - } - if t.Failed() { - t.Errorf("Output:\n%s", got) - } - - // And do a second request: - req, _ = http.NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil) - req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) - res, err = cst.c.Do(req) - if err != nil { - t.Fatal(err) - } - if res.StatusCode != 200 { - t.Fatal(res.Status) - } - res.Body.Close() - - mu.Lock() - got = buf.String() - mu.Unlock() - - sub := "Getting conn for dns-is-faked.golang:" - if gotn, want := strings.Count(got, sub), 2; gotn != want { - t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got) - } - -} - -func TestTransportEventTraceTLSVerify(t *testing.T) { - var mu sync.Mutex - var buf bytes.Buffer - logf := func(format string, args ...interface{}) { - mu.Lock() - defer mu.Unlock() - fmt.Fprintf(&buf, format, args...) - buf.WriteByte('\n') - } - - ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - t.Error("Unexpected request") - })) - defer ts.Close() - ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) { - logf("%s", p) - return len(p), nil - }), "", 0) - - certpool := x509.NewCertPool() - certpool.AddCert(ts.Certificate()) - - tr := T().SetTLSClientConfig(&tls.Config{ - ServerName: "dns-is-faked.golang", - RootCAs: certpool, - }) - c := &http.Client{Transport: tr} - - trace := &httptrace.ClientTrace{ - TLSHandshakeStart: func() { logf("TLSHandshakeStart") }, - TLSHandshakeDone: func(s tls.ConnectionState, err error) { - logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err) - }, - } - - req, _ := http.NewRequest("GET", ts.URL, nil) - req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace)) - _, err := c.Do(req) - if err == nil { - t.Error("Expected request to fail TLS verification") - } - - mu.Lock() - got := buf.String() - mu.Unlock() - - wantOnce := func(sub string) { - if strings.Count(got, sub) != 1 { - t.Errorf("expected substring %q exactly once in output.", sub) - } - } - - wantOnce("TLSHandshakeStart") - wantOnce("TLSHandshakeDone") - wantOnce("x509: certificate is valid for example.com") - - if t.Failed() { - t.Errorf("Output:\n%s", got) - } -} - -var ( - isDNSHijackedOnce sync.Once - isDNSHijacked bool -) - -func skipIfDNSHijacked(t *testing.T) { - // Skip this test if the user is using a shady/ISP - // DNS server hijacking queries. - // See issues 16732, 16716. - isDNSHijackedOnce.Do(func() { - addrs, _ := net.LookupHost("dns-should-not-resolve.golang") - isDNSHijacked = len(addrs) != 0 - }) - if isDNSHijacked { - t.Skip("skipping; test requires non-hijacking DNS server") - } -} - -func TestTransportEventTraceRealDNS(t *testing.T) { - skipIfDNSHijacked(t) - defer afterTest(t) - tr := T() - defer tr.CloseIdleConnections() - c := &http.Client{Transport: tr} - - var mu sync.Mutex // guards buf - var buf bytes.Buffer - logf := func(format string, args ...interface{}) { - mu.Lock() - defer mu.Unlock() - fmt.Fprintf(&buf, format, args...) - buf.WriteByte('\n') - } - - req, _ := http.NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil) - trace := &httptrace.ClientTrace{ - DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) }, - DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) }, - ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) }, - ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) }, - } - req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace)) - - resp, err := c.Do(req) - if err == nil { - resp.Body.Close() - t.Fatal("expected error during DNS lookup") - } - - mu.Lock() - got := buf.String() - mu.Unlock() - - wantSub := func(sub string) { - if !strings.Contains(got, sub) { - t.Errorf("expected substring %q in output.", sub) - } - } - wantSub("DNSStart: {Host:dns-should-not-resolve.golang}") - wantSub("DNSDone: {Addrs:[] Err:") - if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") { - t.Errorf("should not see Connect events") - } - if t.Failed() { - t.Errorf("Output:\n%s", got) - } -} - -// Issue 14353: port can only contain digits. -func TestTransportRejectsAlphaPort(t *testing.T) { - res, err := http.Get("http://dummy.tld:123foo/bar") - if err == nil { - res.Body.Close() - t.Fatal("unexpected success") - } - ue, ok := err.(*url.Error) - if !ok { - t.Fatalf("got %#v; want *url.Error", err) - } - got := ue.Err.Error() - want := `invalid port ":123foo" after host` - if got != want { - t.Errorf("got error %q; want %q", got, want) - } -} - -// Test the httptrace.TLSHandshake{Start,Done} hooks with a https http1 -// connections. The http2 test is done in TestTransportEventTrace_h2 -func TestTLSHandshakeTrace(t *testing.T) { - defer afterTest(t) - ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer ts.Close() - - var mu sync.Mutex - var start, done bool - trace := &httptrace.ClientTrace{ - TLSHandshakeStart: func() { - mu.Lock() - defer mu.Unlock() - start = true - }, - TLSHandshakeDone: func(s tls.ConnectionState, err error) { - mu.Lock() - defer mu.Unlock() - done = true - if err != nil { - t.Fatal("Expected error to be nil but was:", err) - } - }, - } - - c := tc().httpClient - req, err := http.NewRequest("GET", ts.URL, nil) - if err != nil { - t.Fatal("Unable to construct test request:", err) - } - req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) - - r, err := c.Do(req) - if err != nil { - t.Fatal("Unexpected error making request:", err) - } - r.Body.Close() - mu.Lock() - defer mu.Unlock() - if !start { - t.Fatal("Expected TLSHandshakeStart to be called, but wasn't") - } - if !done { - t.Fatal("Expected TLSHandshakeDone to be called, but wasnt't") - } -} - -// Issue 16208: Go 1.7 crashed after Transport.IdleConnTimeout if an -// HTTP/2 connection was established but its caller no longer -// wanted it. (Assuming the connection cache was enabled, which it is -// by default) -// -// This test reproduced the crash by setting the IdleConnTimeout low -// (to make the test reasonable) and then making a request which is -// canceled by the DialTLS hook, which then also waits to return the -// real connection until after the RoundTrip saw the error. Then we -// know the successful tls.Dial from DialTLS will need to go into the -// idle pool. Then we give it a of time to explode. -func TestIdleConnH2Crash(t *testing.T) { - setParallel(t) - cst := newClientServerTest(t, h2Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // nothing - })) - defer cst.close() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - sawDoErr := make(chan bool, 1) - testDone := make(chan struct{}) - defer close(testDone) - - cst.tr.IdleConnTimeout = 5 * time.Millisecond - cst.tr.DialTLSContext = func(_ context.Context, network, addr string) (net.Conn, error) { - c, err := tls.Dial(network, addr, &tls.Config{ - InsecureSkipVerify: true, - NextProtos: []string{"h2"}, - }) - if err != nil { - t.Error(err) - return nil, err - } - if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" { - t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2") - c.Close() - return nil, errors.New("bogus") - } - - cancel() - - failTimer := time.NewTimer(5 * time.Second) - defer failTimer.Stop() - select { - case <-sawDoErr: - case <-testDone: - case <-failTimer.C: - t.Error("timeout in DialTLS, waiting too long for cst.c.Do to fail") - } - return c, nil - } - - req, _ := http.NewRequest("GET", cst.ts.URL, nil) - req = req.WithContext(ctx) - res, err := cst.c.Do(req) - if err == nil { - res.Body.Close() - t.Fatal("unexpected success") - } - sawDoErr <- true - - // Wait for the explosion. - time.Sleep(cst.tr.IdleConnTimeout * 10) -} - -type funcConn struct { - net.Conn - read func([]byte) (int, error) - write func([]byte) (int, error) -} - -func (c funcConn) Read(p []byte) (int, error) { return c.read(p) } -func (c funcConn) Write(p []byte) (int, error) { return c.write(p) } -func (c funcConn) Close() error { return nil } - -// Issue 16465: Transport.RoundTrip should return the raw net.Conn.Read error from Peek -// back to the caller. -func TestTransportReturnsPeekError(t *testing.T) { - errValue := errors.New("specific error value") - - wrote := make(chan struct{}) - var wroteOnce sync.Once - - tr := T().SetDial(func(_ context.Context, network, addr string) (net.Conn, error) { - c := funcConn{ - read: func([]byte) (int, error) { - <-wrote - return 0, errValue - }, - write: func(p []byte) (int, error) { - wroteOnce.Do(func() { close(wrote) }) - return len(p), nil - }, - } - return c, nil - }) - - _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil)) - if err != errValue { - t.Errorf("error = %#v; want %v", err, errValue) - } -} - -// Issue 13290: send User-Agent in proxy CONNECT -func TestTransportProxyConnectHeader(t *testing.T) { - defer afterTest(t) - reqc := make(chan *http.Request, 1) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != "CONNECT" { - t.Errorf("method = %q; want CONNECT", r.Method) - } - reqc <- r - c, _, err := w.(http.Hijacker).Hijack() - if err != nil { - t.Errorf("Hijack: %v", err) - return - } - c.Close() - })) - defer ts.Close() - - c := tc().httpClient - c.Transport.(*Transport).Proxy = func(r *http.Request) (*url.URL, error) { - return url.Parse(ts.URL) - } - c.Transport.(*Transport).ProxyConnectHeader = http.Header{ - "User-Agent": {"foo"}, - "Other": {"bar"}, - } - - res, err := c.Get("https://dummy.tld/") // https to force a CONNECT - if err == nil { - res.Body.Close() - t.Errorf("unexpected success") - } - select { - case <-time.After(3 * time.Second): - t.Fatal("timeout") - case r := <-reqc: - if got, want := r.Header.Get("User-Agent"), "foo"; got != want { - t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) - } - if got, want := r.Header.Get("Other"), "bar"; got != want { - t.Errorf("CONNECT request Other = %q; want %q", got, want) - } - } -} - -func TestTransportProxyGetConnectHeader(t *testing.T) { - defer afterTest(t) - reqc := make(chan *http.Request, 1) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != "CONNECT" { - t.Errorf("method = %q; want CONNECT", r.Method) - } - reqc <- r - c, _, err := w.(http.Hijacker).Hijack() - if err != nil { - t.Errorf("Hijack: %v", err) - return - } - c.Close() - })) - defer ts.Close() - - c := tc().httpClient - c.Transport.(*Transport).Proxy = func(r *http.Request) (*url.URL, error) { - return url.Parse(ts.URL) - } - // These should be ignored: - c.Transport.(*Transport).ProxyConnectHeader = http.Header{ - "User-Agent": {"foo"}, - "Other": {"bar"}, - } - c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (http.Header, error) { - return http.Header{ - "User-Agent": {"foo2"}, - "Other": {"bar2"}, - }, nil - } - - res, err := c.Get("https://dummy.tld/") // https to force a CONNECT - if err == nil { - res.Body.Close() - t.Errorf("unexpected success") - } - select { - case <-time.After(3 * time.Second): - t.Fatal("timeout") - case r := <-reqc: - if got, want := r.Header.Get("User-Agent"), "foo2"; got != want { - t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) - } - if got, want := r.Header.Get("Other"), "bar2"; got != want { - t.Errorf("CONNECT request Other = %q; want %q", got, want) - } - } -} - -var errFakeRoundTrip = errors.New("fake roundtrip") - -type funcRoundTripper func() - -func (fn funcRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { - fn() - return nil, errFakeRoundTrip -} - -func wantBody(res *http.Response, err error, want string) error { - if err != nil { - return err - } - slurp, err := io.ReadAll(res.Body) - if err != nil { - return fmt.Errorf("error reading body: %v", err) - } - if string(slurp) != want { - return fmt.Errorf("body = %q; want %q", slurp, want) - } - if err := res.Body.Close(); err != nil { - return fmt.Errorf("body Close = %v", err) - } - return nil -} - -type countCloseReader struct { - n *int - io.Reader -} - -func (cr countCloseReader) Close() error { - (*cr.n)++ - return nil -} - -// rgz is a gzip quine that uncompresses to itself. -var rgz = []byte{ - 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73, - 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0, - 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, - 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, - 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60, - 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2, - 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00, - 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, - 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16, - 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05, - 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff, - 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00, - 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, - 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, - 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, - 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, - 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, - 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, - 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, - 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, - 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff, - 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00, - 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, - 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, - 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, - 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, - 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06, - 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00, - 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, - 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, - 0x00, 0x00, -} - -// Ensure that a missing status doesn't make the server panic -// See Issue https://golang.org/issues/21701 -func TestMissingStatusNoPanic(t *testing.T) { - t.Parallel() - - const want = "unknown status code" - - ln := tests.NewLocalListener(t) - addr := ln.Addr().String() - done := make(chan bool) - fullAddrURL := fmt.Sprintf("http://%s", addr) - raw := "HTTP/1.1 400\r\n" + - "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + - "Content-Type: text/html; charset=utf-8\r\n" + - "Content-Length: 10\r\n" + - "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" + - "Vary: Accept-Encoding\r\n\r\n" + - "Aloha Olaa" - - go func() { - defer close(done) - - conn, _ := ln.Accept() - if conn != nil { - io.WriteString(conn, raw) - io.ReadAll(conn) - conn.Close() - } - }() - - proxyURL, err := url.Parse(fullAddrURL) - if err != nil { - t.Fatalf("proxyURL: %v", err) - } - - tr := T().SetProxy(http.ProxyURL(proxyURL)) - - req, _ := http.NewRequest("GET", "https://golang.org/", nil) - res, err, panicked := doFetchCheckPanic(tr, req) - if panicked { - t.Error("panicked, expecting an error") - } - if res != nil && res.Body != nil { - io.Copy(io.Discard, res.Body) - res.Body.Close() - } - - if err == nil || !strings.Contains(err.Error(), want) { - t.Errorf("got=%v want=%q", err, want) - } - - ln.Close() - <-done -} - -func doFetchCheckPanic(tr *Transport, req *http.Request) (res *http.Response, err error, panicked bool) { - defer func() { - if r := recover(); r != nil { - panicked = true - } - }() - res, err = tr.RoundTrip(req) - return -} - -// Issue 22330: do not allow the response body to be read when the status code -// forbids a response body. -func TestNoBodyOnChunked304Response(t *testing.T) { - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, buf, _ := w.(http.Hijacker).Hijack() - buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n")) - buf.Flush() - conn.Close() - })) - defer cst.close() - - // Our test server above is sending back bogus data after the - // response (the "0\r\n\r\n" part), which causes the Transport - // code to log spam. Disable keep-alives so we never even try - // to reuse the connection. - cst.tr.DisableKeepAlives = true - - res, err := cst.c.Get(cst.ts.URL) - if err != nil { - t.Fatal(err) - } - - if res.Body != NoBody { - t.Errorf("Unexpected body on 304 response") - } -} - -type funcWriter func([]byte) (int, error) - -func (f funcWriter) Write(p []byte) (int, error) { return f(p) } - -type doneContext struct { - context.Context - err error -} - -func (doneContext) Done() <-chan struct{} { - c := make(chan struct{}) - close(c) - return c -} - -func (d doneContext) Err() error { return d.err } - -// Issue 25852: Transport should check whether Context is done early. -func TestTransportCheckContextDoneEarly(t *testing.T) { - tr := T() - req, _ := http.NewRequest("GET", "http://fake.example/", nil) - wantErr := errors.New("some error") - req = req.WithContext(doneContext{context.Background(), wantErr}) - _, err := tr.RoundTrip(req) - if err != wantErr { - t.Errorf("error = %v; want %v", err, wantErr) - } -} - -// Issue 23399: verify that if a client request times out, the Transport's -// conn is closed so that it's not reused. -// -// This is the test variant that times out before the server replies with -// any response headers. -func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) { - setParallel(t) - defer afterTest(t) - inHandler := make(chan net.Conn, 1) - handlerReadReturned := make(chan bool, 1) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, _, err := w.(http.Hijacker).Hijack() - if err != nil { - t.Error(err) - return - } - inHandler <- conn - n, err := conn.Read([]byte{0}) - if n != 0 || err != io.EOF { - t.Errorf("unexpected Read result: %v, %v", n, err) - } - handlerReadReturned <- true - })) - defer cst.close() - - const timeout = 50 * time.Millisecond - cst.c.Timeout = timeout - - _, err := cst.c.Get(cst.ts.URL) - if err == nil { - t.Fatal("unexpected Get succeess") - } - - select { - case c := <-inHandler: - select { - case <-handlerReadReturned: - // Success. - return - case <-time.After(5 * time.Second): - t.Error("Handler's conn.Read seems to be stuck in Read") - c.Close() // close it to unblock Handler - } - case <-time.After(timeout * 10): - // If we didn't get into the Handler in 50ms, that probably means - // the builder was just slow and the Get failed in that time - // but never made it to the server. That's fine. We'll usually - // test the part above on faster machines. - t.Skip("skipping test on slow builder") - } -} - -// Issue 23399: verify that if a client request times out, the Transport's -// conn is closed so that it's not reused. -// -// This is the test variant that has the server send response headers -// first, and time out during the write of the response body. -func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { - setParallel(t) - defer afterTest(t) - inHandler := make(chan net.Conn, 1) - handlerResult := make(chan error, 1) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Length", "100") - w.(http.Flusher).Flush() - conn, _, err := w.(http.Hijacker).Hijack() - if err != nil { - t.Error(err) - return - } - conn.Write([]byte("foo")) - inHandler <- conn - n, err := conn.Read([]byte{0}) - // The error should be io.EOF or "read tcp - // 127.0.0.1:35827->127.0.0.1:40290: read: connection - // reset by peer" depending on timing. Really we just - // care that it returns at all. But if it returns with - // data, that's weird. - if n != 0 || err == nil { - handlerResult <- fmt.Errorf("unexpected Read result: %v, %v", n, err) - return - } - handlerResult <- nil - })) - defer cst.close() - - // Set Timeout to something very long but non-zero to exercise - // the codepaths that check for it. But rather than wait for it to fire - // (which would make the test slow), we send on the req.Cancel channel instead, - // which happens to exercise the same code paths. - cst.c.Timeout = time.Minute // just to be non-zero, not to hit it. - req, _ := http.NewRequest("GET", cst.ts.URL, nil) - cancel := make(chan struct{}) - req.Cancel = cancel - - res, err := cst.c.Do(req) - if err != nil { - select { - case <-inHandler: - t.Fatalf("Get error: %v", err) - default: - // Failed before entering handler. Ignore result. - t.Skip("skipping test on slow builder") - } - } - - close(cancel) - got, err := io.ReadAll(res.Body) - if err == nil { - t.Fatalf("unexpected success; read %q, nil", got) - } - - select { - case c := <-inHandler: - select { - case err := <-handlerResult: - if err != nil { - t.Errorf("handler: %v", err) - } - return - case <-time.After(5 * time.Second): - t.Error("Handler's conn.Read seems to be stuck in Read") - c.Close() // close it to unblock Handler - } - case <-time.After(5 * time.Second): - t.Fatal("timeout") - } -} - -func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { - setParallel(t) - defer afterTest(t) - done := make(chan struct{}) - defer close(done) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, _, err := w.(http.Hijacker).Hijack() - if err != nil { - t.Error(err) - return - } - defer conn.Close() - io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n") - bs := bufio.NewScanner(conn) - bs.Scan() - fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text())) - <-done - })) - defer cst.close() - - req, _ := http.NewRequest("GET", cst.ts.URL, nil) - req.Header.Set("Upgrade", "foo") - req.Header.Set("Connection", "upgrade") - res, err := cst.c.Do(req) - if err != nil { - t.Fatal(err) - } - if res.StatusCode != 101 { - t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header) - } - rwc, ok := res.Body.(io.ReadWriteCloser) - if !ok { - t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body) - } - defer rwc.Close() - bs := bufio.NewScanner(rwc) - if !bs.Scan() { - t.Fatalf("expected readable input") - } - if got, want := bs.Text(), "Some buffered data"; got != want { - t.Errorf("read %q; want %q", got, want) - } - io.WriteString(rwc, "echo\n") - if !bs.Scan() { - t.Fatalf("expected another line") - } - if got, want := bs.Text(), "ECHO"; got != want { - t.Errorf("read %q; want %q", got, want) - } -} - -func TestTransportRequestReplayable(t *testing.T) { - someBody := io.NopCloser(strings.NewReader("")) - tests := []struct { - name string - req *http.Request - want bool - }{ - { - name: "GET", - req: &http.Request{Method: "GET"}, - want: true, - }, - { - name: "GET_http.NoBody", - req: &http.Request{Method: "GET", Body: NoBody}, - want: true, - }, - { - name: "GET_body", - req: &http.Request{Method: "GET", Body: someBody}, - want: false, - }, - { - name: "POST", - req: &http.Request{Method: "POST"}, - want: false, - }, - { - name: "POST_idempotency-key", - req: &http.Request{Method: "POST", Header: http.Header{"Idempotency-Key": {"x"}}}, - want: true, - }, - { - name: "POST_x-idempotency-key", - req: &http.Request{Method: "POST", Header: http.Header{"X-Idempotency-Key": {"x"}}}, - want: true, - }, - { - name: "POST_body", - req: &http.Request{Method: "POST", Header: http.Header{"Idempotency-Key": {"x"}}, Body: someBody}, - want: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := isReplayable(tt.req) - if got != tt.want { - t.Errorf("replyable = %v; want %v", got, tt.want) - } - }) - } -} - -// testMockTCPConn is a mock TCP connection used to test that -// ReadFrom is called when sending the request body. -type testMockTCPConn struct { - *net.TCPConn - - ReadFromCalled bool -} - -func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) { - c.ReadFromCalled = true - return c.TCPConn.ReadFrom(r) -} - -func TestTransportRequestWriteRoundTrip(t *testing.T) { - nBytes := int64(1 << 10) - newFileFunc := func() (r io.Reader, done func(), err error) { - f, err := os.CreateTemp("", "net-http-newfilefunc") - if err != nil { - return nil, nil, err - } - - // Write some bytes to the file to enable reading. - if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil { - return nil, nil, fmt.Errorf("failed to write data to file: %v", err) - } - if _, err := f.Seek(0, 0); err != nil { - return nil, nil, fmt.Errorf("failed to seek to front: %v", err) - } - - done = func() { - f.Close() - os.Remove(f.Name()) - } - - return f, done, nil - } - - newBufferFunc := func() (io.Reader, func(), error) { - return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil - } - - cases := []struct { - name string - readerFunc func() (io.Reader, func(), error) - contentLength int64 - expectedReadFrom bool - }{ - { - name: "file, length", - readerFunc: newFileFunc, - contentLength: nBytes, - expectedReadFrom: true, - }, - { - name: "file, no length", - readerFunc: newFileFunc, - }, - { - name: "file, negative length", - readerFunc: newFileFunc, - contentLength: -1, - }, - { - name: "buffer", - contentLength: nBytes, - readerFunc: newBufferFunc, - }, - { - name: "buffer, no length", - readerFunc: newBufferFunc, - }, - { - name: "buffer, length -1", - contentLength: -1, - readerFunc: newBufferFunc, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - r, cleanup, err := tc.readerFunc() - if err != nil { - t.Fatal(err) - } - defer cleanup() - - tConn := &testMockTCPConn{} - trFunc := func(tr *Transport) { - tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - var d net.Dialer - conn, err := d.DialContext(ctx, network, addr) - if err != nil { - return nil, err - } - - tcpConn, ok := conn.(*net.TCPConn) - if !ok { - return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr) - } - - tConn.TCPConn = tcpConn - return tConn, nil - } - } - - cst := newClientServerTest( - t, - h1Mode, - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - io.Copy(io.Discard, r.Body) - r.Body.Close() - w.WriteHeader(200) - }), - trFunc, - ) - defer cst.close() - - req, err := http.NewRequest("PUT", cst.ts.URL, r) - if err != nil { - t.Fatal(err) - } - req.ContentLength = tc.contentLength - req.Header.Set("Content-Type", "application/octet-stream") - resp, err := cst.c.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.StatusCode != 200 { - t.Fatalf("status code = %d; want 200", resp.StatusCode) - } - - if !tConn.ReadFromCalled && tc.expectedReadFrom { - t.Fatalf("did not call ReadFrom") - } - - if tConn.ReadFromCalled && !tc.expectedReadFrom { - t.Fatalf("ReadFrom was unexpectedly invoked") - } - }) - } -} - -func TestTransportClone(t *testing.T) { - tr := &Transport{ - Headers: http.Header{ - "test-key": []string{"test-value"}, - }, - Cookies: []*http.Cookie{ - { - Name: "test", - Value: "test", - }, - }, - forceHttpVersion: h1, - Options: transport.Options{ - Proxy: func(*http.Request) (*url.URL, error) { panic("") }, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, - DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, - TLSClientConfig: new(tls.Config), - TLSHandshakeTimeout: time.Second, - DisableKeepAlives: true, - DisableCompression: true, - MaxIdleConns: 1, - MaxIdleConnsPerHost: 1, - MaxConnsPerHost: 1, - IdleConnTimeout: time.Second, - ResponseHeaderTimeout: time.Second, - ExpectContinueTimeout: time.Second, - ProxyConnectHeader: http.Header{}, - GetProxyConnectHeader: func(context.Context, *url.URL, string) (http.Header, error) { return nil, nil }, - MaxResponseHeaderBytes: 1, - ReadBufferSize: 1, - WriteBufferSize: 1, - Debugf: func(format string, v ...interface{}) {}, - }, - } - tr2 := tr.Clone() - rv := reflect.ValueOf(tr2).Elem() - rt := rv.Type() - for i := 0; i < rt.NumField(); i++ { - sf := rt.Field(i) - if !token.IsExported(sf.Name) { - continue - } - if rv.Field(i).IsZero() { - t.Errorf("cloned field t2.%s is zero", sf.Name) - } - } - - // But test that a nil TLSNextProto is kept nil: - tr = new(Transport) - tr2 = tr.Clone() -} - -func TestIs408(t *testing.T) { - tests := []struct { - in string - want bool - }{ - {"HTTP/1.0 408", true}, - {"HTTP/1.1 408", true}, - {"HTTP/1.8 408", true}, - {"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now. - {"HTTP/1.1 408 ", true}, - {"HTTP/1.1 40", false}, - {"http/1.0 408", false}, - {"HTTP/1-1 408", false}, - } - for _, tt := range tests { - if got := is408Message([]byte(tt.in)); got != tt.want { - t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want) - } - } -} - -func TestTransportIgnores408(t *testing.T) { - // Not parallel. Relies on mutating the log package's global Output. - defer log.SetOutput(log.Writer()) - - var logout bytes.Buffer - log.SetOutput(&logout) - - defer afterTest(t) - const target = "backend:443" - - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - nc, _, err := w.(http.Hijacker).Hijack() - if err != nil { - t.Error(err) - return - } - defer nc.Close() - nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok")) - nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail - })) - defer cst.close() - req, err := http.NewRequest("GET", cst.ts.URL, nil) - if err != nil { - t.Fatal(err) - } - res, err := cst.c.Do(req) - if err != nil { - t.Fatal(err) - } - slurp, err := io.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - if err != nil { - t.Fatal(err) - } - if string(slurp) != "ok" { - t.Fatalf("got %q; want ok", slurp) - } - - t0 := time.Now() - for i := 0; i < 50; i++ { - time.Sleep(time.Duration(i) * 5 * time.Millisecond) - if cst.tr.IdleConnKeyCountForTesting() == 0 { - if got := logout.String(); got != "" { - t.Fatalf("expected no log output; got: %s", got) - } - return - } - } - t.Fatalf("timeout after %v waiting for Transport connections to die off", time.Since(t0)) -} - -func TestInvalidHeaderResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, buf, _ := w.(http.Hijacker).Hijack() - buf.Write([]byte("HTTP/1.1 200 OK\r\n" + - "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + - "Content-Type: text/html; charset=utf-8\r\n" + - "Content-Length: 0\r\n" + - "Foo : bar\r\n\r\n")) - buf.Flush() - conn.Close() - })) - defer cst.close() - res, err := cst.c.Get(cst.ts.URL) - if err != nil { - t.Fatal(err) - } - defer res.Body.Close() - if v := res.Header.Get("Foo"); v != "" { - t.Errorf(`unexpected "Foo" header: %q`, v) - } - if v := res.Header.Get("Foo "); v != "bar" { - t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar") - } -} - -type bodyCloser bool - -func (bc *bodyCloser) Close() error { - *bc = true - return nil -} -func (bc *bodyCloser) Read(b []byte) (n int, err error) { - return 0, io.EOF -} - -// Issue 35015: ensure that Transport closes the body on any error -// with an invalid request, as promised by Client.Do docs. -func TestTransportClosesBodyOnInvalidRequests(t *testing.T) { - cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - t.Errorf("Should not have been invoked") - })) - defer cst.Close() - - u, _ := url.Parse(cst.URL) - - tests := []struct { - name string - req *http.Request - wantErr string - }{ - { - name: "invalid method", - req: &http.Request{ - Method: " ", - URL: u, - }, - wantErr: "invalid method", - }, - { - name: "nil URL", - req: &http.Request{ - Method: "GET", - }, - wantErr: "nil Request.URL", - }, - { - name: "invalid header key", - req: &http.Request{ - Method: "GET", - Header: http.Header{"💡": {"emoji"}}, - URL: u, - }, - wantErr: "invalid header field name", - }, - { - name: "invalid header value", - req: &http.Request{ - Method: "POST", - Header: http.Header{"key": {"\x19"}}, - URL: u, - }, - wantErr: "invalid header field value", - }, - { - name: "non HTTP(s) scheme", - req: &http.Request{ - Method: "POST", - URL: &url.URL{Scheme: "faux"}, - }, - wantErr: "unsupported protocol scheme", - }, - { - name: "no Host in URL", - req: &http.Request{ - Method: "POST", - URL: &url.URL{Scheme: "http"}, - }, - wantErr: "no Host", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var bc bodyCloser - req := tt.req - req.Body = &bc - _, err := DefaultClient().httpClient.Do(tt.req) - if err == nil { - t.Fatal("Expected an error") - } - if !bc { - t.Fatal("Expected body to have been closed") - } - if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) { - t.Fatalf("Error mismatch\n\t%q\ndoes not contain\n\t%q", g, w) - } - }) - } -} - -// breakableConn is a net.Conn wrapper with a Write method -// that will fail when its brokenState is true. -type breakableConn struct { - net.Conn - *brokenState -} - -type brokenState struct { - sync.Mutex - broken bool -} - -func (w *breakableConn) Write(b []byte) (n int, err error) { - w.Lock() - defer w.Unlock() - if w.broken { - return 0, errors.New("some write error") - } - return w.Conn.Write(b) -} - -// Issue 34978: don't cache a broken HTTP/2 connection -func TestDontCacheBrokenHTTP2Conn(t *testing.T) { - cst := newClientServerTest(t, h2Mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), optQuietLog) - defer cst.close() - - var brokenState brokenState - - const numReqs = 5 - var numDials, gotConns uint32 // atomic - - cst.tr.DialContext = func(_ context.Context, netw, addr string) (net.Conn, error) { - atomic.AddUint32(&numDials, 1) - c, err := net.Dial(netw, addr) - if err != nil { - t.Errorf("unexpected Dial error: %v", err) - return nil, err - } - return &breakableConn{c, &brokenState}, err - } - - for i := 1; i <= numReqs; i++ { - brokenState.Lock() - brokenState.broken = false - brokenState.Unlock() - - // doBreak controls whether we break the TCP connection after the TLS - // handshake (before the HTTP/2 handshake). We test a few failures - // in a row followed by a final success. - doBreak := i != numReqs - - ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ - GotConn: func(info httptrace.GotConnInfo) { - t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime) - atomic.AddUint32(&gotConns, 1) - }, - TLSHandshakeDone: func(cfg tls.ConnectionState, err error) { - brokenState.Lock() - defer brokenState.Unlock() - if doBreak { - brokenState.broken = true - } - }, - }) - req, err := http.NewRequestWithContext(ctx, "GET", cst.ts.URL, nil) - if err != nil { - t.Fatal(err) - } - _, err = cst.c.Do(req) - if doBreak != (err != nil) { - t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err) - } - } - if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want { - t.Errorf("GotConn calls = %v; want %v", got, want) - } - if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want { - t.Errorf("Dials = %v; want %v", got, want) - } -} - -// Issue 34941 -// When the client has too many concurrent requests on a single connection, -// http.http2noCachedConnError is reported on multiple requests. There should -// only be one decrement regardless of the number of failures. -func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) { - defer afterTest(t) - - h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := w.Write([]byte("foo")) - if err != nil { - t.Fatalf("Write: %v", err) - } - }) - - ts := httptest.NewUnstartedServer(h) - ts.EnableHTTP2 = true - ts.StartTLS() - defer ts.Close() - - c := tc().httpClient - tr := c.Transport.(*Transport) - tr.MaxConnsPerHost = 1 - - errCh := make(chan error, 300) - doReq := func() { - resp, err := c.Get(ts.URL) - if err != nil { - errCh <- fmt.Errorf("request failed: %v", err) - return - } - defer resp.Body.Close() - _, err = io.ReadAll(resp.Body) - if err != nil { - errCh <- fmt.Errorf("read body failed: %v", err) - } - } - - var wg sync.WaitGroup - for i := 0; i < 300; i++ { - wg.Add(1) - go func() { - defer wg.Done() - doReq() - }() - } - wg.Wait() - close(errCh) - - for err := range errCh { - t.Errorf("error occurred: %v", err) - } -} - -// Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers -// that contain a sign (eg. "+3"), per RFC 2616, Section 14.13. -func TestTransportRejectsSignInContentLength(t *testing.T) { - cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Length", "+3") - w.Write([]byte("abc")) - })) - defer cst.Close() - - c := cst.Client() - res, err := c.Get(cst.URL) - if err == nil || res != nil { - t.Fatal("Expected a non-nil error and a nil http.Response") - } - if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) { - t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want) - } -} - -// dumpConn is a net.Conn which writes to Writer and reads from Reader -type dumpConn struct { - io.Writer - io.Reader -} - -func (c *dumpConn) Close() error { return nil } -func (c *dumpConn) LocalAddr() net.Addr { return nil } -func (c *dumpConn) RemoteAddr() net.Addr { return nil } -func (c *dumpConn) SetDeadline(t time.Time) error { return nil } -func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil } -func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil } - -// delegateReader is a reader that delegates to another reader, -// once it arrives on a channel. -type delegateReader struct { - c chan io.Reader - r io.Reader // nil until received from c -} - -func (r *delegateReader) Read(p []byte) (int, error) { - if r.r == nil { - var ok bool - if r.r, ok = <-r.c; !ok { - return 0, errors.New("delegate closed") - } - } - return r.r.Read(p) -} - -func testTransportRace(req *http.Request) { - save := req.Body - pr, pw := io.Pipe() - defer pr.Close() - defer pw.Close() - dr := &delegateReader{c: make(chan io.Reader)} - - t := T().SetDial(func(_ context.Context, net, addr string) (net.Conn, error) { - return &dumpConn{pw, dr}, nil - }) - defer t.CloseIdleConnections() - - quitReadCh := make(chan struct{}) - // Wait for the request before replying with a dummy response: - go func() { - defer close(quitReadCh) - - req, err := http.ReadRequest(bufio.NewReader(pr)) - if err == nil { - // Ensure all the body is read; otherwise - // we'll get a partial dump. - io.Copy(io.Discard, req.Body) - req.Body.Close() - } - select { - case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"): - case quitReadCh <- struct{}{}: - // Ensure delegate is closed so Read doesn't block forever. - close(dr.c) - } - }() - - t.RoundTrip(req) - - // Ensure the reader returns before we reset req.Body to prevent - // a data race on req.Body. - pw.Close() - <-quitReadCh - - req.Body = save -} - -// Issue 37669 -// Test that a cancellation doesn't result in a data race due to the writeLoop -// goroutine being left running, if the caller mutates the processed Request -// upon completion. -func TestErrorWriteLoopRace(t *testing.T) { - if testing.Short() { - return - } - t.Parallel() - for i := 0; i < 1000; i++ { - delay := time.Duration(mrand.Intn(5)) * time.Millisecond - ctx, cancel := context.WithTimeout(context.Background(), delay) - defer cancel() - - r := bytes.NewBuffer(make([]byte, 10000)) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://example.com", r) - if err != nil { - t.Fatal(err) - } - - testTransportRace(req) - } -} - -// Issue 41600 -// Test that a new request which uses the connection of an active request -// cannot cause it to be canceled as well. -func TestCancelRequestWhenSharingConnection(t *testing.T) { - reqc := make(chan chan struct{}, 2) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - ch := make(chan struct{}, 1) - reqc <- ch - <-ch - w.Header().Add("Content-Length", "0") - })) - defer ts.Close() - - client := tc().httpClient - transport := client.Transport.(*Transport) - transport.MaxIdleConns = 1 - transport.MaxConnsPerHost = 1 - - var wg sync.WaitGroup - - wg.Add(1) - putidlec := make(chan chan struct{}) - go func() { - defer wg.Done() - ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ - PutIdleConn: func(error) { - // Signal that the idle conn has been returned to the pool, - // and wait for the order to proceed. - ch := make(chan struct{}) - putidlec <- ch - <-ch - }, - }) - req, _ := http.NewRequestWithContext(ctx, "GET", ts.URL, nil) - res, err := client.Do(req) - if err == nil { - res.Body.Close() - } - if err != nil { - t.Errorf("request 1: got err %v, want nil", err) - } - }() - - // Wait for the first request to receive a response and return the - // connection to the idle pool. - r1c := <-reqc - close(r1c) - idlec := <-putidlec - - wg.Add(1) - cancelctx, cancel := context.WithCancel(context.Background()) - go func() { - defer wg.Done() - req, _ := http.NewRequestWithContext(cancelctx, "GET", ts.URL, nil) - res, err := client.Do(req) - if err == nil { - res.Body.Close() - } - if !errors.Is(err, context.Canceled) { - t.Errorf("request 2: got err %v, want Canceled", err) - } - }() - - // Wait for the second request to arrive at the server, and then cancel - // the request context. - r2c := <-reqc - cancel() - - // Give the cancelation a moment to take effect, and then unblock the first request. - time.Sleep(1 * time.Millisecond) - close(idlec) - - close(r2c) - wg.Wait() -}