Skip to content

Commit

Permalink
Implements a compressed trie for routing
Browse files Browse the repository at this point in the history
  • Loading branch information
jwilner committed Mar 2, 2019
1 parent 8663a54 commit fff0411
Show file tree
Hide file tree
Showing 2 changed files with 221 additions and 116 deletions.
242 changes: 154 additions & 88 deletions rte.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"github.com/jwilner/rte/internal/funcs"
"net/http"
"regexp"
"strings"
)

Expand All @@ -22,10 +23,6 @@ const (
MethodAll = "~"
)

const (
wildcard, wildcardSlash = "*", "*/"
)

// Middleware is shorthand for a function which can handle or modify a request, optionally invoke the next
// handler (or not), and modify (or set) a response.
type Middleware interface {
Expand Down Expand Up @@ -135,8 +132,8 @@ func Must(routes []Route) *Table {

// New builds routes into a Table or returns an error
func New(routes []Route) (*Table, error) {
t := new(Table)
t.root = newNode()
t := &Table{root: newNode(""), Default: http.NotFoundHandler()}
normalizer := regexp.MustCompile(`:[^/]*`)

maxVars := len(funcs.PathVars{})
for i, r := range routes {
Expand All @@ -156,28 +153,25 @@ func New(routes []Route) (*Table, error) {
return nil, Error{Type: ErrTypeNoInitialSlash, Idx: i, Route: r}
}

n := t.root
numPathParams := 0
for _, seg := range strings.SplitAfter(r.Path, "/")[1:] {
// normalize
seg, err := normalize(seg)
if err != nil {
return nil, Error{Type: ErrTypeInvalidSegment, Idx: i, Route: r, cause: err}
} else if seg == wildcard || seg == wildcardSlash {
numPathParams++
}
if strings.Contains(r.Path, "*") {
return nil, Error{Type: ErrTypeInvalidSegment, Idx: i, Route: r}
}

if n.children[seg] == nil {
n.children[seg] = newNode()
}
normalized := normalizer.ReplaceAllString(r.Path, "*")

n = n.children[seg]
var numPathParams int
for _, c := range normalized {
if c == '*' {
numPathParams++
}
}

if numPathParams > maxVars {
return nil, Error{Type: ErrTypeOutOfRange, Idx: i, Route: r}
}

if _, has := n.methods[r.Method]; has {
hndlrs := traverse(t.root, normalized)
if _, exists := hndlrs[r.Method]; exists {
return nil, Error{Type: ErrTypeDuplicateHandler, Idx: i, Route: r}
}

Expand All @@ -195,16 +189,102 @@ func New(routes []Route) (*Table, error) {
h = applyMiddleware(h, r.Middleware)
}

n.methods[r.Method] = h
hndlrs[r.Method] = h
}

t.Default = http.NotFoundHandler()

return t, nil
}

func newNode() *node {
return &node{children: make(map[string]*node), methods: make(map[string]funcs.Handler)}
func traverse(node *node, path string) map[string]funcs.Handler {
i := 0

child := node.get(path[i])
for child != nil {

// find point where label and path diverge (or one ends)
j := 0
for i < len(path) && j < len(child.label) && path[i] == child.label[j] {
i++
j++
}

if j == len(child.label) {
node = child
if i == len(path) {
break
}
child = node.get(path[i])
continue
}

// we've stopped in the middle of the current label, so the child
// will be pushed down in the tree -- update its label
label := child.label
child.label = label[j:]

if i == len(path) { // this is a prefix -- split the path and insert
newChild := newNode(label[:j])
newChild.add(child)
node.add(newChild)
return newChild.hndlrs
}

// they've diverged at j in the current label
newN := newNode(path[i:])

branch := newNode(label[:j])
branch.add(child)
branch.add(newN)
node.add(branch)

return newN.hndlrs
}

// node.edges[r.path[i]] == "" -- i.e., we hit a terminal node
if i == len(path) {
return node.hndlrs
}

// we've still got labels to consume -- add a child
ch := newNode(path[i:])
node.add(ch)
return ch.hndlrs
}

type node struct {
// index[i] == children[i].label[0] always
index []byte
children []*node
label string
hndlrs map[string]funcs.Handler
}

func newNode(label string) *node {
return &node{
hndlrs: make(map[string]funcs.Handler),
label: label,
}
}

func (n *node) add(n2 *node) {
for i, c := range n.index {
if c == n2.label[0] {
n.children[i] = n2
return
}
}

n.index = append(n.index, n2.label[0])
n.children = append(n.children, n2)
}

func (n *node) get(b byte) *node {
for i, ib := range n.index {
if ib == b {
return n.children[i]
}
}
return nil
}

func applyMiddleware(h funcs.Handler, mw Middleware) funcs.Handler {
Expand All @@ -215,82 +295,68 @@ func applyMiddleware(h funcs.Handler, mw Middleware) funcs.Handler {
}
}

func normalize(seg string) (string, error) {
switch {
case strings.ContainsAny(seg, "*"):
return "", fmt.Errorf("segment %q contains invalid characters", seg)
case seg == "", seg[0] != ':':
return seg, nil
case seg == ":", seg == ":/":
return "", fmt.Errorf("wildcard segment %q must have a name", seg)
case seg[len(seg)-1] == '/':
return wildcardSlash, nil
default:
return wildcard, nil
}
}

// Table manages the routing table and a default handler
type Table struct {
Default http.Handler
root *node
}

func (t *Table) ServeHTTP(w http.ResponseWriter, r *http.Request) {
node := t.root
pathIdx := 0
var (
i int
params funcs.PathVars
node = t.root
variables funcs.PathVars
varIdx = 0
hndlrs map[string]funcs.Handler
)
// Analogous to `SplitAfter`, but avoids an alloc for fun
// "" -> [], "/" -> [""], "/abc" -> ["/", "abc"], "/abc/" -> ["/", "abc/", ""]
if start := strings.IndexByte(r.URL.Path, '/') + 1; start != 0 {
for hitEnd := false; !hitEnd; {
var end int
if offset := strings.IndexByte(r.URL.Path[start:], '/'); offset != -1 {
end = start + offset + 1
} else {
end = len(r.URL.Path)
hitEnd = true
}

var pVarName string
if pVarName, node = node.match(r.URL.Path[start:end]); node == nil {
t.Default.ServeHTTP(w, r)
return
} else if pVarName != "" { // we've matched a path var
params[i] = pVarName
i++
}

start = end
outer:
for {
child := node.get(r.RequestURI[pathIdx])
if child == nil {
break outer
}
}

if h, ok := node.methods[r.Method]; ok {
h(w, r, params)
return
}
lblIdx := 0
for {
switch {
case r.RequestURI[pathIdx] == child.label[lblIdx]:
pathIdx++
lblIdx++
case child.label[lblIdx] == '*':
wcStart := pathIdx
for pathIdx < len(r.RequestURI) && r.RequestURI[pathIdx] != '/' {
pathIdx++
}
variables[varIdx] = r.RequestURI[wcStart:pathIdx]
varIdx++
lblIdx++
default:
break outer
}

if h, ok := node.methods[MethodAll]; ok {
h(w, r, params)
return
pathDone, labelDone := pathIdx == len(r.RequestURI), lblIdx == len(child.label)
switch {
case !pathDone && !labelDone:
continue
case pathDone && labelDone:
hndlrs = child.hndlrs
break outer
case pathDone:
break outer
case labelDone:
node = child
continue outer
}
}
}

t.Default.ServeHTTP(w, r)
}

type node struct {
children map[string]*node
methods map[string]funcs.Handler
}

func (n *node) match(seg string) (string, *node) {
if c := n.children[seg]; c != nil {
return "", c
} else if l := len(seg) - 1; l >= 0 && seg[l] == '/' {
return seg[:l], n.children[wildcardSlash]
} else {
return seg, n.children[wildcard]
switch {
case hndlrs[r.Method] != nil:
hndlrs[r.Method](w, r, variables)
case hndlrs[MethodAll] != nil:
hndlrs[MethodAll](w, r, variables)
default:
t.Default.ServeHTTP(w, r)
}
}
Loading

0 comments on commit fff0411

Please sign in to comment.