Skip to content

Commit

Permalink
Use setters and getters to avoid race condition
Browse files Browse the repository at this point in the history
when accessing ghttp.AllowUnhandledRequests and ghttp.UnhandledRequestsStatusCode

Solves #173

Signed-off-by: Derik Evangelista <devangelista@pivotal.io>
  • Loading branch information
edwardecook authored and Derik Evangelista committed Feb 1, 2018
1 parent 4fc1762 commit 13057c3
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 6 deletions.
40 changes: 37 additions & 3 deletions ghttp/test_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,13 @@ type Server struct {
HTTPTestServer *httptest.Server

//Defaults to false. If set to true, the Server will allow more requests than there are registered handlers.
//Direct use of this property is deprecated and is likely to be removed, use GetAllowUnhandledRequests and SetAllowUnhandledRequests instead.
AllowUnhandledRequests bool

//The status code returned when receiving an unhandled request.
//Defaults to http.StatusInternalServerError.
//Only applies if AllowUnhandledRequests is true
//Direct use of this property is deprecated and is likely to be removed, use GetUnhandledRequestStatusCode and SetUnhandledRequestStatusCode instead.
UnhandledRequestStatusCode int

//If provided, ghttp will log about each request received to the provided io.Writer
Expand Down Expand Up @@ -213,7 +215,7 @@ func (s *Server) Close() {
//1. If the request matches a handler registered with RouteToHandler, that handler is called.
//2. Otherwise, if there are handlers registered via AppendHandlers, those handlers are called in order.
//3. If all registered handlers have been called then:
// a) If AllowUnhandledRequests is true, the request will be handled with response code of UnhandledRequestStatusCode
// a) If AllowUnhandledRequests is set to true, the request will be handled with response code of UnhandledRequestStatusCode
// b) If AllowUnhandledRequests is false, the request will not be handled and the current test will be marked as failed.
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s.writeLock.Lock()
Expand Down Expand Up @@ -258,10 +260,10 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
h(w, req)
} else {
s.writeLock.Unlock()
if s.AllowUnhandledRequests {
if s.GetAllowUnhandledRequests() {
ioutil.ReadAll(req.Body)
req.Body.Close()
w.WriteHeader(s.UnhandledRequestStatusCode)
w.WriteHeader(s.GetUnhandledRequestStatusCode())
} else {
Ω(req).Should(BeNil(), "Received Unhandled Request")
}
Expand Down Expand Up @@ -379,3 +381,35 @@ func (s *Server) CloseClientConnections() {

s.HTTPTestServer.CloseClientConnections()
}

//SetAllowUnhandledRequests enables the server to accept unhandled requests.
func (s *Server) SetAllowUnhandledRequests(allowUnhandledRequests bool) {
s.writeLock.Lock()
defer s.writeLock.Unlock()

s.AllowUnhandledRequests = allowUnhandledRequests
}

//GetAllowUnhandledRequests returns true if the server accepts unhandled requests.
func (s *Server) GetAllowUnhandledRequests() bool {
s.writeLock.Lock()
defer s.writeLock.Unlock()

return s.AllowUnhandledRequests
}

//SetUnhandledRequestStatusCode status code to be returned when the server receives unhandled requests
func (s *Server) SetUnhandledRequestStatusCode(statusCode int) {
s.writeLock.Lock()
defer s.writeLock.Unlock()

s.UnhandledRequestStatusCode = statusCode
}

//GetUnhandledRequestStatusCode returns the current status code being returned for unhandled requests
func (s *Server) GetUnhandledRequestStatusCode() int {
s.writeLock.Lock()
defer s.writeLock.Unlock()

return s.UnhandledRequestStatusCode
}
10 changes: 7 additions & 3 deletions ghttp/test_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,14 @@ var _ = Describe("TestServer", func() {
})

Describe("allowing unhandled requests", func() {
It("is not permitted by default", func() {
Expect(s.GetAllowUnhandledRequests()).To(BeFalse())
})

Context("when true", func() {
BeforeEach(func() {
s.AllowUnhandledRequests = true
s.UnhandledRequestStatusCode = http.StatusForbidden
s.SetAllowUnhandledRequests(true)
s.SetUnhandledRequestStatusCode(http.StatusForbidden)
resp, err = http.Get(s.URL() + "/foo")
Ω(err).ShouldNot(HaveOccurred())
})
Expand Down Expand Up @@ -226,7 +230,7 @@ var _ = Describe("TestServer", func() {

Describe("When a handler fails", func() {
BeforeEach(func() {
s.UnhandledRequestStatusCode = http.StatusForbidden //just to be clear that 500s aren't coming from unhandled requests
s.SetUnhandledRequestStatusCode(http.StatusForbidden) //just to be clear that 500s aren't coming from unhandled requests
})

Context("because the handler has panicked", func() {
Expand Down

0 comments on commit 13057c3

Please sign in to comment.