diff --git a/echo.go b/echo.go index f984f4056..c8dd3fed4 100644 --- a/echo.go +++ b/echo.go @@ -8,7 +8,7 @@ import ( "io" "log" "net/http" - "path/filepath" + spath "path" "reflect" "runtime" "strings" @@ -135,8 +135,7 @@ const ( Upgrade = "Upgrade" Vary = "Vary" WWWAuthenticate = "WWW-Authenticate" - XForwardedFor = "X-Forwarded-For" - XRealIP = "X-Real-IP" + //----------- // Protocols //----------- @@ -170,6 +169,15 @@ var ( //---------------- // Error handlers //---------------- + // methodNotAllowedHandler - handler to respond with a http.StatusMethodNotAllowed status + // which is applicable when the route is correct, but the method for the route isn't allowed + // for the route + methodNotAllowedHandler = func(c *Context, allowedMethods ...string) func(c *Context) error { + return func(c *Context) error { + c.response.Header().Add("Allow", strings.Join(allowedMethods, ", ")) + return NewHTTPError(http.StatusMethodNotAllowed) + } + } notFoundHandler = func(c *Context) error { return NewHTTPError(http.StatusNotFound) @@ -180,6 +188,8 @@ var ( } ) +var runtimeGOOS = runtime.GOOS + // New creates an instance of Echo. func New() (e *Echo) { e = &Echo{maxParam: new(int)} @@ -192,7 +202,7 @@ func New() (e *Echo) { // Defaults //---------- - if runtime.GOOS == "windows" { + if runtimeGOOS == "windows" { e.DisableColoredLog() } e.HTTP2() @@ -221,9 +231,12 @@ func (e *Echo) Router() *Router { return e.router } +var colorDisable = color.Disable + // DisableColoredLog disables colored log. func (e *Echo) DisableColoredLog() { - color.Disable() + colorDisable() + } // HTTP2 enables HTTP2 support. @@ -376,14 +389,14 @@ func (e *Echo) Static(path, dir string) { // ServeDir serves files from a directory. func (e *Echo) ServeDir(path, dir string) { e.Get(path+"*", func(c *Context) error { - return serveFile(dir, c.P(0), c) // Param `_*` + return serveFile(dir, c.P(0), c) // Param `_name` }) } // ServeFile serves a file. func (e *Echo) ServeFile(path, file string) { e.Get(path, func(c *Context) error { - dir, file := filepath.Split(file) + dir, file := spath.Split(file) return serveFile(dir, file, c) }) } @@ -397,7 +410,7 @@ func serveFile(dir, file string, c *Context) error { fi, _ := f.Stat() if fi.IsDir() { - file = filepath.Join(file, indexFile) + file = spath.Join(file, indexFile) f, err = fs.Open(file) if err != nil { return NewHTTPError(http.StatusForbidden) diff --git a/echo_test.go b/echo_test.go index 8a4690855..6ddb1d866 100644 --- a/echo_test.go +++ b/echo_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "runtime" "testing" "reflect" @@ -12,6 +13,7 @@ import ( "errors" + "github.com/labstack/gommon/color" "github.com/stretchr/testify/assert" "golang.org/x/net/websocket" ) @@ -23,6 +25,20 @@ type ( } ) +func TestNewRuntimeGOOS(t *testing.T) { + changedColor := make(chan bool) + colorDisable = func() { changedColor <- true } + runtimeGOOS = "windows" + defer func() { + colorDisable = color.Disable + runtimeGOOS = runtime.GOOS + }() + go func() { + New() + }() + assert.True(t, <-changedColor) +} + func TestEcho(t *testing.T) { e := New() req, _ := http.NewRequest(GET, "/", nil) @@ -387,7 +403,7 @@ func TestEchoBadRequest(t *testing.T) { r, _ := http.NewRequest("INVALID", "/files", nil) w := httptest.NewRecorder() e.ServeHTTP(w, r) - assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Equal(t, http.StatusMethodNotAllowed, w.Code) } func TestEchoHTTPError(t *testing.T) { diff --git a/middleware/logger.go b/middleware/logger.go index 6249cde74..19d1301d5 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -2,7 +2,6 @@ package middleware import ( "log" - "net" "time" "github.com/labstack/echo" @@ -12,30 +11,19 @@ import ( func Logger() echo.MiddlewareFunc { return func(h echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { - req := c.Request() - res := c.Response() - - remoteAddr := req.RemoteAddr - if ip := req.Header.Get(echo.XRealIP); ip != "" { - remoteAddr = ip - } else if ip = req.Header.Get(echo.XForwardedFor); ip != "" { - remoteAddr = ip - } - remoteAddr, _, _ = net.SplitHostPort(remoteAddr) - start := time.Now() if err := h(c); err != nil { c.Error(err) } stop := time.Now() - method := req.Method - path := req.URL.Path + method := c.Request().Method + path := c.Request().URL.Path if path == "" { path = "/" } - size := res.Size() + size := c.Response().Size() - n := res.Status() + n := c.Response().Status() code := color.Green(n) switch { case n >= 500: @@ -46,7 +34,7 @@ func Logger() echo.MiddlewareFunc { code = color.Cyan(n) } - log.Printf("%s %s %s %s %s %d", remoteAddr, method, path, code, stop.Sub(start), size) + log.Printf("%s %s %s %s %d", method, path, code, stop.Sub(start), size) return nil } } diff --git a/middleware/logger_test.go b/middleware/logger_test.go index 69c0ca460..e46019f3d 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -2,37 +2,20 @@ package middleware import ( "errors" + "github.com/labstack/echo" "net/http" "net/http/httptest" "testing" - - "github.com/labstack/echo" ) func TestLogger(t *testing.T) { - // Note: Just for the test coverage, not a real test. e := echo.New() req, _ := http.NewRequest(echo.GET, "/", nil) rec := httptest.NewRecorder() c := echo.NewContext(req, echo.NewResponse(rec), e) - // With X-Real-IP - req.Header.Add(echo.XRealIP, "127.0.0.1") - h := func(c *echo.Context) error { - return c.String(http.StatusOK, "test") - } - Logger()(h)(c) - - // With X-Forwarded-For - req.Header.Del(echo.XRealIP) - req.Header.Add(echo.XForwardedFor, "127.0.0.1") - h = func(c *echo.Context) error { - return c.String(http.StatusOK, "test") - } - Logger()(h)(c) - // Status 2xx - h = func(c *echo.Context) error { + h := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } Logger()(h)(c) diff --git a/router.go b/router.go index 7b156fa99..9164f818f 100644 --- a/router.go +++ b/router.go @@ -4,17 +4,9 @@ import "net/http" type ( Router struct { - connectTree *node - deleteTree *node - getTree *node - headTree *node - optionsTree *node - patchTree *node - postTree *node - putTree *node - traceTree *node - routes []Route - echo *Echo + tree *node + routes []Route + echo *Echo } node struct { typ ntype @@ -22,14 +14,131 @@ type ( prefix string parent *node children children - handler HandlerFunc - pnames []string - echo *Echo + //handler map[string]HandlerFunc + handler *handler + pnames []string + echo *Echo } ntype uint8 children []*node ) +type handler struct { + Connect HandlerFunc + Delete HandlerFunc + Get HandlerFunc + Head HandlerFunc + Options HandlerFunc + Patch HandlerFunc + Post HandlerFunc + Put HandlerFunc + Trace HandlerFunc + allowedMethods string +} + +func (h *handler) CopyTo(v *handler) { + v.Get = h.Get + v.Connect = h.Connect + v.Delete = h.Delete + v.Get = h.Get + v.Head = h.Head + v.Options = h.Options + v.Patch = h.Patch + v.Post = h.Post + v.Put = h.Put + v.Trace = h.Trace + v.allowedMethods = h.allowedMethods +} + +func (h *handler) addToAllowedMethods(method string) { + if h.allowedMethods == "" { + h.allowedMethods = method + } else { + h.allowedMethods = h.allowedMethods + ", " + method + } +} + +func (h *handler) AddMethodHandler(method string, handler HandlerFunc) { + if h != nil { + if method == GET { + h.addToAllowedMethods(method) + h.Get = handler + } + if method == HEAD { + h.addToAllowedMethods(method) + h.Head = handler + } + if method == POST { + h.addToAllowedMethods(method) + h.Post = handler + } + if method == OPTIONS { + h.addToAllowedMethods(method) + h.Options = handler + } + if method == PUT { + h.addToAllowedMethods(method) + h.Put = handler + } + if method == DELETE { + h.addToAllowedMethods(method) + h.Delete = handler + } + if method == PATCH { + h.addToAllowedMethods(method) + h.Patch = handler + } + if method == CONNECT { + h.addToAllowedMethods(method) + h.Connect = handler + } + if method == TRACE { + h.addToAllowedMethods(method) + h.Trace = handler + } + } +} + +func (h *handler) GetMethodHandler(method string) (HandlerFunc, string) { + l := len(method) + firstChar := method[0] + secondChar := method[1] + if l == 3 { + if uint16(firstChar)<<8|uint16(secondChar) == 0x4745 { + return h.Get, "" + } + if uint16(firstChar)<<8|uint16(secondChar) == 0x5055 { + return h.Put, "" + } + } else if l == 4 { + if uint16(firstChar)<<8|uint16(secondChar) == 0x504f { + return h.Post, "" + } + if uint16(firstChar)<<8|uint16(secondChar) == 0x4845 { + return h.Head, "" + } + } else if l == 5 { + if uint16(firstChar)<<8|uint16(secondChar) == 0x5452 { + return h.Trace, "" + } + if uint16(firstChar)<<8|uint16(secondChar) == 0x5041 { + return h.Patch, "" + } + } else if l == 6 { + if uint16(firstChar)<<8|uint16(secondChar) == 0x4445 { + return h.Delete, "" + } + } else if l == 7 { + if uint16(firstChar)<<8|uint16(secondChar) == 0x4f50 { + return h.Options, "" + } + if uint16(firstChar)<<8|uint16(secondChar) == 0x434f { + return h.Connect, "" + } + } + return nil, h.allowedMethods +} + const ( stype ntype = iota ptype @@ -38,17 +147,17 @@ const ( func NewRouter(e *Echo) *Router { return &Router{ - connectTree: new(node), - deleteTree: new(node), - getTree: new(node), - headTree: new(node), - optionsTree: new(node), - patchTree: new(node), - postTree: new(node), - putTree: new(node), - traceTree: new(node), - routes: []Route{}, - echo: e, + // tree is base node for the search tree for all routes, each node + // therein contains a handler string->HandlerFunc map. This allows + // us to include the method applicable within the tree, allowing us to + // detect if routes should not allow particular methods, and making the + // router more clear + tree: &node{ + //handler: make(map[string]HandlerFunc), + handler: new(handler), + }, + routes: []Route{}, + echo: e, } } @@ -74,7 +183,7 @@ func (r *Router) Add(method, path string, h HandlerFunc, e *Echo) { r.insert(method, path[:i], nil, ptype, pnames, e) } else if path[i] == '*' { r.insert(method, path[:i], nil, stype, nil, e) - pnames = append(pnames, "_*") + pnames = append(pnames, "_name") r.insert(method, path[:i+1], h, mtype, pnames, e) return } @@ -90,8 +199,8 @@ func (r *Router) insert(method, path string, h HandlerFunc, t ntype, pnames []st *e.maxParam = l } - cn := r.findTree(method) // Current node as root - if cn == nil { + cn := r.tree + if !validMethod(method) { panic("echo => invalid method") } search := path @@ -115,20 +224,31 @@ func (r *Router) insert(method, path string, h HandlerFunc, t ntype, pnames []st cn.prefix = search if h != nil { cn.typ = t - cn.handler = h + // handler is a map of methods to applicable handlers, map the inserted method to the + // handler + //cn.handler = map[string]HandlerFunc{method: h} + cn.handler = new(handler) + cn.handler.AddMethodHandler(method, h) cn.pnames = pnames cn.echo = e } } else if l < pl { // Split node - n := newNode(cn.typ, cn.prefix[l:], cn, cn.children, cn.handler, cn.pnames, cn.echo) + //newHandler := map[string]HandlerFunc{} + newHandler := new(handler) + //for k, v := range cn.handler { + //newHandler[k] = v + //} + cn.handler.CopyTo(newHandler) + n := newNode(cn.typ, cn.prefix[l:], cn, cn.children, newHandler, cn.pnames, cn.echo) // Reset parent node cn.typ = stype cn.label = cn.prefix[0] cn.prefix = cn.prefix[:l] cn.children = nil - cn.handler = nil + //cn.handler = map[string]HandlerFunc{} + cn.handler = new(handler) cn.pnames = nil cn.echo = nil @@ -137,12 +257,16 @@ func (r *Router) insert(method, path string, h HandlerFunc, t ntype, pnames []st if l == sl { // At parent node cn.typ = t - cn.handler = h + // add the handler to the node's map of methods to handlers + //cn.handler[method] = h + cn.handler.AddMethodHandler(method, h) cn.pnames = pnames cn.echo = e } else { // Create child node - n = newNode(t, search[l:], cn, nil, h, pnames, e) + newHandler := new(handler) + newHandler.AddMethodHandler(method, h) + n = newNode(t, search[l:], cn, nil, newHandler, pnames, e) cn.addChild(n) } } else if l < sl { @@ -154,12 +278,16 @@ func (r *Router) insert(method, path string, h HandlerFunc, t ntype, pnames []st continue } // Create child node - n := newNode(t, search, cn, nil, h, pnames, e) + newHandler := new(handler) + newHandler.AddMethodHandler(method, h) + n := newNode(t, search, cn, nil, newHandler, pnames, e) cn.addChild(n) } else { // Node already exists if h != nil { - cn.handler = h + // add the handler to the node's map of methods to handlers + //cn.handler[method] = h + cn.handler.AddMethodHandler(method, h) cn.pnames = pnames cn.echo = e } @@ -168,16 +296,18 @@ func (r *Router) insert(method, path string, h HandlerFunc, t ntype, pnames []st } } -func newNode(t ntype, pre string, p *node, c children, h HandlerFunc, pnames []string, e *Echo) *node { +// newNode - create a new router tree node +func newNode(t ntype, pre string, p *node, c children, h *handler, pnames []string, e *Echo) *node { return &node{ typ: t, label: pre[0], prefix: pre, parent: p, children: c, - handler: h, - pnames: pnames, - echo: e, + // create a handler method to handler map for this node + handler: h, + pnames: pnames, + echo: e, } } @@ -212,74 +342,28 @@ func (n *node) findChildWithType(t ntype) *node { return nil } -func (r *Router) findTree(method string) (n *node) { - switch method[0] { - case 'G': // GET - m := uint32(method[2])<<8 | uint32(method[1])<<16 | uint32(method[0])<<24 - if m == 0x47455400 { - n = r.getTree - } - case 'P': // POST, PUT or PATCH - switch method[1] { - case 'O': // POST - m := uint32(method[3]) | uint32(method[2])<<8 | uint32(method[1])<<16 | - uint32(method[0])<<24 - if m == 0x504f5354 { - n = r.postTree - } - case 'U': // PUT - m := uint32(method[2])<<8 | uint32(method[1])<<16 | uint32(method[0])<<24 - if m == 0x50555400 { - n = r.putTree - } - case 'A': // PATCH - m := uint64(method[4])<<24 | uint64(method[3])<<32 | uint64(method[2])<<40 | - uint64(method[1])<<48 | uint64(method[0])<<56 - if m == 0x5041544348000000 { - n = r.patchTree - } - } - case 'D': // DELETE - m := uint64(method[5])<<16 | uint64(method[4])<<24 | uint64(method[3])<<32 | - uint64(method[2])<<40 | uint64(method[1])<<48 | uint64(method[0])<<56 - if m == 0x44454c4554450000 { - n = r.deleteTree - } - case 'C': // CONNECT - m := uint64(method[6])<<8 | uint64(method[5])<<16 | uint64(method[4])<<24 | - uint64(method[3])<<32 | uint64(method[2])<<40 | uint64(method[1])<<48 | - uint64(method[0])<<56 - if m == 0x434f4e4e45435400 { - n = r.connectTree - } - case 'H': // HEAD - m := uint32(method[3]) | uint32(method[2])<<8 | uint32(method[1])<<16 | - uint32(method[0])<<24 - if m == 0x48454144 { - n = r.headTree - } - case 'O': // OPTIONS - m := uint64(method[6])<<8 | uint64(method[5])<<16 | uint64(method[4])<<24 | - uint64(method[3])<<32 | uint64(method[2])<<40 | uint64(method[1])<<48 | - uint64(method[0])<<56 - if m == 0x4f5054494f4e5300 { - n = r.optionsTree - } - case 'T': // TRACE - m := uint64(method[4])<<24 | uint64(method[3])<<32 | uint64(method[2])<<40 | - uint64(method[1])<<48 | uint64(method[0])<<56 - if m == 0x5452414345000000 { - n = r.traceTree +//validMethod - validate that the http method is valid. +func validMethod(method string) bool { + var ok = false + for _, v := range methods { + if v == method { + ok = true + break } } - return + return ok } func (r *Router) Find(method, path string, ctx *Context) (h HandlerFunc, e *Echo) { + // get tree base node from the router + cn := r.tree + + e = cn.echo h = notFoundHandler - cn := r.findTree(method) // Current node as root - if cn == nil { - h = badRequestHandler + + if !validMethod(method) { + // if the method is completely invalid + h = methodNotAllowedHandler(ctx, cn.handler.allowedMethods) return } @@ -306,10 +390,19 @@ func (r *Router) Find(method, path string, ctx *Context) (h HandlerFunc, e *Echo for { if search == "" { if cn.handler != nil { - // Found - ctx.pnames = cn.pnames - h = cn.handler + // Found route, check if method is applicable + //var ok = false + //h, ok = cn.handler[method] e = cn.echo + //if !ok { + theHandler, allowedMethods := cn.handler.GetMethodHandler(method) + if theHandler == nil { + // route is valid, but method is not allowed, 405 + h = methodNotAllowedHandler(ctx, allowedMethods) + return + } + ctx.pnames = cn.pnames + h = theHandler } return } @@ -347,13 +440,12 @@ func (r *Router) Find(method, path string, ctx *Context) (h HandlerFunc, e *Echo } if search == "" { - if cn.handler == nil { - // Look up for match-any, might have an empty value for *, e.g. - // serving a directory. Issue #207 - cn = cn.findChildWithType(mtype) - ctx.pvalues[len(cn.pnames)-1] = "" + // TODO: Needs improvement + if cn.findChildWithType(mtype) == nil { + continue } - continue + // Empty value + goto MatchAny } // Static node @@ -391,11 +483,11 @@ func (r *Router) Find(method, path string, ctx *Context) (h HandlerFunc, e *Echo // Match-any node MatchAny: - // c = cn.getChild() + // c = cn.getChild() c = cn.findChildWithType(mtype) if c != nil { cn = c - ctx.pvalues[len(cn.pnames)-1] = search + ctx.pvalues[len(ctx.pvalues)-1] = search search = "" // End search continue } diff --git a/router_test.go b/router_test.go index 40a1e62b7..77a001783 100644 --- a/router_test.go +++ b/router_test.go @@ -321,32 +321,19 @@ func TestRouterTwoParam(t *testing.T) { func TestRouterMatchAny(t *testing.T) { e := New() r := e.router - - // Routes - r.Add(GET, "/", func(*Context) error { - return nil - }, e) - r.Add(GET, "/*", func(*Context) error { - return nil - }, e) r.Add(GET, "/users/*", func(*Context) error { return nil }, e) c := NewContext(nil, nil, e) - h, _ := r.Find(GET, "/", c) + h, _ := r.Find(GET, "/users/", c) if assert.NotNil(t, h) { assert.Equal(t, "", c.P(0)) } - h, _ = r.Find(GET, "/download", c) - if assert.NotNil(t, h) { - assert.Equal(t, "download", c.P(0)) - } - - h, _ = r.Find(GET, "/users/joe", c) + h, _ = r.Find(GET, "/users/1", c) if assert.NotNil(t, h) { - assert.Equal(t, "joe", c.P(0)) + assert.Equal(t, "1", c.P(0)) } } @@ -394,7 +381,7 @@ func TestRouterMultiRoute(t *testing.T) { r.Add(GET, "/users/:id", func(c *Context) error { return nil }, e) - c := NewContext(nil, nil, e) + c := NewContext(nil, new(Response), e) // Route > /users h, _ := r.Find(GET, "/users", c) @@ -415,6 +402,14 @@ func TestRouterMultiRoute(t *testing.T) { he := h(c).(*HTTPError) assert.Equal(t, http.StatusNotFound, he.code) } + + // Invalid Method for Resource + c.response.writer = httptest.NewRecorder() + h, _ = r.Find("INVALID", "/users", c) + if assert.IsType(t, new(HTTPError), h(c)) { + he := h(c).(*HTTPError) + assert.Equal(t, http.StatusMethodNotAllowed, he.code) + } } func TestRouterPriority(t *testing.T) { @@ -499,7 +494,7 @@ func TestRouterPriority(t *testing.T) { if assert.NotNil(t, h) { h(c) assert.Equal(t, 7, c.Get("g")) - assert.Equal(t, "joe/books", c.Param("_*")) + assert.Equal(t, "joe/books", c.Param("_name")) } } @@ -553,7 +548,10 @@ func TestRouterAPI(t *testing.T) { return nil }, e) } - c := NewContext(nil, nil, e) + + response := NewResponse(httptest.NewRecorder()) + + c := NewContext(nil, response, e) for _, route := range api { h, _ := r.Find(route.Method, route.Path, c) if assert.NotNil(t, h) {