Skip to content

Commit

Permalink
add httpbuffer middleware (#403)
Browse files Browse the repository at this point in the history
* add httpbuffer middleware

* fix deps
  • Loading branch information
matthewmueller committed Apr 17, 2023
1 parent 3f67df9 commit a7c8911
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 0 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require (
github.com/cespare/xxhash v1.1.0
github.com/evanw/esbuild v0.14.11
github.com/fatih/structtag v1.2.0
github.com/felixge/httpsnoop v1.0.3
github.com/fsnotify/fsnotify v1.5.1
github.com/gitchander/permutation v0.0.0-20201214100618-1f3e7285f953
github.com/go-logfmt/logfmt v0.5.1
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ github.com/evanw/esbuild v0.14.11 h1:bw50N4v70Dqf/B6Wn+3BM6BVttz4A6tHn8m8Ydj9vxk
github.com/evanw/esbuild v0.14.11/go.mod h1:GG+zjdi59yh3ehDn4ZWfPcATxjPDUH53iU4ZJbp7dkY=
github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4=
github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94=
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/frankban/quicktest v1.14.0 h1:+cqqvzZV87b4adx/5ayVOaYZ2CrvM4ejQvUdBzPPUss=
github.com/frankban/quicktest v1.14.0/go.mod h1:NeW+ay9A/U67EYXNFA1nPE8e/tnQv/09mUdL/ijj8og=
github.com/fsnotify/fsnotify v1.5.1 h1:mZcQUHVQUQWoPXXtuf9yuEXKudkV2sx1E06UadKWpgI=
Expand Down
68 changes: 68 additions & 0 deletions package/middleware/httpbuffer/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package httpbuffer

import (
"bytes"
"net/http"

"github.com/felixge/httpsnoop"
"github.com/livebud/bud/package/log"
)

func New(log log.Log) *Middleware {
return &Middleware{log}
}

type Middleware struct {
log log.Log
}

func (m *Middleware) Middleware(next http.Handler) http.Handler {
rw := &responseWriter{
code: 0,
body: new(bytes.Buffer),
}
return http.HandlerFunc(func(original http.ResponseWriter, r *http.Request) {
w := httpsnoop.Wrap(original, httpsnoop.Hooks{
WriteHeader: func(_ httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
return rw.WriteHeader
},
Write: func(_ httpsnoop.WriteFunc) httpsnoop.WriteFunc {
return rw.Write
},
Flush: func(flush httpsnoop.FlushFunc) httpsnoop.FlushFunc {
rw.writeTo(original)
return flush
},
})
next.ServeHTTP(w, r)
rw.writeTo(original)
})
}

type responseWriter struct {
body *bytes.Buffer
code int
wrote bool
}

func (rw *responseWriter) WriteHeader(statusCode int) {
rw.code = statusCode
}

func (rw *responseWriter) Write(b []byte) (int, error) {
return rw.body.Write(b)
}

func (rw *responseWriter) writeTo(w http.ResponseWriter) {
// Only write status code once to avoid:
// "http: superfluous response.WriteHeader"
// Not concurrency safe.
if !rw.wrote {
if rw.code == 0 {
rw.code = http.StatusOK
}
w.WriteHeader(rw.code)
rw.wrote = true
}
rw.body.WriteTo(w)
}
175 changes: 175 additions & 0 deletions package/middleware/httpbuffer/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
package httpbuffer_test

import (
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/livebud/bud/internal/is"
"github.com/livebud/bud/package/log/testlog"
"github.com/livebud/bud/package/middleware/httpbuffer"
)

func TestHeadersNormal(t *testing.T) {
is := is.New(t)
rec := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
is.NoErr(err)
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("X-A", "A")
w.Write([]byte("Hello, world!"))
w.Header().Add("X-B", "B")
})
h.ServeHTTP(rec, req)
res := rec.Result()
is.Equal(res.StatusCode, 200)
is.Equal(res.Header.Get("X-A"), "A")
is.Equal(res.Header.Get("X-B"), "")
}

