Skip to content

Commit

Permalink
Middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
jwilner committed Mar 5, 2019
1 parent d2b4d44 commit ee7e6ca
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 37 deletions.
4 changes: 2 additions & 2 deletions examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,13 @@ func ExampleDefaultMethod() {
// PRETEND /foo not allowed
}

func ExampleGlobalMiddleware() {
func ExampleWrap() {
// applied to the one
m1 := stringMW("and this is m1")
// applied to both
m2 := stringMW("this is m2")

tbl := rte.Must(rte.GlobalMiddleware(m2, rte.Routes(
tbl := rte.Must(rte.Wrap(m2, rte.Routes(
"GET /", func(w http.ResponseWriter, r *http.Request) {
_, _ = fmt.Fprintf(w, "handling GET /\n")
}, m1,
Expand Down
58 changes: 44 additions & 14 deletions helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
// - Route
// - []Route
// - "PATH", []Route (identical to rte.Prefix("PATH", routes))
// - "PATH", []Route, middleware (identical to rte.GlobalMiddleware(rte.Prefix("PATH", routes), middleware))
// - "PATH", []Route, middleware (identical to rte.Wrap(rte.Prefix("PATH", routes), middleware))
func Routes(is ...interface{}) []Route {
var routes []Route

Expand Down Expand Up @@ -94,7 +94,7 @@ func Routes(is ...interface{}) []Route {

if idxMW := idxHandler + 1; idxMW < len(is) {
if mw, ok := is[idxMW].(Middleware); ok {
routes = append(routes, GlobalMiddleware(mw, newRoutes)...)
routes = append(routes, Wrap(mw, newRoutes)...)
idxReqLine = idxMW + 1
continue
}
Expand Down Expand Up @@ -203,27 +203,57 @@ func DefaultMethod(hndlr interface{}, routes []Route) []Route {
return copied
}

// GlobalMiddleware registers a middleware across all provide routes. If a middleware is already set,
// that middleware will be invoked second.
func GlobalMiddleware(mw Middleware, routes []Route) []Route {
// Wrap registers a middleware across all provide routes. If a middleware is already set, that middleware will be
// invoked second.
func Wrap(mw Middleware, routes []Route) []Route {
var copied []Route
for _, r := range routes {
r.Middleware = composeMiddleware(mw, r.Middleware)
if r.Middleware != nil {
r.Middleware = Compose(mw, r.Middleware)
} else {
r.Middleware = mw
}
copied = append(copied, r)
}
return copied
}

func composeMiddleware(mw1, mw2 Middleware) Middleware {
if mw1 == nil {
return mw2
// Compose combines one or more middlewares into a single middleware. The composed middleware will proceed left to right
// through the middleware (and exit right to left).
func Compose(mw Middleware, mws ...Middleware) Middleware {
mws = append([]Middleware{mw}, mws...)
mw = mws[len(mws)-1]
for i := len(mws) - 2; i >= 0; i-- {
mw1, mw2 := mws[i], mw
mw = MiddlewareFunc(func(w http.ResponseWriter, r *http.Request, next http.Handler) {
mw1.Handle(w, r, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mw2.Handle(w, r, next)
}))
})
}
if mw2 == nil {
return mw1
return mw
}

// RecoveryMiddleware returns a middleware which converts any panics into 500 status http errors and stops the panic. If
// a non-nil log is provided, any panic will be logged.
func RecoveryMiddleware(log interface{ Println(...interface{}) }) Middleware {
if log == nil {
return MiddlewareFunc(func(w http.ResponseWriter, r *http.Request, next http.Handler) {
defer func() {
if p := recover(); p != nil {
w.WriteHeader(http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}
return MiddlewareFunc(func(w http.ResponseWriter, r *http.Request, next http.Handler) {
mw1.Handle(w, r, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mw2.Handle(w, r, next)
}))
defer func() {
if p := recover(); p != nil {
w.WriteHeader(http.StatusInternalServerError)
log.Println(p)
}
}()
next.ServeHTTP(w, r)
})
}
99 changes: 78 additions & 21 deletions helper_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package rte_test

import (
"bytes"
"fmt"
"github.com/jwilner/rte"
"log"
"net/http"
"net/http/httptest"
"reflect"
Expand Down Expand Up @@ -136,31 +138,13 @@ func (s stringMW) Handle(w http.ResponseWriter, r *http.Request, next http.Handl
func TestGlobalMiddleware(t *testing.T) {
mw1 := mockMW(true)
t.Run("empty", func(t *testing.T) {
rts := rte.GlobalMiddleware(nil, nil)
rts := rte.Wrap(mw1, nil)
if len(rts) != 0 {
t.Errorf("Wanted no routes returned")
}
})
t.Run("nilPassed", func(t *testing.T) {
rts := rte.GlobalMiddleware(nil, []rte.Route{
{Method: "GET", Path: "/"},
})
want := []rte.Route{{Method: "GET", Path: "/"}}
if !reflect.DeepEqual(rts, want) {
t.Errorf("Wanted %v but got %v", want, rts)
}
})
t.Run("nilPassedMwPresent", func(t *testing.T) {
rts := rte.GlobalMiddleware(nil, []rte.Route{
{Method: "GET", Path: "/", Middleware: mw1},
})
want := []rte.Route{{Method: "GET", Path: "/", Middleware: mw1}}
if !reflect.DeepEqual(rts, want) {
t.Errorf("Wanted %v but got %v", want, rts)
}
})
t.Run("setsMW", func(t *testing.T) {
rts := rte.GlobalMiddleware(mw1, []rte.Route{
rts := rte.Wrap(mw1, []rte.Route{
{Method: "GET", Path: "/"},
})
want := []rte.Route{{Method: "GET", Path: "/", Middleware: mw1}}
Expand All @@ -169,7 +153,7 @@ func TestGlobalMiddleware(t *testing.T) {
}
})
t.Run("composes", func(t *testing.T) {
tbl := rte.Must(rte.GlobalMiddleware(stringMW("hi"), []rte.Route{
tbl := rte.Must(rte.Wrap(stringMW("hi"), []rte.Route{
{
Method: "GET",
Path: "/",
Expand Down Expand Up @@ -374,3 +358,76 @@ func TestRoutes(t *testing.T) {
})
}
}

func TestCompose(t *testing.T) {
getBody := func(mw rte.Middleware) string {
w := httptest.NewRecorder()
mw.Handle(
w,
httptest.NewRequest("GET", "/", nil),
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
)
return w.Body.String()
}

t.Run("one", func(t *testing.T) {
if r := getBody(rte.Compose(stringMW("1"))); r != "1\n" {
t.Fatalf("Wanted \"1\n\" but got %v", r)
}
})
t.Run("two", func(t *testing.T) {
if r := getBody(rte.Compose(stringMW("1"), stringMW("2"))); r != "1\n2\n" {
t.Fatalf("Wanted \"1\n2\n\" but got %v", r)
}
})
t.Run("three", func(t *testing.T) {
if r := getBody(rte.Compose(stringMW("1"), stringMW("2"), stringMW("3"))); r != "1\n2\n3\n" {
t.Fatalf("Wanted \"1\n2\n3\n\" but got %v", r)
}
})
}

func TestRecoveryMiddleware(t *testing.T) {
panicky := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
panic("whoa")
})
noPanic := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})

getCode := func(mw rte.Middleware, h http.Handler) int {
w := httptest.NewRecorder()
mw.Handle(w, httptest.NewRequest("GET", "/", nil), h)
return w.Code
}

t.Run("nilNone", func(t *testing.T) {
if code := getCode(rte.RecoveryMiddleware(nil), noPanic); code != 200 {
t.Fatalf("Expected 200 but got %v", code)
}
})

t.Run("nilPanic", func(t *testing.T) {
if code := getCode(rte.RecoveryMiddleware(nil), panicky); code != 500 {
t.Fatalf("Expected 500 but got %v", code)
}
})

t.Run("logNone", func(t *testing.T) {
var buf bytes.Buffer
if code := getCode(rte.RecoveryMiddleware(log.New(&buf, "", 0)), noPanic); code != 200 {
t.Fatalf("Expected 200 but got %v", code)
}
if buf.Len() != 0 {
t.Fatalf("Expected no bytes written but got %q", buf.String())
}
})

t.Run("logPanic", func(t *testing.T) {
var buf bytes.Buffer
if code := getCode(rte.RecoveryMiddleware(log.New(&buf, "", 0)), panicky); code != 500 {
t.Fatalf("Expected 500 but got %v", code)
}
if buf.String() != "whoa\n" {
t.Fatalf("Expected \"whoa\n\" written but got %q", buf.String())
}
})
}

0 comments on commit ee7e6ca

Please sign in to comment.