Skip to content

Commit

Permalink
feat(group): add path validations
Browse files Browse the repository at this point in the history
  • Loading branch information
savsgio committed Feb 28, 2023
1 parent c3fcfb3 commit 8341673
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 2 deletions.
32 changes: 32 additions & 0 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,58 +5,84 @@ import "github.com/valyala/fasthttp"
// Group returns a new group.
// Path auto-correction, including trailing slashes, is enabled by default.
func (g *Group) Group(path string) *Group {
validatePath(path)

if len(g.prefix) > 0 && path == "/" {
return g
}

return g.router.Group(g.prefix + path)
}

// GET is a shortcut for group.Handle(fasthttp.MethodGet, path, handler)
func (g *Group) GET(path string, handler fasthttp.RequestHandler) {
validatePath(path)

g.router.GET(g.prefix+path, handler)
}

// HEAD is a shortcut for group.Handle(fasthttp.MethodHead, path, handler)
func (g *Group) HEAD(path string, handler fasthttp.RequestHandler) {
validatePath(path)

g.router.HEAD(g.prefix+path, handler)
}

// POST is a shortcut for group.Handle(fasthttp.MethodPost, path, handler)
func (g *Group) POST(path string, handler fasthttp.RequestHandler) {
validatePath(path)

g.router.POST(g.prefix+path, handler)
}

// PUT is a shortcut for group.Handle(fasthttp.MethodPut, path, handler)
func (g *Group) PUT(path string, handler fasthttp.RequestHandler) {
validatePath(path)

g.router.PUT(g.prefix+path, handler)
}

// PATCH is a shortcut for group.Handle(fasthttp.MethodPatch, path, handler)
func (g *Group) PATCH(path string, handler fasthttp.RequestHandler) {
validatePath(path)

g.router.PATCH(g.prefix+path, handler)
}

// DELETE is a shortcut for group.Handle(fasthttp.MethodDelete, path, handler)
func (g *Group) DELETE(path string, handler fasthttp.RequestHandler) {
validatePath(path)

g.router.DELETE(g.prefix+path, handler)
}

// OPTIONS is a shortcut for group.Handle(fasthttp.MethodOptions, path, handler)
func (g *Group) CONNECT(path string, handler fasthttp.RequestHandler) {
validatePath(path)

g.router.CONNECT(g.prefix+path, handler)
}

// OPTIONS is a shortcut for group.Handle(fasthttp.MethodOptions, path, handler)
func (g *Group) OPTIONS(path string, handler fasthttp.RequestHandler) {
validatePath(path)

g.router.OPTIONS(g.prefix+path, handler)
}

// OPTIONS is a shortcut for group.Handle(fasthttp.MethodOptions, path, handler)
func (g *Group) TRACE(path string, handler fasthttp.RequestHandler) {
validatePath(path)

g.router.TRACE(g.prefix+path, handler)
}

// ANY is a shortcut for group.Handle(router.MethodWild, path, handler)
//
// WARNING: Use only for routes where the request method is not important
func (g *Group) ANY(path string, handler fasthttp.RequestHandler) {
validatePath(path)

g.router.ANY(g.prefix+path, handler)
}

