-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(http): add http server middleware (#68)
- Loading branch information
Showing
9 changed files
with
523 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
package middleware | ||
|
||
import ( | ||
"log/slog" | ||
"net/http" | ||
"time" | ||
) | ||
|
||
// RequestLogger logs incoming HTTP requests. | ||
func RequestLogger(logger *slog.Logger, logLevel slog.Level, formatter RequestLogFormatter) func(http.Handler) http.Handler { | ||
return func(next http.Handler) http.Handler { | ||
fn := func(w http.ResponseWriter, r *http.Request) { | ||
start := time.Now() | ||
lrw := &loggingResponseWriter{ResponseWriter: w} | ||
|
||
next.ServeHTTP(lrw, r) | ||
|
||
logger.LogAttrs(r.Context(), logLevel, "http request", formatter.FormatRequest(r, lrw.statusCode, time.Since(start))...) | ||
} | ||
|
||
return http.HandlerFunc(fn) | ||
} | ||
} | ||
|
||
// A RequestLogFormatter takes the HTTP request, the resulting HTTP status code and latency and formats the log entry. | ||
type RequestLogFormatter interface { | ||
FormatRequest(*http.Request, int, time.Duration) []slog.Attr | ||
} | ||
|
||
// DefaultRequestLogFormatter is the default RequestLogFormatter. It logs the request's HTTP method and the path. | ||
var DefaultRequestLogFormatter RequestLogFormatter = &defaultRequestLogFormatter{} | ||
|
||
type defaultRequestLogFormatter struct{} | ||
|
||
func (d defaultRequestLogFormatter) FormatRequest(r *http.Request, statusCode int, latency time.Duration) []slog.Attr { | ||
return []slog.Attr{slog.String("path", r.URL.Path), slog.String("method", r.Method), | ||
slog.Int("code", statusCode), slog.Duration("latency", latency), | ||
} | ||
} | ||
|
||
// The RequestLogFormatterFunc type is an adapter that allows an ordinary function to be used as a RequestLogFormatter. | ||
type RequestLogFormatterFunc func(r *http.Request, statusCode int, latency time.Duration) []slog.Attr | ||
|
||
// FormatRequest calls f(r, statusCode, latency) | ||
func (f RequestLogFormatterFunc) FormatRequest(r *http.Request, statusCode int, latency time.Duration) []slog.Attr { | ||
return f(r, statusCode, latency) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
package middleware_test | ||
|
||
import ( | ||
"bytes" | ||
"github.com/clambin/go-common/http/middleware" | ||
"github.com/stretchr/testify/assert" | ||
"log/slog" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
"time" | ||
) | ||
|
||
func TestLogger(t *testing.T) { | ||
out := bytes.NewBufferString("") | ||
opt := slog.HandlerOptions{Level: slog.LevelDebug, ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr { | ||
// Remove time from the output for predictable test output. | ||
if a.Key == slog.TimeKey { | ||
return slog.Attr{} | ||
} | ||
return a | ||
}} | ||
l := slog.New(slog.NewTextHandler(out, &opt)) | ||
slog.SetDefault(l) | ||
|
||
testCases := []struct { | ||
name string | ||
level slog.Level | ||
logger middleware.RequestLogFormatter | ||
want string | ||
}{ | ||
{ | ||
name: "default", | ||
logger: middleware.DefaultRequestLogFormatter, | ||
want: `level=INFO msg="http request" path=/ method=GET code=200 latency=`, | ||
}, | ||
{ | ||
name: "none", | ||
level: slog.LevelDebug, | ||
logger: middleware.DefaultRequestLogFormatter, | ||
}, | ||
{ | ||
name: "custom", | ||
logger: middleware.RequestLogFormatterFunc(func(r *http.Request, code int, latency time.Duration) []slog.Attr { | ||
return []slog.Attr{ | ||
slog.String("client", r.RemoteAddr), | ||
slog.String("path", r.URL.Path), slog.String("method", r.Method), | ||
slog.Int("code", code), slog.Duration("latency", latency), | ||
} | ||
}), | ||
want: `level=INFO msg="http request" client=127.0.0.1:5000 path=/ method=GET code=200 latency=`, | ||
}, | ||
} | ||
|
||
for _, tt := range testCases { | ||
t.Run(tt.name, func(t *testing.T) { | ||
out.Reset() | ||
|
||
r := http.NewServeMux() | ||
r.Handle("/", middleware.RequestLogger(l, tt.level, tt.logger)(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { | ||
_, _ = writer.Write([]byte("hello")) | ||
}))) | ||
|
||
req, _ := http.NewRequest(http.MethodGet, "/", nil) | ||
req.RemoteAddr = "127.0.0.1:5000" | ||
w := httptest.NewRecorder() | ||
r.ServeHTTP(w, req) | ||
|
||
assert.Equal(t, http.StatusOK, w.Code) | ||
assert.Contains(t, out.String(), tt.want) | ||
|
||
}) | ||
} | ||
} | ||
|
||
/* | ||
func TestDefaultLogger(t *testing.T) { | ||
out := bytes.NewBufferString("") | ||
opt := slog.HandlerOptions{Level: slog.LevelDebug, ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr { | ||
// Remove time from the output for predictable test output. | ||
if a.Key == slog.TimeKey { | ||
return slog.Attr{} | ||
} | ||
return a | ||
}} | ||
l := slog.New(slog.NewTextHandler(out, &opt)) | ||
slog.SetDefault(l) | ||
r := http.NewServeMux() | ||
r.Handle("/", middleware.RequestLogger(middleware.DefaultRequestLogger{})(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { | ||
_, _ = writer.Write([]byte("hello")) | ||
}))) | ||
req, _ := http.NewRequest(http.MethodGet, "/", nil) | ||
w := httptest.NewRecorder() | ||
r.ServeHTTP(w, req) | ||
assert.Equal(t, http.StatusOK, w.Code) | ||
assert.Contains(t, out.String(), `level=INFO msg=request path=/ method=GET code=200 latency=`) | ||
} | ||
func BenchmarkLogger(b *testing.B) { | ||
out := bytes.NewBufferString("") | ||
opt := slog.HandlerOptions{Level: slog.LevelInfo, ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr { | ||
// Remove time from the output for predictable test output. | ||
if a.Key == slog.TimeKey { | ||
return slog.Attr{} | ||
} | ||
return a | ||
}} | ||
l := slog.New(slog.NewTextHandler(out, &opt)) | ||
slog.SetDefault(l) | ||
r := http.NewServeMux() | ||
r.Handle("/", middleware.RequestLogger(middleware.DefaultRequestLogger{})(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { | ||
_, _ = writer.Write([]byte("hello")) | ||
}))) | ||
req, _ := http.NewRequest(http.MethodGet, "/", nil) | ||
w := httptest.NewRecorder() | ||
for i := 0; i < b.N; i++ { | ||
r.ServeHTTP(w, req) | ||
if w.Code != http.StatusOK { | ||
b.Fail() | ||
} | ||
} | ||
} | ||
*/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
package middleware | ||
|
||
import ( | ||
"github.com/prometheus/client_golang/prometheus" | ||
"net/http" | ||
"strconv" | ||
"time" | ||
) | ||
|
||
var _ http.Handler = serverMetricsHandler{} | ||
|
||
type serverMetricsHandler struct { | ||
next http.Handler | ||
metrics ServerMetrics | ||
} | ||
|
||
func WithServerMetrics(m ServerMetrics) func(next http.Handler) http.Handler { | ||
return func(next http.Handler) http.Handler { | ||
return serverMetricsHandler{next: next, metrics: m} | ||
} | ||
} | ||
|
||
func (s serverMetricsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { | ||
lrw := &loggingResponseWriter{ResponseWriter: w, statusCode: http.StatusOK} | ||
start := time.Now() | ||
s.next.ServeHTTP(lrw, r) | ||
s.metrics.Measure(r, lrw.statusCode, time.Since(start)) | ||
} | ||
|
||
///////////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
type ServerMetrics interface { | ||
Measure(req *http.Request, statusCode int, duration time.Duration) | ||
prometheus.Collector | ||
} | ||
|
||
///////////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
var _ ServerMetrics = &DefaultServerSummaryMetrics{} | ||
|
||
type DefaultServerSummaryMetrics struct { | ||
requests *prometheus.CounterVec | ||
duration *prometheus.SummaryVec | ||
} | ||
|
||
func NewDefaultServerSummaryMetrics(namespace, subsystem, application string) *DefaultServerSummaryMetrics { | ||
var constLabels map[string]string | ||
if application != "" { | ||
constLabels = map[string]string{"application": application} | ||
} | ||
return &DefaultServerSummaryMetrics{ | ||
requests: prometheus.NewCounterVec(prometheus.CounterOpts{ | ||
Namespace: namespace, | ||
Subsystem: subsystem, | ||
Name: "http_server_requests_total", | ||
Help: "total number of http server requests", | ||
ConstLabels: constLabels, | ||
}, | ||
[]string{"method", "path", "code"}, | ||
), | ||
duration: prometheus.NewSummaryVec(prometheus.SummaryOpts{ | ||
Namespace: namespace, | ||
Subsystem: subsystem, | ||
Name: "http_server_request_duration_seconds", | ||
Help: "total number of http server requests", | ||
ConstLabels: constLabels, | ||
}, | ||
[]string{"method", "path", "code"}, | ||
), | ||
} | ||
} | ||
|
||
func (d DefaultServerSummaryMetrics) Measure(req *http.Request, statusCode int, duration time.Duration) { | ||
code := strconv.Itoa(statusCode) | ||
d.requests.WithLabelValues(req.Method, req.URL.Path, code).Inc() | ||
d.duration.WithLabelValues(req.Method, req.URL.Path, code).Observe(duration.Seconds()) | ||
} | ||
|
||
func (d DefaultServerSummaryMetrics) Describe(ch chan<- *prometheus.Desc) { | ||
d.requests.Describe(ch) | ||
d.duration.Describe(ch) | ||
} | ||
|
||
func (d DefaultServerSummaryMetrics) Collect(ch chan<- prometheus.Metric) { | ||
d.requests.Collect(ch) | ||
d.duration.Collect(ch) | ||
} | ||
|
||
///////////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
var _ ServerMetrics = &DefaultServerHistogramMetrics{} | ||
|
||
type DefaultServerHistogramMetrics struct { | ||
requests *prometheus.CounterVec | ||
duration *prometheus.HistogramVec | ||
} | ||
|
||
func NewDefaultServerHistogramMetrics(namespace, subsystem, application string, buckets ...float64) *DefaultServerHistogramMetrics { | ||
var constLabels map[string]string | ||
if application != "" { | ||
constLabels = map[string]string{"application": application} | ||
} | ||
if len(buckets) == 0 { | ||
buckets = prometheus.DefBuckets | ||
} | ||
return &DefaultServerHistogramMetrics{ | ||
requests: prometheus.NewCounterVec(prometheus.CounterOpts{ | ||
Namespace: namespace, | ||
Subsystem: subsystem, | ||
Name: "http_server_requests_total", | ||
Help: "total number of http server requests", | ||
ConstLabels: constLabels, | ||
}, | ||
[]string{"method", "path", "code"}, | ||
), | ||
duration: prometheus.NewHistogramVec(prometheus.HistogramOpts{ | ||
Namespace: namespace, | ||
Subsystem: subsystem, | ||
Name: "http_server_request_duration_seconds", | ||
Help: "total number of http server requests", | ||
ConstLabels: constLabels, | ||
Buckets: buckets, | ||
}, | ||
[]string{"method", "path", "code"}, | ||
), | ||
} | ||
} | ||
|
||
func (d DefaultServerHistogramMetrics) Measure(req *http.Request, statusCode int, duration time.Duration) { | ||
code := strconv.Itoa(statusCode) | ||
d.requests.WithLabelValues(req.Method, req.URL.Path, code).Inc() | ||
d.duration.WithLabelValues(req.Method, req.URL.Path, code).Observe(duration.Seconds()) | ||
} | ||
|
||
func (d DefaultServerHistogramMetrics) Describe(ch chan<- *prometheus.Desc) { | ||
d.requests.Describe(ch) | ||
d.duration.Describe(ch) | ||
} | ||
|
||
func (d DefaultServerHistogramMetrics) Collect(ch chan<- prometheus.Metric) { | ||
d.requests.Collect(ch) | ||
d.duration.Collect(ch) | ||
} | ||
|
||
///////////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
var _ ServerMetrics = NoMetrics{} | ||
|
||
type NoMetrics struct{} | ||
|
||
func (n NoMetrics) Measure(_ *http.Request, _ int, _ time.Duration) { | ||
} | ||
|
||
func (n NoMetrics) Describe(_ chan<- *prometheus.Desc) { | ||
} | ||
|
||
func (n NoMetrics) Collect(_ chan<- prometheus.Metric) { | ||
} | ||
|
||
//////////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
type loggingResponseWriter struct { | ||
http.ResponseWriter | ||
wroteHeader bool | ||
statusCode int | ||
} | ||
|
||
// WriteHeader implements the http.ResponseWriter interface. | ||
func (w *loggingResponseWriter) WriteHeader(code int) { | ||
w.ResponseWriter.WriteHeader(code) | ||
w.statusCode = code | ||
w.wroteHeader = true | ||
} | ||
|
||
// Write implements the http.ResponseWriter interface. | ||
func (w *loggingResponseWriter) Write(body []byte) (int, error) { | ||
if !w.wroteHeader { | ||
w.WriteHeader(http.StatusOK) | ||
} | ||
return w.ResponseWriter.Write(body) | ||
} |
Oops, something went wrong.