diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index 5baa3ecf5..036df76a5 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -259,7 +259,7 @@ func TestNotAbsolutePath(t *testing.T) { go engine.Run() time.Sleep(200 * time.Microsecond) - s := "POST ?a=b HTTP/1.1\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" + s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr := mock.NewZeroCopyReader(s) ctx := app.NewContext(0) @@ -270,7 +270,7 @@ func TestNotAbsolutePath(t *testing.T) { assert.DeepEqual(t, consts.StatusOK, ctx.Response.StatusCode()) assert.DeepEqual(t, ctx.Request.Body(), ctx.Response.Body()) - s = "POST a?a=b HTTP/1.1\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" + s = "POST a?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr = mock.NewZeroCopyReader(s) ctx = app.NewContext(0) @@ -291,7 +291,7 @@ func TestNotAbsolutePathWithRawPath(t *testing.T) { go engine.Run() time.Sleep(200 * time.Microsecond) - s := "POST ?a=b HTTP/1.1\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" + s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr := mock.NewZeroCopyReader(s) ctx := app.NewContext(0) @@ -302,7 +302,7 @@ func TestNotAbsolutePathWithRawPath(t *testing.T) { assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode()) assert.DeepEqual(t, default400Body, ctx.Response.Body()) - s = "POST a?a=b HTTP/1.1\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" + s = "POST a?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr = mock.NewZeroCopyReader(s) ctx = app.NewContext(0) diff --git a/pkg/protocol/uri.go b/pkg/protocol/uri.go index 6bba7a9f2..4fd8788e5 100644 --- a/pkg/protocol/uri.go +++ b/pkg/protocol/uri.go @@ -49,6 +49,7 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/internal/nocopy" + "github.com/cloudwego/hertz/pkg/common/hlog" ) // AcquireURI returns an empty URI instance from the pool. @@ -373,6 +374,34 @@ func (u *URI) Parse(host, uri []byte) { u.parse(host, uri, false) } +// Maybe rawURL is of the form scheme:path. +// (Scheme must be [a-zA-Z][a-zA-Z0-9+-.]*) +// If so, return scheme, path; else return nil, rawURL. +func getScheme(rawURL []byte) (scheme, path []byte) { + for i := 0; i < len(rawURL); i++ { + c := rawURL[i] + switch { + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z': + // do nothing + case '0' <= c && c <= '9' || c == '+' || c == '-' || c == '.': + if i == 0 { + return nil, rawURL + } + case c == ':': + if i == 0 { + hlog.Errorf("error happened when try to parse the rawURL(%s): missing protocol scheme", rawURL) + return nil, nil + } + return rawURL[:i], rawURL[i+1:] + default: + // we have encountered an invalid character, + // so there is no valid scheme + return nil, rawURL + } + } + return nil, rawURL +} + func (u *URI) parse(host, uri []byte, isTLS bool) { u.Reset() @@ -455,20 +484,14 @@ func stringContainsCTLByte(s []byte) bool { } func splitHostURI(host, uri []byte) ([]byte, []byte, []byte) { - n := bytes.Index(uri, bytestr.StrSlashSlash) - if n < 0 { - return bytestr.StrHTTP, host, uri - } - scheme := uri[:n] - if bytes.IndexByte(scheme, '/') >= 0 { + scheme, path := getScheme(uri) + + if scheme == nil { return bytestr.StrHTTP, host, uri } - if len(scheme) > 0 && scheme[len(scheme)-1] == ':' { - scheme = scheme[:len(scheme)-1] - } - n += len(bytestr.StrSlashSlash) - uri = uri[n:] - n = bytes.IndexByte(uri, '/') + + uri = path[len(bytestr.StrSlashSlash):] + n := bytes.IndexByte(uri, '/') if n < 0 { // A hack for bogus urls like foobar.com?a=b without // slash after host. @@ -587,6 +610,9 @@ func (u *URI) updateBytes(newURI, buf []byte) []byte { if len(u.scheme) > 0 { schemeOriginal = append([]byte(nil), u.scheme...) } + if n == 0 { + newURI = bytes.Join([][]byte{u.scheme, bytestr.StrColon, newURI}, nil) + } u.Parse(nil, newURI) if len(schemeOriginal) > 0 && len(u.scheme) == 0 { u.scheme = append(u.scheme[:0], schemeOriginal...) diff --git a/pkg/protocol/uri_test.go b/pkg/protocol/uri_test.go index e245faf26..3578dda34 100644 --- a/pkg/protocol/uri_test.go +++ b/pkg/protocol/uri_test.go @@ -42,6 +42,7 @@ package protocol import ( + "bytes" "path/filepath" "reflect" "runtime" @@ -468,3 +469,35 @@ func TestParseURI(t *testing.T) { uri := string(ParseURI(expectURI).FullURI()) assert.DeepEqual(t, expectURI, uri) } + +func TestSplitHostURI(t *testing.T) { + cases := []struct { + host, uri []byte + wantScheme, wantHost, wantPath []byte + }{ + { + []byte("example.com"), []byte("/foobar"), + []byte("http"), []byte("example.com"), []byte("/foobar"), + }, + { + []byte("example2.com"), []byte("http://example2.com"), + []byte("http"), []byte("example2.com"), []byte("/"), + }, + { + []byte("example2.com"), []byte("http://example3.com"), + []byte("http"), []byte("example3.com"), []byte("/"), + }, + { + []byte("example3.com"), []byte("https://foobar.com?a=b"), + []byte("https"), []byte("foobar.com"), []byte("?a=b"), + }, + } + + for _, c := range cases { + gotScheme, gotHost, gotPath := splitHostURI(c.host, c.uri) + if !bytes.Equal(gotScheme, c.wantScheme) || !bytes.Equal(gotHost, c.wantHost) || !bytes.Equal(gotPath, c.wantPath) { + t.Errorf("splitHostURI(%q, %q) == (%q, %q, %q), want (%q, %q, %q)", + c.host, c.uri, gotScheme, gotHost, gotPath, c.wantScheme, c.wantHost, c.wantPath) + } + } +} diff --git a/pkg/route/engine.go b/pkg/route/engine.go index 168c79b48..6ab2de670 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -89,6 +89,8 @@ var ( default404Body = []byte("404 page not found") default405Body = []byte("405 method not allowed") default400Body = []byte("400 bad request") + + requiredHostBody = []byte("missing required Host header") ) type hijackConn struct { @@ -721,6 +723,13 @@ func (engine *Engine) ServeHTTP(c context.Context, ctx *app.RequestContext) { } rPath := string(ctx.Request.URI().Path()) + + // align with https://datatracker.ietf.org/doc/html/rfc2616#section-5.2 + if len(ctx.Request.Host()) == 0 && ctx.Request.Header.IsHTTP11() && bytesconv.B2s(ctx.Request.Method()) != consts.MethodConnect { + serveError(c, ctx, consts.StatusBadRequest, requiredHostBody) + return + } + httpMethod := bytesconv.B2s(ctx.Request.Header.Method()) unescape := false if engine.options.UseRawPath {