diff --git a/router.go b/router.go index b5e50d94f..da8612303 100644 --- a/router.go +++ b/router.go @@ -1,8 +1,9 @@ package echo import ( - "bytes" "net/http" + "sort" + "strings" ) type ( @@ -19,31 +20,21 @@ type ( prefix string parent *node staticChildren children - ppath string - pnames []string - methodHandler *methodHandler + handlers map[string]methodHandler paramChild *node anyChild *node // isLeaf indicates that node does not have child routes isLeaf bool // isHandler indicates that node has at least one handler registered to it - isHandler bool + isHandler bool + allowHeader string } kind uint8 children []*node methodHandler struct { - connect HandlerFunc - delete HandlerFunc - get HandlerFunc - head HandlerFunc - options HandlerFunc - patch HandlerFunc - post HandlerFunc - propfind HandlerFunc - put HandlerFunc - trace HandlerFunc - report HandlerFunc - allowHeader string + ppath string + pnames []string + handlerFunc HandlerFunc } ) @@ -56,70 +47,11 @@ const ( anyLabel = byte('*') ) -func (m *methodHandler) isHandler() bool { - return m.connect != nil || - m.delete != nil || - m.get != nil || - m.head != nil || - m.options != nil || - m.patch != nil || - m.post != nil || - m.propfind != nil || - m.put != nil || - m.trace != nil || - m.report != nil -} - -func (m *methodHandler) updateAllowHeader() { - buf := new(bytes.Buffer) - buf.WriteString(http.MethodOptions) - - if m.connect != nil { - buf.WriteString(", ") - buf.WriteString(http.MethodConnect) - } - if m.delete != nil { - buf.WriteString(", ") - buf.WriteString(http.MethodDelete) - } - if m.get != nil { - buf.WriteString(", ") - buf.WriteString(http.MethodGet) - } - if m.head != nil { - buf.WriteString(", ") - buf.WriteString(http.MethodHead) - } - if m.patch != nil { - buf.WriteString(", ") - buf.WriteString(http.MethodPatch) - } - if m.post != nil { - buf.WriteString(", ") - buf.WriteString(http.MethodPost) - } - if m.propfind != nil { - buf.WriteString(", PROPFIND") - } - if m.put != nil { - buf.WriteString(", ") - buf.WriteString(http.MethodPut) - } - if m.trace != nil { - buf.WriteString(", ") - buf.WriteString(http.MethodTrace) - } - if m.report != nil { - buf.WriteString(", REPORT") - } - m.allowHeader = buf.String() -} - // NewRouter returns a new Router instance. func NewRouter(e *Echo) *Router { return &Router{ tree: &node{ - methodHandler: new(methodHandler), + handlers: make(map[string]methodHandler), }, routes: map[string]*Route{}, echo: e, @@ -184,6 +116,12 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string *r.echo.maxParam = paramLen } + handler := methodHandler{ + ppath: ppath, + pnames: pnames, + handlerFunc: h, + } + currentNode := r.tree // Current node as root if currentNode == nil { panic("echo: invalid method") @@ -209,9 +147,7 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string currentNode.prefix = search if h != nil { currentNode.kind = t - currentNode.addHandler(method, h) - currentNode.ppath = ppath - currentNode.pnames = pnames + currentNode.addHandler(method, handler) } currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil } else if lcpLen < prefixLen { @@ -221,9 +157,7 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string currentNode.prefix[lcpLen:], currentNode, currentNode.staticChildren, - currentNode.methodHandler, - currentNode.ppath, - currentNode.pnames, + currentNode.handlers, currentNode.paramChild, currentNode.anyChild, ) @@ -243,9 +177,7 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string currentNode.label = currentNode.prefix[0] currentNode.prefix = currentNode.prefix[:lcpLen] currentNode.staticChildren = nil - currentNode.methodHandler = new(methodHandler) - currentNode.ppath = "" - currentNode.pnames = nil + currentNode.handlers = make(map[string]methodHandler) currentNode.paramChild = nil currentNode.anyChild = nil currentNode.isLeaf = false @@ -257,13 +189,11 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string if lcpLen == searchLen { // At parent node currentNode.kind = t - currentNode.addHandler(method, h) - currentNode.ppath = ppath - currentNode.pnames = pnames + currentNode.addHandler(method, handler) } else { // Create child node - n = newNode(t, search[lcpLen:], currentNode, nil, new(methodHandler), ppath, pnames, nil, nil) - n.addHandler(method, h) + n = newNode(t, search[lcpLen:], currentNode, nil, nil, nil, nil) + n.addHandler(method, handler) // Only Static children could reach here currentNode.addStaticChild(n) } @@ -277,8 +207,8 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string continue } // Create child node - n := newNode(t, search, currentNode, nil, new(methodHandler), ppath, pnames, nil, nil) - n.addHandler(method, h) + n := newNode(t, search, currentNode, nil, nil, nil, nil) + n.addHandler(method, handler) switch t { case staticKind: currentNode.addStaticChild(n) @@ -291,31 +221,28 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string } else { // Node already exists if h != nil { - currentNode.addHandler(method, h) - currentNode.ppath = ppath - if len(currentNode.pnames) == 0 { // Issue #729 - currentNode.pnames = pnames - } + currentNode.addHandler(method, handler) } } return } } -func newNode(t kind, pre string, p *node, sc children, mh *methodHandler, ppath string, pnames []string, paramChildren, anyChildren *node) *node { +func newNode(t kind, pre string, p *node, sc children, handlers map[string]methodHandler, paramChildren, anyChildren *node) *node { + if handlers == nil { + handlers = make(map[string]methodHandler) + } return &node{ kind: t, label: pre[0], prefix: pre, parent: p, staticChildren: sc, - ppath: ppath, - pnames: pnames, - methodHandler: mh, + handlers: handlers, paramChild: paramChildren, anyChild: anyChildren, isLeaf: sc == nil && paramChildren == nil && anyChildren == nil, - isHandler: mh.isHandler(), + isHandler: len(handlers) > 0, } } @@ -345,67 +272,31 @@ func (n *node) findChildWithLabel(l byte) *node { return nil } -func (n *node) addHandler(method string, h HandlerFunc) { - switch method { - case http.MethodConnect: - n.methodHandler.connect = h - case http.MethodDelete: - n.methodHandler.delete = h - case http.MethodGet: - n.methodHandler.get = h - case http.MethodHead: - n.methodHandler.head = h - case http.MethodOptions: - n.methodHandler.options = h - case http.MethodPatch: - n.methodHandler.patch = h - case http.MethodPost: - n.methodHandler.post = h - case PROPFIND: - n.methodHandler.propfind = h - case http.MethodPut: - n.methodHandler.put = h - case http.MethodTrace: - n.methodHandler.trace = h - case REPORT: - n.methodHandler.report = h +func (n *node) addHandler(method string, handler methodHandler) { + if handler.handlerFunc != nil { + n.handlers[method] = handler } - n.methodHandler.updateAllowHeader() - if h != nil { - n.isHandler = true - } else { - n.isHandler = n.methodHandler.isHandler() - } + n.updateAllowHeader() + n.isHandler = len(n.handlers) != 0 } func (n *node) findHandler(method string) HandlerFunc { - switch method { - case http.MethodConnect: - return n.methodHandler.connect - case http.MethodDelete: - return n.methodHandler.delete - case http.MethodGet: - return n.methodHandler.get - case http.MethodHead: - return n.methodHandler.head - case http.MethodOptions: - return n.methodHandler.options - case http.MethodPatch: - return n.methodHandler.patch - case http.MethodPost: - return n.methodHandler.post - case PROPFIND: - return n.methodHandler.propfind - case http.MethodPut: - return n.methodHandler.put - case http.MethodTrace: - return n.methodHandler.trace - case REPORT: - return n.methodHandler.report - default: - return nil + if m, ok := n.handlers[method]; ok { + return m.handlerFunc } + return nil +} + +func (n *node) updateAllowHeader() { + allowedMethods := []string{http.MethodOptions} + for method := range n.handlers { + allowedMethods = append(allowedMethods, method) + } + sort.Slice(allowedMethods[1:], func(i, j int) bool { + return allowedMethods[i+1] < allowedMethods[j+1] + }) + n.allowHeader = strings.Join(allowedMethods, ", ") } func optionsMethodHandler(allowMethods string) func(c Context) error { @@ -569,7 +460,11 @@ func (r *Router) Find(method, path string, c Context) { if child := currentNode.anyChild; child != nil { // If any node is found, use remaining path for paramValues currentNode = child - paramValues[len(currentNode.pnames)-1] = search + if m, ok := currentNode.handlers[method]; ok { + paramValues[len(m.pnames)-1] = search + } else { + break + } // update indexes/search in case we need to backtrack when no handler match is found paramIndex++ searchIndex += +len(search) @@ -613,13 +508,15 @@ func (r *Router) Find(method, path string, c Context) { ctx.handler = NotFoundHandler if currentNode.isHandler { - ctx.Set(ContextKeyHeaderAllow, currentNode.methodHandler.allowHeader) + ctx.Set(ContextKeyHeaderAllow, currentNode.allowHeader) ctx.handler = MethodNotAllowedHandler if method == http.MethodOptions { - ctx.handler = optionsMethodHandler(currentNode.methodHandler.allowHeader) + ctx.handler = optionsMethodHandler(currentNode.allowHeader) } } } - ctx.path = currentNode.ppath - ctx.pnames = currentNode.pnames + if m, ok := currentNode.handlers[method]; ok { + ctx.path = m.ppath + ctx.pnames = m.pnames + } } diff --git a/router_test.go b/router_test.go index 457566b90..5de534904 100644 --- a/router_test.go +++ b/router_test.go @@ -821,6 +821,28 @@ func TestRouterTwoParam(t *testing.T) { assert.Equal(t, "1", c.Param("fid")) } +// Issue #1726 +// Issue #2201 +func TestRouterTwoDifferentParam(t *testing.T) { + e := New() + r := e.router + r.Add(http.MethodPut, "/users/:vid/files/:gid", func(Context) error { + return nil + }) + r.Add(http.MethodGet, "/users/:uid/files/:fid", func(Context) error { + return nil + }) + c := e.NewContext(nil, nil).(*context) + + r.Find(http.MethodGet, "/users/1/files/2", c) + assert.Equal(t, "1", c.Param("uid")) + assert.Equal(t, "2", c.Param("fid")) + + r.Find(http.MethodPut, "/users/3/files/4", c) + assert.Equal(t, "3", c.Param("vid")) + assert.Equal(t, "4", c.Param("gid")) +} + // Issue #378 func TestRouterParamWithSlash(t *testing.T) { e := New()