diff --git a/rte.go b/rte.go index d143bd8..ea4c849 100644 --- a/rte.go +++ b/rte.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/jwilner/rte/internal/funcs" "net/http" + "regexp" "strings" ) @@ -22,10 +23,6 @@ const ( MethodAll = "~" ) -const ( - wildcard, wildcardSlash = "*", "*/" -) - // Middleware is shorthand for a function which can handle or modify a request, optionally invoke the next // handler (or not), and modify (or set) a response. type Middleware interface { @@ -135,8 +132,8 @@ func Must(routes []Route) *Table { // New builds routes into a Table or returns an error func New(routes []Route) (*Table, error) { - t := new(Table) - t.root = newNode() + t := &Table{root: newNode(""), Default: http.NotFoundHandler()} + normalizer := regexp.MustCompile(`:[^/]*`) maxVars := len(funcs.PathVars{}) for i, r := range routes { @@ -156,28 +153,25 @@ func New(routes []Route) (*Table, error) { return nil, Error{Type: ErrTypeNoInitialSlash, Idx: i, Route: r} } - n := t.root - numPathParams := 0 - for _, seg := range strings.SplitAfter(r.Path, "/")[1:] { - // normalize - seg, err := normalize(seg) - if err != nil { - return nil, Error{Type: ErrTypeInvalidSegment, Idx: i, Route: r, cause: err} - } else if seg == wildcard || seg == wildcardSlash { - numPathParams++ - } + if strings.Contains(r.Path, "*") { + return nil, Error{Type: ErrTypeInvalidSegment, Idx: i, Route: r} + } - if n.children[seg] == nil { - n.children[seg] = newNode() - } + normalized := normalizer.ReplaceAllString(r.Path, "*") - n = n.children[seg] + var numPathParams int + for _, c := range normalized { + if c == '*' { + numPathParams++ + } } + if numPathParams > maxVars { return nil, Error{Type: ErrTypeOutOfRange, Idx: i, Route: r} } - if _, has := n.methods[r.Method]; has { + hndlrs := traverse(t.root, normalized) + if _, exists := hndlrs[r.Method]; exists { return nil, Error{Type: ErrTypeDuplicateHandler, Idx: i, Route: r} } @@ -195,16 +189,102 @@ func New(routes []Route) (*Table, error) { h = applyMiddleware(h, r.Middleware) } - n.methods[r.Method] = h + hndlrs[r.Method] = h } - t.Default = http.NotFoundHandler() - return t, nil } -func newNode() *node { - return &node{children: make(map[string]*node), methods: make(map[string]funcs.Handler)} +func traverse(node *node, path string) map[string]funcs.Handler { + i := 0 + + child := node.get(path[i]) + for child != nil { + + // find point where label and path diverge (or one ends) + j := 0 + for i < len(path) && j < len(child.label) && path[i] == child.label[j] { + i++ + j++ + } + + if j == len(child.label) { + node = child + if i == len(path) { + break + } + child = node.get(path[i]) + continue + } + + // we've stopped in the middle of the current label, so the child + // will be pushed down in the tree -- update its label + label := child.label + child.label = label[j:] + + if i == len(path) { // this is a prefix -- split the path and insert + newChild := newNode(label[:j]) + newChild.add(child) + node.add(newChild) + return newChild.hndlrs + } + + // they've diverged at j in the current label + newN := newNode(path[i:]) + + branch := newNode(label[:j]) + branch.add(child) + branch.add(newN) + node.add(branch) + + return newN.hndlrs + } + + // node.edges[r.path[i]] == "" -- i.e., we hit a terminal node + if i == len(path) { + return node.hndlrs + } + + // we've still got labels to consume -- add a child + ch := newNode(path[i:]) + node.add(ch) + return ch.hndlrs +} + +type node struct { + // index[i] == children[i].label[0] always + index []byte + children []*node + label string + hndlrs map[string]funcs.Handler +} + +func newNode(label string) *node { + return &node{ + hndlrs: make(map[string]funcs.Handler), + label: label, + } +} + +func (n *node) add(n2 *node) { + for i, c := range n.index { + if c == n2.label[0] { + n.children[i] = n2 + return + } + } + + n.index = append(n.index, n2.label[0]) + n.children = append(n.children, n2) +} + +func (n *node) get(b byte) *node { + for i, ib := range n.index { + if ib == b { + return n.children[i] + } + } + return nil } func applyMiddleware(h funcs.Handler, mw Middleware) funcs.Handler { @@ -215,21 +295,6 @@ func applyMiddleware(h funcs.Handler, mw Middleware) funcs.Handler { } } -func normalize(seg string) (string, error) { - switch { - case strings.ContainsAny(seg, "*"): - return "", fmt.Errorf("segment %q contains invalid characters", seg) - case seg == "", seg[0] != ':': - return seg, nil - case seg == ":", seg == ":/": - return "", fmt.Errorf("wildcard segment %q must have a name", seg) - case seg[len(seg)-1] == '/': - return wildcardSlash, nil - default: - return wildcard, nil - } -} - // Table manages the routing table and a default handler type Table struct { Default http.Handler @@ -237,60 +302,61 @@ type Table struct { } func (t *Table) ServeHTTP(w http.ResponseWriter, r *http.Request) { + node := t.root + pathIdx := 0 var ( - i int - params funcs.PathVars - node = t.root + variables funcs.PathVars + varIdx = 0 + hndlrs map[string]funcs.Handler ) - // Analogous to `SplitAfter`, but avoids an alloc for fun - // "" -> [], "/" -> [""], "/abc" -> ["/", "abc"], "/abc/" -> ["/", "abc/", ""] - if start := strings.IndexByte(r.URL.Path, '/') + 1; start != 0 { - for hitEnd := false; !hitEnd; { - var end int - if offset := strings.IndexByte(r.URL.Path[start:], '/'); offset != -1 { - end = start + offset + 1 - } else { - end = len(r.URL.Path) - hitEnd = true - } - - var pVarName string - if pVarName, node = node.match(r.URL.Path[start:end]); node == nil { - t.Default.ServeHTTP(w, r) - return - } else if pVarName != "" { // we've matched a path var - params[i] = pVarName - i++ - } - start = end +outer: + for { + child := node.get(r.RequestURI[pathIdx]) + if child == nil { + break outer } - } - if h, ok := node.methods[r.Method]; ok { - h(w, r, params) - return - } + lblIdx := 0 + for { + switch { + case r.RequestURI[pathIdx] == child.label[lblIdx]: + pathIdx++ + lblIdx++ + case child.label[lblIdx] == '*': + wcStart := pathIdx + for pathIdx < len(r.RequestURI) && r.RequestURI[pathIdx] != '/' { + pathIdx++ + } + variables[varIdx] = r.RequestURI[wcStart:pathIdx] + varIdx++ + lblIdx++ + default: + break outer + } - if h, ok := node.methods[MethodAll]; ok { - h(w, r, params) - return + pathDone, labelDone := pathIdx == len(r.RequestURI), lblIdx == len(child.label) + switch { + case !pathDone && !labelDone: + continue + case pathDone && labelDone: + hndlrs = child.hndlrs + break outer + case pathDone: + break outer + case labelDone: + node = child + continue outer + } + } } - t.Default.ServeHTTP(w, r) -} - -type node struct { - children map[string]*node - methods map[string]funcs.Handler -} - -func (n *node) match(seg string) (string, *node) { - if c := n.children[seg]; c != nil { - return "", c - } else if l := len(seg) - 1; l >= 0 && seg[l] == '/' { - return seg[:l], n.children[wildcardSlash] - } else { - return seg, n.children[wildcard] + switch { + case hndlrs[r.Method] != nil: + hndlrs[r.Method](w, r, variables) + case hndlrs[MethodAll] != nil: + hndlrs[MethodAll](w, r, variables) + default: + t.Default.ServeHTTP(w, r) } } diff --git a/rte_test.go b/rte_test.go index d899e74..8cddecd 100644 --- a/rte_test.go +++ b/rte_test.go @@ -64,12 +64,8 @@ func TestNew(t *testing.T) { ErrMsg: `route 0 "GET hi": no initial slash`, }, { - Name: "invalidSegmentMissingName", - Routes: rte.Routes("GET /:", func(w http.ResponseWriter, r *http.Request) {}), - WantErr: true, - ErrType: rte.ErrTypeInvalidSegment, - ErrIdx: 0, - ErrMsg: `route 0 "GET /:": invalid segment: wildcard segment ":" must have a name`, + Name: "name unrequired", + Routes: rte.Routes("GET /:", func(w http.ResponseWriter, r *http.Request, a string) {}), }, { Name: "invalidSegmentInvalidChar", @@ -77,7 +73,7 @@ func TestNew(t *testing.T) { WantErr: true, ErrType: rte.ErrTypeInvalidSegment, ErrIdx: 0, - ErrMsg: `route 0 "GET /*": invalid segment: segment "*" contains invalid characters`, + ErrMsg: `route 0 "GET /*": invalid segment`, }, { Name: "duplicate handler", @@ -157,6 +153,11 @@ func TestNew(t *testing.T) { }, } { t.Run(c.Name, func(t *testing.T) { + defer func() { + if p := recover(); p != nil { + t.Fatalf("panicked: %v", p) + } + }() _, err := rte.New(c.Routes) if c.WantErr != (err != nil) { t.Fatalf("want err %v, got %v", c.WantErr, err) @@ -167,7 +168,7 @@ func TestNew(t *testing.T) { case !ok: t.Fatalf("expected a rte.Error, got %T: %v", err, err) case e.Type != c.ErrType: - t.Fatalf("expected error type %v, but got %v", c.ErrType, e.Type) + t.Fatalf("expected error type %v, but got %v", c, e) case e.Idx != c.ErrIdx: t.Fatalf("expected error to occur with route %v, but got route %v", c.ErrIdx, e.Idx) case e.Error() != c.ErrMsg: @@ -234,79 +235,118 @@ func Test_matchPath(t *testing.T) { tests := []struct { name string req *http.Request - rte rte.Route + rte []rte.Route code int body string }{ { "match", httptest.NewRequest("GET", "/abc", nil), - rte.Route{Method: "GET", Path: "/abc", Handler: h200}, + rte.Routes("GET /abc", h200), 200, "null", }, { "wrong-method", httptest.NewRequest("PUT", "/abcd", nil), - rte.Route{Method: "POST", Path: "/abcd", Handler: h200}, + rte.Routes("POST /abcd", h200), 404, "404", }, { "match-trailing", httptest.NewRequest("HEAD", "/abc/", nil), - rte.Route{Method: "HEAD", Path: "/abc/", Handler: h200}, + rte.Routes("HEAD /abc/", h200), 200, "null", }, { "require-trailing", httptest.NewRequest("GET", "/abc/", nil), - rte.Route{Method: "GET", Path: "/abc", Handler: h200}, + rte.Routes("GET /abc", h200), + 404, "404", + }, + { + "nested-miss", + httptest.NewRequest("GET", "/abc/abcde", nil), + rte.Routes("GET /abc/abcdef", h200), + 404, "404", + }, + { + "unequal", + httptest.NewRequest("GET", "/abc/abcdeg24", nil), + rte.Routes("GET /abc/abcdef", h200), 404, "404", }, { "slash-match", httptest.NewRequest("GET", "/", nil), - rte.Route{Method: "GET", Path: "/", Handler: h200}, + rte.Routes("GET /", h200), 200, "null", }, { "wildcard-match", httptest.NewRequest("GET", "/abc", nil), - rte.Route{ - Method: "GET", Path: "/:whoo", - Handler: func(w http.ResponseWriter, r *http.Request, whoo string) { + rte.Routes( + "GET /:whoo", + func(w http.ResponseWriter, r *http.Request, whoo string) { _ = json.NewEncoder(w).Encode([]string{whoo}) }, - }, + ), 200, `["abc"]`, }, { "multiple-wildcard", httptest.NewRequest("GET", "/abc/123", nil), - rte.Route{ - Method: "GET", Path: "/:foo/:bar", - Handler: func(w http.ResponseWriter, r *http.Request, foo, bar string) { + rte.Routes( + "GET /:foo/:bar", + func(w http.ResponseWriter, r *http.Request, foo, bar string) { _ = json.NewEncoder(w).Encode([]string{foo, bar}) }, - }, + ), 200, `["abc","123"]`, }, { "match-method-not-allowed", httptest.NewRequest("GET", "/abc/123", nil), - rte.Route{ - Method: rte.MethodAll, Path: "/:foo/:bar", - Handler: func(w http.ResponseWriter, r *http.Request, foo, bar string) { + rte.Routes( + rte.MethodAll+" /:foo/:bar", + func(w http.ResponseWriter, r *http.Request, foo, bar string) { w.WriteHeader(http.StatusMethodNotAllowed) _ = json.NewEncoder(w).Encode([]string{foo, bar}) }, - }, + ), + 405, `["abc","123"]`, + }, + { + "", + httptest.NewRequest("GET", "/abc/123", nil), + rte.Routes( + rte.MethodAll+" /:foo/:bar", + func(w http.ResponseWriter, r *http.Request, foo, bar string) { + w.WriteHeader(http.StatusMethodNotAllowed) + _ = json.NewEncoder(w).Encode([]string{foo, bar}) + }, + ), 405, `["abc","123"]`, }, + + // multi route + { + "", + httptest.NewRequest("GET", "/abc/123", nil), + rte.Routes( + "GET /abc/:bar", + func(w http.ResponseWriter, r *http.Request, bar string) { + w.WriteHeader(http.StatusAccepted) + _ = json.NewEncoder(w).Encode([]string{bar}) + }, + "GET /abc", h200, + ), + http.StatusAccepted, `["123"]`, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tbl := rte.Must([]rte.Route{tt.rte}) + tbl := rte.Must(tt.rte) tbl.Default = http.HandlerFunc(h404) w := httptest.NewRecorder() @@ -388,4 +428,3 @@ func TestMiddleware(t *testing.T) { }) } } -