Skip to content

Commit

Permalink
Merge pull request #68 from HotelsDotCom/master
Browse files Browse the repository at this point in the history
added interceptor option
  • Loading branch information
husobee committed May 25, 2017
2 parents 8ad6f81 + 1d7d0b6 commit bc768a4
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 26 deletions.
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,30 @@ func GeneralHandler(w http.ResponseWriter, r *http.Request) {

```

## Interceptors

Router supports optional interceptors (vestigo provides only interceptor interface, it is up to the user to create one).
These can be either set at global level (all requests go through these):

```go

router := vestigo.NewRouter(authInterceptor, accessLogInterceptor)

```

Or per route:

```go

router.Get("/welcome", GetWelcomeHandler, accessLogInterceptor)

```

Interceptor interface has three methods, `Before() bool`, `After() bool` (which specify if the interceptor should be
called before or after handler call) and `Intercept(w http.ResponseWriter, r *http.Request) bool`. This method returns
true, if the execution (of either handler, or chained interceptors) should continue.


## App Performance with net/http/pprof

It is often very helpful to view profiling information from your web application.
Expand Down
13 changes: 13 additions & 0 deletions interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package vestigo

import "net/http"

type Interceptor interface {
// returns true if the interceptor should run before handler
Before() bool
// returns true if the interceptor should run after handler
After() bool
// the actual intercept function, returns true if the request should continue to handler and/or
// chained interceptors, false if the execution should terminate
Intercept(w http.ResponseWriter, r *http.Request) bool
}
27 changes: 27 additions & 0 deletions interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package vestigo

import "net/http"

type MockInterceptor struct {
before bool
intercept bool
after bool
CalledIntercept int
}

func (m *MockInterceptor) Before() bool {
return m.before
}

func (m *MockInterceptor) After() bool {
return m.after
}

func (m *MockInterceptor) Intercept(w http.ResponseWriter, r *http.Request) bool {
m.CalledIntercept += 1
return m.intercept
}

func NewMockInterceptor(before, intercept, after bool) *MockInterceptor {
return &MockInterceptor{before: before, intercept: intercept, after: after}
}
80 changes: 55 additions & 25 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@ type (

// Router - The main vestigo router data structure
type Router struct {
root *node
globalCors *CorsAccessControl
root *node
globalCors *CorsAccessControl
interceptors []Interceptor
}

// NewRouter - Create a new vestigo router
func NewRouter() *Router {
// NewRouter - Create a new vestigo router with optional global interceptors
func NewRouter(interceptors ...Interceptor) *Router {
return &Router{
root: &node{
resource: newResource(),
},
interceptors: interceptors,
}
}

Expand Down Expand Up @@ -64,57 +66,57 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

// Get - Helper method to add HTTP GET Method to router
func (r *Router) Get(path string, handler http.HandlerFunc) {
r.Add(http.MethodGet, path, handler)
func (r *Router) Get(path string, handler http.HandlerFunc, interceptors ...Interceptor) {
r.Add(http.MethodGet, path, handler, interceptors...)
}

// Post - Helper method to add HTTP POST Method to router
func (r *Router) Post(path string, handler http.HandlerFunc) {
r.Add(http.MethodPost, path, handler)
func (r *Router) Post(path string, handler http.HandlerFunc, interceptors ...Interceptor) {
r.Add(http.MethodPost, path, handler, interceptors...)
}

// Connect - Helper method to add HTTP CONNECT Method to router
func (r *Router) Connect(path string, handler http.HandlerFunc) {
r.Add(http.MethodConnect, path, handler)
func (r *Router) Connect(path string, handler http.HandlerFunc, interceptors ...Interceptor) {
r.Add(http.MethodConnect, path, handler, interceptors...)
}

// Delete - Helper method to add HTTP DELETE Method to router
func (r *Router) Delete(path string, handler http.HandlerFunc) {
r.Add(http.MethodDelete, path, handler)
func (r *Router) Delete(path string, handler http.HandlerFunc, interceptors ...Interceptor) {
r.Add(http.MethodDelete, path, handler, interceptors...)
}

// Patch - Helper method to add HTTP PATCH Method to router
func (r *Router) Patch(path string, handler http.HandlerFunc) {
r.Add(http.MethodPatch, path, handler)
func (r *Router) Patch(path string, handler http.HandlerFunc, interceptors ...Interceptor) {
r.Add(http.MethodPatch, path, handler, interceptors...)
}

// Put - Helper method to add HTTP PUT Method to router
func (r *Router) Put(path string, handler http.HandlerFunc) {
r.Add(http.MethodPut, path, handler)
func (r *Router) Put(path string, handler http.HandlerFunc, interceptors ...Interceptor) {
r.Add(http.MethodPut, path, handler, interceptors...)
}

// Trace - Helper method to add HTTP TRACE Method to router
func (r *Router) Trace(path string, handler http.HandlerFunc) {
r.Add(http.MethodTrace, path, handler)
func (r *Router) Trace(path string, handler http.HandlerFunc, interceptors ...Interceptor) {
r.Add(http.MethodTrace, path, handler, interceptors...)
}

// Handle - Helper method to add all HTTP Methods to router
func (r *Router) Handle(path string, handler http.Handler) {
func (r *Router) Handle(path string, handler http.Handler, interceptors ...Interceptor) {
for k := range methods {
if k == http.MethodHead || k == http.MethodOptions || k == http.MethodTrace {
continue
}
r.Add(k, path, handler.ServeHTTP)
r.Add(k, path, handler.ServeHTTP, interceptors...)
}
}

// HandleFunc - Helper method to add all HTTP Methods to router
func (r *Router) HandleFunc(path string, handler http.HandlerFunc) {
func (r *Router) HandleFunc(path string, handler http.HandlerFunc, interceptors ...Interceptor) {
for k := range methods {
if k == http.MethodHead || k == http.MethodOptions || k == http.MethodTrace {
continue
}
r.Add(k, path, handler.ServeHTTP)
r.Add(k, path, handler.ServeHTTP, interceptors...)
}
}

Expand All @@ -124,12 +126,13 @@ func (r *Router) addWithCors(method, path string, h http.HandlerFunc, cors *Cors
}

// Add - Add a method/handler combination to the router
func (r *Router) Add(method, path string, h http.HandlerFunc) {
r.add(method, path, h, nil)
func (r *Router) Add(method, path string, h http.HandlerFunc, interceptors ...Interceptor) {
r.add(method, path, h,nil, interceptors...)
}

// Add - Add a method/handler combination to the router
func (r *Router) add(method, path string, h http.HandlerFunc, cors *CorsAccessControl) {
func (r *Router) add(method, path string, h http.HandlerFunc, cors *CorsAccessControl, interceptors ...Interceptor) {
h = r.interceptHandlerFunc(h, interceptors)
pnames := make(pNames)
pnames[method] = []string{}

Expand Down Expand Up @@ -469,3 +472,30 @@ func (r *Router) insert(method, path string, h http.HandlerFunc, t ntype, pnames
return
}
}

func (r *Router) interceptHandlerFunc(handler http.HandlerFunc, interceptors []Interceptor) http.HandlerFunc {

return func(w http.ResponseWriter, req *http.Request) {
// merge global and handler specific interceptors - handler specific run first
is := append(interceptors, r.interceptors...)
// before
for _, v := range is {
if v.Before() {
if !v.Intercept(w, req) {
return
}
}
}

handler(w, req)

// after
for _, v := range is {
if v.After() {
if !v.Intercept(w, req) {
return
}
}
}
}
}
123 changes: 122 additions & 1 deletion router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,127 @@ func TestRouterParam(t *testing.T) {
}
}

func TestRouter_AddWithGlobalInterceptor(t *testing.T) {

cases := []struct{
interceptor *MockInterceptor
expectedIntercept int
handlerCalled bool
}{
// interceptor gets called twice, before and after handler. handler gets called
{NewMockInterceptor(true, true, true), 2, true},
// interceptor gets called once, after handler. handler gets called
{NewMockInterceptor(false, true, true), 1, true},
// interceptor gets called once, before handler, breaks execution and does not get executed after
{NewMockInterceptor(true, false, true), 1, false},
// interceptor gets called once, before handler
{NewMockInterceptor(true, true, false), 1, true},
// interceptor gets called once, after handler
{NewMockInterceptor(false, false, true), 1, true},
}

for index, c := range cases {

i := c.interceptor
r := NewRouter(i)
handlerCalled := false
r.Add("GET", "/", func(w http.ResponseWriter, r *http.Request){handlerCalled = true})

h := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)
r.ServeHTTP(h, req)
msg := fmt.Sprintf("global interceptor - case index: %d", index)
assert.Equal(t, c.expectedIntercept, i.CalledIntercept, msg)
assert.Equal(t, c.handlerCalled, handlerCalled, msg)
}
}

func TestRouter_AddWithInterceptor(t *testing.T) {

cases := []struct{
interceptor *MockInterceptor
expectedIntercept int
handlerCalled bool
}{
// interceptor gets called twice, before and after handler. handler gets called
{NewMockInterceptor(true, true, true), 2, true},
// interceptor gets called once, after handler. handler gets called
{NewMockInterceptor(false, true, true), 1, true},
// interceptor gets called once, before handler, breaks execution and does not get executed after
{NewMockInterceptor(true, false, true), 1, false},
// interceptor gets called once, before handler
{NewMockInterceptor(true, true, false), 1, true},
// interceptor gets called once, after handler
{NewMockInterceptor(false, false, true), 1, true},
}

for index, c := range cases {

i := c.interceptor
r := NewRouter()
handlerCalled := false
r.Get("/", func(w http.ResponseWriter, r *http.Request){handlerCalled = true}, i)

h := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)
r.ServeHTTP(h, req)
msg := fmt.Sprintf("interceptor - case index: %d", index)
assert.Equal(t, c.expectedIntercept, i.CalledIntercept, msg)
assert.Equal(t, c.handlerCalled, handlerCalled, msg)
}
}

func TestRouter_AddWithGlobalAndPerRouteInterceptor(t *testing.T) {

cases := []struct{
globalInterceptor *MockInterceptor
interceptor *MockInterceptor
handlerCalled bool
}{
{
NewMockInterceptor(true, true, true),
NewMockInterceptor(true, true, true),
true,
},
{
NewMockInterceptor(true, false, true),
NewMockInterceptor(true, true, true),
false,
},
{
NewMockInterceptor(true, true, true),
NewMockInterceptor(true, false, true),
false,
},
{
NewMockInterceptor(true, false, true),
NewMockInterceptor(true, false, true),
false,
},
{
NewMockInterceptor(false, false, true),
NewMockInterceptor(false, false, true),
true,
},
}

for index, c := range cases {

for _, method := range []string{"GET", "POST", "PUT", "PATCH", "DELETE", "CONNECT", "TRACE"} {

r := NewRouter(c.globalInterceptor)
handlerCalled := false
r.Add(method,"/", func(w http.ResponseWriter, r *http.Request){handlerCalled = true}, c.interceptor)

h := httptest.NewRecorder()
req, _ := http.NewRequest(method, "/", nil)
r.ServeHTTP(h, req)
msg := fmt.Sprintf("interceptor method %s - case index: %d", method, index)
assert.Equal(t, c.handlerCalled, handlerCalled, msg)
}
}
}

func TestRouterTwoParam(t *testing.T) {
r := NewRouter()
r.Add("GET", "/users/:uid/files/:fid", func(w http.ResponseWriter, r *http.Request) {})
Expand Down Expand Up @@ -573,7 +694,7 @@ func TestRouterAddInvalidMethod(t *testing.T) {

func TestMethodSpecificAddRoute(t *testing.T) {
router := NewRouter()
m := map[string]func(path string, handler http.HandlerFunc){
m := map[string]func(path string, handler http.HandlerFunc, interceptors ...Interceptor){
"GET": router.Get,
"POST": router.Post,
"CONNECT": router.Connect,
Expand Down

0 comments on commit bc768a4

Please sign in to comment.