From 6e11f45ebdbc7b0ee1367c80ea0a0c0ec52d6db5 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 16 Dec 2015 18:51:12 +0000 Subject: [PATCH] net/http: make Server validate Host headers Fixes #11206 (that we accept invalid bytes) Fixes #13624 (that we don't require a Host header in HTTP/1.1 per spec) Change-Id: I4138281d513998789163237e83bb893aeda43336 Reviewed-on: https://go-review.googlesource.com/17892 Reviewed-by: Russ Cox Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot --- src/net/http/request.go | 63 ++++++++++++++++++++++++++++++++++++-- src/net/http/serve_test.go | 58 ++++++++++++++++++++++++++++++++--- src/net/http/server.go | 27 ++++++++++++++-- 3 files changed, 139 insertions(+), 9 deletions(-) diff --git a/src/net/http/request.go b/src/net/http/request.go index 9f740422ede46..01575f33a5ae9 100644 --- a/src/net/http/request.go +++ b/src/net/http/request.go @@ -689,8 +689,9 @@ func putTextprotoReader(r *textproto.Reader) { } // ReadRequest reads and parses an incoming request from b. -func ReadRequest(b *bufio.Reader) (req *Request, err error) { +func ReadRequest(b *bufio.Reader) (req *Request, err error) { return readRequest(b, true) } +func readRequest(b *bufio.Reader, deleteHostHeader bool) (req *Request, err error) { tp := newTextprotoReader(b) req = new(Request) @@ -757,7 +758,9 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) { if req.Host == "" { req.Host = req.Header.get("Host") } - delete(req.Header, "Host") + if deleteHostHeader { + delete(req.Header, "Host") + } fixPragmaCacheControl(req.Header) @@ -1060,3 +1063,59 @@ func (r *Request) isReplayable() bool { r.Method == "OPTIONS" || r.Method == "TRACE") } + +func validHostHeader(h string) bool { + // The latests spec is actually this: + // + // http://tools.ietf.org/html/rfc7230#section-5.4 + // Host = uri-host [ ":" port ] + // + // Where uri-host is: + // http://tools.ietf.org/html/rfc3986#section-3.2.2 + // + // But we're going to be much more lenient for now and just + // search for any byte that's not a valid byte in any of those + // expressions. + for i := 0; i < len(h); i++ { + if !validHostByte[h[i]] { + return false + } + } + return true +} + +// See the validHostHeader comment. +var validHostByte = [256]bool{ + '0': true, '1': true, '2': true, '3': true, '4': true, '5': true, '6': true, '7': true, + '8': true, '9': true, + + 'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true, 'h': true, + 'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true, 'o': true, 'p': true, + 'q': true, 'r': true, 's': true, 't': true, 'u': true, 'v': true, 'w': true, 'x': true, + 'y': true, 'z': true, + + 'A': true, 'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true, + 'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true, 'P': true, + 'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true, 'W': true, 'X': true, + 'Y': true, 'Z': true, + + '!': true, // sub-delims + '$': true, // sub-delims + '%': true, // pct-encoded (and used in IPv6 zones) + '&': true, // sub-delims + '(': true, // sub-delims + ')': true, // sub-delims + '*': true, // sub-delims + '+': true, // sub-delims + ',': true, // sub-delims + '-': true, // unreserved + '.': true, // unreserved + ':': true, // IPv6address + Host expression's optional port + ';': true, // sub-delims + '=': true, // sub-delims + '[': true, + '\'': true, // sub-delims + ']': true, + '_': true, // unreserved + '~': true, // unreserved +} diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 3e84f2e11d3d7..31ba06a267497 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -2201,7 +2201,7 @@ func TestClientWriteShutdown(t *testing.T) { // buffered before chunk headers are added, not after chunk headers. func TestServerBufferedChunking(t *testing.T) { conn := new(testConn) - conn.readBuf.Write([]byte("GET / HTTP/1.1\r\n\r\n")) + conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n")) conn.closec = make(chan bool, 1) ls := &oneConnListener{conn} go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { @@ -2934,9 +2934,9 @@ func TestCodesPreventingContentTypeAndBody(t *testing.T) { "GET / HTTP/1.0", "GET /header HTTP/1.0", "GET /more HTTP/1.0", - "GET / HTTP/1.1", - "GET /header HTTP/1.1", - "GET /more HTTP/1.1", + "GET / HTTP/1.1\nHost: foo", + "GET /header HTTP/1.1\nHost: foo", + "GET /more HTTP/1.1\nHost: foo", } { got := ht.rawResponse(req) wantStatus := fmt.Sprintf("%d %s", code, StatusText(code)) @@ -2957,7 +2957,7 @@ func TestContentTypeOkayOn204(t *testing.T) { w.Header().Set("Content-Type", "foo/bar") w.WriteHeader(204) })) - got := ht.rawResponse("GET / HTTP/1.1") + got := ht.rawResponse("GET / HTTP/1.1\nHost: foo") if !strings.Contains(got, "Content-Type: foo/bar") { t.Errorf("Response = %q; want Content-Type: foo/bar", got) } @@ -3628,6 +3628,54 @@ func testHandlerSetsBodyNil(t *testing.T, h2 bool) { } } +// Test that we validate the Host header. +func TestServerValidatesHostHeader(t *testing.T) { + tests := []struct { + proto string + host string + want int + }{ + {"HTTP/1.1", "", 400}, + {"HTTP/1.1", "Host: \r\n", 200}, + {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200}, + {"HTTP/1.1", "Host: foo.com\r\n", 200}, + {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200}, + {"HTTP/1.1", "Host: foo.com:80\r\n", 200}, + {"HTTP/1.1", "Host: ::1\r\n", 200}, + {"HTTP/1.1", "Host: [::1]\r\n", 200}, // questionable without port, but accept it + {"HTTP/1.1", "Host: [::1]:80\r\n", 200}, + {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200}, + {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200}, + {"HTTP/1.1", "Host: \x06\r\n", 400}, + {"HTTP/1.1", "Host: \xff\r\n", 400}, + {"HTTP/1.1", "Host: {\r\n", 400}, + {"HTTP/1.1", "Host: }\r\n", 400}, + {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400}, + + // HTTP/1.0 can lack a host header, but if present + // must play by the rules too: + {"HTTP/1.0", "", 200}, + {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400}, + {"HTTP/1.0", "Host: \xff\r\n", 400}, + } + for _, tt := range tests { + conn := &testConn{closec: make(chan bool)} + io.WriteString(&conn.readBuf, "GET / "+tt.proto+"\r\n"+tt.host+"\r\n") + + ln := &oneConnListener{conn} + go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {})) + <-conn.closec + res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil) + if err != nil { + t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res) + continue + } + if res.StatusCode != tt.want { + t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want) + } + } +} + func BenchmarkClientServer(b *testing.B) { b.ReportAllocs() b.StopTimer() diff --git a/src/net/http/server.go b/src/net/http/server.go index cd5f9cf34f45e..a00085c2498df 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -686,7 +686,7 @@ func (c *conn) readRequest() (w *response, err error) { peek, _ := c.bufr.Peek(4) // ReadRequest will get err below c.bufr.Discard(numLeadingCRorLF(peek)) } - req, err := ReadRequest(c.bufr) + req, err := readRequest(c.bufr, false) c.mu.Unlock() if err != nil { if c.r.hitReadLimit() { @@ -697,6 +697,18 @@ func (c *conn) readRequest() (w *response, err error) { c.lastMethod = req.Method c.r.setInfiniteReadLimit() + hosts, haveHost := req.Header["Host"] + if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) { + return nil, badRequestError("missing required Host header") + } + if len(hosts) > 1 { + return nil, badRequestError("too many Host headers") + } + if len(hosts) == 1 && !validHostHeader(hosts[0]) { + return nil, badRequestError("malformed Host header") + } + delete(req.Header, "Host") + req.RemoteAddr = c.remoteAddr req.TLS = c.tlsState if body, ok := req.Body.(*body); ok { @@ -1334,6 +1346,13 @@ func (c *conn) setState(nc net.Conn, state ConnState) { } } +// badRequestError is a literal string (used by in the server in HTML, +// unescaped) to tell the user why their request was bad. It should +// be plain text without user info or other embeddded errors. +type badRequestError string + +func (e badRequestError) Error() string { return "Bad Request: " + string(e) } + // Serve a new connection. func (c *conn) serve() { c.remoteAddr = c.rwc.RemoteAddr().String() @@ -1399,7 +1418,11 @@ func (c *conn) serve() { if neterr, ok := err.(net.Error); ok && neterr.Timeout() { return // don't reply } - io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n400 Bad Request") + var publicErr string + if v, ok := err.(badRequestError); ok { + publicErr = ": " + string(v) + } + io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n400 Bad Request"+publicErr) return }