func TestHeadersWrapped(t *testing.T) {
is := is.New(t)
rec := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
is.NoErr(err)
log := testlog.New()
wrap := httpbuffer.New(log)
h := wrap.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("X-A", "A")
w.Write([]byte("Hello, world!"))
w.Header().Add("X-B", "B")
}))
h.ServeHTTP(rec, req)
res := rec.Result()
is.Equal(res.StatusCode, 200)
is.Equal(res.Header.Get("X-A"), "A")
is.Equal(res.Header.Get("X-B"), "B")
}

func TestWriteStatusNormal(t *testing.T) {
is := is.New(t)
rec := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
is.NoErr(err)
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("X-A", "A")
w.WriteHeader(201)
w.Write([]byte("Hello, world!"))
w.Header().Add("X-B", "B")
})
h.ServeHTTP(rec, req)
res := rec.Result()
is.Equal(res.StatusCode, 201)
is.Equal(res.Header.Get("X-A"), "A")
is.Equal(res.Header.Get("X-B"), "")
body, err := io.ReadAll(res.Body)
is.NoErr(err)
is.Equal(string(body), "Hello, world!")
}

func TestWriteStatusWrapped(t *testing.T) {
is := is.New(t)
rec := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
is.NoErr(err)
log := testlog.New()
wrap := httpbuffer.New(log)
h := wrap.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("X-A", "A")
w.WriteHeader(201)
w.Write([]byte("Hello, world!"))
w.Header().Add("X-B", "B")
}))
h.ServeHTTP(rec, req)
res := rec.Result()
is.Equal(res.StatusCode, 201)
is.Equal(res.Header.Get("X-A"), "A")
is.Equal(res.Header.Get("X-B"), "B")
body, err := io.ReadAll(res.Body)
is.NoErr(err)
is.Equal(string(body), "Hello, world!")
}

func TestFlushNormal(t *testing.T) {
is := is.New(t)
rec := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
is.NoErr(err)
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("X-A", "A")
w.WriteHeader(201)
w.Write([]byte("Hello, world!"))
flush, ok := w.(http.Flusher)
if ok {
flush.Flush()
flush.Flush()
}
w.Header().Add("X-B", "B")
})
h.ServeHTTP(rec, req)
res := rec.Result()
is.Equal(res.StatusCode, 201)
is.Equal(res.Header.Get("X-A"), "A")
is.Equal(res.Header.Get("X-B"), "")
body, err := io.ReadAll(res.Body)
is.NoErr(err)
is.Equal(string(body), "Hello, world!")
}

func TestFlushWrapped(t *testing.T) {
is := is.New(t)
rec := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
is.NoErr(err)
log := testlog.New()
wrap := httpbuffer.New(log)
h := wrap.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, world!"))
w.Header().Add("X-A", "A")
flush, ok := w.(http.Flusher)
if ok {
flush.Flush()
w.Write([]byte("yoyo"))
flush.Flush()
}
w.Header().Add("X-B", "B")
}))
h.ServeHTTP(rec, req)
res := rec.Result()
is.Equal(res.StatusCode, 200)
is.Equal(res.Header.Get("X-A"), "A")
is.Equal(res.Header.Get("X-B"), "")
body, err := io.ReadAll(res.Body)
is.NoErr(err)
is.Equal(string(body), "Hello, world!yoyo")
}

func TestFlushStatusWrapped(t *testing.T) {
is := is.New(t)
rec := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/", nil)
is.NoErr(err)
log := testlog.New()
wrap := httpbuffer.New(log)
h := wrap.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, world!"))
w.WriteHeader(201)
w.Header().Add("X-A", "A")
flush, ok := w.(http.Flusher)
if ok {
flush.Flush()
w.Write([]byte("yoyo"))
flush.Flush()
}
w.Header().Add("X-B", "B")
}))
h.ServeHTTP(rec, req)
res := rec.Result()
is.Equal(res.StatusCode, 201)
is.Equal(res.Header.Get("X-A"), "A")
is.Equal(res.Header.Get("X-B"), "")
body, err := io.ReadAll(res.Body)
is.NoErr(err)
is.Equal(string(body), "Hello, world!yoyo")
}

0 comments on commit a7c8911

Please sign in to comment.