Skip to content

Commit

Permalink
Merge branch 'feature/clean' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
nbari committed Oct 22, 2015
2 parents 89ea9fb + 15ec9b8 commit 3cfb333
Show file tree
Hide file tree
Showing 6 changed files with 294 additions and 56 deletions.
3 changes: 3 additions & 0 deletions middleware/README.md
@@ -0,0 +1,3 @@
# middleware

Check out [this blog post](http://justinas.org/alice-painless-middleware-chaining-for-go/)
98 changes: 98 additions & 0 deletions middleware/middleware.go
@@ -0,0 +1,98 @@
// https://github.com/justinas/alice
package middleware

import "net/http"

// Constructor pattern for all middleware
type Constructor func(http.Handler) http.Handler

// Chain acts as a list of http.Handler constructors.
type Chain struct {
constructors []Constructor
}

// New creates a new chain
func New(constructors ...Constructor) Chain {
c := Chain{}
c.constructors = append(c.constructors, constructors...)

return c
}

// Then chains the middleware and returns the final http.Handler.
// New(m1, m2, m3).Then(h)
// is equivalent to:
// m1(m2(m3(h)))
// Then() treats nil as http.DefaultServeMux.
func (c Chain) Then(h http.Handler) http.Handler {
var final http.Handler
if h != nil {
final = h
} else {
final = http.DefaultServeMux
}

for i := len(c.constructors) - 1; i >= 0; i-- {
final = c.constructors[i](final)
}

return final
}

// ThenFunc works identically to Then, but takes
// a HandlerFunc instead of a Handler.
//
// The following two statements are equivalent:
// c.Then(http.HandlerFunc(fn))
// c.ThenFunc(fn)
//
// ThenFunc provides all the guarantees of Then.
func (c Chain) ThenFunc(fn http.HandlerFunc) http.Handler {
if fn == nil {
return c.Then(nil)
}
return c.Then(http.HandlerFunc(fn))
}

// Append extends a chain, adding the specified constructors
// as the last ones in the request flow.
//
// Append returns a new chain, leaving the original one untouched.
//
// stdChain := middleware.New(m1, m2)
// extChain := stdChain.Append(m3, m4)
// // requests in stdChain go m1 -> m2
// // requests in extChain go m1 -> m2 -> m3 -> m4
func (c Chain) Append(constructors ...Constructor) Chain {
newCons := make([]Constructor, len(c.constructors)+len(constructors))
copy(newCons, c.constructors)
copy(newCons[len(c.constructors):], constructors)

newChain := New(newCons...)
return newChain
}

// Extend extends a chain by adding the specified chain
// as the last one in the request flow.
//
// Extend returns a new chain, leaving the original one untouched.
//
// stdChain := middleware.New(m1, m2)
// ext1Chain := middleware.New(m3, m4)
// ext2Chain := stdChain.Extend(ext1Chain)
// // requests in stdChain go m1 -> m2
// // requests in ext1Chain go m3 -> m4
// // requests in ext2Chain go m1 -> m2 -> m3 -> m4
//
// Another example:
// aHtmlAfterNosurf := middleware.New(m2)
// aHtml := middleware.New(m1, func(h http.Handler) http.Handler {
// csrf := nosurf.New(h)
// csrf.SetFailureHandler(aHtmlAfterNosurf.ThenFunc(csrfFail))
// return csrf
// }).Extend(aHtmlAfterNosurf)
// // requests to aHtml hitting nosurfs success handler go m1 -> nosurf -> m2 -> target-handler
// // requests to aHtml hitting nosurfs failure handler go m1 -> nosurf -> m2 -> csrfFail
func (c Chain) Extend(chain Chain) Chain {
return c.Append(chain.constructors...)
}
144 changes: 144 additions & 0 deletions middleware/middleware_test.go
@@ -0,0 +1,144 @@
// Package middleware implements a middleware chaining solution.
package middleware

import (
"net/http"
"net/http/httptest"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
)

// A constructor for middleware
// that writes its own "tag" into the RW and does nothing else.
// Useful in checking if a chain is behaving in the right order.
func tagMiddleware(tag string) Constructor {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(tag))
h.ServeHTTP(w, r)
})
}
}

// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer),
// but the best we can do.
func funcsEqual(f1, f2 interface{}) bool {
val1 := reflect.ValueOf(f1)
val2 := reflect.ValueOf(f2)
return val1.Pointer() == val2.Pointer()
}

var testApp = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("app\n"))
})

// Tests creating a new chain
func TestNew(t *testing.T) {
c1 := func(h http.Handler) http.Handler {
return nil
}
c2 := func(h http.Handler) http.Handler {
return http.StripPrefix("potato", nil)
}

slice := []Constructor{c1, c2}

chain := New(slice...)
assert.True(t, funcsEqual(chain.constructors[0], slice[0]))
assert.True(t, funcsEqual(chain.constructors[1], slice[1]))
}

func TestThenWorksWithNoMiddleware(t *testing.T) {
assert.NotPanics(t, func() {
chain := New()
final := chain.Then(testApp)

assert.True(t, funcsEqual(final, testApp))
})
}

func TestThenTreatsNilAsDefaultServeMux(t *testing.T) {
chained := New().Then(nil)
assert.Equal(t, chained, http.DefaultServeMux)
}

func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) {
chained := New().ThenFunc(nil)
assert.Equal(t, chained, http.DefaultServeMux)
}

func TestThenOrdersHandlersRight(t *testing.T) {
t1 := tagMiddleware("t1\n")
t2 := tagMiddleware("t2\n")
t3 := tagMiddleware("t3\n")

chained := New(t1, t2, t3).Then(testApp)

w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}

chained.ServeHTTP(w, r)

assert.Equal(t, w.Body.String(), "t1\nt2\nt3\napp\n")
}

func TestAppendAddsHandlersCorrectly(t *testing.T) {
chain := New(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
newChain := chain.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n"))

assert.Equal(t, len(chain.constructors), 2)
assert.Equal(t, len(newChain.constructors), 4)

chained := newChain.Then(testApp)

w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}

chained.ServeHTTP(w, r)

assert.Equal(t, w.Body.String(), "t1\nt2\nt3\nt4\napp\n")
}

func TestAppendRespectsImmutability(t *testing.T) {
chain := New(tagMiddleware(""))
newChain := chain.Append(tagMiddleware(""))

assert.NotEqual(t, &chain.constructors[0], &newChain.constructors[0])
}

