Skip to content

Commit

Permalink
[update] changed tree type in router struct
Browse files Browse the repository at this point in the history
  • Loading branch information
bmf-san committed Jan 12, 2023
1 parent 6c4c10b commit 6f667a9
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 541 deletions.
44 changes: 26 additions & 18 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

// Router represents the router which handles routing.
type Router struct {
tree *tree
tree map[string]*tree
NotFoundHandler http.Handler
MethodNotAllowedHandler http.Handler
DefaultOPTIONSHandler http.Handler
Expand All @@ -33,9 +33,9 @@ var (
)

// NewRouter creates a new router.
func NewRouter() Router {
return Router{
tree: newTree(),
func NewRouter() *Router {
return &Router{
tree: map[string]*tree{},
}
}

Expand All @@ -45,13 +45,13 @@ func (r *Router) UseGlobal(mws ...middleware) {
}

// Use sets middlewares.
func (r Router) Use(mws ...middleware) Router {
func (r *Router) Use(mws ...middleware) *Router {
nm := NewMiddlewares(mws)
tmpRoute.middlewares = nm
return r
}

func (r Router) Methods(methods ...string) Router {
func (r *Router) Methods(methods ...string) *Router {
tmpRoute.methods = append(tmpRoute.methods, methods...)
return r
}
Expand All @@ -64,37 +64,45 @@ func (r Router) Handler(path string, handler http.Handler) {
}

// Handle handles a route.
func (r Router) Handle() {
r.tree.Insert(tmpRoute.methods, tmpRoute.path, tmpRoute.handler, tmpRoute.middlewares)
func (r *Router) Handle() {
for i := 0; i < len(tmpRoute.methods); i++ {
_, ok := r.tree[tmpRoute.methods[i]]
if !ok {
r.tree[tmpRoute.methods[i]] = newTree()
}
r.tree[tmpRoute.methods[i]].Insert(tmpRoute.path, tmpRoute.handler, tmpRoute.middlewares)
}
tmpRoute = &route{}
}

// ServeHTTP dispatches the request to the handler whose
// pattern most closely matches the request URL.
func (r Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
method := req.Method
if method == http.MethodOptions {
if r.DefaultOPTIONSHandler != nil {
r.DefaultOPTIONSHandler.ServeHTTP(w, req)
return
}
}
action, params, err := r.tree.Search(method, req.URL.Path)
if err == ErrNotFound {
if r.NotFoundHandler == nil {
http.NotFoundHandler().ServeHTTP(w, req)

t, ok := r.tree[method]
if !ok {
if r.MethodNotAllowedHandler == nil {
methodNotAllowedHandler().ServeHTTP(w, req)
return
}
r.NotFoundHandler.ServeHTTP(w, req)
r.MethodNotAllowedHandler.ServeHTTP(w, req)
return
}

if err == ErrMethodNotAllowed {
if r.MethodNotAllowedHandler == nil {
methodNotAllowedHandler().ServeHTTP(w, req)
action, params, err := t.Search(req.URL.Path)
if err == ErrNotFound {
if r.NotFoundHandler == nil {
http.NotFoundHandler().ServeHTTP(w, req)
return
}
r.MethodNotAllowedHandler.ServeHTTP(w, req)
r.NotFoundHandler.ServeHTTP(w, req)
return
}

Expand Down
5 changes: 3 additions & 2 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (

func TestNewRouter(t *testing.T) {
actual := NewRouter()
expected := Router{
tree: newTree(),
expected := &Router{
tree: map[string]*tree{},
}

if !reflect.DeepEqual(actual, expected) {
Expand Down Expand Up @@ -323,6 +323,7 @@ func TestCustomErrorHandler(t *testing.T) {
code: http.StatusNotFound,
body: "statusnotfound",
},
// NewRouterのtreeの対応関連
{
path: "/custommethodnotallowed",
method: http.MethodPost,
Expand Down
40 changes: 18 additions & 22 deletions trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ type tree struct {
// node is a node of tree.
type node struct {
label string
actions map[string]*action // key is method
children []*node // key is label of next nodes
action *action // key is method
children []*node // key is label of next nodes
}

// action is an action.
Expand All @@ -41,7 +41,7 @@ func newTree() *tree {
return &tree{
node: &node{
label: "/",
actions: make(map[string]*action),
action: &action{},
children: []*node{},
},
}
Expand Down Expand Up @@ -77,17 +77,15 @@ func (n *node) getChild(label string) *node {
}

// Insert inserts a route definition to tree.
func (t *tree) Insert(methods []string, path string, handler http.Handler, mws middlewares) {
func (t *tree) Insert(path string, handler http.Handler, mws middlewares) {
path = cleanPath(path)
curNode := t.node

if path == "/" {
curNode.label = path
for i := 0; i < len(methods); i++ {
curNode.actions[methods[i]] = &action{
middlewares: mws,
handler: handler,
}
curNode.action = &action{
middlewares: mws,
handler: handler,
}
return
}
Expand Down Expand Up @@ -134,7 +132,7 @@ func (t *tree) Insert(methods []string, path string, handler http.Handler, mws m
if nextNode == nil {
child := &node{
label: l,
actions: make(map[string]*action),
action: &action{},
children: []*node{},
}
curNode.children = append(curNode.children, child)
Expand All @@ -149,11 +147,9 @@ func (t *tree) Insert(methods []string, path string, handler http.Handler, mws m
// If there is already registered data, overwrite it.
if i == cnt-1 {
curNode.label = l
for j := 0; j < len(methods); j++ {
curNode.actions[methods[j]] = &action{
middlewares: mws,
handler: handler,
}
curNode.action = &action{
middlewares: mws,
handler: handler,
}
break
}
Expand Down Expand Up @@ -195,11 +191,11 @@ func (rc *regCache) getReg(ptn string) (*regexp.Regexp, error) {
var regC = &regCache{}

// Search searches a path from a tree.
func (t *tree) Search(method string, path string) (*action, []Param, error) {
func (t *tree) Search(path string) (*action, []Param, error) {
path = cleanPath(path)
curNode := t.node

if path == "/" && curNode.actions[method] == nil {
if path == "/" && curNode.action == nil {
return nil, nil, ErrNotFound
}

Expand Down Expand Up @@ -297,15 +293,15 @@ func (t *tree) Search(method string, path string) (*action, []Param, error) {
}
}

actions := curNode.actions[method]
if actions == nil {
action := curNode.action
if action.handler == nil {
// no matching handler and middlewares was found.
return nil, nil, ErrMethodNotAllowed
return nil, nil, ErrNotFound
}
if params == nil {
return actions, nil, nil
return action, nil, nil
}
return actions, *params, nil
return action, *params, nil
}

// getPattern gets a pattern from a label.
Expand Down
Loading

0 comments on commit 6f667a9

Please sign in to comment.