Expand All @@ -70,6 +96,8 @@ func (g *Group) ANY(path string, handler fasthttp.RequestHandler) {
//
// router.ServeFiles("/src/{filepath:*}", "./")
func (g *Group) ServeFiles(path string, rootPath string) {
validatePath(path)

g.router.ServeFiles(g.prefix+path, rootPath)
}

Expand All @@ -84,6 +112,8 @@ func (g *Group) ServeFiles(path string, rootPath string) {
//
// router.ServeFilesCustom("/src/{filepath:*}", *customFS)
func (g *Group) ServeFilesCustom(path string, fs *fasthttp.FS) {
validatePath(path)

g.router.ServeFilesCustom(g.prefix+path, fs)
}

Expand All @@ -96,5 +126,7 @@ func (g *Group) ServeFilesCustom(path string, fs *fasthttp.FS) {
// frequently used, non-standardized or custom methods (e.g. for internal
// communication with a proxy).
func (g *Group) Handle(method, path string, handler fasthttp.RequestHandler) {
validatePath(path)

g.router.Handle(method, g.prefix+path, handler)
}
67 changes: 67 additions & 0 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,61 @@ package router

import (
"bufio"
"reflect"
"strings"
"testing"

"github.com/valyala/fasthttp"
)

type routerGrouper interface {
Group(string) *Group
ServeFiles(path string, rootPath string)
ServeFilesCustom(path string, fs *fasthttp.FS)
}

func assertGroup(t *testing.T, gs ...routerGrouper) {
for i, g := range gs {
g2 := g.Group("/")

v1 := reflect.ValueOf(g)
v2 := reflect.ValueOf(g2)

if v1.String() != v2.String() { // router -> group
if v1.Pointer() == v2.Pointer() {
t.Errorf("[%d] equal pointers: %p == %p", i, g, g2)
}
} else { // group -> subgroup
if v1.Pointer() != v2.Pointer() {
t.Errorf("[%d] mismatch pointers: %p != %p", i, g, g2)
}
}

if err := catchPanic(func() { g.Group("v999") }); err == nil {
t.Error("an error was expected when a path does not begin with slash")
}

if err := catchPanic(func() { g.Group("/v999/") }); err == nil {
t.Error("an error was expected when a path has a trailing slash")
}

if err := catchPanic(func() { g.Group("") }); err == nil {
t.Error("an error was expected with an empty path")
}

if err := catchPanic(func() { g.ServeFiles("static/{filepath:*}", "./") }); err == nil {
t.Error("an error was expected when a path does not begin with slash")
}

if err := catchPanic(func() {
g.ServeFilesCustom("", &fasthttp.FS{Root: "./"})
}); err == nil {
t.Error("an error was expected with an empty path")
}

}
}

func TestGroup(t *testing.T) {
r1 := New()
r2 := r1.Group("/boo")
Expand All @@ -16,6 +65,8 @@ func TestGroup(t *testing.T) {
r5 := r4.Group("/foo")
r6 := r5.Group("/foo")

assertGroup(t, r1, r2, r3, r4, r5, r6)

hit := false

r1.POST("/foo", func(ctx *fasthttp.RequestCtx) {
Expand Down Expand Up @@ -113,6 +164,14 @@ func TestGroup_shortcutsAndHandle(t *testing.T) {

for _, fn := range shortcuts {
fn("/bar", func(_ *fasthttp.RequestCtx) {})

if err := catchPanic(func() { fn("buzz", func(_ *fasthttp.RequestCtx) {}) }); err == nil {
t.Error("an error was expected when a path does not begin with slash")
}

if err := catchPanic(func() { fn("", func(_ *fasthttp.RequestCtx) {}) }); err == nil {
t.Error("an error was expected with an empty path")
}
}

methods := httpMethods[:len(httpMethods)-1] // Avoid customs methods
Expand All @@ -128,6 +187,14 @@ func TestGroup_shortcutsAndHandle(t *testing.T) {
for _, method := range httpMethods {
g2.Handle(method, "/bar", func(_ *fasthttp.RequestCtx) {})

if err := catchPanic(func() { g2.Handle(method, "buzz", func(_ *fasthttp.RequestCtx) {}) }); err == nil {
t.Error("an error was expected when a path does not begin with slash")
}

if err := catchPanic(func() { g2.Handle(method, "", func(_ *fasthttp.RequestCtx) {}) }); err == nil {
t.Error("an error was expected with an empty path")
}

h, _ := r.Lookup(method, "/v1/foo/bar", nil)
if h == nil {
t.Errorf("Bad shorcurt")
Expand Down
10 changes: 8 additions & 2 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ func New() *Router {
// Group returns a new group.
// Path auto-correction, including trailing slashes, is enabled by default.
func (r *Router) Group(path string) *Group {
validatePath(path)

if path != "/" && strings.HasSuffix(path, "/") {
panic("group path must not end with a trailing slash")
}

return &Group{
router: r,
prefix: path,
Expand Down Expand Up @@ -216,10 +222,10 @@ func (r *Router) Handle(method, path string, handler fasthttp.RequestHandler) {
switch {
case len(method) == 0:
panic("method must not be empty")
case len(path) < 1 || path[0] != '/':
panic("path must begin with '/' in path '" + path + "'")
case handler == nil:
panic("handler must not be nil")
default:
validatePath(path)
}

r.registeredPaths[method] = append(r.registeredPaths[method], path)
Expand Down
10 changes: 10 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package router

import "strings"

func validatePath(path string) {
switch {
case len(path) == 0 || !strings.HasPrefix(path, "/"):
panic("path must begin with '/' in path '" + path + "'")
}
}
17 changes: 17 additions & 0 deletions utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package router

import "testing"

func Test_validatePath(t *testing.T) {
if err := catchPanic(func() { validatePath("") }); err == nil {
t.Error("an error was expected with an empty path")
}

if err := catchPanic(func() { validatePath("foo") }); err == nil {
t.Error("an error was expected with an empty path")
}

if err := catchPanic(func() { validatePath("/foo") }); err != nil {
t.Errorf("unexpected error: %v", err)
}
}

0 comments on commit 8341673

Please sign in to comment.