Skip to content

Commit

Permalink
Merge pull request #611 from lobotomist/module/apmhttp/panic-propagation
Browse files Browse the repository at this point in the history
Added WithPanicPropagation function
  • Loading branch information
axw committed Aug 8, 2019
2 parents bfee654 + 12143d4 commit 7215d0a
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 6 deletions.
29 changes: 23 additions & 6 deletions module/apmhttp/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ func Wrap(h http.Handler, o ...ServerOption) http.Handler {
//
// The http.Request's context will be updated with the transaction.
type handler struct {
handler http.Handler
tracer *apm.Tracer
recovery RecoveryFunc
requestName RequestNameFunc
requestIgnorer RequestIgnorerFunc
handler http.Handler
tracer *apm.Tracer
recovery RecoveryFunc
panicPropagation bool
requestName RequestNameFunc
requestIgnorer RequestIgnorerFunc
}

// ServeHTTP delegates to h.Handler, tracing the transaction with
Expand All @@ -77,7 +78,14 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
w, resp := WrapResponseWriter(w)
defer func() {
if v := recover(); v != nil {
if resp.StatusCode == 0 {
if h.panicPropagation {
defer panic(v)
// 500 status code will be set only for APM transaction
// to allow other middleware to choose a different response code
if resp.StatusCode == 0 {
resp.StatusCode = http.StatusInternalServerError
}
} else if resp.StatusCode == 0 {
w.WriteHeader(http.StatusInternalServerError)
}
h.recovery(w, req, resp, body, tx, v)
Expand Down Expand Up @@ -260,6 +268,15 @@ func WithRecovery(r RecoveryFunc) ServerOption {
}
}

// WithPanicPropagation returns a ServerOption which enable panic propagation.
// Any panic will be recovered and recorded as an error in a transaction, then
// panic will be caused again.
func WithPanicPropagation() ServerOption {
return func(h *handler) {
h.panicPropagation = true
}
}

// RequestNameFunc is the type of a function for use in
// WithServerRequestName.
type RequestNameFunc func(*http.Request) string
Expand Down
80 changes: 80 additions & 0 deletions module/apmhttp/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,70 @@ func TestHandlerRecoveryNoHeaders(t *testing.T) {
assert.Equal(t, &model.Response{StatusCode: resp.StatusCode}, error0.Context.Response)
}

func TestHandlerWithPanicPropagation(t *testing.T) {
tracer, transport := transporttest.NewRecorderTracer()
defer tracer.Close()

h := apmhttp.Wrap(
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
panic("foo")
}),
apmhttp.WithTracer(tracer),
apmhttp.WithPanicPropagation(),
)

recovery := recoveryMiddleware(http.StatusBadGateway)
h = recovery(h)

server := httptest.NewServer(h)
defer server.Close()

resp, err := http.Get(server.URL)
require.NoError(t, err)
resp.Body.Close()

assert.Equal(t, http.StatusBadGateway, resp.StatusCode)
tracer.Flush(nil)

payloads := transport.Payloads()
error0 := payloads.Errors[0]
transaction := payloads.Transactions[0]

assert.Equal(t, &model.Response{StatusCode: http.StatusInternalServerError}, transaction.Context.Response)
assert.Equal(t, &model.Response{StatusCode: http.StatusInternalServerError}, error0.Context.Response)
}

func TestHandlerWithPanicPropagationResponseCodeForwarding(t *testing.T) {
tracer, transport := transporttest.NewRecorderTracer()
defer tracer.Close()

h := apmhttp.Wrap(
http.HandlerFunc(panicHandler),
apmhttp.WithTracer(tracer),
apmhttp.WithPanicPropagation(),
)

recovery := recoveryMiddleware(0)
h = recovery(h)

server := httptest.NewServer(h)
defer server.Close()

resp, err := http.Get(server.URL)
require.NoError(t, err)
resp.Body.Close()

assert.Equal(t, http.StatusTeapot, resp.StatusCode)
tracer.Flush(nil)

payloads := transport.Payloads()
error0 := payloads.Errors[0]
transaction := payloads.Transactions[0]

assert.Equal(t, &model.Response{StatusCode: resp.StatusCode}, transaction.Context.Response)
assert.Equal(t, &model.Response{StatusCode: resp.StatusCode}, error0.Context.Response)
}

func TestHandlerRequestIgnorer(t *testing.T) {
tracer, transport := transporttest.NewRecorderTracer()
defer tracer.Close()
Expand Down Expand Up @@ -372,3 +436,19 @@ func panicHandler(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusTeapot)
panic("foo")
}

func recoveryMiddleware(code int) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
defer func() {
if v := recover(); v != nil {
if code == 0 {
return
}
w.WriteHeader(code)
}
}()
next.ServeHTTP(w, req)
})
}
}

0 comments on commit 7215d0a

Please sign in to comment.