This repository has been archived by the owner on Mar 14, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #15 from hellofresh/add-middleware
Add middleware
- Loading branch information
Showing
7 changed files
with
221 additions
and
1 deletion.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
package context | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/sirupsen/logrus" | ||
) | ||
|
||
type loggerKeyType int | ||
|
||
const loggerKey loggerKeyType = iota | ||
|
||
// New returns a context that has a logrus logger | ||
func New(ctx context.Context) context.Context { | ||
return context.WithValue(ctx, loggerKey, WithContext(ctx)) | ||
} | ||
|
||
// WithContext returns a logrus logger from the context | ||
func WithContext(ctx context.Context) *logrus.Logger { | ||
if ctx == nil { | ||
return logrus.StandardLogger() | ||
} | ||
|
||
if ctxLogger, ok := ctx.Value(loggerKey).(*logrus.Logger); ok { | ||
return ctxLogger | ||
} | ||
|
||
return logrus.StandardLogger() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
package context | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestContext(t *testing.T) { | ||
t.Parallel() | ||
|
||
tests := []struct { | ||
scenario string | ||
function func(*testing.T) | ||
}{ | ||
{ | ||
scenario: "when set the context", | ||
function: testSetContext, | ||
}, | ||
{ | ||
scenario: "when the context is nil", | ||
function: testGetLoggerWhenContextIsNil, | ||
}, | ||
{ | ||
scenario: "when the logger is not the context", | ||
function: testGetLoggerWhenNoLogIsOnContext, | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
t.Run(test.scenario, func(t *testing.T) { | ||
test.function(t) | ||
}) | ||
} | ||
} | ||
|
||
func testSetContext(t *testing.T) { | ||
ctx := context.Background() | ||
ctx = New(ctx) | ||
|
||
logger := WithContext(ctx) | ||
require.NotNil(t, logger) | ||
} | ||
|
||
func testGetLoggerWhenContextIsNil(t *testing.T) { | ||
client := WithContext(nil) | ||
require.NotNil(t, client) | ||
} | ||
|
||
func testGetLoggerWhenNoLogIsOnContext(t *testing.T) { | ||
client := WithContext(context.Background()) | ||
require.NotNil(t, client) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
package middleware | ||
|
||
import ( | ||
"net/http" | ||
"net/url" | ||
"time" | ||
|
||
"github.com/felixge/httpsnoop" | ||
"github.com/hellofresh/logging-go/context" | ||
"github.com/sirupsen/logrus" | ||
) | ||
|
||
// New creates a new stats middleware | ||
func New() func(http.Handler) http.Handler { | ||
return func(handler http.Handler) http.Handler { | ||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
r = r.WithContext(context.New(r.Context())) | ||
|
||
logger := context.WithContext(r.Context()) | ||
logger.WithFields(logrus.Fields{"method": r.Method, "path": r.URL.Path}).Debug("Started request") | ||
|
||
// reverse proxy replaces original request with target request, so keep original one | ||
originalURL := &url.URL{} | ||
*originalURL = *r.URL | ||
|
||
fields := logrus.Fields{ | ||
"method": r.Method, | ||
"host": r.Host, | ||
"request": r.RequestURI, | ||
"remote-addr": r.RemoteAddr, | ||
"referer": r.Referer(), | ||
"user-agent": r.UserAgent(), | ||
} | ||
|
||
m := httpsnoop.CaptureMetrics(handler, w, r) | ||
|
||
fields["code"] = m.Code | ||
fields["duration"] = int(m.Duration / time.Millisecond) | ||
fields["duration-fmt"] = m.Duration.String() | ||
|
||
if originalURL.String() != r.URL.String() { | ||
fields["upstream-host"] = r.URL.Host | ||
fields["upstream-request"] = r.URL.RequestURI() | ||
} | ||
|
||
logger.WithFields(fields).Info("Completed handling request") | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package middleware | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestMiddleware(t *testing.T) { | ||
t.Parallel() | ||
|
||
tests := []struct { | ||
scenario string | ||
function func(*testing.T, *http.Request, *httptest.ResponseRecorder) | ||
}{ | ||
{ | ||
scenario: "when a request is successful", | ||
function: testRecorded, | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
t.Run(test.scenario, func(t *testing.T) { | ||
r := httptest.NewRequest(http.MethodGet, "/", nil) | ||
w := httptest.NewRecorder() | ||
test.function(t, r, w) | ||
}) | ||
} | ||
} | ||
|
||
func testRecorded(t *testing.T, r *http.Request, w *httptest.ResponseRecorder) { | ||
mw := New() | ||
mw(http.HandlerFunc(ping)).ServeHTTP(w, r) | ||
assert.Equal(t, http.StatusOK, w.Code) | ||
} | ||
|
||
// ping is a test handler | ||
func ping(w http.ResponseWriter, r *http.Request) { | ||
w.Header().Add("Content-Type", "application/json") | ||
w.Write([]byte("OK\n")) | ||
} |