From cd48f0153b85087ada42c34d32d4b5dc02bfae84 Mon Sep 17 00:00:00 2001 From: limpo1989 Date: Fri, 15 Dec 2023 22:14:32 +0800 Subject: [PATCH] Allowed overridden request in middlewares --- web/bind.go | 29 +- web/bind_test.go | 63 +++ web/context.go | 90 +-- web/router.go | 71 +-- web/router_test.go | 1315 ++++++++++++++++++++++++++++++++++++++++++++ web/tree.go | 14 +- web/tree_test.go | 15 +- 7 files changed, 1489 insertions(+), 108 deletions(-) create mode 100644 web/bind_test.go create mode 100644 web/router_test.go diff --git a/web/bind.go b/web/bind.go index 513b4bf3..6affbd0b 100644 --- a/web/bind.go +++ b/web/bind.go @@ -58,11 +58,11 @@ func Bind(fn interface{}, render Renderer) http.HandlerFunc { switch h := fn.(type) { case http.HandlerFunc: - return warpHandlerCtx(h) + return warpContext(h) case http.Handler: - return warpHandlerCtx(h.ServeHTTP) + return warpContext(h.ServeHTTP) case func(http.ResponseWriter, *http.Request): - return warpHandlerCtx(h) + return warpContext(h) default: // valid func if err := validMappingFunc(fnType); nil != err { @@ -75,12 +75,8 @@ func Bind(fn interface{}, render Renderer) http.HandlerFunc { return func(writer http.ResponseWriter, request *http.Request) { // param of context - ctx := request.Context() - webCtx := FromContext(ctx) - if nil == webCtx { - webCtx = &Context{Writer: writer, Request: request} - ctx = WithContext(request.Context(), webCtx) - } + webCtx := &Context{Writer: writer, Request: request} + ctx := WithContext(request.Context(), webCtx) defer func() { if nil != request.MultipartForm { @@ -205,22 +201,13 @@ func validMappingFunc(fnType reflect.Type) error { return nil } -func warpHandlerCtx(handler http.HandlerFunc) http.HandlerFunc { +func warpContext(handler http.HandlerFunc) http.HandlerFunc { return func(writer http.ResponseWriter, request *http.Request) { - if nil != FromContext(request.Context()) { - handler.ServeHTTP(writer, request) - } else { - webCtx := &Context{Writer: writer, Request: request} - handler.ServeHTTP(writer, requestWithCtx(request, webCtx)) - } + webCtx := &Context{Writer: writer, Request: request} + handler.ServeHTTP(writer, request.WithContext(WithContext(request.Context(), webCtx))) } } -func requestWithCtx(r *http.Request, webCtx *Context) *http.Request { - ctx := WithContext(r.Context(), webCtx) - return r.WithContext(ctx) -} - func defaultJsonRender(ctx *Context, err error, result interface{}) { var code = 0 diff --git a/web/bind_test.go b/web/bind_test.go new file mode 100644 index 00000000..764d32a6 --- /dev/null +++ b/web/bind_test.go @@ -0,0 +1,63 @@ +package web + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "go-spring.dev/spring/internal/utils/assert" +) + +func TestBindWithoutParams(t *testing.T) { + + var handler = func(ctx context.Context) string { + webCtx := FromContext(ctx) + assert.NotNil(t, webCtx) + return "0987654321" + } + + request := httptest.NewRequest(http.MethodGet, "/get", strings.NewReader("{}")) + response := httptest.NewRecorder() + Bind(handler, RendererFunc(defaultJsonRender))(response, request) + assert.Equal(t, response.Body.String(), "{\"code\":0,\"data\":\"0987654321\"}\n") +} + +func TestBindWithParams(t *testing.T) { + var handler = func(ctx context.Context, req struct { + Username string `json:"username"` + Password string `json:"password"` + }) string { + webCtx := FromContext(ctx) + assert.NotNil(t, webCtx) + assert.Equal(t, req.Username, "aaa") + assert.Equal(t, req.Password, "88888888") + return "success" + } + + request := httptest.NewRequest(http.MethodPost, "/post", strings.NewReader(`{"username": "aaa", "password": "88888888"}`)) + request.Header.Add("Content-Type", "application/json") + response := httptest.NewRecorder() + Bind(handler, RendererFunc(defaultJsonRender))(response, request) + assert.Equal(t, response.Body.String(), "{\"code\":0,\"data\":\"success\"}\n") +} + +func TestBindWithParamsAndError(t *testing.T) { + var handler = func(ctx context.Context, req struct { + Username string `json:"username"` + Password string `json:"password"` + }) (string, error) { + webCtx := FromContext(ctx) + assert.NotNil(t, webCtx) + assert.Equal(t, req.Username, "aaa") + assert.Equal(t, req.Password, "88888888") + return "requestid: 9999999", Error(403, "user locked") + } + + request := httptest.NewRequest(http.MethodPost, "/post", strings.NewReader(`{"username": "aaa", "password": "88888888"}`)) + request.Header.Add("Content-Type", "application/json") + response := httptest.NewRecorder() + Bind(handler, RendererFunc(defaultJsonRender))(response, request) + assert.Equal(t, response.Body.String(), "{\"code\":403,\"message\":\"user locked\",\"data\":\"requestid: 9999999\"}\n") +} diff --git a/web/context.go b/web/context.go index 33a8e745..754598e8 100644 --- a/web/context.go +++ b/web/context.go @@ -54,33 +54,9 @@ type Context struct { // or to be sent by a client. Request *http.Request - routes Routes - // SameSite allows a server to define a cookie attribute making it impossible for // the browser to send this cookie along with cross-site requests. sameSite http.SameSite - - // URLParams are the stack of routeParams captured during the - // routing lifecycle across a stack of sub-routers. - urlParams RouteParams - - // routeParams matched for the current sub-router. It is - // intentionally unexported so it can't be tampered. - routeParams RouteParams - - // Routing path/method override used during the route search. - routePath string - routeMethod string - - // The endpoint routing pattern that matched the request URI path - // or `RoutePath` of the current sub-router. This value will update - // during the lifecycle of a request passing through a stack of - // sub-routers. - routePattern string - routePatterns []string - - methodNotAllowed bool - methodsAllowed []methodTyp } // Context returns the request's context. @@ -116,7 +92,10 @@ func (c *Context) Cookie(name string) (string, bool) { // PathParam returns the named variables in the request. func (c *Context) PathParam(name string) (string, bool) { - return c.urlParams.Get(name) + if ctx := FromRouteContext(c.Request.Context()); nil != ctx { + return ctx.URLParams.Get(name) + } + return "", false } // QueryParam returns the named query in the request. @@ -315,18 +294,53 @@ func (c *Context) ClientIP() string { return remoteIP.String() } +type routeContextKey struct{} + +func WithRouteContext(parent context.Context, ctx *RouteContext) context.Context { + return context.WithValue(parent, routeContextKey{}, ctx) +} + +func FromRouteContext(ctx context.Context) *RouteContext { + if v := ctx.Value(routeContextKey{}); v != nil { + return v.(*RouteContext) + } + return nil +} + +type RouteContext struct { + Routes Routes + // URLParams are the stack of routeParams captured during the + // routing lifecycle across a stack of sub-routers. + URLParams RouteParams + + // routeParams matched for the current sub-router. It is + // intentionally unexported so it can't be tampered. + routeParams RouteParams + + // Routing path/method override used during the route search. + RoutePath string + RouteMethod string + + // The endpoint routing pattern that matched the request URI path + // or `RoutePath` of the current sub-router. This value will update + // during the lifecycle of a request passing through a stack of + // sub-routers. + RoutePattern string + routePatterns []string + + methodNotAllowed bool + methodsAllowed []methodTyp +} + // Reset context to initial state -func (c *Context) Reset() { - c.Writer = nil - c.Request = nil - c.sameSite = 0 - c.routes = nil - c.routePath = "" - c.routeMethod = "" - c.routePattern = "" +func (c *RouteContext) Reset() { + c.Routes = nil + c.RoutePath = "" + c.RouteMethod = "" + c.RoutePattern = "" c.routePatterns = c.routePatterns[:0] - c.urlParams.Keys = c.urlParams.Keys[:0] - c.urlParams.Values = c.urlParams.Values[:0] + c.URLParams.Keys = c.URLParams.Keys[:0] + c.URLParams.Values = c.URLParams.Values[:0] c.routeParams.Keys = c.routeParams.Keys[:0] c.routeParams.Values = c.routeParams.Values[:0] c.methodNotAllowed = false @@ -345,9 +359,9 @@ func (s *RouteParams) Add(key, value string) { } func (s *RouteParams) Get(key string) (value string, ok bool) { - for index, k := range s.Keys { - if key == k { - return s.Values[index], true + for i := len(s.Keys) - 1; i >= 0; i-- { + if s.Keys[i] == key { + return s.Values[i], true } } return "", false diff --git a/web/router.go b/web/router.go index 2484dda5..82b07b4a 100644 --- a/web/router.go +++ b/web/router.go @@ -48,7 +48,7 @@ type Router interface { Renderer(renderer Renderer) // Group creates a new router group. - Group(pattern string) Router + Group(pattern string, fn ...func(subRouter Router)) Router // Handle registers a new route with a matcher for the URL pattern. Handle(pattern string, handler http.Handler) @@ -254,7 +254,7 @@ type Routes interface { // Match searches the routing tree for a handler that matches // the method/path - similar to routing a http request, but without // executing the handler thereafter. - Match(webCtx *Context, method, path string) bool + Match(ctx *RouteContext, method, path string) bool } // NewRouter returns a new router instance. @@ -262,7 +262,7 @@ func NewRouter() Router { return &routerGroup{ tree: &node{}, renderer: RendererFunc(defaultJsonRender), - pool: &sync.Pool{New: func() interface{} { return &Context{} }}, + pool: &sync.Pool{New: func() interface{} { return &RouteContext{} }}, } } @@ -313,25 +313,23 @@ func (rg *routerGroup) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - webCtx := FromContext(r.Context()) - if nil != webCtx { + ctx := FromRouteContext(r.Context()) + if nil != ctx { rg.handler.ServeHTTP(w, r) return } // get context from pool - webCtx = rg.pool.Get().(*Context) - webCtx.Writer = w - webCtx.Request = r - webCtx.routes = rg + ctx = rg.pool.Get().(*RouteContext) + ctx.Routes = rg // with context - r = r.WithContext(WithContext(r.Context(), webCtx)) + r = r.WithContext(WithRouteContext(r.Context(), ctx)) rg.handler.ServeHTTP(w, r) // put context to pool - webCtx.Reset() - rg.pool.Put(webCtx) + ctx.Reset() + rg.pool.Put(ctx) } @@ -346,23 +344,23 @@ func (rg *routerGroup) updateSubRoutes(fn func(subMux *routerGroup)) { } } -func (rg *routerGroup) nextRoutePath(webCtx *Context) string { +func (rg *routerGroup) nextRoutePath(ctx *RouteContext) string { routePath := "/" - nx := len(webCtx.routeParams.Keys) - 1 // index of last param in list - if nx >= 0 && webCtx.routeParams.Keys[nx] == "*" && len(webCtx.routeParams.Values) > nx { - routePath = "/" + webCtx.routeParams.Values[nx] + nx := len(ctx.routeParams.Keys) - 1 // index of last param in list + if nx >= 0 && ctx.routeParams.Keys[nx] == "*" && len(ctx.routeParams.Values) > nx { + routePath = "/" + ctx.routeParams.Values[nx] } return routePath } -// routeHTTP routes a http.Request through the routing tree to serve +// routeHTTP Routes a http.Request through the routing tree to serve // the matching handler for a particular http method. func (rg *routerGroup) routeHTTP(w http.ResponseWriter, r *http.Request) { // Grab the route context object - webCtx := FromContext(r.Context()) + ctx := FromRouteContext(r.Context()) // The request routing path - routePath := webCtx.routePath + routePath := ctx.RoutePath if routePath == "" { if r.URL.RawPath != "" { routePath = r.URL.RawPath @@ -374,22 +372,22 @@ func (rg *routerGroup) routeHTTP(w http.ResponseWriter, r *http.Request) { } } - if webCtx.routeMethod == "" { - webCtx.routeMethod = r.Method + if ctx.RouteMethod == "" { + ctx.RouteMethod = r.Method } - method, ok := methodMap[webCtx.routeMethod] + method, ok := methodMap[ctx.RouteMethod] if !ok { rg.NotAllowedHandler().ServeHTTP(w, r) return } // Find the route - if _, _, h := rg.tree.FindRoute(webCtx, method, routePath); h != nil { + if _, _, h := rg.tree.FindRoute(ctx, method, routePath); h != nil { h.ServeHTTP(w, r) return } - if webCtx.methodNotAllowed { + if ctx.methodNotAllowed { rg.NotAllowedHandler().ServeHTTP(w, r) } else { rg.NotFoundHandler().ServeHTTP(w, r) @@ -397,8 +395,11 @@ func (rg *routerGroup) routeHTTP(w http.ResponseWriter, r *http.Request) { } // Group creates a new router group. -func (rg *routerGroup) Group(pattern string) Router { +func (rg *routerGroup) Group(pattern string, fn ...func(subRouter Router)) Router { subRouter := &routerGroup{tree: &node{}, renderer: rg.renderer, pool: rg.pool} + for _, f := range fn { + f(subRouter) + } rg.Mount(pattern, subRouter) return subRouter } @@ -427,15 +428,15 @@ func (rg *routerGroup) Mount(pattern string, handler http.Handler) { } mountHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - webCtx := FromContext(r.Context()) + ctx := FromRouteContext(r.Context()) // shift the url path past the previous subrouter - webCtx.routePath = rg.nextRoutePath(webCtx) + ctx.RoutePath = rg.nextRoutePath(ctx) // reset the wildcard URLParam which connects the subrouter - n := len(webCtx.urlParams.Keys) - 1 - if n >= 0 && webCtx.urlParams.Keys[n] == "*" && len(webCtx.urlParams.Values) > n { - webCtx.urlParams.Values[n] = "" + n := len(ctx.URLParams.Keys) - 1 + if n >= 0 && ctx.URLParams.Keys[n] == "*" && len(ctx.URLParams.Values) > n { + ctx.URLParams.Values[n] = "" } handler.ServeHTTP(w, r) @@ -584,7 +585,7 @@ func (rg *routerGroup) MethodNotAllowed(handler http.HandlerFunc) { } // Routes returns a slice of routing information from the tree, -// useful for traversing available routes of a router. +// useful for traversing available Routes of a router. func (rg *routerGroup) Routes() []Route { return rg.tree.routes() } @@ -597,17 +598,17 @@ func (rg *routerGroup) Middlewares() Middlewares { // Match searches the routing tree for a handler that matches the method/path. // It's similar to routing a http request, but without executing the handler // thereafter. -func (rg *routerGroup) Match(webCtx *Context, method, path string) bool { +func (rg *routerGroup) Match(ctx *RouteContext, method, path string) bool { m, ok := methodMap[method] if !ok { return false } - node, _, h := rg.tree.FindRoute(webCtx, m, path) + node, _, h := rg.tree.FindRoute(ctx, m, path) if node != nil && node.subroutes != nil { - webCtx.routePath = rg.nextRoutePath(webCtx) - return node.subroutes.Match(webCtx, method, webCtx.routePath) + ctx.RoutePath = rg.nextRoutePath(ctx) + return node.subroutes.Match(ctx, method, ctx.RoutePath) } return h != nil diff --git a/web/router_test.go b/web/router_test.go new file mode 100644 index 00000000..14217710 --- /dev/null +++ b/web/router_test.go @@ -0,0 +1,1315 @@ +package web + +import ( + "bytes" + "context" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" +) + +func URLParam(r *http.Request, name string) string { + if ctx := FromRouteContext(r.Context()); nil != ctx { + v, _ := ctx.URLParams.Get(name) + return v + } + return "" +} + +func TestMuxBasic(t *testing.T) { + var count uint64 + countermw := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count++ + next.ServeHTTP(w, r) + }) + } + + usermw := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ctx = context.WithValue(ctx, ctxKey{"user"}, "peter") + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + }) + } + + exmw := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), ctxKey{"ex"}, "a") + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + }) + } + + logbuf := bytes.NewBufferString("") + logmsg := "logmw test" + logmw := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + logbuf.WriteString(logmsg) + next.ServeHTTP(w, r) + }) + } + + cxindex := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + user := ctx.Value(ctxKey{"user"}).(string) + w.WriteHeader(200) + w.Write([]byte(fmt.Sprintf("hi %s", user))) + } + + headPing := func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Ping", "1") + w.WriteHeader(200) + } + + createPing := func(w http.ResponseWriter, r *http.Request) { + // create .... + w.WriteHeader(201) + } + + pingAll2 := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("ping all2")) + } + + pingOne := func(w http.ResponseWriter, r *http.Request) { + idParam := URLParam(r, "id") + w.WriteHeader(200) + w.Write([]byte(fmt.Sprintf("ping one id: %s", idParam))) + } + + pingWoop := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("woop." + URLParam(r, "iidd"))) + } + + catchAll := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("catchall")) + } + + m := NewRouter() + m.Use(countermw) + m.Use(usermw) + m.Use(exmw) + m.Use(logmw) + m.Get("/", cxindex) + m.Get("/ping/all2", pingAll2) + + m.Head("/ping", headPing) + m.Post("/ping", createPing) + m.Get("/ping/{id}", pingWoop) + m.Get("/ping/{id}", pingOne) // expected to overwrite to pingOne handler + m.Get("/ping/{iidd}/woop", pingWoop) + m.HandleFunc("/admin/*", catchAll) + // m.Post("/admin/*", catchAll) + + ts := httptest.NewServer(m) + defer ts.Close() + + // GET / + if _, body := testRequest(t, ts, "GET", "/", nil); body != "hi peter" { + t.Fatalf(body) + } + tlogmsg, _ := logbuf.ReadString(0) + if tlogmsg != logmsg { + t.Error("expecting log message from middleware:", logmsg) + } + + // GET /ping/all2 + if _, body := testRequest(t, ts, "GET", "/ping/all2", nil); body != "ping all2" { + t.Fatalf(body) + } + + // GET /ping/123 + if _, body := testRequest(t, ts, "GET", "/ping/123", nil); body != "ping one id: 123" { + t.Fatalf(body) + } + + // GET /ping/allan + if _, body := testRequest(t, ts, "GET", "/ping/allan", nil); body != "ping one id: allan" { + t.Fatalf(body) + } + + // GET /ping/1/woop + if _, body := testRequest(t, ts, "GET", "/ping/1/woop", nil); body != "woop.1" { + t.Fatalf(body) + } + + // HEAD /ping + resp, err := http.Head(ts.URL + "/ping") + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != 200 { + t.Error("head failed, should be 200") + } + if resp.Header.Get("X-Ping") == "" { + t.Error("expecting X-Ping header") + } + + // GET /admin/catch-this + if _, body := testRequest(t, ts, "GET", "/admin/catch-thazzzzz", nil); body != "catchall" { + t.Fatalf(body) + } + + // POST /admin/catch-this + resp, err = http.Post(ts.URL+"/admin/casdfsadfs", "text/plain", bytes.NewReader([]byte{})) + if err != nil { + t.Fatal(err) + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + t.Error("POST failed, should be 200") + } + + if string(body) != "catchall" { + t.Error("expecting response body: 'catchall'") + } + + // Custom http method DIE /ping/1/woop + if resp, body := testRequest(t, ts, "DIE", "/ping/1/woop", nil); body != "405 method not allowed\n" || resp.StatusCode != 405 { + t.Fatalf(fmt.Sprintf("expecting 405 status and empty body, got %d '%s'", resp.StatusCode, body)) + } +} + +func TestMuxMounts(t *testing.T) { + r := NewRouter() + + r.Get("/{hash}", func(w http.ResponseWriter, r *http.Request) { + v := URLParam(r, "hash") + w.Write([]byte(fmt.Sprintf("/%s", v))) + }) + + (func(r Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + v := URLParam(r, "hash") + w.Write([]byte(fmt.Sprintf("/%s/share", v))) + }) + r.Get("/{network}", func(w http.ResponseWriter, r *http.Request) { + v := URLParam(r, "hash") + n := URLParam(r, "network") + w.Write([]byte(fmt.Sprintf("/%s/share/%s", v, n))) + }) + })(r.Group("/{hash}/share")) + + m := NewRouter().(*routerGroup) + m.Mount("/sharing", r) + + ts := httptest.NewServer(m) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/sharing/aBc", nil); body != "/aBc" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/sharing/aBc/share", nil); body != "/aBc/share" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/sharing/aBc/share/twitter", nil); body != "/aBc/share/twitter" { + t.Fatalf(body) + } +} + +func TestMuxPlain(t *testing.T) { + r := NewRouter() + r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("bye")) + }) + r.NotFound(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte("nothing here")) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" { + t.Fatalf(body) + } +} + +func TestMuxEmptyRoutes(t *testing.T) { + mux := NewRouter() + + apiRouter := NewRouter() + // oops, we forgot to declare any route handlers + + mux.Handle("/api*", apiRouter) + + if _, body := testHandler(t, mux, "GET", "/", nil); body != "404 page not found\n" { + t.Fatalf(body) + } + + if _, body := testHandler(t, apiRouter, "GET", "/", nil); body != "404 page not found\n" { + t.Fatalf(body) + } +} + +// Test a mux that routes a trailing slash, see also middleware/strip_test.go +// for an example of using a middleware to handle trailing slashes. +func TestMuxTrailingSlash(t *testing.T) { + r := NewRouter().(*routerGroup) + r.NotFound(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte("nothing here")) + }) + + subRoutes := NewRouter() + indexHandler := func(w http.ResponseWriter, r *http.Request) { + accountID := URLParam(r, "accountID") + w.Write([]byte(accountID)) + } + subRoutes.Get("/", indexHandler) + + r.Mount("/accounts/{accountID}", subRoutes) + r.Get("/accounts/{accountID}/", indexHandler) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/accounts/admin", nil); body != "admin" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/accounts/admin/", nil); body != "admin" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" { + t.Fatalf(body) + } +} + +func TestMethodNotAllowed(t *testing.T) { + r := NewRouter() + + r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi, get")) + }) + + r.Head("/hi", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi, head")) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + t.Run("Registered Method", func(t *testing.T) { + resp, _ := testRequest(t, ts, "GET", "/hi", nil) + if resp.StatusCode != 200 { + t.Fatal(resp.Status) + } + if resp.Header.Values("Allow") != nil { + t.Fatal("allow should be empty when method is registered") + } + }) + + t.Run("Unregistered Method", func(t *testing.T) { + resp, _ := testRequest(t, ts, "POST", "/hi", nil) + if resp.StatusCode != 405 { + t.Fatal(resp.Status) + } + }) +} + +func TestMuxNestedMethodNotAllowed(t *testing.T) { + r := NewRouter().(*routerGroup) + r.Get("/root", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("root")) + }) + r.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(405) + w.Write([]byte("root 405")) + }) + + sr1 := NewRouter() + sr1.Get("/sub1", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("sub1")) + }) + sr1.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(405) + w.Write([]byte("sub1 405")) + }) + + sr2 := NewRouter() + sr2.Get("/sub2", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("sub2")) + }) + + pathVar := NewRouter() + pathVar.Get("/{var}", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("pv")) + }) + pathVar.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(405) + w.Write([]byte("pv 405")) + }) + + r.Mount("/prefix1", sr1) + r.Mount("/prefix2", sr2) + r.Mount("/pathVar", pathVar) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/root", nil); body != "root" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "PUT", "/root", nil); body != "root 405" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/prefix1/sub1", nil); body != "sub1" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "PUT", "/prefix1/sub1", nil); body != "sub1 405" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/prefix2/sub2", nil); body != "sub2" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "PUT", "/prefix2/sub2", nil); body != "root 405" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/pathVar/myvar", nil); body != "pv" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "DELETE", "/pathVar/myvar", nil); body != "pv 405" { + t.Fatalf(body) + } +} + +func TestMuxComplicatedNotFound(t *testing.T) { + decorateRouter := func(r *routerGroup) { + // Root router with groups + r.Get("/auth", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("auth get")) + }) + (func(r Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("public get")) + }) + })(r.Group("/public")) + + // sub router with groups + sub0 := NewRouter() + (func(r Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("private get")) + }) + })(sub0.Group("/resource")) + r.Mount("/private", sub0) + + // sub router with groups + sub1 := NewRouter() + (func(r Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("private get")) + }) + })(sub1.Group("/resource")) + } + + testNotFound := func(t *testing.T, r *routerGroup) { + ts := httptest.NewServer(r) + defer ts.Close() + + // check that we didn't break correct routes + if _, body := testRequest(t, ts, "GET", "/auth", nil); body != "auth get" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/public", nil); body != "public get" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/public/", nil); body != "public get" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/private/resource", nil); body != "private get" { + t.Fatalf(body) + } + // check custom not-found on all levels + if _, body := testRequest(t, ts, "GET", "/nope", nil); body != "custom not-found" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/public/nope", nil); body != "custom not-found" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/private/nope", nil); body != "custom not-found" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/private/resource/nope", nil); body != "custom not-found" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/private_mw/nope", nil); body != "custom not-found" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/private_mw/resource/nope", nil); body != "custom not-found" { + t.Fatalf(body) + } + // check custom not-found on trailing slash routes + if _, body := testRequest(t, ts, "GET", "/auth/", nil); body != "custom not-found" { + t.Fatalf(body) + } + } + + t.Run("pre", func(t *testing.T) { + r := NewRouter().(*routerGroup) + r.NotFound(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("custom not-found")) + }) + decorateRouter(r) + testNotFound(t, r) + }) + + t.Run("post", func(t *testing.T) { + r := NewRouter().(*routerGroup) + decorateRouter(r) + r.NotFound(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("custom not-found")) + }) + testNotFound(t, r) + }) +} + +func TestMuxMiddlewareStack(t *testing.T) { + var stdmwInit, stdmwHandler uint64 + stdmw := func(next http.Handler) http.Handler { + stdmwInit++ + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + stdmwHandler++ + next.ServeHTTP(w, r) + }) + } + _ = stdmw + + var ctxmwInit, ctxmwHandler uint64 + ctxmw := func(next http.Handler) http.Handler { + ctxmwInit++ + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctxmwHandler++ + ctx := r.Context() + ctx = context.WithValue(ctx, ctxKey{"count.ctxmwHandler"}, ctxmwHandler) + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + }) + } + + r := NewRouter() + r.Use(stdmw) + r.Use(ctxmw) + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/ping" { + w.Write([]byte("pong")) + return + } + next.ServeHTTP(w, r) + }) + }) + + var handlerCount uint64 + + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + handlerCount++ + ctx := r.Context() + ctxmwHandlerCount := ctx.Value(ctxKey{"count.ctxmwHandler"}).(uint64) + w.Write([]byte(fmt.Sprintf("inits:%d reqs:%d ctxValue:%d", ctxmwInit, handlerCount, ctxmwHandlerCount))) + }) + + r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("wooot")) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + testRequest(t, ts, "GET", "/", nil) + testRequest(t, ts, "GET", "/", nil) + var body string + _, body = testRequest(t, ts, "GET", "/", nil) + if body != "inits:1 reqs:3 ctxValue:3" { + t.Fatalf("got: '%s'", body) + } + + _, body = testRequest(t, ts, "GET", "/ping", nil) + if body != "pong" { + t.Fatalf("got: '%s'", body) + } +} + +func TestMuxSubroutesBasic(t *testing.T) { + hIndex := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("index")) + }) + hArticlesList := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("articles-list")) + }) + hSearchArticles := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("search-articles")) + }) + hGetArticle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(fmt.Sprintf("get-article:%s", URLParam(r, "id")))) + }) + hSyncArticle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(fmt.Sprintf("sync-article:%s", URLParam(r, "id")))) + }) + + r := NewRouter() + // var rr1, rr2 *Mux + r.Get("/", hIndex) + (func(r Router) { + // rr1 = r.(*Mux) + r.Get("/", hArticlesList) + r.Get("/search", hSearchArticles) + (func(r Router) { + // rr2 = r.(*Mux) + r.Get("/", hGetArticle) + r.Get("/sync", hSyncArticle) + })(r.Group("/{id}")) + })(r.Group("/articles")) + + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + // debugPrintTree(0, 0, r.tree, 0) + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + // debugPrintTree(0, 0, rr1.tree, 0) + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + // debugPrintTree(0, 0, rr2.tree, 0) + // log.Println("~~~~~~~~~") + // log.Println("~~~~~~~~~") + + ts := httptest.NewServer(r) + defer ts.Close() + + var body, expected string + + _, body = testRequest(t, ts, "GET", "/", nil) + expected = "index" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + _, body = testRequest(t, ts, "GET", "/articles", nil) + expected = "articles-list" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + _, body = testRequest(t, ts, "GET", "/articles/search", nil) + expected = "search-articles" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + _, body = testRequest(t, ts, "GET", "/articles/123", nil) + expected = "get-article:123" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + _, body = testRequest(t, ts, "GET", "/articles/123/sync", nil) + expected = "sync-article:123" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } +} + +func TestMuxSubroutes(t *testing.T) { + hHubView1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hub1")) + }) + hHubView2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hub2")) + }) + hHubView3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hub3")) + }) + hAccountView1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("account1")) + }) + hAccountView2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("account2")) + }) + + r := NewRouter().(*routerGroup) + r.Get("/hubs/{hubID}/view", hHubView1) + r.Get("/hubs/{hubID}/view/*", hHubView2) + + sr := NewRouter().(*routerGroup) + sr.Get("/", hHubView3) + r.Mount("/hubs/{hubID}/users", sr) + r.Get("/hubs/{hubID}/users/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hub3 override")) + }) + + sr3 := NewRouter() + sr3.Get("/", hAccountView1) + sr3.Get("/hi", hAccountView2) + + // var sr2 *Mux + (func(r Router) { + rg := r.(*routerGroup) // sr2 + // r.Get("/", hAccountView1) + rg.Mount("/", sr3) + })(r.Group("/accounts/{accountID}")) + + // This is the same as the r.Route() call mounted on sr2 + // sr2 := NewRouter() + // sr2.Mount("/", sr3) + // r.Mount("/accounts/{accountID}", sr2) + + ts := httptest.NewServer(r) + defer ts.Close() + + var body, expected string + + _, body = testRequest(t, ts, "GET", "/hubs/123/view", nil) + expected = "hub1" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + _, body = testRequest(t, ts, "GET", "/hubs/123/view/index.html", nil) + expected = "hub2" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + _, body = testRequest(t, ts, "GET", "/hubs/123/users", nil) + expected = "hub3" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + _, body = testRequest(t, ts, "GET", "/hubs/123/users/", nil) + expected = "hub3 override" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + _, body = testRequest(t, ts, "GET", "/accounts/44", nil) + expected = "account1" + if body != expected { + t.Fatalf("request:%s expected:%s got:%s", "GET /accounts/44", expected, body) + } + _, body = testRequest(t, ts, "GET", "/accounts/44/hi", nil) + expected = "account2" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + + // Test that we're building the routingPatterns properly + router := r + req, _ := http.NewRequest("GET", "/accounts/44/hi", nil) + + rctx := &RouteContext{} + req = req.WithContext(context.WithValue(req.Context(), routeContextKey{}, rctx)) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + body = w.Body.String() + expected = "account2" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } + + routePatterns := rctx.routePatterns + if len(rctx.routePatterns) != 3 { + t.Fatalf("expected 3 routing patterns, got:%d", len(rctx.routePatterns)) + } + expected = "/accounts/{accountID}/*" + if routePatterns[0] != expected { + t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[0]) + } + expected = "/*" + if routePatterns[1] != expected { + t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[1]) + } + expected = "/hi" + if routePatterns[2] != expected { + t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[2]) + } + +} + +func TestSingleHandler(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + name := URLParam(r, "name") + w.Write([]byte("hi " + name)) + }) + + r, _ := http.NewRequest("GET", "/", nil) + rctx := &RouteContext{} + r = r.WithContext(context.WithValue(r.Context(), routeContextKey{}, rctx)) + rctx.URLParams.Add("name", "joe") + + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + + body := w.Body.String() + expected := "hi joe" + if body != expected { + t.Fatalf("expected:%s got:%s", expected, body) + } +} + +// TODO: a Router wrapper test.. +// +// type ACLMux struct { +// *Mux +// XX string +// } +// +// func NewACLMux() *ACLMux { +// return &ACLMux{Mux: NewRouter(), XX: "hihi"} +// } +// +// // TODO: this should be supported... +// func TestWoot(t *testing.T) { +// var r Router = NewRouter() +// +// var r2 Router = NewACLMux() //NewRouter() +// r2.Get("/hi", func(w http.ResponseWriter, r *http.Request) { +// w.Write([]byte("hi")) +// }) +// +// r.Mount("/", r2) +// } + +func TestServeHTTPExistingContext(t *testing.T) { + r := NewRouter() + r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { + s, _ := r.Context().Value(ctxKey{"testCtx"}).(string) + w.Write([]byte(s)) + }) + r.NotFound(func(w http.ResponseWriter, r *http.Request) { + s, _ := r.Context().Value(ctxKey{"testCtx"}).(string) + w.WriteHeader(404) + w.Write([]byte(s)) + }) + + testcases := []struct { + Ctx context.Context + Method string + Path string + ExpectedBody string + ExpectedStatus int + }{ + { + Method: "GET", + Path: "/hi", + Ctx: context.WithValue(context.Background(), ctxKey{"testCtx"}, "hi ctx"), + ExpectedStatus: 200, + ExpectedBody: "hi ctx", + }, + { + Method: "GET", + Path: "/hello", + Ctx: context.WithValue(context.Background(), ctxKey{"testCtx"}, "nothing here ctx"), + ExpectedStatus: 404, + ExpectedBody: "nothing here ctx", + }, + } + + for _, tc := range testcases { + resp := httptest.NewRecorder() + req, err := http.NewRequest(tc.Method, tc.Path, nil) + if err != nil { + t.Fatalf("%v", err) + } + req = req.WithContext(tc.Ctx) + r.ServeHTTP(resp, req) + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("%v", err) + } + if resp.Code != tc.ExpectedStatus { + t.Fatalf("%v != %v", tc.ExpectedStatus, resp.Code) + } + if string(b) != tc.ExpectedBody { + t.Fatalf("%s != %s", tc.ExpectedBody, b) + } + } +} + +func TestMiddlewarePanicOnLateUse(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello\n")) + } + + mw := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + }) + } + + defer func() { + if recover() == nil { + t.Error("expected panic()") + } + }() + + r := NewRouter() + r.Get("/", handler) + r.Use(mw) // Too late to apply middleware, we're expecting panic(). +} + +func TestMountingExistingPath(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) {} + + defer func() { + if recover() == nil { + t.Error("expected panic()") + } + }() + + r := NewRouter().(*routerGroup) + r.Get("/", handler) + r.Mount("/hi", http.HandlerFunc(handler)) + r.Mount("/hi", http.HandlerFunc(handler)) +} + +func TestMountingSimilarPattern(t *testing.T) { + r := NewRouter().(*routerGroup) + r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("bye")) + }) + + r2 := NewRouter() + r2.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("foobar")) + }) + + r3 := NewRouter() + r3.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("foo")) + }) + + r.Mount("/foobar", r2) + r.Mount("/foo", r3) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" { + t.Fatalf(body) + } +} + +func TestMuxEmptyParams(t *testing.T) { + r := NewRouter() + r.Get(`/users/{x}/{y}/{z}`, func(w http.ResponseWriter, r *http.Request) { + x := URLParam(r, "x") + y := URLParam(r, "y") + z := URLParam(r, "z") + w.Write([]byte(fmt.Sprintf("%s-%s-%s", x, y, z))) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/users/a/b/c", nil); body != "a-b-c" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/users///c", nil); body != "--c" { + t.Fatalf(body) + } +} + +func TestMuxMissingParams(t *testing.T) { + r := NewRouter() + r.Get(`/user/{userId:\d+}`, func(w http.ResponseWriter, r *http.Request) { + userID := URLParam(r, "userId") + w.Write([]byte(fmt.Sprintf("userId = '%s'", userID))) + }) + r.NotFound(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte("nothing here")) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/user/123", nil); body != "userId = '123'" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/user/", nil); body != "nothing here" { + t.Fatalf(body) + } +} + +func TestMuxWildcardRoute(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) {} + + defer func() { + if recover() == nil { + t.Error("expected panic()") + } + }() + + r := NewRouter() + r.Get("/*/wildcard/must/be/at/end", handler) +} + +func TestMuxWildcardRouteCheckTwo(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) {} + + defer func() { + if recover() == nil { + t.Error("expected panic()") + } + }() + + r := NewRouter() + r.Get("/*/wildcard/{must}/be/at/end", handler) +} + +func TestMuxRegexp(t *testing.T) { + r := NewRouter() + r.Group("/{param:[0-9]*}/test", func(r Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(fmt.Sprintf("Hi: %s", URLParam(r, "param")))) + }) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "//test", nil); body != "Hi: " { + t.Fatalf(body) + } +} + +func TestMuxRegexp2(t *testing.T) { + r := NewRouter() + r.Get("/foo-{suffix:[a-z]{2,3}}.json", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(URLParam(r, "suffix"))) + }) + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/foo-.json", nil); body != "" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/foo-abc.json", nil); body != "abc" { + t.Fatalf(body) + } +} + +func TestMuxRegexp3(t *testing.T) { + r := NewRouter() + r.Get("/one/{firstId:[a-z0-9-]+}/{secondId:[a-z]+}/first", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("first")) + }) + r.Get("/one/{firstId:[a-z0-9-_]+}/{secondId:[0-9]+}/second", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("second")) + }) + r.Delete("/one/{firstId:[a-z0-9-_]+}/{secondId:[0-9]+}/second", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("third")) + }) + + (func(r Router) { + r.Get("/{dns:[a-z-0-9_]+}", func(writer http.ResponseWriter, request *http.Request) { + writer.Write([]byte("_")) + }) + r.Get("/{dns:[a-z-0-9_]+}/info", func(writer http.ResponseWriter, request *http.Request) { + writer.Write([]byte("_")) + }) + r.Delete("/{id:[0-9]+}", func(writer http.ResponseWriter, request *http.Request) { + writer.Write([]byte("forth")) + }) + })(r.Group("/one")) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/one/hello/peter/first", nil); body != "first" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/one/hithere/123/second", nil); body != "second" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "DELETE", "/one/hithere/123/second", nil); body != "third" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "DELETE", "/one/123", nil); body != "forth" { + t.Fatalf(body) + } +} + +func TestMuxSubrouterWildcardParam(t *testing.T) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "param:%v *:%v", URLParam(r, "param"), URLParam(r, "*")) + }) + + r := NewRouter() + + r.Get("/bare/{param}", h) + r.Get("/bare/{param}/*", h) + + (func(r Router) { + r.Get("/{param}", h) + r.Get("/{param}/*", h) + })(r.Group("/case0")) + + ts := httptest.NewServer(r) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/bare/hi", nil); body != "param:hi *:" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/bare/hi/yes", nil); body != "param:hi *:yes" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/case0/hi", nil); body != "param:hi *:" { + t.Fatalf(body) + } + if _, body := testRequest(t, ts, "GET", "/case0/hi/yes", nil); body != "param:hi *:yes" { + t.Fatalf(body) + } +} + +func TestMuxContextIsThreadSafe(t *testing.T) { + router := NewRouter() + router.Get("/{id}", func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), 1*time.Millisecond) + defer cancel() + + <-ctx.Done() + }) + + wg := sync.WaitGroup{} + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10000; j++ { + w := httptest.NewRecorder() + r, err := http.NewRequest("GET", "/ok", nil) + if err != nil { + t.Error(err) + return + } + + ctx, cancel := context.WithCancel(r.Context()) + r = r.WithContext(ctx) + + go func() { + cancel() + }() + router.ServeHTTP(w, r) + } + }() + } + wg.Wait() +} + +func TestEscapedURLParams(t *testing.T) { + m := NewRouter() + m.Get("/api/{identifier}/{region}/{size}/{rotation}/*", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + rctx := FromRouteContext(r.Context()) + if rctx == nil { + t.Error("no context") + return + } + identifier := URLParam(r, "identifier") + if identifier != "http:%2f%2fexample.com%2fimage.png" { + t.Errorf("identifier path parameter incorrect %s", identifier) + return + } + region := URLParam(r, "region") + if region != "full" { + t.Errorf("region path parameter incorrect %s", region) + return + } + size := URLParam(r, "size") + if size != "max" { + t.Errorf("size path parameter incorrect %s", size) + return + } + rotation := URLParam(r, "rotation") + if rotation != "0" { + t.Errorf("rotation path parameter incorrect %s", rotation) + return + } + w.Write([]byte("success")) + }) + + ts := httptest.NewServer(m) + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/api/http:%2f%2fexample.com%2fimage.png/full/max/0/color.png", nil); body != "success" { + t.Fatalf(body) + } +} + +func TestMuxMatch(t *testing.T) { + r := NewRouter() + r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Test", "yes") + w.Write([]byte("bye")) + }) + (func(r Router) { + r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) { + id := URLParam(r, "id") + w.Header().Set("X-Article", id) + w.Write([]byte("article:" + id)) + }) + })(r.Group("/articles")) + (func(r Router) { + r.Head("/{id}", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-User", "-") + w.Write([]byte("user")) + }) + r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) { + id := URLParam(r, "id") + w.Header().Set("X-User", id) + w.Write([]byte("user:" + id)) + }) + })(r.Group("/users")) + + tctx := &RouteContext{} + + tctx.Reset() + if r.(Routes).Match(tctx, "GET", "/users/1") == false { + t.Fatal("expecting to find match for route:", "GET", "/users/1") + } + + tctx.Reset() + if r.(Routes).Match(tctx, "HEAD", "/articles/10") == true { + t.Fatal("not expecting to find match for route:", "HEAD", "/articles/10") + } +} + +func TestServerBaseContext(t *testing.T) { + r := NewRouter() + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + baseYes := r.Context().Value(ctxKey{"base"}).(string) + if _, ok := r.Context().Value(http.ServerContextKey).(*http.Server); !ok { + panic("missing server context") + } + if _, ok := r.Context().Value(http.LocalAddrContextKey).(net.Addr); !ok { + panic("missing local addr context") + } + w.Write([]byte(baseYes)) + }) + + // Setup http Server with a base context + ctx := context.WithValue(context.Background(), ctxKey{"base"}, "yes") + ts := httptest.NewUnstartedServer(r) + ts.Config.BaseContext = func(_ net.Listener) context.Context { + return ctx + } + ts.Start() + + defer ts.Close() + + if _, body := testRequest(t, ts, "GET", "/", nil); body != "yes" { + t.Fatalf(body) + } +} + +func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) { + req, err := http.NewRequest(method, ts.URL+path, body) + if err != nil { + t.Fatal(err) + return nil, "" + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + return nil, "" + } + + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + return nil, "" + } + defer resp.Body.Close() + + return resp, string(respBody) +} + +func testHandler(t *testing.T, h http.Handler, method, path string, body io.Reader) (*http.Response, string) { + r, _ := http.NewRequest(method, path, body) + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + return w.Result(), w.Body.String() +} + +type ctxKey struct { + name string +} + +func (k ctxKey) String() string { + return "context value " + k.name +} + +func BenchmarkMux(b *testing.B) { + h1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + h3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + h4 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + h5 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + h6 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + mx := NewRouter() + mx.Get("/", h1) + mx.Get("/hi", h2) + mx.Post("/hi-post", h2) // used to benchmark 405 responses + mx.Get("/sup/{id}/and/{this}", h3) + mx.Get("/sup/{id}/{bar:foo}/{this}", h3) + + mx.Group("/sharing/{x}/{hash}", func(mx Router) { + mx.Get("/", h4) // subrouter-1 + mx.Get("/{network}", h5) // subrouter-1 + mx.Get("/twitter", h5) + mx.Group("/direct", func(mx Router) { + mx.Get("/", h6) // subrouter-2 + mx.Get("/download", h6) + }) + }) + + routes := []string{ + "/", + "/hi", + "/hi-post", + "/sup/123/and/this", + "/sup/123/foo/this", + "/sharing/z/aBc", // subrouter-1 + "/sharing/z/aBc/twitter", // subrouter-1 + "/sharing/z/aBc/direct", // subrouter-2 + "/sharing/z/aBc/direct/download", // subrouter-2 + } + + for _, path := range routes { + b.Run("route:"+path, func(b *testing.B) { + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", path, nil) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + mx.ServeHTTP(w, r) + } + }) + } +} diff --git a/web/tree.go b/web/tree.go index b29ec0a7..7f4bd027 100644 --- a/web/tree.go +++ b/web/tree.go @@ -349,9 +349,9 @@ func (n *node) setEndpoint(method methodTyp, handler http.Handler, pattern strin } } -func (n *node) FindRoute(rctx *Context, method methodTyp, path string) (*node, endpoints, http.Handler) { +func (n *node) FindRoute(rctx *RouteContext, method methodTyp, path string) (*node, endpoints, http.Handler) { // Reset the context routing pattern and params - rctx.routePattern = "" + rctx.RoutePattern = "" rctx.routeParams.Keys = rctx.routeParams.Keys[:0] rctx.routeParams.Values = rctx.routeParams.Values[:0] @@ -362,13 +362,13 @@ func (n *node) FindRoute(rctx *Context, method methodTyp, path string) (*node, e } // Record the routing params in the request lifecycle - rctx.urlParams.Keys = append(rctx.urlParams.Keys, rctx.routeParams.Keys...) - rctx.urlParams.Values = append(rctx.urlParams.Values, rctx.routeParams.Values...) + rctx.URLParams.Keys = append(rctx.URLParams.Keys, rctx.routeParams.Keys...) + rctx.URLParams.Values = append(rctx.URLParams.Values, rctx.routeParams.Values...) // Record the routing pattern in the request lifecycle if rn.endpoints[method].pattern != "" { - rctx.routePattern = rn.endpoints[method].pattern - rctx.routePatterns = append(rctx.routePatterns, rctx.routePattern) + rctx.RoutePattern = rn.endpoints[method].pattern + rctx.routePatterns = append(rctx.routePatterns, rctx.RoutePattern) } return rn, rn.endpoints, rn.endpoints[method].handler @@ -376,7 +376,7 @@ func (n *node) FindRoute(rctx *Context, method methodTyp, path string) (*node, e // Recursive edge traversal by checking all nodeTyp groups along the way. // It's like searching through a multi-dimensional radix trie. -func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node { +func (n *node) findRoute(rctx *RouteContext, method methodTyp, path string) *node { nn := n search := path diff --git a/web/tree_test.go b/web/tree_test.go index edf698eb..c350c9cb 100644 --- a/web/tree_test.go +++ b/web/tree_test.go @@ -72,7 +72,7 @@ func TestTree(t *testing.T) { tr.InsertRoute(mGET, "/hubs/{hubID}/view", hHubView1) tr.InsertRoute(mGET, "/hubs/{hubID}/view/*", hHubView2) - sr := &routerGroup{tree: &node{}} + sr := NewRouter() sr.Get("/users", hHubView3) tr.InsertRoute(mGET, "/hubs/{hubID}/*", sr) tr.InsertRoute(mGET, "/hubs/{hubID}/users", hHubView3) @@ -127,7 +127,7 @@ func TestTree(t *testing.T) { // log.Println("~~~~~~~~~") for i, tt := range tests { - rctx := &Context{} + rctx := &RouteContext{} _, handlers, _ := tr.FindRoute(rctx, mGET, tt.r) @@ -243,7 +243,7 @@ func TestTreeMoar(t *testing.T) { // log.Println("~~~~~~~~~") for i, tt := range tests { - rctx := &Context{} + rctx := &RouteContext{} _, handlers, _ := tr.FindRoute(rctx, tt.m, tt.r) @@ -309,7 +309,7 @@ func TestTreeRegexp(t *testing.T) { } for i, tt := range tests { - rctx := &Context{} + rctx := &RouteContext{} _, handlers, _ := tr.FindRoute(rctx, mGET, tt.r) @@ -360,7 +360,7 @@ func TestTreeRegexpRecursive(t *testing.T) { } for i, tt := range tests { - rctx := &Context{} + rctx := &RouteContext{} _, handlers, _ := tr.FindRoute(rctx, mGET, tt.r) @@ -387,7 +387,7 @@ func TestTreeRegexpRecursive(t *testing.T) { func TestTreeRegexMatchWholeParam(t *testing.T) { hStub1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - rctx := &Context{} + rctx := &RouteContext{} tr := &node{} tr.InsertRoute(mGET, "/{id:[0-9]+}", hStub1) tr.InsertRoute(mGET, "/{x:.+}/foo", hStub1) @@ -498,11 +498,12 @@ func BenchmarkTreeGet(b *testing.B) { tr.InsertRoute(mGET, "/pinggggg", h2) tr.InsertRoute(mGET, "/hello", h1) + mctx := &RouteContext{} b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - mctx := &Context{} + mctx.Reset() tr.FindRoute(mctx, mGET, "/ping/123/456") } }