Skip to content

Commit

Permalink
fix: improve url parsing (#990)
Browse files Browse the repository at this point in the history
  • Loading branch information
welkeyever committed Nov 13, 2023
1 parent 66f5338 commit 1279f53
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 16 deletions.
8 changes: 4 additions & 4 deletions pkg/app/server/hertz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ func TestNotAbsolutePath(t *testing.T) {
go engine.Run()
time.Sleep(200 * time.Microsecond)

s := "POST ?a=b HTTP/1.1\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
zr := mock.NewZeroCopyReader(s)

ctx := app.NewContext(0)
Expand All @@ -270,7 +270,7 @@ func TestNotAbsolutePath(t *testing.T) {
assert.DeepEqual(t, consts.StatusOK, ctx.Response.StatusCode())
assert.DeepEqual(t, ctx.Request.Body(), ctx.Response.Body())

s = "POST a?a=b HTTP/1.1\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
s = "POST a?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
zr = mock.NewZeroCopyReader(s)

ctx = app.NewContext(0)
Expand All @@ -291,7 +291,7 @@ func TestNotAbsolutePathWithRawPath(t *testing.T) {
go engine.Run()
time.Sleep(200 * time.Microsecond)

s := "POST ?a=b HTTP/1.1\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
zr := mock.NewZeroCopyReader(s)

ctx := app.NewContext(0)
Expand All @@ -302,7 +302,7 @@ func TestNotAbsolutePathWithRawPath(t *testing.T) {
assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode())
assert.DeepEqual(t, default400Body, ctx.Response.Body())

s = "POST a?a=b HTTP/1.1\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
s = "POST a?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
zr = mock.NewZeroCopyReader(s)

ctx = app.NewContext(0)
Expand Down
50 changes: 38 additions & 12 deletions pkg/protocol/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import (
"github.com/cloudwego/hertz/internal/bytesconv"
"github.com/cloudwego/hertz/internal/bytestr"
"github.com/cloudwego/hertz/internal/nocopy"
"github.com/cloudwego/hertz/pkg/common/hlog"
)

// AcquireURI returns an empty URI instance from the pool.
Expand Down Expand Up @@ -373,6 +374,34 @@ func (u *URI) Parse(host, uri []byte) {
u.parse(host, uri, false)
}

// Maybe rawURL is of the form scheme:path.
// (Scheme must be [a-zA-Z][a-zA-Z0-9+-.]*)
// If so, return scheme, path; else return nil, rawURL.
func getScheme(rawURL []byte) (scheme, path []byte) {
for i := 0; i < len(rawURL); i++ {
c := rawURL[i]
switch {
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z':
// do nothing
case '0' <= c && c <= '9' || c == '+' || c == '-' || c == '.':
if i == 0 {
return nil, rawURL
}
case c == ':':
if i == 0 {
hlog.Errorf("error happened when try to parse the rawURL(%s): missing protocol scheme", rawURL)
return nil, nil
}
return rawURL[:i], rawURL[i+1:]
default:
// we have encountered an invalid character,
// so there is no valid scheme
return nil, rawURL
}
}
return nil, rawURL
}

func (u *URI) parse(host, uri []byte, isTLS bool) {
u.Reset()

Expand Down Expand Up @@ -455,20 +484,14 @@ func stringContainsCTLByte(s []byte) bool {
}

func splitHostURI(host, uri []byte) ([]byte, []byte, []byte) {
n := bytes.Index(uri, bytestr.StrSlashSlash)
if n < 0 {
return bytestr.StrHTTP, host, uri
}
scheme := uri[:n]
if bytes.IndexByte(scheme, '/') >= 0 {
scheme, path := getScheme(uri)

if scheme == nil {
return bytestr.StrHTTP, host, uri
}
if len(scheme) > 0 && scheme[len(scheme)-1] == ':' {
scheme = scheme[:len(scheme)-1]
}
n += len(bytestr.StrSlashSlash)
uri = uri[n:]
n = bytes.IndexByte(uri, '/')

uri = path[len(bytestr.StrSlashSlash):]
n := bytes.IndexByte(uri, '/')
if n < 0 {
// A hack for bogus urls like foobar.com?a=b without
// slash after host.
Expand Down Expand Up @@ -587,6 +610,9 @@ func (u *URI) updateBytes(newURI, buf []byte) []byte {
if len(u.scheme) > 0 {
schemeOriginal = append([]byte(nil), u.scheme...)
}
if n == 0 {
newURI = bytes.Join([][]byte{u.scheme, bytestr.StrColon, newURI}, nil)
}
u.Parse(nil, newURI)
if len(schemeOriginal) > 0 && len(u.scheme) == 0 {
u.scheme = append(u.scheme[:0], schemeOriginal...)
Expand Down
33 changes: 33 additions & 0 deletions pkg/protocol/uri_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
package protocol

import (
"bytes"
"path/filepath"
"reflect"
"runtime"
Expand Down Expand Up @@ -468,3 +469,35 @@ func TestParseURI(t *testing.T) {
uri := string(ParseURI(expectURI).FullURI())
assert.DeepEqual(t, expectURI, uri)
}

func TestSplitHostURI(t *testing.T) {
cases := []struct {
host, uri []byte
wantScheme, wantHost, wantPath []byte
}{
{
[]byte("example.com"), []byte("/foobar"),
[]byte("http"), []byte("example.com"), []byte("/foobar"),
},
{
[]byte("example2.com"), []byte("http://example2.com"),
[]byte("http"), []byte("example2.com"), []byte("/"),
},
{
[]byte("example2.com"), []byte("http://example3.com"),
[]byte("http"), []byte("example3.com"), []byte("/"),
},
{
[]byte("example3.com"), []byte("https://foobar.com?a=b"),
[]byte("https"), []byte("foobar.com"), []byte("?a=b"),
},
}

for _, c := range cases {
gotScheme, gotHost, gotPath := splitHostURI(c.host, c.uri)
if !bytes.Equal(gotScheme, c.wantScheme) || !bytes.Equal(gotHost, c.wantHost) || !bytes.Equal(gotPath, c.wantPath) {
t.Errorf("splitHostURI(%q, %q) == (%q, %q, %q), want (%q, %q, %q)",
c.host, c.uri, gotScheme, gotHost, gotPath, c.wantScheme, c.wantHost, c.wantPath)
}
}
}
9 changes: 9 additions & 0 deletions pkg/route/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ var (
default404Body = []byte("404 page not found")
default405Body = []byte("405 method not allowed")
default400Body = []byte("400 bad request")

requiredHostBody = []byte("missing required Host header")
)

type hijackConn struct {
Expand Down Expand Up @@ -721,6 +723,13 @@ func (engine *Engine) ServeHTTP(c context.Context, ctx *app.RequestContext) {
}

rPath := string(ctx.Request.URI().Path())

// align with https://datatracker.ietf.org/doc/html/rfc2616#section-5.2
if len(ctx.Request.Host()) == 0 && ctx.Request.Header.IsHTTP11() && bytesconv.B2s(ctx.Request.Method()) != consts.MethodConnect {
serveError(c, ctx, consts.StatusBadRequest, requiredHostBody)
return
}

httpMethod := bytesconv.B2s(ctx.Request.Header.Method())
unescape := false
if engine.options.UseRawPath {
Expand Down

0 comments on commit 1279f53

Please sign in to comment.