Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 60 additions & 163 deletions router.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package echo

import (
"bytes"
"net/http"
"sort"
"strings"
)

type (
Expand All @@ -19,31 +20,21 @@ type (
prefix string
parent *node
staticChildren children
ppath string
pnames []string
methodHandler *methodHandler
handlers map[string]methodHandler
paramChild *node
anyChild *node
// isLeaf indicates that node does not have child routes
isLeaf bool
// isHandler indicates that node has at least one handler registered to it
isHandler bool
isHandler bool
allowHeader string
}
kind uint8
children []*node
methodHandler struct {
connect HandlerFunc
delete HandlerFunc
get HandlerFunc
head HandlerFunc
options HandlerFunc
patch HandlerFunc
post HandlerFunc
propfind HandlerFunc
put HandlerFunc
trace HandlerFunc
report HandlerFunc
allowHeader string
ppath string
pnames []string
handlerFunc HandlerFunc
}
)

Expand All @@ -56,70 +47,11 @@ const (
anyLabel = byte('*')
)

func (m *methodHandler) isHandler() bool {
return m.connect != nil ||
m.delete != nil ||
m.get != nil ||
m.head != nil ||
m.options != nil ||
m.patch != nil ||
m.post != nil ||
m.propfind != nil ||
m.put != nil ||
m.trace != nil ||
m.report != nil
}

func (m *methodHandler) updateAllowHeader() {
buf := new(bytes.Buffer)
buf.WriteString(http.MethodOptions)

if m.connect != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodConnect)
}
if m.delete != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodDelete)
}
if m.get != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodGet)
}
if m.head != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodHead)
}
if m.patch != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodPatch)
}
if m.post != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodPost)
}
if m.propfind != nil {
buf.WriteString(", PROPFIND")
}
if m.put != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodPut)
}
if m.trace != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodTrace)
}
if m.report != nil {
buf.WriteString(", REPORT")
}
m.allowHeader = buf.String()
}

// NewRouter returns a new Router instance.
func NewRouter(e *Echo) *Router {
return &Router{
tree: &node{
methodHandler: new(methodHandler),
handlers: make(map[string]methodHandler),
},
routes: map[string]*Route{},
echo: e,
Expand Down Expand Up @@ -184,6 +116,12 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
*r.echo.maxParam = paramLen
}

handler := methodHandler{
ppath: ppath,
pnames: pnames,
handlerFunc: h,
}

currentNode := r.tree // Current node as root
if currentNode == nil {
panic("echo: invalid method")
Expand All @@ -209,9 +147,7 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
currentNode.prefix = search
if h != nil {
currentNode.kind = t
currentNode.addHandler(method, h)
currentNode.ppath = ppath
currentNode.pnames = pnames
currentNode.addHandler(method, handler)
}
currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil
} else if lcpLen < prefixLen {
Expand All @@ -221,9 +157,7 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
currentNode.prefix[lcpLen:],
currentNode,
currentNode.staticChildren,
currentNode.methodHandler,
currentNode.ppath,
currentNode.pnames,
currentNode.handlers,
currentNode.paramChild,
currentNode.anyChild,
)
Expand All @@ -243,9 +177,7 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
currentNode.label = currentNode.prefix[0]
currentNode.prefix = currentNode.prefix[:lcpLen]
currentNode.staticChildren = nil
currentNode.methodHandler = new(methodHandler)
currentNode.ppath = ""
currentNode.pnames = nil
currentNode.handlers = make(map[string]methodHandler)
currentNode.paramChild = nil
currentNode.anyChild = nil
currentNode.isLeaf = false
Expand All @@ -257,13 +189,11 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
if lcpLen == searchLen {
// At parent node
currentNode.kind = t
currentNode.addHandler(method, h)
currentNode.ppath = ppath
currentNode.pnames = pnames
currentNode.addHandler(method, handler)
} else {
// Create child node
n = newNode(t, search[lcpLen:], currentNode, nil, new(methodHandler), ppath, pnames, nil, nil)
n.addHandler(method, h)
n = newNode(t, search[lcpLen:], currentNode, nil, nil, nil, nil)
n.addHandler(method, handler)
// Only Static children could reach here
currentNode.addStaticChild(n)
}
Expand All @@ -277,8 +207,8 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
continue
}
// Create child node
n := newNode(t, search, currentNode, nil, new(methodHandler), ppath, pnames, nil, nil)
n.addHandler(method, h)
n := newNode(t, search, currentNode, nil, nil, nil, nil)
n.addHandler(method, handler)
switch t {
case staticKind:
currentNode.addStaticChild(n)
Expand All @@ -291,31 +221,28 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
} else {
// Node already exists
if h != nil {
currentNode.addHandler(method, h)
currentNode.ppath = ppath
if len(currentNode.pnames) == 0 { // Issue #729
currentNode.pnames = pnames
}
currentNode.addHandler(method, handler)
}
}
return
}
}