func TestExtendAddsHandlersCorrectly(t *testing.T) {
chain1 := New(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
chain2 := New(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
newChain := chain1.Extend(chain2)

assert.Equal(t, len(chain1.constructors), 2)
assert.Equal(t, len(chain2.constructors), 2)
assert.Equal(t, len(newChain.constructors), 4)

chained := newChain.Then(testApp)

w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}

chained.ServeHTTP(w, r)

assert.Equal(t, w.Body.String(), "t1\nt2\nt3\nt4\napp\n")
}

func TestExtendRespectsImmutability(t *testing.T) {
chain := New(tagMiddleware(""))
newChain := chain.Extend(New(tagMiddleware("")))

assert.NotEqual(t, &chain.constructors[0], &newChain.constructors[0])
}
2 changes: 1 addition & 1 deletion trie.go
Expand Up @@ -22,7 +22,7 @@ func NewTrie() *Trie {
}

// Set adds a node (url part) to the Trie
func (t *Trie) Set(path []string, handler http.HandlerFunc, method string) error {
func (t *Trie) Set(path []string, handler http.Handler, method string) error {

if len(path) == 0 {
return errors.New("path cannot be empty")
Expand Down
66 changes: 26 additions & 40 deletions violetear.go
Expand Up @@ -25,7 +25,7 @@
//
// func main() {
// router := violetear.New()
// router.LogRequests = true
// router.LogRequests = true
// router.Request_ID = "REQUEST_LOG_ID"
//
// router.AddRegex(":uuid", `[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}`)
Expand All @@ -34,8 +34,6 @@
// router.HandleFunc("/hello/", helloWorld, "GET,HEAD")
// router.HandleFunc("/root/:uuid/item", handleUUID, "POST,PUT")
//
// router.SetHeader("X-app-version", "1.1")
//
// log.Fatal(http.ListenAndServe(":8080", router))
// }
//
Expand All @@ -58,7 +56,7 @@ type Router struct {
// dynamicRoutes map of dynamic routes and regular expresions
dynamicRoutes dynamicSet

// logRequests yes or no
// LogRequests yes or no
LogRequests bool

// NotFoundHandler configurable http.Handler which is called when no matching
Expand All @@ -68,17 +66,14 @@ type Router struct {
// NotAllowedHandler configurable http.Handler which is called when method not allowed.
NotAllowedHandler http.Handler

// PanicHandler function to handle panics.
PanicHandler http.HandlerFunc

// request-id to use
Request_ID string

// extraHeaders adds exta headers to the response
extraHeaders map[string]string

// count counter for hits
count int64

// Verbose
Verbose bool
}

var split_path_rx = regexp.MustCompile(`[^/ ]+`)
Expand All @@ -88,18 +83,11 @@ func New() *Router {
return &Router{
routes: NewTrie(),
dynamicRoutes: make(dynamicSet),
extraHeaders: make(map[string]string),
Verbose: true,
}
}

// SetHeader adds extra headers to the response
func (v *Router) SetHeader(key, value string) {
v.extraHeaders[key] = value
}

// HandleFunc add a route to the router (path, HandlerFunc, methods)
func (v *Router) HandleFunc(path string, handler http.HandlerFunc, http_methods ...string) error {
// Handle registers the handler for the given pattern (path, http.Handler, methods).
func (v *Router) Handle(path string, handler http.Handler, http_methods ...string) error {
path_parts := v.splitPath(path)

// search for dynamic routes
Expand All @@ -117,15 +105,19 @@ func (v *Router) HandleFunc(path string, handler http.HandlerFunc, http_methods
methods = http_methods[0]
}

if v.Verbose {
log.Printf("Adding path: %s [%s]", path, methods)
}
log.Printf("Adding path: %s [%s]", path, methods)

if err := v.routes.Set(path_parts, handler, methods); err != nil {
return err
}
return nil
}

// HandleFunc add a route to the router (path, http.HandlerFunc, methods)
func (v *Router) HandleFunc(path string, handler http.HandlerFunc, http_methods ...string) error {
return v.Handle(path, handler, http_methods...)
}

// AddRegex adds a ":named" regular expression to the dynamicRoutes
func (v *Router) AddRegex(name string, regex string) error {
return v.dynamicRoutes.Set(name, regex)
Expand All @@ -147,6 +139,18 @@ func (v *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
atomic.AddInt64(&v.count, 1)
lw := NewResponseWriter(w)

// panic handler
defer func() {
if err := recover(); err != nil {
if v.PanicHandler != nil {
v.PanicHandler(w, r)
} else {
http.Error(w, http.StatusText(500), http.StatusInternalServerError)
}
}
}()

// _ path never empty, defaults to ("/")
node, path, leaf, _ := v.routes.Get(v.splitPath(r.URL.Path))

// checkMethod check if method is allowed or not
Expand Down Expand Up @@ -202,27 +206,9 @@ func (v *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Request-ID", rid)
}

// set extra headers
for k, v := range v.extraHeaders {
w.Header().Set(k, v)
}

//h http.Handler
h := match(node, path, leaf)

// panicHandler
defer func() {
if err := recover(); err != nil {
log.Printf("panic: %+v - %s [%s] %v %s",
err,
r.RemoteAddr,
r.URL,
time.Since(start),
rid)
http.Error(w, http.StatusText(500), http.StatusInternalServerError)
}
}()

// dispatch request
h.ServeHTTP(lw, r)

Expand Down

0 comments on commit 3cfb333

Please sign in to comment.