Skip to content

Commit

Permalink
fix bug in responsewriter
Browse files Browse the repository at this point in the history
  • Loading branch information
nbari committed Oct 2, 2017
1 parent 27e3d0c commit a225aa9
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 7 deletions.
6 changes: 2 additions & 4 deletions responsewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ func NewResponseWriter(w http.ResponseWriter, rid string) *ResponseWriter {
ResponseWriter: w,
requestID: rid,
start: time.Now(),
status: http.StatusOK,
}
}

Expand All @@ -39,15 +40,12 @@ func (w *ResponseWriter) RequestTime() string {

// RequestID retrieve the Request ID
func (w *ResponseWriter) RequestID() string {
return w.Header().Get(w.requestID)
return w.requestID
}

// Write satisfies the http.ResponseWriter interface and
// captures data written, in bytes
func (w *ResponseWriter) Write(data []byte) (int, error) {
if w.status == 0 {
w.WriteHeader(http.StatusOK)
}
size, err := w.ResponseWriter.Write(data)
w.size += size
return size, err
Expand Down
98 changes: 97 additions & 1 deletion responsewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func TestResponseWriterStatus(t *testing.T) {
rec := httptest.NewRecorder()
rw := NewResponseWriter(rec, "")

expect(t, rw.Status(), 0)
expect(t, rw.Status(), 200)

rw.Write([]byte(""))
expect(t, rw.Status(), http.StatusOK)
Expand Down Expand Up @@ -60,3 +60,99 @@ func TestResponseWriterWriteHeader(t *testing.T) {
expect(t, rw.Status(), http.StatusNotFound)
expect(t, rw.Size(), 0)
}

func TestResponseWriterLogger(t *testing.T) {
mylogger := func(w *ResponseWriter, r *http.Request) {
expect(t, r.URL.String(), "/test")
expect(t, w.RequestID(), "123")
expect(t, w.Size(), 11)
expect(t, w.Status(), 200)
}
router := New()
router.LogRequests = true
router.RequestID = "rid"
router.Logger = mylogger
router.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
expect(t, w.Header().Get("rid"), "123")
w.Write([]byte("hello world"))
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("rid", "123")
router.ServeHTTP(w, req)
expect(t, w.Code, 200)
expect(t, w.HeaderMap.Get("rid"), "123")
}

func TestResponseWriterLoggerStatus200(t *testing.T) {
mylogger := func(w *ResponseWriter, r *http.Request) {
expect(t, r.URL.String(), "/test")
expect(t, w.RequestID(), "123")
expect(t, w.Size(), 0)
expect(t, w.Status(), 200)
}
router := New()
router.LogRequests = true
router.RequestID = "rid"
router.Logger = mylogger
router.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
expect(t, w.Header().Get("rid"), "123")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("rid", "123")
router.ServeHTTP(w, req)
expect(t, w.Code, 200)
expect(t, w.HeaderMap.Get("rid"), "123")
}

func TestResponseWriterLoggerStatus405(t *testing.T) {
mylogger := func(w *ResponseWriter, r *http.Request) {
expect(t, r.URL.String(), "/test")
expect(t, w.RequestID(), "123")
expect(t, w.Status(), 405)
}
router := New()
router.LogRequests = true
router.RequestID = "rid"
router.Logger = mylogger
router.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
expect(t, w.Header().Get("rid"), "123")
}, "POST")
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("rid", "123")
router.ServeHTTP(w, req)
expect(t, w.Code, 405)
expect(t, w.HeaderMap.Get("rid"), "123")
}

func TestResponseWriterNoLogger(t *testing.T) {
router := New()
router.LogRequests = false
router.RequestID = "rid"
router.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
expect(t, w.Header().Get("rid"), "123")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("rid", "123")
router.ServeHTTP(w, req)
expect(t, w.Code, 200)
expect(t, w.HeaderMap.Get("rid"), "123")
}

func TestResponseWriterNoLogger455(t *testing.T) {
router := New()
router.LogRequests = false
router.RequestID = "rid"
router.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
expect(t, w.Header().Get("rid"), "123")
}, "POST")
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("rid", "123")
router.ServeHTTP(w, req)
expect(t, w.Code, 405)
expect(t, w.HeaderMap.Get("rid"), "123")
}
5 changes: 3 additions & 2 deletions violetear.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,16 +228,17 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}()

// Request-ID
var rid string
if r.RequestID != "" {
if rid := req.Header.Get(r.RequestID); rid != "" {
if rid = req.Header.Get(r.RequestID); rid != "" {
w.Header().Set(r.RequestID, rid)
}
}

// wrap ResponseWriter
var ww *ResponseWriter
if r.LogRequests {
ww = NewResponseWriter(w, r.RequestID)
ww = NewResponseWriter(w, rid)
}

// set version based on the value of "Accept: application/vnd.*"
Expand Down

0 comments on commit a225aa9

Please sign in to comment.