Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions json.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,20 @@ import (
log "github.com/Sirupsen/logrus"
)

// ContextKeys is a type alias for string to namespace Context keys per-package.
type ContextKeys string
// contextKeys is a type alias for string to namespace Context keys per-package.
type contextKeys string

// CtxValueLogger is the key to extract the logrus Logger.
const CtxValueLogger = ContextKeys("logger")
// ctxValueLogger is the key to extract the logrus Logger.
const ctxValueLogger = contextKeys("logger")

// GetLogger retrieves the logrus logger from the supplied context. Returns nil if there is no logger.
func GetLogger(ctx context.Context) *log.Entry {
l := ctx.Value(ctxValueLogger)
if l == nil {
return nil
}
return l.(*log.Entry)
}

// JSONRequestHandler represents an interface that must be satisfied in order to respond to incoming
// HTTP requests with JSON. The interface returned will be marshalled into JSON to be sent to the client,
Expand All @@ -34,12 +43,12 @@ type JSONError struct {

// Protect panicking HTTP requests from taking down the entire process, and log them using
// the correct logger, returning a 500 with a JSON response rather than abruptly closing the
// connection. The http.Request MUST have a CtxValueLogger.
// connection. The http.Request MUST have a ctxValueLogger.
func Protect(handler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
defer func() {
if r := recover(); r != nil {
logger := req.Context().Value(CtxValueLogger).(*log.Entry)
logger := req.Context().Value(ctxValueLogger).(*log.Entry)
logger.WithFields(log.Fields{
"panic": r,
}).Errorf(
Expand All @@ -56,18 +65,18 @@ func Protect(handler http.HandlerFunc) http.HandlerFunc {

// MakeJSONAPI creates an HTTP handler which always responds to incoming requests with JSON responses.
// Incoming http.Requests will have a logger (with a request ID/method/path logged) attached to the Context.
// This can be accessed via the const CtxValueLogger. The type of the logger is *log.Entry from github.com/Sirupsen/logrus
// This can be accessed via GetLogger(Context). The type of the logger is *log.Entry from github.com/Sirupsen/logrus
func MakeJSONAPI(handler JSONRequestHandler) http.HandlerFunc {
return Protect(func(w http.ResponseWriter, req *http.Request) {
// Set a Logger on the context
ctx := context.WithValue(req.Context(), CtxValueLogger, log.WithFields(log.Fields{
ctx := context.WithValue(req.Context(), ctxValueLogger, log.WithFields(log.Fields{
"req.method": req.Method,
"req.path": req.URL.Path,
"req.id": RandomString(12),
}))
req = req.WithContext(ctx)

logger := req.Context().Value(CtxValueLogger).(*log.Entry)
logger := req.Context().Value(ctxValueLogger).(*log.Entry)
logger.Print("Incoming request")

res, httpErr := handler.OnIncomingRequest(req)
Expand Down Expand Up @@ -99,7 +108,7 @@ func MakeJSONAPI(handler JSONRequestHandler) http.HandlerFunc {
}

func jsonErrorResponse(w http.ResponseWriter, req *http.Request, httpErr *HTTPError) {
logger := req.Context().Value(CtxValueLogger).(*log.Entry)
logger := req.Context().Value(ctxValueLogger).(*log.Entry)
if httpErr.Code == 302 {
logger.WithField("err", httpErr.Error()).Print("Redirecting")
http.Redirect(w, req, httpErr.Message, 302)
Expand Down
20 changes: 19 additions & 1 deletion json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,30 @@ func TestMakeJSONAPIRedirect(t *testing.T) {
}
}

func TestGetLogger(t *testing.T) {
log.SetLevel(log.PanicLevel) // suppress logs in test output
entry := log.WithField("test", "yep")
mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil)
ctx := context.WithValue(mockReq.Context(), ctxValueLogger, entry)
mockReq = mockReq.WithContext(ctx)
ctxLogger := GetLogger(mockReq.Context())
if ctxLogger != entry {
t.Errorf("TestGetLogger wanted logger '%v', got '%v'", entry, ctxLogger)
}

noLoggerInReq, _ := http.NewRequest("GET", "http://example.com/foo", nil)
ctxLogger = GetLogger(noLoggerInReq.Context())
if ctxLogger != nil {
t.Errorf("TestGetLogger wanted nil logger, got '%v'", ctxLogger)
}
}

func TestProtect(t *testing.T) {
log.SetLevel(log.PanicLevel) // suppress logs in test output
mockWriter := httptest.NewRecorder()
mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil)
mockReq = mockReq.WithContext(
context.WithValue(mockReq.Context(), CtxValueLogger, log.WithField("test", "yep")),
context.WithValue(mockReq.Context(), ctxValueLogger, log.WithField("test", "yep")),
)
h := Protect(func(w http.ResponseWriter, req *http.Request) {
panic("oh noes!")
Expand Down