diff --git a/ghttp/test_server.go b/ghttp/test_server.go index 40d92dea2..c7cbf8ee5 100644 --- a/ghttp/test_server.go +++ b/ghttp/test_server.go @@ -160,11 +160,13 @@ type Server struct { HTTPTestServer *httptest.Server //Defaults to false. If set to true, the Server will allow more requests than there are registered handlers. + //Direct use of this property is deprecated and is likely to be removed, use GetAllowUnhandledRequests and SetAllowUnhandledRequests instead. AllowUnhandledRequests bool //The status code returned when receiving an unhandled request. //Defaults to http.StatusInternalServerError. //Only applies if AllowUnhandledRequests is true + //Direct use of this property is deprecated and is likely to be removed, use GetUnhandledRequestStatusCode and SetUnhandledRequestStatusCode instead. UnhandledRequestStatusCode int //If provided, ghttp will log about each request received to the provided io.Writer @@ -213,7 +215,7 @@ func (s *Server) Close() { //1. If the request matches a handler registered with RouteToHandler, that handler is called. //2. Otherwise, if there are handlers registered via AppendHandlers, those handlers are called in order. //3. If all registered handlers have been called then: -// a) If AllowUnhandledRequests is true, the request will be handled with response code of UnhandledRequestStatusCode +// a) If AllowUnhandledRequests is set to true, the request will be handled with response code of UnhandledRequestStatusCode // b) If AllowUnhandledRequests is false, the request will not be handled and the current test will be marked as failed. func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { s.writeLock.Lock() @@ -258,10 +260,10 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { h(w, req) } else { s.writeLock.Unlock() - if s.AllowUnhandledRequests { + if s.GetAllowUnhandledRequests() { ioutil.ReadAll(req.Body) req.Body.Close() - w.WriteHeader(s.UnhandledRequestStatusCode) + w.WriteHeader(s.GetUnhandledRequestStatusCode()) } else { Ω(req).Should(BeNil(), "Received Unhandled Request") } @@ -379,3 +381,35 @@ func (s *Server) CloseClientConnections() { s.HTTPTestServer.CloseClientConnections() } + +//SetAllowUnhandledRequests enables the server to accept unhandled requests. +func (s *Server) SetAllowUnhandledRequests(allowUnhandledRequests bool) { + s.writeLock.Lock() + defer s.writeLock.Unlock() + + s.AllowUnhandledRequests = allowUnhandledRequests +} + +//GetAllowUnhandledRequests returns true if the server accepts unhandled requests. +func (s *Server) GetAllowUnhandledRequests() bool { + s.writeLock.Lock() + defer s.writeLock.Unlock() + + return s.AllowUnhandledRequests +} + +//SetUnhandledRequestStatusCode status code to be returned when the server receives unhandled requests +func (s *Server) SetUnhandledRequestStatusCode(statusCode int) { + s.writeLock.Lock() + defer s.writeLock.Unlock() + + s.UnhandledRequestStatusCode = statusCode +} + +//GetUnhandledRequestStatusCode returns the current status code being returned for unhandled requests +func (s *Server) GetUnhandledRequestStatusCode() int { + s.writeLock.Lock() + defer s.writeLock.Unlock() + + return s.UnhandledRequestStatusCode +} diff --git a/ghttp/test_server_test.go b/ghttp/test_server_test.go index 88b324654..3324093c9 100644 --- a/ghttp/test_server_test.go +++ b/ghttp/test_server_test.go @@ -86,10 +86,14 @@ var _ = Describe("TestServer", func() { }) Describe("allowing unhandled requests", func() { + It("is not permitted by default", func() { + Expect(s.GetAllowUnhandledRequests()).To(BeFalse()) + }) + Context("when true", func() { BeforeEach(func() { - s.AllowUnhandledRequests = true - s.UnhandledRequestStatusCode = http.StatusForbidden + s.SetAllowUnhandledRequests(true) + s.SetUnhandledRequestStatusCode(http.StatusForbidden) resp, err = http.Get(s.URL() + "/foo") Ω(err).ShouldNot(HaveOccurred()) }) @@ -226,7 +230,7 @@ var _ = Describe("TestServer", func() { Describe("When a handler fails", func() { BeforeEach(func() { - s.UnhandledRequestStatusCode = http.StatusForbidden //just to be clear that 500s aren't coming from unhandled requests + s.SetUnhandledRequestStatusCode(http.StatusForbidden) //just to be clear that 500s aren't coming from unhandled requests }) Context("because the handler has panicked", func() {