From c61a770e448c3ce5c3ff5cf0a1407e05cc69fb18 Mon Sep 17 00:00:00 2001 From: Nick Palmer Date: Wed, 22 Jun 2016 16:41:16 -0700 Subject: [PATCH] Add a trie based HandlerFactory. --- trie/factory.go | 321 ++++++++++++++++++++++++++++++++++ trie/factory_test.go | 400 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 721 insertions(+) create mode 100644 trie/factory.go create mode 100644 trie/factory_test.go diff --git a/trie/factory.go b/trie/factory.go new file mode 100644 index 0000000..eaf7722 --- /dev/null +++ b/trie/factory.go @@ -0,0 +1,321 @@ +package trie + +import ( + "fmt" + "net/http" + "strings" + + "github.com/nick-codes/routem" + + "golang.org/x/net/context" +) + +// Compile time type assertions +var _ http.Handler = &rootNode{} +var _ routem.HandlerFactory = &factory{} + +type routeInfo struct { + route routem.Route + params map[int]string + handler routem.HandlerFunc +} + +type rootNode struct { + ctx context.Context + errorHandler routem.ErrorHandlerFunc + node +} + +type node struct { + path string + routes map[routem.Method]*routeInfo + children map[string]*node +} + +type factory struct { + ctx context.Context + errorHandler routem.ErrorHandlerFunc +} + +// Constructs a new handler factory which uses a trie data structure +// to quickly look up routes. +// +// All routes will be passed a context +// derived from the context passed to the factory. If no context is +// passed then context.Background() is used as the root context. +// +// If an ErrorHandlerFunc is provided and the route does not have a route +// specific error handler that handler will be called if a route +// returns an error. Otherwise a 500 error will be returned to the client. +func NewHandlerFactory(ctx context.Context, errorHandler routem.ErrorHandlerFunc) routem.HandlerFactory { + if ctx == nil { + ctx = context.Background() + } + return &factory{ + ctx: ctx, + errorHandler: errorHandler, + } +} + +func (f *factory) Handler(routes []routem.Route) (http.Handler, error) { + + if routes == nil || len(routes) == 0 { + return nil, fmt.Errorf("Received no routes") + } + + root := &rootNode{ + ctx: f.ctx, + errorHandler: f.errorHandler, + node: node{ + path: "", + children: make(map[string]*node), + routes: make(map[routem.Method]*routeInfo), + }, + } + + for _, route := range routes { + if route == nil { + return nil, fmt.Errorf("Received a nil route.") + } + + if len(route.Path()) == 0 { + return nil, fmt.Errorf("Received a zero length path.") + } + + if strings.Contains(route.Path(), "//") { + return nil, fmt.Errorf("Route contains an invalid path: %s", route.Path()) + } + + if !strings.HasPrefix(route.Path(), "/") { + return nil, fmt.Errorf("Route does not begin with a slash: %s", route.Path()) + } + + parts := strings.Split(route.Path(), "/") + params := make(map[string]struct{}, len(parts)) + + for _, part := range parts { + if strings.HasPrefix(part, ":") { + _, exists := params[part] + if exists { + return nil, fmt.Errorf("Route has duplicate parameter: %s", part) + } + params[part] = struct{}{} + } + } + + inserted, err := root.insert(parts, route, 0, nil) + + if err != nil { + return nil, err + } + + // This should never happen + if !inserted { + return nil, fmt.Errorf("An unknown error occured.") + } + } + + return root, nil +} + +func newNode(path string) (*node, error) { + var ret *node + var err error + + if strings.HasPrefix(path, ":") { + paramName := strings.TrimPrefix(path, ":") + + if len(paramName) == 0 { + err = fmt.Errorf("Found an un-named parameter: %s", path) + } else { + ret = &node{ + path: ":", + routes: make(map[routem.Method]*routeInfo), + children: make(map[string]*node), + } + } + } else { + ret = &node{ + path: path, + routes: make(map[routem.Method]*routeInfo), + children: make(map[string]*node), + } + } + + return ret, err +} + +func (n *node) insert(parts []string, route routem.Route, depth int, params map[int]string) (bool, error) { + + inserted := false + var err error + + // Is this a parameter segment? + thisPath := parts[0] + if strings.HasPrefix(thisPath, ":") { + if params == nil { + params = make(map[int]string, len(parts)) + } + params[depth] = strings.TrimPrefix(thisPath, ":") + thisPath = ":" + } + + // Does this belong in this sub-tree? + if thisPath == n.path { + + // Is this the path leaf? + if len(parts) == 1 { + + // Do we already have a route here? + for _, method := range route.Methods() { + if n.routes[method] != nil { + err = fmt.Errorf("Duplicate route: %s - %s", route.Path(), n.routes[method].route.Path()) + } + } + + // No? Then do the insert + if err == nil { + // Build the middleware stack + handler := route.Handler() + middlewares := route.Middlewares() + for i := len(middlewares) - 1; i >= 0; i-- { + handler = middlewares[i](handler) + } + + // Remember all the info for the route + info := &routeInfo{ + route: route, + params: params, + handler: handler, + } + + // Set it on the various methods + for _, method := range route.Methods() { + n.routes[method] = info + } + + inserted = true + } + } else { + + // Check if we can insert in any existing children + for _, child := range n.children { + inserted, err = child.insert(parts[1:], route, depth+1, params) + if inserted || err != nil { + break + } + } + + // Okay, then make a new child and insert + if !inserted && err == nil { + newChild, err := newNode(parts[1]) + + if err == nil { + n.children[newChild.path] = newChild + inserted, err = newChild.insert(parts[1:], route, depth+1, params) + } + } + + } + } + + return inserted, err +} + +var routeNotFoundError = routem.NewHTTPError(http.StatusNotFound, fmt.Errorf("No Such Route")) + +func (n *node) find(parts []string, method routem.Method) (*routeInfo, routem.HTTPError) { + var info *routeInfo = nil + var err routem.HTTPError = nil + + // Did we fish our wish? + if n.path == ":" || parts[0] == n.path { + + // Did we run out of parts? + if len(parts) == 1 { + info = n.routes[method] + } else { + + // Search all the children + subParts := parts[1:] + for _, child := range n.children { + info, err = child.find(subParts, method) + + // If we found something return it up the stack + if info != nil { + break + } + } + + } + } + + if info == nil { + err = routeNotFoundError + } + + return info, err +} + +func routeParams(route *routeInfo, parts []string) routem.Params { + var params routem.Params + + if route != nil && route.params != nil { + params = make(routem.Params, len(parts)) + + for index, param := range route.params { + // Route cannot match unless it is of sufficient length, so we are sure + // index < len(parts) at this point. + params[param] = parts[index] + } + } + + return params +} + +func (root *rootNode) ServeHTTP(response http.ResponseWriter, request *http.Request) { + parts := strings.Split(request.URL.Path, "/") + routeInfo, err := root.find(parts, routem.Method(request.Method)) + + timeout := routem.DefaultTimeout + if err == nil { + timeout = routeInfo.route.Timeout() + } + + ctx, cancel := routem.NewRequestContext( + root.ctx, timeout, request, response, + routeParams(routeInfo, parts)) + + defer cancel() + + if err == nil { + complete := make(chan routem.HTTPError) + go func() { + complete <- routeInfo.handler(ctx) + }() + + select { + case <-ctx.Done(): + err = routem.NewHTTPError(408, fmt.Errorf("Request Timed Out!")) + case err = <-complete: + } + } + + if err != nil { + var errErr error + + if routeInfo != nil && routeInfo.route.ErrorHandler() != nil { + errErr = routeInfo.route.ErrorHandler()(err, ctx) + } else if root.errorHandler != nil { + errErr = root.errorHandler(err, ctx) + } else if err == routeNotFoundError { + http.Error(response, fmt.Sprintf("Route Not Found: %s", errErr), http.StatusNotFound) + } else { + errErr = err + } + + if errErr != nil { + http.Error(response, fmt.Sprintf("Internal Server Error: %s", errErr), http.StatusInternalServerError) + } + } +} diff --git a/trie/factory_test.go b/trie/factory_test.go new file mode 100644 index 0000000..84620a0 --- /dev/null +++ b/trie/factory_test.go @@ -0,0 +1,400 @@ +package trie + +import ( + "fmt" + "net/http" + "net/http/httptest" + "time" + + "github.com/nick-codes/routem" + + "golang.org/x/net/context" + + "testing" + + "github.com/stretchr/testify/assert" +) + +type ( + keyType int +) + +const ( + zerothKey keyType = 0 + firstKey keyType = 1 + secondKey keyType = 2 +) + +type testRoute struct { + path string + method []routem.Method + handler routem.HandlerFunc + errorHandler routem.ErrorHandlerFunc +} + +func (t *testRoute) Handler() routem.HandlerFunc { + if t.handler != nil { + return t.handler + } + return func(ctx context.Context) routem.HTTPError { + + if 2 == ctx.Value(secondKey) { + return nil + } + + return routem.NewHTTPError(500, fmt.Errorf("Value 2 not in context.")) + } +} + +func (t *testRoute) Methods() []routem.Method { + if len(t.method) > 0 { + return t.method + } + return routem.GetMethod +} + +func (t *testRoute) Prefix(prefix string) routem.Route { + return &testRoute{path: prefix + t.path} +} + +func (t *testRoute) WithTimeout(time.Duration) routem.RouteConfigurator { + return t +} + +func (t *testRoute) WithErrorHandler(routem.ErrorHandlerFunc) routem.RouteConfigurator { + return t +} + +func (t *testRoute) WithMiddleware(routem.MiddlewareFunc) routem.RouteConfigurator { + return t +} + +func (t *testRoute) WithMiddlewares([]routem.MiddlewareFunc) routem.RouteConfigurator { + return t +} + +func (t *testRoute) Middlewares() []routem.MiddlewareFunc { + return []routem.MiddlewareFunc{ + func(next routem.HandlerFunc) routem.HandlerFunc { + return func(ctx context.Context) routem.HTTPError { + ctx = context.WithValue(ctx, firstKey, 1) + return next(ctx) + } + }, + func(next routem.HandlerFunc) routem.HandlerFunc { + return func(ctx context.Context) routem.HTTPError { + if 1 == ctx.Value(firstKey) { + ctx = context.WithValue(ctx, secondKey, 2) + return next(ctx) + } + return routem.NewHTTPError(500, fmt.Errorf("First value not in context.")) + } + }, + } +} + +func (*testRoute) Timeout() time.Duration { + return routem.DefaultTimeout +} + +func (t *testRoute) ErrorHandler() routem.ErrorHandlerFunc { + return t.errorHandler +} + +func (r *testRoute) Path() string { + return r.path +} + +// Helpers +func assertError(t *testing.T, routes []routem.Route) { + factory := NewHandlerFactory(nil, nil) + + handler, err := factory.Handler(routes) + + t.Logf("%s", err) + + assert.Nil(t, handler) + assert.NotNil(t, err) +} + +func assertSuccess(t *testing.T, routes []routem.Route) { + factory := NewHandlerFactory(nil, nil) + + handler, err := factory.Handler(routes) + + t.Logf("%#v", handler) + + assert.NotNil(t, handler) + assert.Nil(t, err) +} + +// Real tests begin here + +func TestNewHandlerFactory(t *testing.T) { + factory := NewHandlerFactory(nil, nil) + + assert.NotNil(t, factory) +} + +func TestErrorWithNilRoutes(t *testing.T) { + assertError(t, nil) +} + +func TestErrorWithEmptyRoutes(t *testing.T) { + routes := make([]routem.Route, 0) + assertError(t, routes) +} + +func TestErrorWithNilRoute(t *testing.T) { + routes := make([]routem.Route, 1, 1) + assertError(t, routes) +} + +func TestErrorWithDuplicateRoutes(t *testing.T) { + routes := []routem.Route{ + &testRoute{path: "/test"}, + &testRoute{path: "/test"}, + } + assertError(t, routes) +} + +func TestErrorWithDuplicateParamName(t *testing.T) { + routes := []routem.Route{ + &testRoute{path: "/:name/:name"}, + } + assertError(t, routes) +} + +func TestErrorWithUnnamedParamter(t *testing.T) { + routes := []routem.Route{ + &testRoute{path: "/:"}, + } + assertError(t, routes) +} + +func TestErrorWithZeroLengthRoute(t *testing.T) { + routes := []routem.Route{ + &testRoute{path: ""}, + } + assertError(t, routes) +} + +func TestErrorWithNonSlashRoute(t *testing.T) { + routes := []routem.Route{ + &testRoute{path: "no/slash"}, + } + assertError(t, routes) +} + +func TestErrorWithDuplicateParamRoutesDifferentNames(t *testing.T) { + routes := []routem.Route{ + &testRoute{path: "/:name"}, + &testRoute{path: "/:id"}, + } + assertError(t, routes) +} + +func TestErrorWithDuplicateParamRoutes(t *testing.T) { + routes := []routem.Route{ + &testRoute{path: "/test/:name"}, + &testRoute{path: "/test/:name"}, + } + assertError(t, routes) +} + +func TestErrorWithDoubleSlashRoute(t *testing.T) { + routes := []routem.Route{ + &testRoute{path: "/:name//"}, + } + assertError(t, routes) +} + +func TestShallowSuccess(t *testing.T) { + routes := []routem.Route{ + &testRoute{path: "/name"}, + } + assertSuccess(t, routes) +} + +func TestSuccessRoot(t *testing.T) { + routes := []routem.Route{ + &testRoute{path: "/"}, + } + assertSuccess(t, routes) +} + +func TestDuplicateRouteDifferentMethods(t *testing.T) { + routes := []routem.Route{ + &testRoute{path: "/test"}, + &testRoute{path: "/test", method: routem.PutMethod}, + } + assertSuccess(t, routes) +} + +func assertServer(t *testing.T, routes []routem.Route, method routem.Method, url string) *httptest.ResponseRecorder { + factory := NewHandlerFactory(context.Background(), nil) + return assertServerFactory(t, factory, routes, method, url) +} + +func assertServerFactory(t *testing.T, factory routem.HandlerFactory, routes []routem.Route, method routem.Method, url string) *httptest.ResponseRecorder { + + handler, err := factory.Handler(routes) + + assert.NotNil(t, handler) + assert.Nil(t, err) + + response := httptest.NewRecorder() + request, err := http.NewRequest(string(method), url, nil) + assert.Nil(t, err) + handler.ServeHTTP(response, request) + + return response +} + +func TestSimpleRouting(t *testing.T) { + routes := []routem.Route{ + &testRoute{path: "/test"}, + } + + assertServer(t, routes, routem.Get, "http://localhost/test") +} + +func TestParamRouting(t *testing.T) { + called := false + routes := []routem.Route{ + &testRoute{ + path: "/test/:id/:name", + handler: func(ctx context.Context) routem.HTTPError { + params := routem.ParamsFromContext(ctx) + assert.NotNil(t, params) + assert.Equal(t, "5", params["id"]) + assert.Equal(t, "bill", params["name"]) + called = true + return nil + }, + }, + } + + response := assertServer(t, routes, routem.Get, "http://localhost/test/5/bill") + + assert.True(t, called) + assert.Equal(t, 200, response.Code) +} + +func TestTimeout(t *testing.T) { + done := make(chan struct{}) + routes := []routem.Route{ + &testRoute{ + path: "/test", + handler: func(ctx context.Context) routem.HTTPError { + time.Sleep(3 * time.Second) + assert.NotNil(t, ctx.Err()) + close(done) + return nil + }, + }, + } + + response := assertServer(t, routes, routem.Get, "http://localhost/test") + assert.Equal(t, 500, response.Code) + // Wait for the check for context error + <-done +} + +func TestRouteNotFoundMethod(t *testing.T) { + routes := []routem.Route{ + &testRoute{path: "/test"}, + } + + response := assertServer(t, routes, routem.Put, "http://localhost/test") + assert.Equal(t, 404, response.Code) +} + +func TestDeepRouteNotFoundMethod(t *testing.T) { + routes := []routem.Route{ + &testRoute{path: "/test/a"}, + &testRoute{path: "/test/b"}, + } + + response := assertServer(t, routes, routem.Put, "http://localhost/test/c") + assert.Equal(t, 404, response.Code) +} + +func TestFactoryErrorHandler(t *testing.T) { + routes := []routem.Route{ + &testRoute{path: "/test"}, + } + factory := NewHandlerFactory(context.Background(), func(err routem.HTTPError, ctx context.Context) error { + response := routem.ResponseWriterFromContext(ctx) + http.Error(response, "Converted to 400", 400) + return nil + }) + response := assertServerFactory(t, factory, routes, routem.Put, "http://localhost/test/c") + assert.Equal(t, 400, response.Code) +} + +func TestErrorWithFactoryErrorHandler(t *testing.T) { + routes := []routem.Route{ + &testRoute{path: "/test"}, + } + factory := NewHandlerFactory(context.Background(), func(err routem.HTTPError, ctx context.Context) error { + return fmt.Errorf("Error handling an error: %v", err) + }) + response := assertServerFactory(t, factory, routes, routem.Put, "http://localhost/test/c") + assert.Equal(t, 500, response.Code) +} + +func TestRouteErrorHandler(t *testing.T) { + routes := []routem.Route{ + &testRoute{ + path: "/test", + handler: func(ctx context.Context) routem.HTTPError { + return routem.NewHTTPError(304, fmt.Errorf("Moved Permanently!")) + }, + errorHandler: func(err routem.HTTPError, ctx context.Context) error { + response := routem.ResponseWriterFromContext(ctx) + http.Error(response, "Converted to 400", 400) + return nil + }, + }, + } + + response := assertServer(t, routes, routem.Get, "http://localhost/test") + assert.Equal(t, 400, response.Code) +} + +func TestErrorRouteErrorHandler(t *testing.T) { + routes := []routem.Route{ + &testRoute{ + path: "/test", + handler: func(ctx context.Context) routem.HTTPError { + return routem.NewHTTPError(304, fmt.Errorf("Moved Permanently!")) + }, + errorHandler: func(err routem.HTTPError, ctx context.Context) error { + return fmt.Errorf("Error handling an error: %v", err) + }, + }, + } + + response := assertServer(t, routes, routem.Get, "http://localhost/test") + assert.Equal(t, 500, response.Code) +} + +func TestRootContextPassed(t *testing.T) { + called := false + routes := []routem.Route{ + &testRoute{ + path: "/test", + handler: func(ctx context.Context) routem.HTTPError { + assert.Equal(t, 0, ctx.Value(zerothKey)) + called = true + return nil + }, + }, + } + + ctx := context.WithValue(context.Background(), zerothKey, 0) + factory := NewHandlerFactory(ctx, nil) + assertServerFactory(t, factory, routes, routem.Get, "http://localhost/test") + assert.True(t, called) +}