func newNode(t kind, pre string, p *node, sc children, mh *methodHandler, ppath string, pnames []string, paramChildren, anyChildren *node) *node {
func newNode(t kind, pre string, p *node, sc children, handlers map[string]methodHandler, paramChildren, anyChildren *node) *node {
if handlers == nil {
handlers = make(map[string]methodHandler)
}
return &node{
kind: t,
label: pre[0],
prefix: pre,
parent: p,
staticChildren: sc,
ppath: ppath,
pnames: pnames,
methodHandler: mh,
handlers: handlers,
paramChild: paramChildren,
anyChild: anyChildren,
isLeaf: sc == nil && paramChildren == nil && anyChildren == nil,
isHandler: mh.isHandler(),
isHandler: len(handlers) > 0,
}
}

Expand Down Expand Up @@ -345,67 +272,31 @@ func (n *node) findChildWithLabel(l byte) *node {
return nil
}

func (n *node) addHandler(method string, h HandlerFunc) {
switch method {
case http.MethodConnect:
n.methodHandler.connect = h
case http.MethodDelete:
n.methodHandler.delete = h
case http.MethodGet:
n.methodHandler.get = h
case http.MethodHead:
n.methodHandler.head = h
case http.MethodOptions:
n.methodHandler.options = h
case http.MethodPatch:
n.methodHandler.patch = h
case http.MethodPost:
n.methodHandler.post = h
case PROPFIND:
n.methodHandler.propfind = h
case http.MethodPut:
n.methodHandler.put = h
case http.MethodTrace:
n.methodHandler.trace = h
case REPORT:
n.methodHandler.report = h
func (n *node) addHandler(method string, handler methodHandler) {
if handler.handlerFunc != nil {
n.handlers[method] = handler
}

n.methodHandler.updateAllowHeader()
if h != nil {
n.isHandler = true
} else {
n.isHandler = n.methodHandler.isHandler()
}
n.updateAllowHeader()
n.isHandler = len(n.handlers) != 0
}

func (n *node) findHandler(method string) HandlerFunc {
switch method {
case http.MethodConnect:
return n.methodHandler.connect
case http.MethodDelete:
return n.methodHandler.delete
case http.MethodGet:
return n.methodHandler.get
case http.MethodHead:
return n.methodHandler.head
case http.MethodOptions:
return n.methodHandler.options
case http.MethodPatch:
return n.methodHandler.patch
case http.MethodPost:
return n.methodHandler.post
case PROPFIND:
return n.methodHandler.propfind
case http.MethodPut:
return n.methodHandler.put
case http.MethodTrace:
return n.methodHandler.trace
case REPORT:
return n.methodHandler.report
default:
return nil
if m, ok := n.handlers[method]; ok {
return m.handlerFunc
}
return nil
}

func (n *node) updateAllowHeader() {
allowedMethods := []string{http.MethodOptions}
for method := range n.handlers {
allowedMethods = append(allowedMethods, method)
}
sort.Slice(allowedMethods[1:], func(i, j int) bool {
return allowedMethods[i+1] < allowedMethods[j+1]
})
n.allowHeader = strings.Join(allowedMethods, ", ")
}

func optionsMethodHandler(allowMethods string) func(c Context) error {
Expand Down Expand Up @@ -569,7 +460,11 @@ func (r *Router) Find(method, path string, c Context) {
if child := currentNode.anyChild; child != nil {
// If any node is found, use remaining path for paramValues
currentNode = child
paramValues[len(currentNode.pnames)-1] = search
if m, ok := currentNode.handlers[method]; ok {
paramValues[len(m.pnames)-1] = search
} else {
break
}
// update indexes/search in case we need to backtrack when no handler match is found
paramIndex++
searchIndex += +len(search)
Expand Down Expand Up @@ -613,13 +508,15 @@ func (r *Router) Find(method, path string, c Context) {

ctx.handler = NotFoundHandler
if currentNode.isHandler {
ctx.Set(ContextKeyHeaderAllow, currentNode.methodHandler.allowHeader)
ctx.Set(ContextKeyHeaderAllow, currentNode.allowHeader)
ctx.handler = MethodNotAllowedHandler
if method == http.MethodOptions {
ctx.handler = optionsMethodHandler(currentNode.methodHandler.allowHeader)
ctx.handler = optionsMethodHandler(currentNode.allowHeader)
}
}
}
ctx.path = currentNode.ppath
ctx.pnames = currentNode.pnames
if m, ok := currentNode.handlers[method]; ok {
ctx.path = m.ppath
ctx.pnames = m.pnames
}
}
22 changes: 22 additions & 0 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,28 @@ func TestRouterTwoParam(t *testing.T) {
assert.Equal(t, "1", c.Param("fid"))
}

// Issue #1726
// Issue #2201
func TestRouterTwoDifferentParam(t *testing.T) {
e := New()
r := e.router
r.Add(http.MethodPut, "/users/:vid/files/:gid", func(Context) error {
return nil
})
r.Add(http.MethodGet, "/users/:uid/files/:fid", func(Context) error {
return nil
})
c := e.NewContext(nil, nil).(*context)

r.Find(http.MethodGet, "/users/1/files/2", c)
assert.Equal(t, "1", c.Param("uid"))
assert.Equal(t, "2", c.Param("fid"))

r.Find(http.MethodPut, "/users/3/files/4", c)
assert.Equal(t, "3", c.Param("vid"))
assert.Equal(t, "4", c.Param("gid"))
}

// Issue #378
func TestRouterParamWithSlash(t *testing.T) {
e := New()
Expand Down