diff --git a/internal/bisect/bisect.go b/internal/bisect/bisect.go new file mode 100644 index 00000000..3e5a6849 --- /dev/null +++ b/internal/bisect/bisect.go @@ -0,0 +1,794 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package bisect can be used by compilers and other programs +// to serve as a target for the bisect debugging tool. +// See [golang.org/x/tools/cmd/bisect] for details about using the tool. +// +// To be a bisect target, allowing bisect to help determine which of a set of independent +// changes provokes a failure, a program needs to: +// +// 1. Define a way to accept a change pattern on its command line or in its environment. +// The most common mechanism is a command-line flag. +// The pattern can be passed to [New] to create a [Matcher], the compiled form of a pattern. +// +// 2. Assign each change a unique ID. One possibility is to use a sequence number, +// but the most common mechanism is to hash some kind of identifying information +// like the file and line number where the change might be applied. +// [Hash] hashes its arguments to compute an ID. +// +// 3. Enable each change that the pattern says should be enabled. +// The [Matcher.ShouldEnable] method answers this question for a given change ID. +// +// 4. Print a report identifying each change that the pattern says should be printed. +// The [Matcher.ShouldPrint] method answers this question for a given change ID. +// The report consists of one more lines on standard error or standard output +// that contain a “match marker”. [Marker] returns the match marker for a given ID. +// When bisect reports a change as causing the failure, it identifies the change +// by printing the report lines with the match marker removed. +// +// # Example Usage +// +// A program starts by defining how it receives the pattern. In this example, we will assume a flag. +// The next step is to compile the pattern: +// +// m, err := bisect.New(patternFlag) +// if err != nil { +// log.Fatal(err) +// } +// +// Then, each time a potential change is considered, the program computes +// a change ID by hashing identifying information (source file and line, in this case) +// and then calls m.ShouldPrint and m.ShouldEnable to decide whether to +// print and enable the change, respectively. The two can return different values +// depending on whether bisect is trying to find a minimal set of changes to +// disable or to enable to provoke the failure. +// +// It is usually helpful to write a helper function that accepts the identifying information +// and then takes care of hashing, printing, and reporting whether the identified change +// should be enabled. For example, a helper for changes identified by a file and line number +// would be: +// +// func ShouldEnable(file string, line int) { +// h := bisect.Hash(file, line) +// if m.ShouldPrint(h) { +// fmt.Fprintf(os.Stderr, "%v %s:%d\n", bisect.Marker(h), file, line) +// } +// return m.ShouldEnable(h) +// } +// +// Finally, note that New returns a nil Matcher when there is no pattern, +// meaning that the target is not running under bisect at all, +// so all changes should be enabled and none should be printed. +// In that common case, the computation of the hash can be avoided entirely +// by checking for m == nil first: +// +// func ShouldEnable(file string, line int) bool { +// if m == nil { +// return true +// } +// h := bisect.Hash(file, line) +// if m.ShouldPrint(h) { +// fmt.Fprintf(os.Stderr, "%v %s:%d\n", bisect.Marker(h), file, line) +// } +// return m.ShouldEnable(h) +// } +// +// When the identifying information is expensive to format, this code can call +// [Matcher.MarkerOnly] to find out whether short report lines containing only the +// marker are permitted for a given run. (Bisect permits such lines when it is +// still exploring the space of possible changes and will not be showing the +// output to the user.) If so, the client can choose to print only the marker: +// +// func ShouldEnable(file string, line int) bool { +// if m == nil { +// return true +// } +// h := bisect.Hash(file, line) +// if m.ShouldPrint(h) { +// if m.MarkerOnly() { +// bisect.PrintMarker(os.Stderr, h) +// } else { +// fmt.Fprintf(os.Stderr, "%v %s:%d\n", bisect.Marker(h), file, line) +// } +// } +// return m.ShouldEnable(h) +// } +// +// This specific helper – deciding whether to enable a change identified by +// file and line number and printing about the change when necessary – is +// provided by the [Matcher.FileLine] method. +// +// Another common usage is deciding whether to make a change in a function +// based on the caller's stack, to identify the specific calling contexts that the +// change breaks. The [Matcher.Stack] method takes care of obtaining the stack, +// printing it when necessary, and reporting whether to enable the change +// based on that stack. +// +// # Pattern Syntax +// +// Patterns are generated by the bisect tool and interpreted by [New]. +// Users should not have to understand the patterns except when +// debugging a target's bisect support or debugging the bisect tool itself. +// +// The pattern syntax selecting a change is a sequence of bit strings +// separated by + and - operators. Each bit string denotes the set of +// changes with IDs ending in those bits, + is set addition, - is set subtraction, +// and the expression is evaluated in the usual left-to-right order. +// The special binary number “y” denotes the set of all changes, +// standing in for the empty bit string. +// In the expression, all the + operators must appear before all the - operators. +// A leading + adds to an empty set. A leading - subtracts from the set of all +// possible suffixes. +// +// For example: +// +// - “01+10” and “+01+10” both denote the set of changes +// with IDs ending with the bits 01 or 10. +// +// - “01+10-1001” denotes the set of changes with IDs +// ending with the bits 01 or 10, but excluding those ending in 1001. +// +// - “-01-1000” and “y-01-1000 both denote the set of all changes +// with IDs not ending in 01 nor 1000. +// +// - “0+1-01+001” is not a valid pattern, because all the + operators do not +// appear before all the - operators. +// +// In the syntaxes described so far, the pattern specifies the changes to +// enable and report. If a pattern is prefixed by a “!”, the meaning +// changes: the pattern specifies the changes to DISABLE and report. This +// mode of operation is needed when a program passes with all changes +// enabled but fails with no changes enabled. In this case, bisect +// searches for minimal sets of changes to disable. +// Put another way, the leading “!” inverts the result from [Matcher.ShouldEnable] +// but does not invert the result from [Matcher.ShouldPrint]. +// +// As a convenience for manual debugging, “n” is an alias for “!y”, +// meaning to disable and report all changes. +// +// Finally, a leading “v” in the pattern indicates that the reports will be shown +// to the user of bisect to describe the changes involved in a failure. +// At the API level, the leading “v” causes [Matcher.Visible] to return true. +// See the next section for details. +// +// # Match Reports +// +// The target program must enable only those changed matched +// by the pattern, and it must print a match report for each such change. +// A match report consists of one or more lines of text that will be +// printed by the bisect tool to describe a change implicated in causing +// a failure. Each line in the report for a given change must contain a +// match marker with that change ID, as returned by [Marker]. +// The markers are elided when displaying the lines to the user. +// +// A match marker has the form “[bisect-match 0x1234]” where +// 0x1234 is the change ID in hexadecimal. +// An alternate form is “[bisect-match 010101]”, giving the change ID in binary. +// +// When [Matcher.Visible] returns false, the match reports are only +// being processed by bisect to learn the set of enabled changes, +// not shown to the user, meaning that each report can be a match +// marker on a line by itself, eliding the usual textual description. +// When the textual description is expensive to compute, +// checking [Matcher.Visible] can help the avoid that expense +// in most runs. +package bisect + +import ( + "runtime" + "sync" + "sync/atomic" + "unsafe" +) + +// New creates and returns a new Matcher implementing the given pattern. +// The pattern syntax is defined in the package doc comment. +// +// In addition to the pattern syntax syntax, New("") returns nil, nil. +// The nil *Matcher is valid for use: it returns true from ShouldEnable +// and false from ShouldPrint for all changes. Callers can avoid calling +// [Hash], [Matcher.ShouldEnable], and [Matcher.ShouldPrint] entirely +// when they recognize the nil Matcher. +func New(pattern string) (*Matcher, error) { + if pattern == "" { + return nil, nil + } + + m := new(Matcher) + + p := pattern + // Special case for leading 'q' so that 'qn' quietly disables, e.g. fmahash=qn to disable fma + // Any instance of 'v' disables 'q'. + if len(p) > 0 && p[0] == 'q' { + m.quiet = true + p = p[1:] + if p == "" { + return nil, &parseError{"invalid pattern syntax: " + pattern} + } + } + // Allow multiple v, so that “bisect cmd vPATTERN” can force verbose all the time. + for len(p) > 0 && p[0] == 'v' { + m.verbose = true + m.quiet = false + p = p[1:] + if p == "" { + return nil, &parseError{"invalid pattern syntax: " + pattern} + } + } + + // Allow multiple !, each negating the last, so that “bisect cmd !PATTERN” works + // even when bisect chooses to add its own !. + m.enable = true + for len(p) > 0 && p[0] == '!' { + m.enable = !m.enable + p = p[1:] + if p == "" { + return nil, &parseError{"invalid pattern syntax: " + pattern} + } + } + + if p == "n" { + // n is an alias for !y. + m.enable = !m.enable + p = "y" + } + + // Parse actual pattern syntax. + result := true + bits := uint64(0) + start := 0 + wid := 1 // 1-bit (binary); sometimes 4-bit (hex) + for i := 0; i <= len(p); i++ { + // Imagine a trailing - at the end of the pattern to flush final suffix + c := byte('-') + if i < len(p) { + c = p[i] + } + if i == start && wid == 1 && c == 'x' { // leading x for hex + start = i + 1 + wid = 4 + continue + } + switch c { + default: + return nil, &parseError{"invalid pattern syntax: " + pattern} + case '2', '3', '4', '5', '6', '7', '8', '9': + if wid != 4 { + return nil, &parseError{"invalid pattern syntax: " + pattern} + } + fallthrough + case '0', '1': + bits <<= wid + bits |= uint64(c - '0') + case 'a', 'b', 'c', 'd', 'e', 'f', 'A', 'B', 'C', 'D', 'E', 'F': + if wid != 4 { + return nil, &parseError{"invalid pattern syntax: " + pattern} + } + bits <<= 4 + bits |= uint64(c&^0x20 - 'A' + 10) + case 'y': + if i+1 < len(p) && (p[i+1] == '0' || p[i+1] == '1') { + return nil, &parseError{"invalid pattern syntax: " + pattern} + } + bits = 0 + case '+', '-': + if c == '+' && result == false { + // Have already seen a -. Should be - from here on. + return nil, &parseError{"invalid pattern syntax (+ after -): " + pattern} + } + if i > 0 { + n := (i - start) * wid + if n > 64 { + return nil, &parseError{"pattern bits too long: " + pattern} + } + if n <= 0 { + return nil, &parseError{"invalid pattern syntax: " + pattern} + } + if p[start] == 'y' { + n = 0 + } + mask := uint64(1)<= 0; i-- { + c := &m.list[i] + if id&c.mask == c.bits { + return c.result + } + } + return false +} + +// FileLine reports whether the change identified by file and line should be enabled. +// If the change should be printed, FileLine prints a one-line report to w. +func (m *Matcher) FileLine(w Writer, file string, line int) bool { + if m == nil { + return true + } + return m.fileLine(w, file, line) +} + +// fileLine does the real work for FileLine. +// This lets FileLine's body handle m == nil and potentially be inlined. +func (m *Matcher) fileLine(w Writer, file string, line int) bool { + h := Hash(file, line) + if m.ShouldPrint(h) { + if m.MarkerOnly() { + PrintMarker(w, h) + } else { + printFileLine(w, h, file, line) + } + } + return m.ShouldEnable(h) +} + +// printFileLine prints a non-marker-only report for file:line to w. +func printFileLine(w Writer, h uint64, file string, line int) error { + const markerLen = 40 // overestimate + b := make([]byte, 0, markerLen+len(file)+24) + b = AppendMarker(b, h) + b = appendFileLine(b, file, line) + b = append(b, '\n') + _, err := w.Write(b) + return err +} + +// appendFileLine appends file:line to dst, returning the extended slice. +func appendFileLine(dst []byte, file string, line int) []byte { + dst = append(dst, file...) + dst = append(dst, ':') + u := uint(line) + if line < 0 { + dst = append(dst, '-') + u = -u + } + var buf [24]byte + i := len(buf) + for i == len(buf) || u > 0 { + i-- + buf[i] = '0' + byte(u%10) + u /= 10 + } + dst = append(dst, buf[i:]...) + return dst +} + +// MatchStack assigns the current call stack a change ID. +// If the stack should be printed, MatchStack prints it. +// Then MatchStack reports whether a change at the current call stack should be enabled. +func (m *Matcher) Stack(w Writer) bool { + if m == nil { + return true + } + return m.stack(w) +} + +// stack does the real work for Stack. +// This lets stack's body handle m == nil and potentially be inlined. +func (m *Matcher) stack(w Writer) bool { + const maxStack = 16 + var stk [maxStack]uintptr + n := runtime.Callers(2, stk[:]) + // caller #2 is not for printing; need it to normalize PCs if ASLR. + if n <= 1 { + return false + } + + base := stk[0] + // normalize PCs + for i := range stk[:n] { + stk[i] -= base + } + + h := Hash(stk[:n]) + if m.ShouldPrint(h) { + var d *dedup + for { + d = m.dedup.Load() + if d != nil { + break + } + d = new(dedup) + if m.dedup.CompareAndSwap(nil, d) { + break + } + } + + if m.MarkerOnly() { + if !d.seenLossy(h) { + PrintMarker(w, h) + } + } else { + if !d.seen(h) { + // Restore PCs in stack for printing + for i := range stk[:n] { + stk[i] += base + } + printStack(w, h, stk[1:n]) + } + } + } + return m.ShouldEnable(h) +} + +// Writer is the same interface as io.Writer. +// It is duplicated here to avoid importing io. +type Writer interface { + Write([]byte) (int, error) +} + +// PrintMarker prints to w a one-line report containing only the marker for h. +// It is appropriate to use when [Matcher.ShouldPrint] and [Matcher.MarkerOnly] both return true. +func PrintMarker(w Writer, h uint64) error { + var buf [50]byte + b := AppendMarker(buf[:0], h) + b = append(b, '\n') + _, err := w.Write(b) + return err +} + +// printStack prints to w a multi-line report containing a formatting of the call stack stk, +// with each line preceded by the marker for h. +func printStack(w Writer, h uint64, stk []uintptr) error { + buf := make([]byte, 0, 2048) + + var prefixBuf [100]byte + prefix := AppendMarker(prefixBuf[:0], h) + + frames := runtime.CallersFrames(stk) + for { + f, more := frames.Next() + buf = append(buf, prefix...) + buf = append(buf, f.Func.Name()...) + buf = append(buf, "()\n"...) + buf = append(buf, prefix...) + buf = append(buf, '\t') + buf = appendFileLine(buf, f.File, f.Line) + buf = append(buf, '\n') + if !more { + break + } + } + buf = append(buf, prefix...) + buf = append(buf, '\n') + _, err := w.Write(buf) + return err +} + +// Marker returns the match marker text to use on any line reporting details +// about a match of the given ID. +// It always returns the hexadecimal format. +func Marker(id uint64) string { + return string(AppendMarker(nil, id)) +} + +// AppendMarker is like [Marker] but appends the marker to dst. +func AppendMarker(dst []byte, id uint64) []byte { + const prefix = "[bisect-match 0x" + var buf [len(prefix) + 16 + 1]byte + copy(buf[:], prefix) + for i := 0; i < 16; i++ { + buf[len(prefix)+i] = "0123456789abcdef"[id>>60] + id <<= 4 + } + buf[len(prefix)+16] = ']' + return append(dst, buf[:]...) +} + +// CutMarker finds the first match marker in line and removes it, +// returning the shortened line (with the marker removed), +// the ID from the match marker, +// and whether a marker was found at all. +// If there is no marker, CutMarker returns line, 0, false. +func CutMarker(line string) (short string, id uint64, ok bool) { + // Find first instance of prefix. + prefix := "[bisect-match " + i := 0 + for ; ; i++ { + if i >= len(line)-len(prefix) { + return line, 0, false + } + if line[i] == '[' && line[i:i+len(prefix)] == prefix { + break + } + } + + // Scan to ]. + j := i + len(prefix) + for j < len(line) && line[j] != ']' { + j++ + } + if j >= len(line) { + return line, 0, false + } + + // Parse id. + idstr := line[i+len(prefix) : j] + if len(idstr) >= 3 && idstr[:2] == "0x" { + // parse hex + if len(idstr) > 2+16 { // max 0x + 16 digits + return line, 0, false + } + for i := 2; i < len(idstr); i++ { + id <<= 4 + switch c := idstr[i]; { + case '0' <= c && c <= '9': + id |= uint64(c - '0') + case 'a' <= c && c <= 'f': + id |= uint64(c - 'a' + 10) + case 'A' <= c && c <= 'F': + id |= uint64(c - 'A' + 10) + } + } + } else { + if idstr == "" || len(idstr) > 64 { // min 1 digit, max 64 digits + return line, 0, false + } + // parse binary + for i := 0; i < len(idstr); i++ { + id <<= 1 + switch c := idstr[i]; c { + default: + return line, 0, false + case '0', '1': + id |= uint64(c - '0') + } + } + } + + // Construct shortened line. + // Remove at most one space from around the marker, + // so that "foo [marker] bar" shortens to "foo bar". + j++ // skip ] + if i > 0 && line[i-1] == ' ' { + i-- + } else if j < len(line) && line[j] == ' ' { + j++ + } + short = line[:i] + line[j:] + return short, id, true +} + +// Hash computes a hash of the data arguments, +// each of which must be of type string, byte, int, uint, int32, uint32, int64, uint64, uintptr, or a slice of one of those types. +func Hash(data ...any) uint64 { + h := offset64 + for _, v := range data { + switch v := v.(type) { + default: + // Note: Not printing the type, because reflect.ValueOf(v) + // would make the interfaces prepared by the caller escape + // and therefore allocate. This way, Hash(file, line) runs + // without any allocation. It should be clear from the + // source code calling Hash what the bad argument was. + panic("bisect.Hash: unexpected argument type") + case string: + h = fnvString(h, v) + case byte: + h = fnv(h, v) + case int: + h = fnvUint64(h, uint64(v)) + case uint: + h = fnvUint64(h, uint64(v)) + case int32: + h = fnvUint32(h, uint32(v)) + case uint32: + h = fnvUint32(h, v) + case int64: + h = fnvUint64(h, uint64(v)) + case uint64: + h = fnvUint64(h, v) + case uintptr: + h = fnvUint64(h, uint64(v)) + case []string: + for _, x := range v { + h = fnvString(h, x) + } + case []byte: + for _, x := range v { + h = fnv(h, x) + } + case []int: + for _, x := range v { + h = fnvUint64(h, uint64(x)) + } + case []uint: + for _, x := range v { + h = fnvUint64(h, uint64(x)) + } + case []int32: + for _, x := range v { + h = fnvUint32(h, uint32(x)) + } + case []uint32: + for _, x := range v { + h = fnvUint32(h, x) + } + case []int64: + for _, x := range v { + h = fnvUint64(h, uint64(x)) + } + case []uint64: + for _, x := range v { + h = fnvUint64(h, x) + } + case []uintptr: + for _, x := range v { + h = fnvUint64(h, uint64(x)) + } + } + } + return h +} + +// Trivial error implementation, here to avoid importing errors. + +// parseError is a trivial error implementation, +// defined here to avoid importing errors. +type parseError struct{ text string } + +func (e *parseError) Error() string { return e.text } + +// FNV-1a implementation. See Go's hash/fnv/fnv.go. +// Copied here for simplicity (can handle integers more directly) +// and to avoid importing hash/fnv. + +const ( + offset64 uint64 = 14695981039346656037 + prime64 uint64 = 1099511628211 +) + +func fnv(h uint64, x byte) uint64 { + h ^= uint64(x) + h *= prime64 + return h +} + +func fnvString(h uint64, x string) uint64 { + for i := 0; i < len(x); i++ { + h ^= uint64(x[i]) + h *= prime64 + } + return h +} + +func fnvUint64(h uint64, x uint64) uint64 { + for i := 0; i < 8; i++ { + h ^= x & 0xFF + x >>= 8 + h *= prime64 + } + return h +} + +func fnvUint32(h uint64, x uint32) uint64 { + for i := 0; i < 4; i++ { + h ^= uint64(x & 0xFF) + x >>= 8 + h *= prime64 + } + return h +} + +// A dedup is a deduplicator for call stacks, so that we only print +// a report for new call stacks, not for call stacks we've already +// reported. +// +// It has two modes: an approximate but lock-free mode that +// may still emit some duplicates, and a precise mode that uses +// a lock and never emits duplicates. +type dedup struct { + // 128-entry 4-way, lossy cache for seenLossy + recent [128][4]uint64 + + // complete history for seen + mu sync.Mutex + m map[uint64]bool +} + +// seen records that h has now been seen and reports whether it was seen before. +// When seen returns false, the caller is expected to print a report for h. +func (d *dedup) seen(h uint64) bool { + d.mu.Lock() + if d.m == nil { + d.m = make(map[uint64]bool) + } + seen := d.m[h] + d.m[h] = true + d.mu.Unlock() + return seen +} + +// seenLossy is a variant of seen that avoids a lock by using a cache of recently seen hashes. +// Each cache entry is N-way set-associative: h can appear in any of the slots. +// If h does not appear in any of them, then it is inserted into a random slot, +// overwriting whatever was there before. +func (d *dedup) seenLossy(h uint64) bool { + cache := &d.recent[uint(h)%uint(len(d.recent))] + for i := 0; i < len(cache); i++ { + if atomic.LoadUint64(&cache[i]) == h { + return true + } + } + + // Compute index in set to evict as hash of current set. + ch := offset64 + for _, x := range cache { + ch = fnvUint64(ch, x) + } + atomic.StoreUint64(&cache[uint(ch)%uint(len(cache))], h) + return false +} diff --git a/internal/godebug/godebug.go b/internal/godebug/godebug.go index ac434e5f..f5a0f53a 100644 --- a/internal/godebug/godebug.go +++ b/internal/godebug/godebug.go @@ -2,33 +2,244 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package godebug parses the GODEBUG environment variable. +// Package godebug makes the settings in the $GODEBUG environment variable +// available to other packages. These settings are often used for compatibility +// tweaks, when we need to change a default behavior but want to let users +// opt back in to the original. For example GODEBUG=http2server=0 disables +// HTTP/2 support in the net/http server. +// +// In typical usage, code should declare a Setting as a global +// and then call Value each time the current setting value is needed: +// +// var http2server = godebug.New("http2server") +// +// func ServeConn(c net.Conn) { +// if http2server.Value() == "0" { +// disallow HTTP/2 +// ... +// } +// ... +// } +// +// Each time a non-default setting causes a change in program behavior, +// code should call [Setting.IncNonDefault] to increment a counter that can +// be reported by [runtime/metrics.Read]. +// Note that counters used with IncNonDefault must be added to +// various tables in other packages. See the [Setting.IncNonDefault] +// documentation for details. package godebug -import "os" +// Note: Be careful about new imports here. Any package +// that internal/godebug imports cannot itself import internal/godebug, +// meaning it cannot introduce a GODEBUG setting of its own. +// We keep imports to the absolute bare minimum. +import ( + "sync" + "sync/atomic" + _ "unsafe" // go:linkname -// Get returns the value for the provided GODEBUG key. -func Get(key string) string { - return get(os.Getenv("GODEBUG"), key) + "github.com/imroc/req/v3/internal/bisect" + "github.com/imroc/req/v3/internal/godebugs" +) + +// A Setting is a single setting in the $GODEBUG environment variable. +type Setting struct { + name string + once sync.Once + *setting +} + +type setting struct { + value atomic.Pointer[value] + nonDefaultOnce sync.Once + nonDefault atomic.Uint64 + info *godebugs.Info +} + +type value struct { + text string + bisect *bisect.Matcher +} + +// New returns a new Setting for the $GODEBUG setting with the given name. +// +// GODEBUGs meant for use by end users must be listed in ../godebugs/table.go, +// which is used for generating and checking various documentation. +// If the name is not listed in that table, New will succeed but calling Value +// on the returned Setting will panic. +// To disable that panic for access to an undocumented setting, +// prefix the name with a #, as in godebug.New("#gofsystrace"). +// The # is a signal to New but not part of the key used in $GODEBUG. +func New(name string) *Setting { + return &Setting{name: name} +} + +// Name returns the name of the setting. +func (s *Setting) Name() string { + if s.name != "" && s.name[0] == '#' { + return s.name[1:] + } + return s.name } -// get returns the value part of key=value in s (a GODEBUG value). -func get(s, key string) string { - for i := 0; i < len(s)-len(key)-1; i++ { - if i > 0 && s[i-1] != ',' { - continue +// Undocumented reports whether this is an undocumented setting. +func (s *Setting) Undocumented() bool { + return s.name != "" && s.name[0] == '#' +} + +// String returns a printable form for the setting: name=value. +func (s *Setting) String() string { + return s.Name() + "=" + s.Value() +} + +// IncNonDefault increments the non-default behavior counter +// associated with the given setting. +// This counter is exposed in the runtime/metrics value +// /godebug/non-default-behavior/:events. +// +// Note that Value must be called at least once before IncNonDefault. +func (s *Setting) IncNonDefault() { + s.nonDefaultOnce.Do(s.register) + s.nonDefault.Add(1) +} + +func (s *Setting) register() { + if s.info == nil || s.info.Opaque { + panic("godebug: unexpected IncNonDefault of " + s.name) + } +} + +// cache is a cache of all the GODEBUG settings, +// a locked map[string]*atomic.Pointer[string]. +// +// All Settings with the same name share a single +// *atomic.Pointer[string], so that when GODEBUG +// changes only that single atomic string pointer +// needs to be updated. +// +// A name appears in the values map either if it is the +// name of a Setting for which Value has been called +// at least once, or if the name has ever appeared in +// a name=value pair in the $GODEBUG environment variable. +// Once entered into the map, the name is never removed. +var cache sync.Map // name string -> value *atomic.Pointer[string] + +var empty value + +// Value returns the current value for the GODEBUG setting s. +// +// Value maintains an internal cache that is synchronized +// with changes to the $GODEBUG environment variable, +// making Value efficient to call as frequently as needed. +// Clients should therefore typically not attempt their own +// caching of Value's result. +func (s *Setting) Value() string { + s.once.Do(func() { + s.setting = lookup(s.Name()) + if s.info == nil && !s.Undocumented() { + panic("godebug: Value of name not listed in godebugs.All: " + s.name) } - afterKey := s[i+len(key):] - if afterKey[0] != '=' || s[i:i+len(key)] != key { - continue + }) + v := *s.value.Load() + if v.bisect != nil && !v.bisect.Stack(&stderr) { + return "" + } + return v.text +} + +// lookup returns the unique *setting value for the given name. +func lookup(name string) *setting { + if v, ok := cache.Load(name); ok { + return v.(*setting) + } + s := new(setting) + s.info = godebugs.Lookup(name) + s.value.Store(&empty) + if v, loaded := cache.LoadOrStore(name, s); loaded { + // Lost race: someone else created it. Use theirs. + return v.(*setting) + } + + return s +} + +func newIncNonDefault(name string) func() { + s := New(name) + s.Value() + return s.IncNonDefault +} + +var updateMu sync.Mutex + +// update records an updated GODEBUG setting. +// def is the default GODEBUG setting for the running binary, +// and env is the current value of the $GODEBUG environment variable. +func update(def, env string) { + updateMu.Lock() + defer updateMu.Unlock() + + // Update all the cached values, creating new ones as needed. + // We parse the environment variable first, so that any settings it has + // are already locked in place (did[name] = true) before we consider + // the defaults. + did := make(map[string]bool) + parse(did, env) + parse(did, def) + + // Clear any cached values that are no longer present. + cache.Range(func(name, s any) bool { + if !did[name.(string)] { + s.(*setting).value.Store(&empty) } - val := afterKey[1:] - for i, b := range val { - if b == ',' { - return val[:i] + return true + }) +} + +// parse parses the GODEBUG setting string s, +// which has the form k=v,k2=v2,k3=v3. +// Later settings override earlier ones. +// Parse only updates settings k=v for which did[k] = false. +// It also sets did[k] = true for settings that it updates. +// Each value v can also have the form v#pattern, +// in which case the GODEBUG is only enabled for call stacks +// matching pattern, for use with golang.org/x/tools/cmd/bisect. +func parse(did map[string]bool, s string) { + // Scan the string backward so that later settings are used + // and earlier settings are ignored. + // Note that a forward scan would cause cached values + // to temporarily use the ignored value before being + // updated to the "correct" one. + end := len(s) + eq := -1 + for i := end - 1; i >= -1; i-- { + if i == -1 || s[i] == ',' { + if eq >= 0 { + name, arg := s[i+1:eq], s[eq+1:end] + if !did[name] { + did[name] = true + v := &value{text: arg} + for j := 0; j < len(arg); j++ { + if arg[j] == '#' { + v.text = arg[:j] + v.bisect, _ = bisect.New(arg[j+1:]) + break + } + } + lookup(name).value.Store(v) + } } + eq = -1 + end = i + } else if s[i] == '=' { + eq = i } - return val } - return "" +} + +type runtimeStderr struct{} + +var stderr runtimeStderr + +func (*runtimeStderr) Write(b []byte) (int, error) { + return len(b), nil } diff --git a/internal/godebug/godebug_test.go b/internal/godebug/godebug_test.go deleted file mode 100644 index 41b9117b..00000000 --- a/internal/godebug/godebug_test.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2021 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package godebug - -import "testing" - -func TestGet(t *testing.T) { - tests := []struct { - godebug string - key string - want string - }{ - {"", "", ""}, - {"", "foo", ""}, - {"foo=bar", "foo", "bar"}, - {"foo=bar,after=x", "foo", "bar"}, - {"before=x,foo=bar,after=x", "foo", "bar"}, - {"before=x,foo=bar", "foo", "bar"}, - {",,,foo=bar,,,", "foo", "bar"}, - {"foodecoy=wrong,foo=bar", "foo", "bar"}, - {"foo=", "foo", ""}, - {"foo", "foo", ""}, - {",foo", "foo", ""}, - {"foo=bar,baz", "loooooooong", ""}, - } - for _, tt := range tests { - got := get(tt.godebug, tt.key) - if got != tt.want { - t.Errorf("get(%q, %q) = %q; want %q", tt.godebug, tt.key, got, tt.want) - } - } -} diff --git a/internal/godebugs/table.go b/internal/godebugs/table.go new file mode 100644 index 00000000..d5ac707a --- /dev/null +++ b/internal/godebugs/table.go @@ -0,0 +1,78 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package godebugs provides a table of known GODEBUG settings, +// for use by a variety of other packages, including internal/godebug, +// runtime, runtime/metrics, and cmd/go/internal/load. +package godebugs + +// An Info describes a single known GODEBUG setting. +type Info struct { + Name string // name of the setting ("panicnil") + Package string // package that uses the setting ("runtime") + Changed int // minor version when default changed, if any; 21 means Go 1.21 + Old string // value that restores behavior prior to Changed + Opaque bool // setting does not export information to runtime/metrics using [internal/godebug.Setting.IncNonDefault] +} + +// All is the table of known settings, sorted by Name. +// +// Note: After adding entries to this table, run 'go generate runtime/metrics' +// to update the runtime/metrics doc comment. +// (Otherwise the runtime/metrics test will fail.) +// +// Note: After adding entries to this table, update the list in doc/godebug.md as well. +// (Otherwise the test in this package will fail.) +var All = []Info{ + {Name: "execerrdot", Package: "os/exec"}, + {Name: "gocachehash", Package: "cmd/go"}, + {Name: "gocachetest", Package: "cmd/go"}, + {Name: "gocacheverify", Package: "cmd/go"}, + {Name: "gotypesalias", Package: "go/types"}, + {Name: "http2client", Package: "net/http"}, + {Name: "http2debug", Package: "net/http", Opaque: true}, + {Name: "http2server", Package: "net/http"}, + {Name: "httplaxcontentlength", Package: "net/http", Changed: 22, Old: "1"}, + {Name: "httpmuxgo121", Package: "net/http", Changed: 22, Old: "1"}, + {Name: "installgoroot", Package: "go/build"}, + {Name: "jstmpllitinterp", Package: "html/template"}, + //{Name: "multipartfiles", Package: "mime/multipart"}, + {Name: "multipartmaxheaders", Package: "mime/multipart"}, + {Name: "multipartmaxparts", Package: "mime/multipart"}, + {Name: "multipathtcp", Package: "net"}, + {Name: "netdns", Package: "net", Opaque: true}, + {Name: "panicnil", Package: "runtime", Changed: 21, Old: "1"}, + {Name: "randautoseed", Package: "math/rand"}, + {Name: "tarinsecurepath", Package: "archive/tar"}, + {Name: "tls10server", Package: "crypto/tls", Changed: 22, Old: "1"}, + {Name: "tlsmaxrsasize", Package: "crypto/tls"}, + {Name: "tlsrsakex", Package: "crypto/tls", Changed: 22, Old: "1"}, + {Name: "tlsunsafeekm", Package: "crypto/tls", Changed: 22, Old: "1"}, + {Name: "winreadlinkvolume", Package: "os", Changed: 22, Old: "0"}, + {Name: "winsymlink", Package: "os", Changed: 22, Old: "0"}, + {Name: "x509sha1", Package: "crypto/x509"}, + {Name: "x509usefallbackroots", Package: "crypto/x509"}, + {Name: "x509usepolicies", Package: "crypto/x509"}, + {Name: "zipinsecurepath", Package: "archive/zip"}, +} + +// Lookup returns the Info with the given name. +func Lookup(name string) *Info { + // binary search, avoiding import of sort. + lo := 0 + hi := len(All) + for lo < hi { + m := int(uint(lo+hi) >> 1) + mid := All[m].Name + if name == mid { + return &All[m] + } + if name < mid { + hi = m + } else { + lo = m + 1 + } + } + return nil +} diff --git a/request_test.go b/request_test.go index ac30f85b..a326e1f2 100644 --- a/request_test.go +++ b/request_test.go @@ -5,8 +5,6 @@ import ( "encoding/json" "encoding/xml" "fmt" - "github.com/imroc/req/v3/internal/header" - "github.com/imroc/req/v3/internal/tests" "io" "net/http" "net/url" @@ -15,6 +13,9 @@ import ( "strings" "testing" "time" + + "github.com/imroc/req/v3/internal/header" + "github.com/imroc/req/v3/internal/tests" ) func TestMustSendMethods(t *testing.T) { @@ -635,7 +636,6 @@ func testQueryParam(t *testing.T, c *Client) { Get("/query-parameter") assertSuccess(t, resp, err) tests.AssertEqual(t, "key1=value1&key1=value11&key2=value2&key2=value22&key3=value3&key4=value4&key4=value44&key5=value5&key6=value6&key6=value66", resp.String()) - } func TestPathParam(t *testing.T) { @@ -963,7 +963,7 @@ func (r *SlowReader) Read(p []byte) (int, error) { func TestUploadCallback(t *testing.T) { r := tc().R() - file := "transport_test.go" + file := "transport.go" fileInfo, err := os.Stat(file) if err != nil { t.Fatal(err) diff --git a/roundtrip_js.go b/roundtrip_js.go index af51f13e..9c6b6c4a 100644 --- a/roundtrip_js.go +++ b/roundtrip_js.go @@ -12,7 +12,10 @@ import ( "io" "net/http" "strconv" + "strings" "syscall/js" + + "github.com/imroc/req/v3/internal/ascii" ) var uint8Array = js.Global().Get("Uint8Array") @@ -45,45 +48,17 @@ const jsFetchRedirect = "js.fetch:redirect" // the browser globals. var jsFetchMissing = js.Global().Get("fetch").IsUndefined() -// jsFetchDisabled will be true if the "process" global is present. -// We use this as an indicator that we're running in Node.js. We -// want to disable the Fetch API in Node.js because it breaks -// our wasm tests. See https://go.dev/issue/57613 for more information. -var jsFetchDisabled = !js.Global().Get("process").IsUndefined() - -// Determine whether the JS runtime supports streaming request bodies. -// Courtesy: https://developer.chrome.com/articles/fetch-streaming-requests/#feature-detection -func supportsPostRequestStreams() bool { - requestOpt := js.Global().Get("Object").New() - requestBody := js.Global().Get("ReadableStream").New() - - requestOpt.Set("method", "POST") - requestOpt.Set("body", requestBody) - - // There is quite a dance required to define a getter if you do not have the { get property() { ... } } - // syntax available. However, it is possible: - // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Functions/get#defining_a_getter_on_existing_objects_using_defineproperty - duplexCalled := false - duplexGetterObj := js.Global().Get("Object").New() - duplexGetterFunc := js.FuncOf(func(this js.Value, args []js.Value) any { - duplexCalled = true - return "half" - }) - defer duplexGetterFunc.Release() - duplexGetterObj.Set("get", duplexGetterFunc) - js.Global().Get("Object").Call("defineProperty", requestOpt, "duplex", duplexGetterObj) - - // Slight difference here between the aforementioned example: Non-browser-based runtimes - // do not have a non-empty API Base URL (https://html.spec.whatwg.org/multipage/webappapis.html#api-base-url) - // so we have to supply a valid URL here. - requestObject := js.Global().Get("Request").New("https://www.example.org", requestOpt) - - hasContentTypeHeader := requestObject.Get("headers").Call("has", "Content-Type").Bool() - - return duplexCalled && !hasContentTypeHeader -} +// jsFetchDisabled controls whether the use of Fetch API is disabled. +// It's set to true when we detect we're running in Node.js, so that +// RoundTrip ends up talking over the same fake network the HTTP servers +// currently use in various tests and examples. See go.dev/issue/57613. +// +// TODO(go.dev/issue/60810): See if it's viable to test the Fetch API +// code path. +var jsFetchDisabled = js.Global().Get("process").Type() == js.TypeObject && + strings.HasPrefix(js.Global().Get("process").Get("argv0").String(), "node") -// RoundTrip implements the RoundTripper interface using the WHATWG Fetch API. +// RoundTrip implements the [RoundTripper] interface using the WHATWG Fetch API. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { // The Transport has a documented contract that states that if the DialContext or // DialTLSContext functions are set, they will be used to set up the connections. @@ -131,60 +106,25 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { } opt.Set("headers", headers) - var readableStreamStart, readableStreamPull, readableStreamCancel js.Func if req.Body != nil { - if !supportsPostRequestStreams() { - body, err := io.ReadAll(req.Body) - if err != nil { - req.Body.Close() // RoundTrip must always close the body, including on errors. - return nil, err - } - if len(body) != 0 { - buf := uint8Array.New(len(body)) - js.CopyBytesToJS(buf, body) - opt.Set("body", buf) - } - } else { - readableStreamCtorArg := js.Global().Get("Object").New() - readableStreamCtorArg.Set("type", "bytes") - readableStreamCtorArg.Set("autoAllocateChunkSize", t.writeBufferSize()) - - readableStreamPull = js.FuncOf(func(this js.Value, args []js.Value) any { - controller := args[0] - byobRequest := controller.Get("byobRequest") - if byobRequest.IsNull() { - controller.Call("close") - } - - byobRequestView := byobRequest.Get("view") - - bodyBuf := make([]byte, byobRequestView.Get("byteLength").Int()) - readBytes, readErr := io.ReadFull(req.Body, bodyBuf) - if readBytes > 0 { - buf := uint8Array.New(byobRequestView.Get("buffer")) - js.CopyBytesToJS(buf, bodyBuf) - byobRequest.Call("respond", readBytes) - } - - if readErr == io.EOF || readErr == io.ErrUnexpectedEOF { - controller.Call("close") - } else if readErr != nil { - readErrCauseObject := js.Global().Get("Object").New() - readErrCauseObject.Set("cause", readErr.Error()) - readErr := js.Global().Get("Error").New("io.ReadFull failed while streaming POST body", readErrCauseObject) - controller.Call("error", readErr) - } - // Note: This a return from the pull callback of the controller and *not* RoundTrip(). - return nil - }) - readableStreamCtorArg.Set("pull", readableStreamPull) - - opt.Set("body", js.Global().Get("ReadableStream").New(readableStreamCtorArg)) - // There is a requirement from the WHATWG fetch standard that the duplex property of - // the object given as the options argument to the fetch call be set to 'half' - // when the body property of the same options object is a ReadableStream: - // https://fetch.spec.whatwg.org/#dom-requestinit-duplex - opt.Set("duplex", "half") + // TODO(johanbrandhorst): Stream request body when possible. + // See https://bugs.chromium.org/p/chromium/issues/detail?id=688906 for Blink issue. + // See https://bugzilla.mozilla.org/show_bug.cgi?id=1387483 for Firefox issue. + // See https://github.com/web-platform-tests/wpt/issues/7693 for WHATWG tests issue. + // See https://developer.mozilla.org/en-US/docs/Web/API/Streams_API for more details on the Streams API + // and browser support. + // NOTE(haruyama480): Ensure HTTP/1 fallback exists. + // See https://go.dev/issue/61889 for discussion. + body, err := io.ReadAll(req.Body) + if err != nil { + req.Body.Close() // RoundTrip must always close the body, including on errors. + return nil, err + } + req.Body.Close() + if len(body) != 0 { + buf := uint8Array.New(len(body)) + js.CopyBytesToJS(buf, body) + opt.Set("body", buf) } } @@ -197,11 +137,6 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { success = js.FuncOf(func(this js.Value, args []js.Value) any { success.Release() failure.Release() - readableStreamCancel.Release() - readableStreamPull.Release() - readableStreamStart.Release() - - req.Body.Close() result := args[0] header := http.Header{} @@ -252,11 +187,21 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { } code := result.Get("status").Int() + + uncompressed := false + if ascii.EqualFold(header.Get("Content-Encoding"), "gzip") { + // The fetch api will decode the gzip, but Content-Encoding not be deleted. + header.Del("Content-Encoding") + header.Del("Content-Length") + contentLength = -1 + uncompressed = true + } respCh <- &http.Response{ Status: fmt.Sprintf("%d %s", code, http.StatusText(code)), StatusCode: code, Header: header, ContentLength: contentLength, + Uncompressed: uncompressed, Body: body, Request: req, } @@ -266,11 +211,6 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { failure = js.FuncOf(func(this js.Value, args []js.Value) any { success.Release() failure.Release() - readableStreamCancel.Release() - readableStreamPull.Release() - readableStreamStart.Release() - - req.Body.Close() err := args[0] // The error is a JS Error type diff --git a/server.go b/server.go new file mode 100644 index 00000000..8cd25f11 --- /dev/null +++ b/server.go @@ -0,0 +1,18 @@ +package req + +import "sync" + +const copyBufPoolSize = 32 * 1024 + +var copyBufPool = sync.Pool{New: func() any { return new([copyBufPoolSize]byte) }} + +func getCopyBuf() []byte { + return copyBufPool.Get().(*[copyBufPoolSize]byte)[:] +} + +func putCopyBuf(b []byte) { + if len(b) != copyBufPoolSize { + panic("trying to put back buffer of the wrong size in the copyBufPool") + } + copyBufPool.Put((*[copyBufPoolSize]byte)(b)) +} diff --git a/textproto_reader.go b/textproto_reader.go index 46103c07..1c09872a 100644 --- a/textproto_reader.go +++ b/textproto_reader.go @@ -7,12 +7,13 @@ package req import ( "bufio" "bytes" + "errors" "fmt" - "github.com/imroc/req/v3/internal/dump" - "github.com/imroc/req/v3/internal/util" + "math" "net/textproto" - "strings" "sync" + + "github.com/imroc/req/v3/internal/dump" ) func isASCIILetter(b byte) bool { @@ -20,6 +21,10 @@ func isASCIILetter(b byte) bool { return 'a' <= b && b <= 'z' } +// TODO: This should be a distinguishable error (ErrMessageTooLarge) +// to allow mime/multipart to detect it. +var errMessageTooLarge = errors.New("message too large") + // A textprotoReader implements convenience methods for reading requests // or responses from a text protocol network connection. type textprotoReader struct { @@ -67,11 +72,14 @@ func newTextprotoReader(r *bufio.Reader, ds dump.Dumpers) *textprotoReader { // ReadLine reads a single line from r, // eliding the final \n or \r\n from the returned string. func (r *textprotoReader) ReadLine() (string, error) { - line, err := r.readLineSlice() + line, err := r.readLineSlice(-1) return string(line), err } -func (r *textprotoReader) readLineSlice() ([]byte, error) { +// readLineSlice reads a single line from r, +// up to lim bytes long (or unlimited if lim is less than 0), +// eliding the final \r or \r\n from the returned string. +func (r *textprotoReader) readLineSlice(lim int64) ([]byte, error) { var line []byte for { @@ -79,6 +87,9 @@ func (r *textprotoReader) readLineSlice() ([]byte, error) { if err != nil { return nil, err } + if lim >= 0 && int64(len(line))+int64(len(l)) > lim { + return nil, errMessageTooLarge + } // Avoid the copy if the first call produced a full line. if line == nil && !more { return l, nil @@ -109,13 +120,14 @@ func trim(s []byte) []byte { // returning a byte slice with all lines. The validateFirstLine function // is run on the first read line, and if it returns an error then this // error is returned from readContinuedLineSlice. -func (r *textprotoReader) readContinuedLineSlice(validateFirstLine func([]byte) error) ([]byte, error) { +// It reads up to lim bytes of data (or unlimited if lim is less than 0). +func (r *textprotoReader) readContinuedLineSlice(lim int64, validateFirstLine func([]byte) error) ([]byte, error) { if validateFirstLine == nil { return nil, fmt.Errorf("missing validateFirstLine func") } // Read the first line. - line, err := r.readLineSlice() + line, err := r.readLineSlice(lim) if err != nil { return nil, err } @@ -143,13 +155,21 @@ func (r *textprotoReader) readContinuedLineSlice(validateFirstLine func([]byte) // copy the slice into buf. r.buf = append(r.buf[:0], trim(line)...) + if lim < 0 { + lim = math.MaxInt64 + } + lim -= int64(len(r.buf)) + // Read continuation lines. for r.skipSpace() > 0 { - line, err := r.readLineSlice() + r.buf = append(r.buf, ' ') + if int64(len(r.buf)) >= lim { + return nil, errMessageTooLarge + } + line, err := r.readLineSlice(lim - int64(len(r.buf))) if err != nil { break } - r.buf = append(r.buf, ' ') r.buf = append(r.buf, trim(line)...) } return r.buf, nil @@ -186,7 +206,7 @@ var colon = []byte(":") // ReadMIMEHeader reads a MIME-style header from r. // The header is a sequence of possibly continued Key: Value lines // ending in a blank line. -// The returned map m maps CanonicalMIMEHeaderKey(key) to a +// The returned map m maps [CanonicalMIMEHeaderKey](key) to a // sequence of values in the same order encountered in the input. // // For example, consider this input: @@ -203,20 +223,36 @@ var colon = []byte(":") // "Long-Key": {"Even Longer Value"}, // } func (r *textprotoReader) ReadMIMEHeader() (textproto.MIMEHeader, error) { + return r.readMIMEHeader(math.MaxInt64, math.MaxInt64) +} + +// readMIMEHeader is a version of ReadMIMEHeader which takes a limit on the header size. +// It is called by the mime/multipart package. +func (r *textprotoReader) readMIMEHeader(maxMemory, maxHeaders int64) (textproto.MIMEHeader, error) { // Avoid lots of small slice allocations later by allocating one // large one ahead of time which we'll cut up into smaller // slices. If this isn't big enough later, we allocate small ones. var strs []string - hint := r.upcomingHeaderNewlines() + hint := r.upcomingHeaderKeys() if hint > 0 { + if hint > 1000 { + hint = 1000 // set a cap to avoid overallocation + } strs = make([]string, hint) } m := make(textproto.MIMEHeader, hint) + // Account for 400 bytes of overhead for the MIMEHeader, plus 200 bytes per entry. + // Benchmarking map creation as of go1.20, a one-entry MIMEHeader is 416 bytes and large + // MIMEHeaders average about 200 bytes per entry. + maxMemory -= 400 + const mapEntryOverhead = 200 + // The first line cannot start with a leading space. if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') { - line, err := r.readLineSlice() + const errorLimit = 80 // arbitrary limit on how much of the line we'll quote + line, err := r.readLineSlice(errorLimit) if err != nil { return m, err } @@ -224,29 +260,43 @@ func (r *textprotoReader) ReadMIMEHeader() (textproto.MIMEHeader, error) { } for { - kv, err := r.readContinuedLineSlice(mustHaveFieldNameColon) + kv, err := r.readContinuedLineSlice(maxMemory, mustHaveFieldNameColon) if len(kv) == 0 { return m, err } // Key ends at first colon. - k, v, ok := util.CutBytes(kv, colon) + k, v, ok := bytes.Cut(kv, colon) + if !ok { + return m, protocolError("malformed MIME header line: " + string(kv)) + } + key, ok := canonicalMIMEHeaderKey(k) if !ok { return m, protocolError("malformed MIME header line: " + string(kv)) } - key := canonicalMIMEHeaderKey(k) + for _, c := range v { + if !validHeaderValueByte(c) { + return m, protocolError("malformed MIME header line: " + string(kv)) + } + } - // As per RFC 7230 field-name is a token, tokens consist of one or more chars. - // We could return a protocolError here, but better to be liberal in what we - // accept, so if we get an empty key, skip it. - if key == "" { - continue + maxHeaders-- + if maxHeaders < 0 { + return nil, errMessageTooLarge } // Skip initial spaces in value. - value := strings.TrimLeft(string(v), " \t") + value := string(bytes.TrimLeft(v, " \t")) vv := m[key] + if vv == nil { + maxMemory -= int64(len(key)) + maxMemory -= mapEntryOverhead + } + maxMemory -= int64(len(value)) + if maxMemory < 0 { + return m, errMessageTooLarge + } if vv == nil && len(strs) > 0 { // More than likely this will be a single-element key. // Most headers aren't multi-valued. @@ -277,9 +327,9 @@ func mustHaveFieldNameColon(line []byte) error { var nl = []byte("\n") -// upcomingHeaderNewlines returns an approximation of the number of newlines +// upcomingHeaderKeys returns an approximation of the number of keys // that will be in this header. If it gets confused, it returns 0. -func (r *textprotoReader) upcomingHeaderNewlines() (n int) { +func (r *textprotoReader) upcomingHeaderKeys() (n int) { // Try to determine the 'hint' size. r.R.Peek(1) // force a buffer load if empty s := r.R.Buffered() @@ -287,7 +337,20 @@ func (r *textprotoReader) upcomingHeaderNewlines() (n int) { return } peek, _ := r.R.Peek(s) - return bytes.Count(peek, nl) + for len(peek) > 0 && n < 1000 { + var line []byte + line, peek, _ = bytes.Cut(peek, nl) + if len(line) == 0 || (len(line) == 1 && line[0] == '\r') { + // Blank line separating headers from the body. + break + } + if line[0] == ' ' || line[0] == '\t' { + // Folded continuation of the previous line. + continue + } + n++ + } + return n } const toLower = 'a' - 'A' @@ -310,14 +373,33 @@ func validHeaderFieldByte(b byte) bool { // // For invalid inputs (if a contains spaces or non-token bytes), a // is unchanged and a string copy is returned. -func canonicalMIMEHeaderKey(a []byte) string { +// +// ok is true if the header key contains only valid characters and spaces. +// ReadMIMEHeader accepts header keys containing spaces, but does not +// canonicalize them. +func canonicalMIMEHeaderKey(a []byte) (_ string, ok bool) { + if len(a) == 0 { + return "", false + } + // See if a looks like a header key. If not, return it unchanged. + noCanon := false for _, c := range a { if validHeaderFieldByte(c) { continue } // Don't canonicalize. - return string(a) + if c == ' ' { + // We accept invalid headers with a space before the + // colon, but must not canonicalize them. + // See https://go.dev/issue/34540. + noCanon = true + continue + } + return string(a), false + } + if noCanon { + return string(a), true } upper := true @@ -334,13 +416,40 @@ func canonicalMIMEHeaderKey(a []byte) string { a[i] = c upper = c == '-' // for next time } + commonHeaderOnce.Do(initCommonHeader) // The compiler recognizes m[string(byteSlice)] as a special // case, so a copy of a's bytes into a new string does not // happen in this map lookup: if v := commonHeader[string(a)]; v != "" { - return v + return v, true } - return string(a) + return string(a), true +} + +// validHeaderValueByte reports whether c is a valid byte in a header +// field value. RFC 7230 says: +// +// field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] +// field-vchar = VCHAR / obs-text +// obs-text = %x80-FF +// +// RFC 5234 says: +// +// HTAB = %x09 +// SP = %x20 +// VCHAR = %x21-7E +func validHeaderValueByte(c byte) bool { + // mask is a 128-bit bitmap with 1s for allowed bytes, + // so that the byte c can be tested with a shift and an and. + // If c >= 128, then 1<>64)) == 0 } // commonHeader interns common header strings. diff --git a/transfer.go b/transfer.go index c7c623f8..92a5d305 100644 --- a/transfer.go +++ b/transfer.go @@ -9,9 +9,6 @@ import ( "bytes" "errors" "fmt" - "github.com/imroc/req/v3/internal" - "github.com/imroc/req/v3/internal/ascii" - "github.com/imroc/req/v3/internal/dump" "io" "net/http" "net/textproto" @@ -22,6 +19,11 @@ import ( "sync" "time" + "github.com/imroc/req/v3/internal" + "github.com/imroc/req/v3/internal/ascii" + "github.com/imroc/req/v3/internal/dump" + "github.com/imroc/req/v3/internal/godebug" + "golang.org/x/net/http/httpguts" ) @@ -317,7 +319,7 @@ func (t *transferWriter) writeBody(w io.Writer, dumps []*dump.Dumper) (err error // OS-level optimizations in the event that the body is an // *os.File. if t.Body != nil { - var body = t.unwrapBody() + body := t.unwrapBody() if chunked(t.TransferEncoding) { if bw, ok := rw.(*bufio.Writer); ok { rw = &internal.FlushAfterChunkWriter{Writer: bw} @@ -386,7 +388,9 @@ func (t *transferWriter) writeBody(w io.Writer, dumps []*dump.Dumper) (err error // // This function is only intended for use in writeBody. func (t *transferWriter) doBodyCopy(dst io.Writer, src io.Reader) (n int64, err error) { - n, err = io.Copy(dst, src) + buf := getCopyBuf() + defer putCopyBuf(buf) + n, err = io.CopyBuffer(dst, src, buf) if err != nil && err != io.EOF { t.bodyReadError = err } @@ -478,7 +482,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { return err } if isResponse && t.RequestMethod == "HEAD" { - if n, err := parseContentLength(headerGet(t.Header, "Content-Length")); err != nil { + if n, err := parseContentLength(t.Header["Content-Length"]); err != nil { return err } else { t.ContentLength = n @@ -582,19 +586,6 @@ func (t *transferReader) parseTransferEncoding() error { return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])} } - // RFC 7230 3.3.2 says "A sender MUST NOT send a Content-Length header field - // in any message that contains a Transfer-Encoding header field." - // - // but also: "If a message is received with both a Transfer-Encoding and a - // Content-Length header field, the Transfer-Encoding overrides the - // Content-Length. Such a message might indicate an attempt to perform - // request smuggling (Section 9.5) or response splitting (Section 9.4) and - // ought to be handled as an error. A sender MUST remove the received - // Content-Length field prior to forwarding such a message downstream." - // - // Reportedly, these appear in the wild. - delete(t.Header, "Content-Length") - t.Chunked = true return nil } @@ -602,7 +593,8 @@ func (t *transferReader) parseTransferEncoding() error { // Determine the expected body length, using RFC 7230 Section 3.3. This // function is not a method, because ultimately it should be shared by // ReadResponse and ReadRequest. -func fixLength(isResponse bool, status int, requestMethod string, header http.Header, chunked bool) (int64, error) { +func fixLength(isResponse bool, status int, requestMethod string, header http.Header, chunked bool) (n int64, err error) { + isRequest := !isResponse contentLens := header["Content-Length"] // Hardening against HTTP request smuggling @@ -625,6 +617,14 @@ func fixLength(isResponse bool, status int, requestMethod string, header http.He contentLens = header["Content-Length"] } + // Reject requests with invalid Content-Length headers. + if len(contentLens) > 0 { + n, err = parseContentLength(contentLens) + if err != nil { + return -1, err + } + } + // Logic based on response type or status if isResponse && noResponseBodyExpected(requestMethod) { return 0, nil @@ -637,25 +637,43 @@ func fixLength(isResponse bool, status int, requestMethod string, header http.He return 0, nil } + // According to RFC 9112, "If a message is received with both a + // Transfer-Encoding and a Content-Length header field, the Transfer-Encoding + // overrides the Content-Length. Such a message might indicate an attempt to + // perform request smuggling (Section 11.2) or response splitting (Section 11.1) + // and ought to be handled as an error. An intermediary that chooses to forward + // the message MUST first remove the received Content-Length field and process + // the Transfer-Encoding (as described below) prior to forwarding the message downstream." + // + // Chunked-encoding requests with either valid Content-Length + // headers or no Content-Length headers are accepted after removing + // the Content-Length field from header. + // // Logic based on Transfer-Encoding if chunked { + header.Del("Content-Length") return -1, nil } // Logic based on Content-Length - var cl string - if len(contentLens) == 1 { - cl = textproto.TrimString(contentLens[0]) - } - if cl != "" { - n, err := parseContentLength(cl) - if err != nil { - return -1, err - } + if len(contentLens) > 0 { return n, nil } + header.Del("Content-Length") + if isRequest { + // RFC 7230 neither explicitly permits nor forbids an + // entity-body on a GET request so we permit one if + // declared, but we default to 0 here (not -1 below) + // if there's no mention of a body. + // Likewise, all other request methods are assumed to have + // no body if neither Transfer-Encoding chunked nor a + // Content-Length are set. + return 0, nil + } + + // Body-EOF logic based on other methods (like closing, or chunked coding) return -1, nil } @@ -955,19 +973,31 @@ func (bl bodyLocked) Read(p []byte) (n int, err error) { return bl.b.readLocked(p) } -// parseContentLength trims whitespace from s and returns -1 if no value -// is set, or the value if it's >= 0. -func parseContentLength(cl string) (int64, error) { - cl = textproto.TrimString(cl) - if cl == "" { +var laxContentLength = godebug.New("httplaxcontentlength") + +// parseContentLength checks that the header is valid and then trims +// whitespace. It returns -1 if no value is set otherwise the value +// if it's >= 0. +func parseContentLength(clHeaders []string) (int64, error) { + if len(clHeaders) == 0 { return -1, nil } + cl := textproto.TrimString(clHeaders[0]) + + // The Content-Length must be a valid numeric value. + // See: https://datatracker.ietf.org/doc/html/rfc2616/#section-14.13 + if cl == "" { + if laxContentLength.Value() == "1" { + laxContentLength.IncNonDefault() + return -1, nil + } + return 0, badStringError("invalid empty Content-Length", cl) + } n, err := strconv.ParseUint(cl, 10, 63) if err != nil { return 0, badStringError("bad Content-Length", cl) } return int64(n), nil - } // finishAsyncByteRead finishes reading the 1-byte sniff @@ -991,11 +1021,13 @@ func (fr finishAsyncByteRead) Read(p []byte) (n int, err error) { return } -var nopCloserType = reflect.TypeOf(io.NopCloser(nil)) -var nopCloserWriterToType = reflect.TypeOf(io.NopCloser(struct { - io.Reader - io.WriterTo -}{})) +var ( + nopCloserType = reflect.TypeOf(io.NopCloser(nil)) + nopCloserWriterToType = reflect.TypeOf(io.NopCloser(struct { + io.Reader + io.WriterTo + }{})) +) // unwrapNopCloser return the underlying reader and true if r is a NopCloser // else it return false. diff --git a/transport.go b/transport.go index 2fdcff3c..e79c005f 100644 --- a/transport.go +++ b/transport.go @@ -1583,6 +1583,13 @@ func (w *wantConn) waiting() bool { } } +// getCtxForDial returns context for dial or nil if connection was delivered or canceled. +func (w *wantConn) getCtxForDial() context.Context { + w.mu.Lock() + defer w.mu.Unlock() + return w.ctx +} + // tryDeliver attempts to deliver pc, err to w and reports whether it succeeded. func (w *wantConn) tryDeliver(pc *persistConn, err error) bool { w.mu.Lock() @@ -1592,6 +1599,7 @@ func (w *wantConn) tryDeliver(pc *persistConn, err error) bool { return false } + w.ctx = nil w.pc = pc w.err = err if w.pc == nil && w.err == nil { @@ -1609,6 +1617,7 @@ func (w *wantConn) cancel(t *Transport, err error) { close(w.ready) // catch misbehavior in future delivery } pc := w.pc + w.ctx = nil w.pc = nil w.err = err w.mu.Unlock() @@ -1814,6 +1823,11 @@ func (t *Transport) queueForDial(w *wantConn) { // If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()]. func (t *Transport) dialConnFor(w *wantConn) { defer w.afterDial() + ctx := w.getCtxForDial() + if ctx == nil { + t.decConnsPerHost(w.key) + return + } pc, err := t.dialConn(w.ctx, w.cm) delivered := w.tryDeliver(pc, err) @@ -1882,7 +1896,6 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { // tunnel, this function establishes a nested TLS session inside the encrypted channel. // The remote endpoint's name may be overridden by TLSClientConfig.ServerName. func (pc *persistConn) addTLS(ctx context.Context, name string, trace *httptrace.ClientTrace, forProxy bool) error { - // Initiate TLS and check remote host name against certificate. cfg := cloneTLSConfig(pc.t.TLSClientConfig) if cfg.ServerName == "" { @@ -1912,6 +1925,11 @@ func (pc *persistConn) addTLS(ctx context.Context, name string, trace *httptrace }() if err := <-errc; err != nil { plainConn.Close() + if err == (tlsHandshakeTimeoutError{}) { + // Now that we have closed the connection, + // wait for the call to HandshakeContext to return. + <-errc + } if trace != nil && trace.TLSHandshakeDone != nil { trace.TLSHandshakeDone(tls.ConnectionState{}, err) } @@ -2148,6 +2166,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if t.OnProxyConnectResponse != nil { err = t.OnProxyConnectResponse(ctx, cm.proxyURL, connectReq, resp) if err != nil { + conn.Close() return nil, err } } @@ -2690,7 +2709,6 @@ func (pc *persistConn) readLoop() { waitForBodyRead <- false <-eofc // will be closed by deferred call at the end of the function return nil - }, fn: func(err error) error { isEOF := err == io.EOF @@ -2737,7 +2755,7 @@ func (pc *persistConn) readLoop() { } case <-rc.req.Cancel: alive = false - pc.t.CancelRequest(rc.req) + pc.t.cancelRequest(rc.cancelKey, common.ErrRequestCanceled) case <-rc.req.Context().Done(): alive = false pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err()) @@ -3227,16 +3245,18 @@ type writeRequest struct { continueCh <-chan struct{} } -type httpError struct { - err string - timeout bool +// httpTimeoutError represents a timeout. +// It implements net.Error and wraps context.DeadlineExceeded. +type timeoutError struct { + err string } -func (e *httpError) Error() string { return e.err } -func (e *httpError) Timeout() bool { return e.timeout } -func (e *httpError) Temporary() bool { return true } +func (e *timeoutError) Error() string { return e.err } +func (e *timeoutError) Timeout() bool { return true } +func (e *timeoutError) Temporary() bool { return true } +func (e *timeoutError) Is(err error) bool { return err == context.DeadlineExceeded } -var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true} +var errTimeout error = &timeoutError{"net/http: timeout awaiting response headers"} var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify?