diff --git a/src/pkg/net/http/httptest/recorder.go b/src/pkg/net/http/httptest/recorder.go index 9aa0d510bd4d5..5451f54234c9e 100644 --- a/src/pkg/net/http/httptest/recorder.go +++ b/src/pkg/net/http/httptest/recorder.go @@ -17,6 +17,8 @@ type ResponseRecorder struct { HeaderMap http.Header // the HTTP response headers Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to Flushed bool + + wroteHeader bool } // NewRecorder returns an initialized ResponseRecorder. @@ -24,6 +26,7 @@ func NewRecorder() *ResponseRecorder { return &ResponseRecorder{ HeaderMap: make(http.Header), Body: new(bytes.Buffer), + Code: 200, } } @@ -33,26 +36,37 @@ const DefaultRemoteAddr = "1.2.3.4" // Header returns the response headers. func (rw *ResponseRecorder) Header() http.Header { - return rw.HeaderMap + m := rw.HeaderMap + if m == nil { + m = make(http.Header) + rw.HeaderMap = m + } + return m } // Write always succeeds and writes to rw.Body, if not nil. func (rw *ResponseRecorder) Write(buf []byte) (int, error) { + if !rw.wroteHeader { + rw.WriteHeader(200) + } if rw.Body != nil { rw.Body.Write(buf) } - if rw.Code == 0 { - rw.Code = http.StatusOK - } return len(buf), nil } // WriteHeader sets rw.Code. func (rw *ResponseRecorder) WriteHeader(code int) { - rw.Code = code + if !rw.wroteHeader { + rw.Code = code + } + rw.wroteHeader = true } // Flush sets rw.Flushed to true. func (rw *ResponseRecorder) Flush() { + if !rw.wroteHeader { + rw.WriteHeader(200) + } rw.Flushed = true } diff --git a/src/pkg/net/http/httptest/recorder_test.go b/src/pkg/net/http/httptest/recorder_test.go new file mode 100644 index 0000000000000..2b563260c769a --- /dev/null +++ b/src/pkg/net/http/httptest/recorder_test.go @@ -0,0 +1,90 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest + +import ( + "fmt" + "net/http" + "testing" +) + +func TestRecorder(t *testing.T) { + type checkFunc func(*ResponseRecorder) error + check := func(fns ...checkFunc) []checkFunc { return fns } + + hasStatus := func(wantCode int) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Code != wantCode { + return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode) + } + return nil + } + } + hasContents := func(want string) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Body.String() != want { + return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want) + } + return nil + } + } + hasFlush := func(want bool) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Flushed != want { + return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want) + } + return nil + } + } + + tests := []struct { + name string + h func(w http.ResponseWriter, r *http.Request) + checks []checkFunc + }{ + { + "200 default", + func(w http.ResponseWriter, r *http.Request) {}, + check(hasStatus(200), hasContents("")), + }, + { + "first code only", + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(201) + w.WriteHeader(202) + w.Write([]byte("hi")) + }, + check(hasStatus(201), hasContents("hi")), + }, + { + "write sends 200", + func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi first")) + w.WriteHeader(201) + w.WriteHeader(202) + }, + check(hasStatus(200), hasContents("hi first"), hasFlush(false)), + }, + { + "flush", + func(w http.ResponseWriter, r *http.Request) { + w.(http.Flusher).Flush() // also sends a 200 + w.WriteHeader(201) + }, + check(hasStatus(200), hasFlush(true)), + }, + } + r, _ := http.NewRequest("GET", "http://foo.com/", nil) + for _, tt := range tests { + h := http.HandlerFunc(tt.h) + rec := NewRecorder() + h.ServeHTTP(rec, r) + for _, check := range tt.checks { + if err := check(rec); err != nil { + t.Errorf("%s: %v", tt.name, err) + } + } + } +}