Skip to content

Commit

Permalink
Added before hook to martini.ResponseWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
codegangsta committed Dec 4, 2013
1 parent f86ef05 commit 789c3d8
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
26 changes: 22 additions & 4 deletions response_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,36 @@ type ResponseWriter interface {
Written() bool
// Size returns the size of the response body.
Size() int
// Before allows for a function to be called before the ResponseWriter has been written to. This is
// useful for setting headers or any other operations that must happen before a response has been written.
Before(BeforeFunc)
}

// BeforeFunc is a function that is called before the ResponseWriter has been written to.
type BeforeFunc func(ResponseWriter)

// NewResponseWriter creates a ResponseWriter that wraps an http.ResponseWriter
func NewResponseWriter(rw http.ResponseWriter) ResponseWriter {
return &responseWriter{rw, 0, 0}
return &responseWriter{rw, 0, 0, nil}
}

type responseWriter struct {
http.ResponseWriter
status int
size int
status int
size int
beforeFuncs []BeforeFunc
}

func (rw *responseWriter) WriteHeader(s int) {
rw.callBefore()
rw.ResponseWriter.WriteHeader(s)
rw.status = s
}

func (rw *responseWriter) Write(b []byte) (int, error) {
if !rw.Written() {
// The status will be StatusOK if WriteHeader has not been called yet
rw.status = http.StatusOK
rw.WriteHeader(http.StatusOK)
}
size, err := rw.ResponseWriter.Write(b)
rw.size += size
Expand All @@ -58,10 +66,20 @@ func (rw *responseWriter) Written() bool {
return rw.status != 0
}

func (rw *responseWriter) Before(before BeforeFunc) {
rw.beforeFuncs = append(rw.beforeFuncs, before)
}

func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := rw.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("ResponseWriter doesn't support Hijacker interface")
}
return hijacker.Hijack()
}

func (rw *responseWriter) callBefore() {
for i := len(rw.beforeFuncs) - 1; i >= 0; i-- {
rw.beforeFuncs[i](rw)
}
}
21 changes: 21 additions & 0 deletions response_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,27 @@ func Test_ResponseWriter_WritingHeader(t *testing.T) {
expect(t, rw.Size(), 0)
}

func Test_ResponseWriter_Before(t *testing.T) {
rec := httptest.NewRecorder()
rw := NewResponseWriter(rec)
result := ""

rw.Before(func(ResponseWriter) {
result += "foo"
})
rw.Before(func(ResponseWriter) {
result += "bar"
})

rw.WriteHeader(http.StatusNotFound)

expect(t, rec.Code, rw.Status())
expect(t, rec.Body.String(), "")
expect(t, rw.Status(), http.StatusNotFound)
expect(t, rw.Size(), 0)
expect(t, result, "barfoo")
}

func Test_ResponseWriter_Hijack(t *testing.T) {
hijackable := newHijackableResponse()
rw := NewResponseWriter(hijackable)
Expand Down

0 comments on commit 789c3d8

Please sign in to comment.