Skip to content

Commit

Permalink
before filter are standard HttpHandleFunc, they don't need to return …
Browse files Browse the repository at this point in the history
…a bool anymore. The handlers stack stops when one of them writes in the responsewriter.
  • Loading branch information
gravityblast committed Nov 17, 2013
1 parent 1d40e2b commit 0b60169
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 44 deletions.
27 changes: 21 additions & 6 deletions app_response_writer.go
Expand Up @@ -10,31 +10,42 @@ type ResponseWriter interface {
GetVar(string) interface{}
AddBeforeWriteHandler(handler func())
StatusCode() int
Written() bool
}

type AppResponseWriter struct {
http.ResponseWriter
WroteBody bool
written bool
statusCode int
env map[string]interface{}
routerEnv *map[string]interface{}
beforeWriteHandlers []func()
}


func (w *AppResponseWriter) beforeWrite() {
for _, handler := range w.beforeWriteHandlers {
handler()
}
}

func (w *AppResponseWriter) Write(data []byte) (n int, err error) {
if !w.WroteBody {
for _, handler := range w.beforeWriteHandlers {
handler()
}
w.WroteBody = true
if !w.written {
w.beforeWrite()
w.written = true
}

return w.ResponseWriter.Write(data)
}

func (w *AppResponseWriter) WriteHeader(statusCode int) {
if !w.written {
w.beforeWrite()
}

w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
w.written = true
}

func (w *AppResponseWriter) StatusCode() int {
Expand All @@ -45,6 +56,10 @@ func (w *AppResponseWriter) SetVar(key string, value interface{}) {
w.env[key] = value
}

func (w *AppResponseWriter) Written() bool {
return w.written
}

func (w *AppResponseWriter) GetVar(key string) interface{} {
// local env
value := w.env[key]
Expand Down
4 changes: 2 additions & 2 deletions app_response_writer_test.go
Expand Up @@ -86,13 +86,13 @@ func TestAppResponseWriter_Write(t *testing.T) {
arw.AddBeforeWriteHandler(handler_1.handler)
arw.AddBeforeWriteHandler(handler_2.handler)

assert.False(t, arw.WroteBody)
assert.False(t, arw.Written())
assert.Equal(t, 0, handler_1.calls)
assert.Equal(t, 0, handler_2.calls)

arw.Write([]byte("foo"))

assert.True(t, arw.WroteBody)
assert.True(t, arw.Written())
assert.Equal(t, 1, handler_1.calls)
assert.Equal(t, 1, handler_2.calls)
assert.Equal(t, []byte("foo"), recorder.Body.Bytes())
Expand Down
18 changes: 4 additions & 14 deletions examples/before-filter/main.go
Expand Up @@ -21,37 +21,27 @@ func pageHandler(w traffic.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Page ID: %s\n", params.Get("id"))
}

func checkApiKey(w traffic.ResponseWriter, r *http.Request) bool {
func checkApiKey(w traffic.ResponseWriter, r *http.Request) {
params := r.URL.Query()
if params.Get("api_key") != "foo" {
w.WriteHeader(http.StatusUnauthorized)
return false
}

return true
}

func checkPrivatePageApiKey(w traffic.ResponseWriter, r *http.Request) bool {
func checkPrivatePageApiKey(w traffic.ResponseWriter, r *http.Request) {
params := r.URL.Query()
if params.Get("private_api_key") != "bar" {
w.WriteHeader(http.StatusUnauthorized)
return false
}

return true
}

func addAppNameHeader(w traffic.ResponseWriter, r *http.Request) bool {
func addAppNameHeader(w traffic.ResponseWriter, r *http.Request) {
w.Header().Add("X-APP-NAME", "My App")

return true
}

func addTimeHeader(w traffic.ResponseWriter, r *http.Request) bool {
func addTimeHeader(w traffic.ResponseWriter, r *http.Request) {
t := fmt.Sprintf("%s", time.Now())
w.Header().Add("X-APP-TIME", t)

return true
}

func main() {
Expand Down
4 changes: 2 additions & 2 deletions route.go
Expand Up @@ -12,10 +12,10 @@ type Route struct {
Path string
PathRegexp *regexp.Regexp
Handler HttpHandleFunc
beforeFilters []BeforeFilterFunc
beforeFilters []HttpHandleFunc
}

func (route *Route) AddBeforeFilter(beforeFilter BeforeFilterFunc) *Route {
func (route *Route) AddBeforeFilter(beforeFilter HttpHandleFunc) *Route {
route.beforeFilters = append(route.beforeFilters, beforeFilter)

return route
Expand Down
4 changes: 2 additions & 2 deletions route_test.go
Expand Up @@ -101,8 +101,8 @@ func TestRoute_Match_WithOptionalSegments(t *testing.T) {
func TestRoute_AddBeforeFilterToRoute(t *testing.T) {
route := NewRoute("/", httpHandlerExample)
assert.Equal(t, 0, len(route.beforeFilters))
filterA := BeforeFilterFunc(func(w ResponseWriter, r *http.Request) bool { return true })
filterB := BeforeFilterFunc(func(w ResponseWriter, r *http.Request) bool { return true })
filterA := HttpHandleFunc(func(w ResponseWriter, r *http.Request) {})
filterB := HttpHandleFunc(func(w ResponseWriter, r *http.Request) {})

route.AddBeforeFilter(filterA)
assert.Equal(t, 1, len(route.beforeFilters))
Expand Down
9 changes: 4 additions & 5 deletions router.go
Expand Up @@ -11,7 +11,6 @@ import (

type HttpMethod string

type BeforeFilterFunc func(ResponseWriter, *http.Request) bool
type ErrorHandlerFunc func(ResponseWriter, *http.Request, interface{})

type NextMiddlewareFunc func() Middleware
Expand All @@ -24,7 +23,7 @@ type Router struct {
NotFoundHandler HttpHandleFunc
ErrorHandler ErrorHandlerFunc
routes map[HttpMethod][]*Route
beforeFilters []BeforeFilterFunc
beforeFilters []HttpHandleFunc
middlewares []Middleware
env map[string]interface{}
}
Expand Down Expand Up @@ -78,7 +77,7 @@ func (router *Router) Patch(path string, handler HttpHandleFunc) *Route {
return router.Add(HttpMethod("PATCH"), path, handler)
}

func (router *Router) AddBeforeFilter(beforeFilter BeforeFilterFunc) *Router {
func (router *Router) AddBeforeFilter(beforeFilter HttpHandleFunc) *Router {
router.beforeFilters = append(router.beforeFilters, beforeFilter)

return router
Expand Down Expand Up @@ -123,7 +122,7 @@ func (router *Router) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
nextMiddleware.ServeHTTP(w, r, nextMiddlewareFunc)
}

if w.StatusCode() == http.StatusNotFound && !w.WroteBody {
if w.StatusCode() == http.StatusNotFound && !w.Written() {
router.handleNotFound(w, r)
}
}
Expand Down Expand Up @@ -198,7 +197,7 @@ func init() {
func New() *Router {
router := &Router{
routes: make(map[HttpMethod][]*Route),
beforeFilters: make([]BeforeFilterFunc, 0),
beforeFilters: make([]HttpHandleFunc, 0),
middlewares: make([]Middleware, 0),
env: make(map[string]interface{}),
}
Expand Down
15 changes: 5 additions & 10 deletions router_middleware.go
Expand Up @@ -19,21 +19,16 @@ func (routerMiddleware *RouterMiddleware) ServeHTTP(w ResponseWriter, r *http.Re

r.URL.RawQuery = newValues.Encode()

continueAfterBeforeFilter := true
handlers := append(routerMiddleware.router.beforeFilters, route.beforeFilters...)
handlers = append(handlers, route.Handler)

filters := append(routerMiddleware.router.beforeFilters, route.beforeFilters...)

for _, beforeFilter := range filters {
continueAfterBeforeFilter = beforeFilter(w, r)
if !continueAfterBeforeFilter {
for _, handler := range handlers {
handler(w, r)
if w.Written() {
break
}
}

if continueAfterBeforeFilter {
route.Handler(w, r)
}

return w, r
}
}
Expand Down
2 changes: 1 addition & 1 deletion router_middleware_test.go
Expand Up @@ -21,7 +21,7 @@ func newTestRequest(method, path string) (ResponseWriter, *httptest.ResponseReco
func newTestRouterMiddleware() *RouterMiddleware {
router := &Router{}
router.routes = make(map[HttpMethod][]*Route)
router.beforeFilters = make([]BeforeFilterFunc, 0)
router.beforeFilters = make([]HttpHandleFunc, 0)
router.middlewares = make([]Middleware, 0)
routerMiddleware := &RouterMiddleware{ router }

Expand Down
4 changes: 2 additions & 2 deletions router_test.go
Expand Up @@ -79,8 +79,8 @@ func TestRouter_AddBeforeFilter(t *testing.T) {
router := New()
assert.Equal(t, 0, len(router.beforeFilters))

filterA := BeforeFilterFunc(func(w ResponseWriter, r *http.Request) bool { return true })
filterB := BeforeFilterFunc(func(w ResponseWriter, r *http.Request) bool { return true })
filterA := HttpHandleFunc(func(w ResponseWriter, r *http.Request) {})
filterB := HttpHandleFunc(func(w ResponseWriter, r *http.Request) {})

router.AddBeforeFilter(filterA)
assert.Equal(t, 1, len(router.beforeFilters))
Expand Down

0 comments on commit 0b60169

Please sign in to comment.