Skip to content

Commit

Permalink
feat: add new option r.handleFallbackRoute
Browse files Browse the repository at this point in the history
  • Loading branch information
inhere committed Aug 13, 2020
1 parent f28ae60 commit 859c064
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 88 deletions.
6 changes: 3 additions & 3 deletions dispatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ var internal405Handler HandlerFunc = func(c *Context) {
// ServeHTTP for handle HTTP request, response data to client.
func (r *Router) ServeHTTP(res http.ResponseWriter, req *http.Request) {
// get new context
ctx := r.pool.Get().(*Context)
ctx := r.ctxPool.Get().(*Context)
ctx.Init(res, req)

// handle HTTP Request
Expand All @@ -125,14 +125,14 @@ func (r *Router) ServeHTTP(res http.ResponseWriter, req *http.Request) {
// reset data
ctx.Reset()
// release data
r.pool.Put(ctx)
r.ctxPool.Put(ctx)
}

// HandleContext handle a given context
func (r *Router) HandleContext(c *Context) {
c.Reset()
r.handleHTTPRequest(c)
r.pool.Put(c)
r.ctxPool.Put(c)
}

// handle HTTP Request
Expand Down
12 changes: 7 additions & 5 deletions parse_match.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,12 @@ func (r *Router) Match(method, path string) (result *MatchResult) {
}
}

// if has fallback route. router->Any("/*", handler)
key := method + "/*"
if route, ok := r.stableRoutes[key]; ok {
return newFoundResult(route, nil)
// handle fallback route. add by: router->Any("/*", handler)
if r.handleFallbackRoute {
key := method + "/*"
if route, ok := r.stableRoutes[key]; ok {
return newFoundResult(route, nil)
}
}

// handle method not allowed. will find allowed methods
Expand All @@ -170,7 +172,7 @@ func (r *Router) Match(method, path string) (result *MatchResult) {
}
}

// don't handle method not allowed, return not found
// don't handle method not allowed, will return not found
return
}

