Skip to content

Commit

Permalink
Fix bad and overflow redirections #50
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergio Andres Virviescas Santana committed Apr 3, 2021
1 parent cc03ff3 commit 030ebe2
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 109 deletions.
161 changes: 65 additions & 96 deletions radix/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,128 +274,97 @@ func (n *node) add(path, fullPath string, handler fasthttp.RequestHandler) (*nod
}

func (n *node) getFromChild(path string, ctx *fasthttp.RequestCtx) (fasthttp.RequestHandler, bool) {
var parent *node

parentIndex, childIndex := 0, 0

walk:
for {
for _, child := range n.children[childIndex:] {
childIndex++
for _, child := range n.children {
switch child.nType {
case static:

switch child.nType {
case static:
// Checks if the first byte is equal
// It's faster than compare strings
if path[0] != child.path[0] {
continue
}

// Checks if the first byte is equal
// It's faster than compare strings
if path[0] != child.path[0] {
if len(path) > len(child.path) {
if path[:len(child.path)] != child.path {
continue
}

if len(path) > len(child.path) {
if path[:len(child.path)] != child.path {
continue
}

path = path[len(child.path):]

parent = n
n = child

parentIndex = childIndex
childIndex = 0

continue walk

} else if path == child.path {
switch {
case child.tsr:
return nil, true
case child.handler != nil:
return child.handler, false
case child.wildcard != nil:
if ctx != nil {
ctx.SetUserValue(child.wildcard.paramKey, "")
}

return child.wildcard.handler, false
h, tsr := child.getFromChild(path[len(child.path):], ctx)
if h != nil || tsr {
return h, tsr
}
} else if path == child.path {
switch {
case child.tsr:
return nil, true
case child.handler != nil:
return child.handler, false
case child.wildcard != nil:
if ctx != nil {
ctx.SetUserValue(child.wildcard.paramKey, "")
}

return nil, false
return child.wildcard.handler, false
}

case param:
end := segmentEndIndex(path, false)
values := []string{copyString(path[:end])}

if child.paramRegex != nil {
end, values = child.findEndIndexAndValues(path[:end])
if end == -1 {
continue
}
}
return nil, false
}

if len(path) > end {
h, tsr := child.getFromChild(path[end:], ctx)
if tsr {
return nil, tsr
} else if h != nil {
if ctx != nil {
for i, key := range child.paramKeys {
ctx.SetUserValue(key, values[i])
}
}
case param:
end := segmentEndIndex(path, false)
values := []string{copyString(path[:end])}

return h, false
}
if child.paramRegex != nil {
end, values = child.findEndIndexAndValues(path[:end])
if end == -1 {
continue
}
}

} else if len(path) == end {
switch {
case child.tsr:
return nil, true
case child.handler == nil:
// try another child
continue
case ctx != nil:
if len(path) > end {
h, tsr := child.getFromChild(path[end:], ctx)
if tsr {
return nil, tsr
} else if h != nil {
if ctx != nil {
for i, key := range child.paramKeys {
ctx.SetUserValue(key, values[i])
}
}

return child.handler, false
return h, false
}

default:
panic("invalid node type")
}

}

// Go back and continue with the remaining children of the parent
// to try to discover the correct child node
// if the parent has a child node of type param
//
// See: https://github.com/fasthttp/router/issues/37
if parent != nil && parent.hasWildChild && len(parent.children[parentIndex:]) > 0 {
path = n.path + path
childIndex = parentIndex
} else if len(path) == end {
switch {
case child.tsr:
return nil, true
case child.handler == nil:
// try another child
continue
case ctx != nil:
for i, key := range child.paramKeys {
ctx.SetUserValue(key, values[i])
}
}

n = parent
parent = nil
return child.handler, false
}

continue walk
default:
panic("invalid node type")
}
}

if n.wildcard != nil {
if ctx != nil {
ctx.SetUserValue(n.wildcard.paramKey, copyString(path))
}

return n.wildcard.handler, false
if n.wildcard != nil {
if ctx != nil {
ctx.SetUserValue(n.wildcard.paramKey, copyString(path))
}

return nil, false
return n.wildcard.handler, false
}

return nil, false
}

func (n *node) find(path string, buf *bytebufferpool.ByteBuffer) (bool, bool) {
Expand Down
4 changes: 3 additions & 1 deletion radix/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ func TestTreeAddAndGet(t *testing.T) {
"/α",
"/β",
"/hello/test",
"/hello/tooth",
"/hello/{name}",
}

Expand All @@ -137,7 +138,8 @@ func TestTreeAddAndGet(t *testing.T) {
{"/α", false, "/α", nil},
{"/β", false, "/β", nil},
{"/hello/test", false, "/hello/test", nil},
{"/hello/test1", false, "/hello/{name}", map[string]interface{}{"name": "test1"}},
{"/hello/tooth", false, "/hello/tooth", nil},
{"/hello/testastretta", false, "/hello/{name}", map[string]interface{}{"name": "testastretta"}},
{"/hello/tes", false, "/hello/{name}", map[string]interface{}{"name": "tes"}},
{"/hello/test/bye", true, "", nil},
})
Expand Down
12 changes: 1 addition & 11 deletions radix/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,7 @@ func (t *Tree) Get(path string, ctx *fasthttp.RequestCtx) (fasthttp.RequestHandl

path = path[len(t.root.path):]

handler, tsr := t.root.getFromChild(path, ctx)
if handler == nil {
if t.root.wildcard != nil {
if ctx != nil {
ctx.SetUserValue(t.root.wildcard.paramKey, path)
}
return t.root.wildcard.handler, false
}
}

return handler, tsr
return t.root.getFromChild(path, ctx)

} else if path == t.root.path {
switch {
Expand Down
1 change: 0 additions & 1 deletion radix/tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ func Test_AddWithParam(t *testing.T) {

// Not found
testHandlerAndParams(t, tree, "/api/prefixV1_1111_sufix/fake", nil, false, nil)

}

func Test_TreeRootWildcard(t *testing.T) {
Expand Down
25 changes: 25 additions & 0 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1151,3 +1151,28 @@ func BenchmarkRouterRedirectTrailingSlash(b *testing.B) {
r.Handler(ctx)
}
}

func Benchmark_Get(b *testing.B) {
handler := func(ctx *fasthttp.RequestCtx) {}

r := New()

r.GET("/", handler)
r.GET("/plaintext", handler)
r.GET("/json", handler)
r.GET("/fortune", handler)
r.GET("/fortune-quick", handler)
r.GET("/db", handler)
r.GET("/queries", handler)
r.GET("/update", handler)

ctx := new(fasthttp.RequestCtx)
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/update")

b.ResetTimer()

for i := 0; i < b.N; i++ {
r.Handler(ctx)
}
}

0 comments on commit 030ebe2

Please sign in to comment.