Skip to content

Commit

Permalink
router: support Routes accepting multiple strings as method (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
unknwon committed Jun 1, 2022
1 parent fd701bd commit 7e355eb
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 18 deletions.
3 changes: 2 additions & 1 deletion handler.go
Expand Up @@ -5,6 +5,7 @@
package flamego

import (
"fmt"
"net/http"
"reflect"

Expand Down Expand Up @@ -52,7 +53,7 @@ func (invoke teapotInvoker) Invoke([]interface{}) ([]reflect.Value, error) {
// gain up to 3x performance improvement.
func validateAndWrapHandler(h Handler, wrapper func(Handler) Handler) Handler {
if reflect.TypeOf(h).Kind() != reflect.Func {
panic("handler must be a callable function")
panic(fmt.Sprintf("handler must be a callable function, but got %T", h))
}

if inject.IsFastInvoker(h) {
Expand Down
25 changes: 21 additions & 4 deletions router.go
Expand Up @@ -50,10 +50,12 @@ type Router interface {
Trace(routePath string, handlers ...Handler) *Route
// Any is a shortcut for `r.Route("*", routePath, handlers)`.
Any(routePath string, handlers ...Handler) *Route
// Routes is a shortcut of adding same handlers for different HTTP methods.
// Routes is a shortcut of adding route with same list of handlers for different
// HTTP methods.
//
// Example:
// f.Routes("/", "GET,POST", handlers)
// f.Routes("/", http.MethodGet, http.MethodPost, handlers...)
// f.Routes("/", "GET,POST", handlers...)
Routes(routePath, methods string, handlers ...Handler) *Route
// NotFound configures a http.HandlerFunc to be called when no matching route is
// found. When it is not set, http.NotFound is used. Be sure to set
Expand Down Expand Up @@ -272,9 +274,24 @@ func (r *router) Routes(routePath, methods string, handlers ...Handler) *Route {
panic("empty methods")
}

var route *Route
var ms []string
for _, m := range strings.Split(methods, ",") {
route = r.Route(strings.TrimSpace(m), routePath, handlers)
ms = append(ms, strings.TrimSpace(m))
}

// Collect methods from handlers if they are strings
for i, h := range handlers {
m, ok := h.(string)
if !ok {
handlers = handlers[i:]
break
}
ms = append(ms, m)
}

var route *Route
for _, m := range ms {
route = r.Route(m, routePath, handlers)
}
return route
}
Expand Down
49 changes: 36 additions & 13 deletions router_test.go
Expand Up @@ -119,29 +119,52 @@ func TestRouter_Route(t *testing.T) {

func TestRouter_Routes(t *testing.T) {
ctx := newMockContext()
contextCreator := func(w http.ResponseWriter, r *http.Request, params route.Params, handlers []Handler, urlPath urlPather) internalContext {
contextCreator := func(_ http.ResponseWriter, _ *http.Request, params route.Params, _ []Handler, _ urlPather) internalContext {
ctx.MockContext.ParamFunc.SetDefaultHook(func(s string) string {
return params[s]
})
return ctx
}
r := newRouter(contextCreator)

r.Routes("/routes", "GET,POST", func() {})
t.Run("use single string", func(t *testing.T) {
r := newRouter(contextCreator)

for _, m := range []string{http.MethodGet, http.MethodPost} {
gotRoute := ""
ctx.run_ = func() { gotRoute = ctx.Param("route") }
r.Routes("/routes", "GET,POST", func() {})

resp := httptest.NewRecorder()
req, err := http.NewRequest(m, "/routes", nil)
assert.Nil(t, err)
for _, m := range []string{http.MethodGet, http.MethodPost} {
gotRoute := ""
ctx.run_ = func() { gotRoute = ctx.Param("route") }

r.ServeHTTP(resp, req)
resp := httptest.NewRecorder()
req, err := http.NewRequest(m, "/routes", nil)
assert.Nil(t, err)

assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, "/routes", gotRoute)
}
r.ServeHTTP(resp, req)

assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, "/routes", gotRoute)
}
})

t.Run("use multiple strings", func(t *testing.T) {
r := newRouter(contextCreator)

r.Routes("/routes", http.MethodGet, http.MethodPost, func() {})

for _, m := range []string{http.MethodGet, http.MethodPost} {
gotRoute := ""
ctx.run_ = func() { gotRoute = ctx.Param("route") }

resp := httptest.NewRecorder()
req, err := http.NewRequest(m, "/routes", nil)
assert.Nil(t, err)

r.ServeHTTP(resp, req)

assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, "/routes", gotRoute)
}
})
}

func TestRouter_AutoHead(t *testing.T) {
Expand Down

0 comments on commit 7e355eb

Please sign in to comment.