Expand Down
163 changes: 85 additions & 78 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ type methodRoutes map[string]routes
type Router struct {
// router name
Name string
// context pool
pool sync.Pool
// count routes
counter int
// context pool
ctxPool sync.Pool

// Static/stable/fixed routes, no path params.
// {
Expand Down Expand Up @@ -123,7 +123,7 @@ type Router struct {
// storage named routes. {"name": Route}
namedRoutes map[string]*Route
// TODO pool for storage MatchResult
// matchResultPool sync.Pool
matchResultPool sync.Pool

// some data for group
currentGroupPrefix string
Expand Down Expand Up @@ -155,6 +155,9 @@ type Router struct {
// maxMultipartMemory int64
// whether checks if another method is allowed for the current route. default is False
handleMethodNotAllowed bool
// whether handle the fallback route "/*"
// add by router->Any("/*", handler)
handleFallbackRoute bool

//
// Extends tools
Expand Down Expand Up @@ -190,14 +193,14 @@ func New(options ...func(*Router)) *Router {

// with some options
router.WithOptions(options...)
router.pool.New = func() interface{} {
router.ctxPool.New = func() interface{} {
return &Context{index: -1, router: router}
}

// match result pool
// router.matchResultPool.New = func() interface{} {
// return &MatchResult{Status: Found}
// }
router.matchResultPool.New = func() interface{} {
return &MatchResult{Status: Found}
}

return router
}
Expand Down Expand Up @@ -250,6 +253,11 @@ func StrictLastSlash(r *Router) {
// }
// }

// HandleFallbackRoute enable for the router
func HandleFallbackRoute(r *Router) {
r.handleFallbackRoute = true
}

// HandleMethodNotAllowed enable for the router
func HandleMethodNotAllowed(r *Router) {
r.handleMethodNotAllowed = true
Expand Down Expand Up @@ -350,76 +358,6 @@ func (r *Router) AddRoute(route *Route) *Route {
return route
}

func (r *Router) appendRoute(route *Route) {
// route check: methods, handler
route.goodInfo()
// format path and append group info
r.appendGroupInfo(route)
// print debug info
debugPrintRoute(route)

// has route name.
if route.name != "" {
r.namedRoutes[route.name] = route
}

// path is fixed(no param vars). eg. "/users"
if isFixedPath(route.path) {
path := route.path
for _, method := range route.methods {
key := method + path

r.counter++
r.stableRoutes[key] = route
}
return
}

// parsing route path with parameters
if first := r.parseParamRoute(route); first != "" {
for _, method := range route.methods {
key := method + first
rs, has := r.regularRoutes[key]
if !has {
rs = routes{}
}

r.counter++
r.regularRoutes[key] = append(rs, route)
}
return
}

// it's irregular param route
for _, method := range route.methods {
rs, has := r.irregularRoutes[method]
if has {
rs = routes{}
}

r.counter++
r.irregularRoutes[method] = append(rs, route)
}
}

func (r *Router) appendGroupInfo(route *Route) {
path := r.formatPath(route.path)
if r.currentGroupPrefix != "" {
path = r.formatPath(r.currentGroupPrefix + path)
}

if len(r.currentGroupHandlers) > 0 {
route.handlers = combineHandlers(r.currentGroupHandlers, route.handlers)

if finalSize := len(route.handlers); finalSize >= int(abortIndex) {
panicf("too many handlers(number: %d)", finalSize)
}
}

// re-set formatted path
route.path = path
}

// Group add an group routes, can with middleware
func (r *Router) Group(prefix string, register func(), middles ...HandlerFunc) {
prevPrefix := r.currentGroupPrefix
Expand Down Expand Up @@ -623,7 +561,6 @@ func (r *Router) Routes() (rs []RouteInfo) {
r.IterateRoutes(func(route *Route) {
rs = append(rs, route.Info())
})

return
}

Expand Down Expand Up @@ -691,3 +628,73 @@ func (r *Router) formatPath(path string) string {

return path
}

func (r *Router) appendRoute(route *Route) {
// route check: methods, handler
route.goodInfo()
// format path and append group info
r.appendGroupInfo(route)
// print debug info
debugPrintRoute(route)

// has route name.
if route.name != "" {
r.namedRoutes[route.name] = route
}

// path is fixed(no param vars). eg. "/users"
if isFixedPath(route.path) {
path := route.path
for _, method := range route.methods {
key := method + path

r.counter++
r.stableRoutes[key] = route
}
return
}

// parsing route path with parameters
if first := r.parseParamRoute(route); first != "" {
for _, method := range route.methods {
key := method + first
rs, has := r.regularRoutes[key]
if !has {
rs = routes{}
}

r.counter++
r.regularRoutes[key] = append(rs, route)
}
return
}

// it's irregular param route
for _, method := range route.methods {
rs, has := r.irregularRoutes[method]
if has {
rs = routes{}
}

r.counter++
r.irregularRoutes[method] = append(rs, route)
}
}

func (r *Router) appendGroupInfo(route *Route) {
path := r.formatPath(route.path)
if r.currentGroupPrefix != "" {
path = r.formatPath(r.currentGroupPrefix + path)
}

if len(r.currentGroupHandlers) > 0 {
route.handlers = combineHandlers(r.currentGroupHandlers, route.handlers)

if finalSize := len(route.handlers); finalSize >= int(abortIndex) {
panicf("too many handlers(number: %d)", finalSize)
}
}

// re-set formatted path
route.path = path
}
12 changes: 10 additions & 2 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,22 @@ func TestAddRoute(t *testing.T) {
ret = r.Match(POST, "/site")
is.Equal(Found, ret.Status)

// add fallback route
// fallback route(Need enable option: r.handleFallbackRoute)
r.Any("/*", emptyHandler)
for _, m := range anyMethods {
ret = r.Match(m, "/not-exist")
is.Equal(Found, ret.Status)
is.Equal(NotFound, ret.Status)
}

Debug(false)

r = New(HandleFallbackRoute)
// add fallback route
r.Any("/*", emptyHandler)
for _, m := range anyMethods {
ret = r.Match(m, "/not-exist")
is.Equal(Found, ret.Status)
}
}

func TestNameRoute(t *testing.T) {
Expand Down

0 comments on commit 859c064

Please sign in to comment.