Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'feature/clean' into develop
- Loading branch information
Showing
6 changed files
with
294 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# middleware | ||
|
||
Check out [this blog post](http://justinas.org/alice-painless-middleware-chaining-for-go/) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
// https://github.com/justinas/alice | ||
package middleware | ||
|
||
import "net/http" | ||
|
||
// Constructor pattern for all middleware | ||
type Constructor func(http.Handler) http.Handler | ||
|
||
// Chain acts as a list of http.Handler constructors. | ||
type Chain struct { | ||
constructors []Constructor | ||
} | ||
|
||
// New creates a new chain | ||
func New(constructors ...Constructor) Chain { | ||
c := Chain{} | ||
c.constructors = append(c.constructors, constructors...) | ||
|
||
return c | ||
} | ||
|
||
// Then chains the middleware and returns the final http.Handler. | ||
// New(m1, m2, m3).Then(h) | ||
// is equivalent to: | ||
// m1(m2(m3(h))) | ||
// Then() treats nil as http.DefaultServeMux. | ||
func (c Chain) Then(h http.Handler) http.Handler { | ||
var final http.Handler | ||
if h != nil { | ||
final = h | ||
} else { | ||
final = http.DefaultServeMux | ||
} | ||
|
||
for i := len(c.constructors) - 1; i >= 0; i-- { | ||
final = c.constructors[i](final) | ||
} | ||
|
||
return final | ||
} | ||
|
||
// ThenFunc works identically to Then, but takes | ||
// a HandlerFunc instead of a Handler. | ||
// | ||
// The following two statements are equivalent: | ||
// c.Then(http.HandlerFunc(fn)) | ||
// c.ThenFunc(fn) | ||
// | ||
// ThenFunc provides all the guarantees of Then. | ||
func (c Chain) ThenFunc(fn http.HandlerFunc) http.Handler { | ||
if fn == nil { | ||
return c.Then(nil) | ||
} | ||
return c.Then(http.HandlerFunc(fn)) | ||
} | ||
|
||
// Append extends a chain, adding the specified constructors | ||
// as the last ones in the request flow. | ||
// | ||
// Append returns a new chain, leaving the original one untouched. | ||
// | ||
// stdChain := middleware.New(m1, m2) | ||
// extChain := stdChain.Append(m3, m4) | ||
// // requests in stdChain go m1 -> m2 | ||
// // requests in extChain go m1 -> m2 -> m3 -> m4 | ||
func (c Chain) Append(constructors ...Constructor) Chain { | ||
newCons := make([]Constructor, len(c.constructors)+len(constructors)) | ||
copy(newCons, c.constructors) | ||
copy(newCons[len(c.constructors):], constructors) | ||
|
||
newChain := New(newCons...) | ||
return newChain | ||
} | ||
|
||
// Extend extends a chain by adding the specified chain | ||
// as the last one in the request flow. | ||
// | ||
// Extend returns a new chain, leaving the original one untouched. | ||
// | ||
// stdChain := middleware.New(m1, m2) | ||
// ext1Chain := middleware.New(m3, m4) | ||
// ext2Chain := stdChain.Extend(ext1Chain) | ||
// // requests in stdChain go m1 -> m2 | ||
// // requests in ext1Chain go m3 -> m4 | ||
// // requests in ext2Chain go m1 -> m2 -> m3 -> m4 | ||
// | ||
// Another example: | ||
// aHtmlAfterNosurf := middleware.New(m2) | ||
// aHtml := middleware.New(m1, func(h http.Handler) http.Handler { | ||
// csrf := nosurf.New(h) | ||
// csrf.SetFailureHandler(aHtmlAfterNosurf.ThenFunc(csrfFail)) | ||
// return csrf | ||
// }).Extend(aHtmlAfterNosurf) | ||
// // requests to aHtml hitting nosurfs success handler go m1 -> nosurf -> m2 -> target-handler | ||
// // requests to aHtml hitting nosurfs failure handler go m1 -> nosurf -> m2 -> csrfFail | ||
func (c Chain) Extend(chain Chain) Chain { | ||
return c.Append(chain.constructors...) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
// Package middleware implements a middleware chaining solution. | ||
package middleware | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"reflect" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
// A constructor for middleware | ||
// that writes its own "tag" into the RW and does nothing else. | ||
// Useful in checking if a chain is behaving in the right order. | ||
func tagMiddleware(tag string) Constructor { | ||
return func(h http.Handler) http.Handler { | ||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.Write([]byte(tag)) | ||
h.ServeHTTP(w, r) | ||
}) | ||
} | ||
} | ||
|
||
// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer), | ||
// but the best we can do. | ||
func funcsEqual(f1, f2 interface{}) bool { | ||
val1 := reflect.ValueOf(f1) | ||
val2 := reflect.ValueOf(f2) | ||
return val1.Pointer() == val2.Pointer() | ||
} | ||
|
||
var testApp = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.Write([]byte("app\n")) | ||
}) | ||
|
||
// Tests creating a new chain | ||
func TestNew(t *testing.T) { | ||
c1 := func(h http.Handler) http.Handler { | ||
return nil | ||
} | ||
c2 := func(h http.Handler) http.Handler { | ||
return http.StripPrefix("potato", nil) | ||
} | ||
|
||
slice := []Constructor{c1, c2} | ||
|
||
chain := New(slice...) | ||
assert.True(t, funcsEqual(chain.constructors[0], slice[0])) | ||
assert.True(t, funcsEqual(chain.constructors[1], slice[1])) | ||
} | ||
|
||
func TestThenWorksWithNoMiddleware(t *testing.T) { | ||
assert.NotPanics(t, func() { | ||
chain := New() | ||
final := chain.Then(testApp) | ||
|
||
assert.True(t, funcsEqual(final, testApp)) | ||
}) | ||
} | ||
|
||
func TestThenTreatsNilAsDefaultServeMux(t *testing.T) { | ||
chained := New().Then(nil) | ||
assert.Equal(t, chained, http.DefaultServeMux) | ||
} | ||
|
||
func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) { | ||
chained := New().ThenFunc(nil) | ||
assert.Equal(t, chained, http.DefaultServeMux) | ||
} | ||
|
||
func TestThenOrdersHandlersRight(t *testing.T) { | ||
t1 := tagMiddleware("t1\n") | ||
t2 := tagMiddleware("t2\n") | ||
t3 := tagMiddleware("t3\n") | ||
|
||
chained := New(t1, t2, t3).Then(testApp) | ||
|
||
w := httptest.NewRecorder() | ||
r, err := http.NewRequest("GET", "/", nil) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
chained.ServeHTTP(w, r) | ||
|
||
assert.Equal(t, w.Body.String(), "t1\nt2\nt3\napp\n") | ||
} | ||
|
||
func TestAppendAddsHandlersCorrectly(t *testing.T) { | ||
chain := New(tagMiddleware("t1\n"), tagMiddleware("t2\n")) | ||
newChain := chain.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n")) | ||
|
||
assert.Equal(t, len(chain.constructors), 2) | ||
assert.Equal(t, len(newChain.constructors), 4) | ||
|
||
chained := newChain.Then(testApp) | ||
|
||
w := httptest.NewRecorder() | ||
r, err := http.NewRequest("GET", "/", nil) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
chained.ServeHTTP(w, r) | ||
|
||
assert.Equal(t, w.Body.String(), "t1\nt2\nt3\nt4\napp\n") | ||
} | ||
|
||
func TestAppendRespectsImmutability(t *testing.T) { | ||
chain := New(tagMiddleware("")) | ||
newChain := chain.Append(tagMiddleware("")) | ||
|
||
assert.NotEqual(t, &chain.constructors[0], &newChain.constructors[0]) | ||
} | ||
|
||
func TestExtendAddsHandlersCorrectly(t *testing.T) { | ||
chain1 := New(tagMiddleware("t1\n"), tagMiddleware("t2\n")) | ||
chain2 := New(tagMiddleware("t3\n"), tagMiddleware("t4\n")) | ||
newChain := chain1.Extend(chain2) | ||
|
||
assert.Equal(t, len(chain1.constructors), 2) | ||
assert.Equal(t, len(chain2.constructors), 2) | ||
assert.Equal(t, len(newChain.constructors), 4) | ||
|
||
chained := newChain.Then(testApp) | ||
|
||
w := httptest.NewRecorder() | ||
r, err := http.NewRequest("GET", "/", nil) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
chained.ServeHTTP(w, r) | ||
|
||
assert.Equal(t, w.Body.String(), "t1\nt2\nt3\nt4\napp\n") | ||
} | ||
|
||
func TestExtendRespectsImmutability(t *testing.T) { | ||
chain := New(tagMiddleware("")) | ||
newChain := chain.Extend(New(tagMiddleware(""))) | ||
|
||
assert.NotEqual(t, &chain.constructors[0], &newChain.constructors[0]) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.