Skip to content

Commit

Permalink
feat(http): add http server middleware (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
clambin committed Mar 19, 2024
1 parent 4c53181 commit 99e2edd
Show file tree
Hide file tree
Showing 9 changed files with 523 additions and 0 deletions.
47 changes: 47 additions & 0 deletions http/middleware/logger.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package middleware

import (
"log/slog"
"net/http"
"time"
)

// RequestLogger logs incoming HTTP requests.
func RequestLogger(logger *slog.Logger, logLevel slog.Level, formatter RequestLogFormatter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
lrw := &loggingResponseWriter{ResponseWriter: w}

next.ServeHTTP(lrw, r)

logger.LogAttrs(r.Context(), logLevel, "http request", formatter.FormatRequest(r, lrw.statusCode, time.Since(start))...)
}

return http.HandlerFunc(fn)
}
}

// A RequestLogFormatter takes the HTTP request, the resulting HTTP status code and latency and formats the log entry.
type RequestLogFormatter interface {
FormatRequest(*http.Request, int, time.Duration) []slog.Attr
}

// DefaultRequestLogFormatter is the default RequestLogFormatter. It logs the request's HTTP method and the path.
var DefaultRequestLogFormatter RequestLogFormatter = &defaultRequestLogFormatter{}

type defaultRequestLogFormatter struct{}

func (d defaultRequestLogFormatter) FormatRequest(r *http.Request, statusCode int, latency time.Duration) []slog.Attr {
return []slog.Attr{slog.String("path", r.URL.Path), slog.String("method", r.Method),
slog.Int("code", statusCode), slog.Duration("latency", latency),
}
}

// The RequestLogFormatterFunc type is an adapter that allows an ordinary function to be used as a RequestLogFormatter.
type RequestLogFormatterFunc func(r *http.Request, statusCode int, latency time.Duration) []slog.Attr

// FormatRequest calls f(r, statusCode, latency)
func (f RequestLogFormatterFunc) FormatRequest(r *http.Request, statusCode int, latency time.Duration) []slog.Attr {
return f(r, statusCode, latency)
}
130 changes: 130 additions & 0 deletions http/middleware/logger_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package middleware_test

import (
"bytes"
"github.com/clambin/go-common/http/middleware"
"github.com/stretchr/testify/assert"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"
)

func TestLogger(t *testing.T) {
out := bytes.NewBufferString("")
opt := slog.HandlerOptions{Level: slog.LevelDebug, ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr {
// Remove time from the output for predictable test output.
if a.Key == slog.TimeKey {
return slog.Attr{}
}
return a
}}
l := slog.New(slog.NewTextHandler(out, &opt))
slog.SetDefault(l)

testCases := []struct {
name string
level slog.Level
logger middleware.RequestLogFormatter
want string
}{
{
name: "default",
logger: middleware.DefaultRequestLogFormatter,
want: `level=INFO msg="http request" path=/ method=GET code=200 latency=`,
},
{
name: "none",
level: slog.LevelDebug,
logger: middleware.DefaultRequestLogFormatter,
},
{
name: "custom",
logger: middleware.RequestLogFormatterFunc(func(r *http.Request, code int, latency time.Duration) []slog.Attr {
return []slog.Attr{
slog.String("client", r.RemoteAddr),
slog.String("path", r.URL.Path), slog.String("method", r.Method),
slog.Int("code", code), slog.Duration("latency", latency),
}
}),
want: `level=INFO msg="http request" client=127.0.0.1:5000 path=/ method=GET code=200 latency=`,
},
}

for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
out.Reset()

r := http.NewServeMux()
r.Handle("/", middleware.RequestLogger(l, tt.level, tt.logger)(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
_, _ = writer.Write([]byte("hello"))
})))

req, _ := http.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "127.0.0.1:5000"
w := httptest.NewRecorder()
r.ServeHTTP(w, req)

assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, out.String(), tt.want)

})
}
}

/*
func TestDefaultLogger(t *testing.T) {
out := bytes.NewBufferString("")
opt := slog.HandlerOptions{Level: slog.LevelDebug, ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr {
// Remove time from the output for predictable test output.
if a.Key == slog.TimeKey {
return slog.Attr{}
}
return a
}}
l := slog.New(slog.NewTextHandler(out, &opt))
slog.SetDefault(l)
r := http.NewServeMux()
r.Handle("/", middleware.RequestLogger(middleware.DefaultRequestLogger{})(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
_, _ = writer.Write([]byte("hello"))
})))
req, _ := http.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, out.String(), `level=INFO msg=request path=/ method=GET code=200 latency=`)
}
func BenchmarkLogger(b *testing.B) {
out := bytes.NewBufferString("")
opt := slog.HandlerOptions{Level: slog.LevelInfo, ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr {
// Remove time from the output for predictable test output.
if a.Key == slog.TimeKey {
return slog.Attr{}
}
return a
}}
l := slog.New(slog.NewTextHandler(out, &opt))
slog.SetDefault(l)
r := http.NewServeMux()
r.Handle("/", middleware.RequestLogger(middleware.DefaultRequestLogger{})(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
_, _ = writer.Write([]byte("hello"))
})))
req, _ := http.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
for i := 0; i < b.N; i++ {
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
b.Fail()
}
}
}
*/
181 changes: 181 additions & 0 deletions http/middleware/prometheus.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
package middleware

import (
"github.com/prometheus/client_golang/prometheus"
"net/http"
"strconv"
"time"
)

var _ http.Handler = serverMetricsHandler{}

type serverMetricsHandler struct {
next http.Handler
metrics ServerMetrics
}

func WithServerMetrics(m ServerMetrics) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return serverMetricsHandler{next: next, metrics: m}
}
}

func (s serverMetricsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
lrw := &loggingResponseWriter{ResponseWriter: w, statusCode: http.StatusOK}
start := time.Now()
s.next.ServeHTTP(lrw, r)
s.metrics.Measure(r, lrw.statusCode, time.Since(start))
}

/////////////////////////////////////////////////////////////////////////////////////////////////////////

type ServerMetrics interface {
Measure(req *http.Request, statusCode int, duration time.Duration)
prometheus.Collector
}

/////////////////////////////////////////////////////////////////////////////////////////////////////////

var _ ServerMetrics = &DefaultServerSummaryMetrics{}

type DefaultServerSummaryMetrics struct {
requests *prometheus.CounterVec
duration *prometheus.SummaryVec
}

func NewDefaultServerSummaryMetrics(namespace, subsystem, application string) *DefaultServerSummaryMetrics {
var constLabels map[string]string
if application != "" {
constLabels = map[string]string{"application": application}
}
return &DefaultServerSummaryMetrics{
requests: prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "http_server_requests_total",
Help: "total number of http server requests",
ConstLabels: constLabels,
},
[]string{"method", "path", "code"},
),
duration: prometheus.NewSummaryVec(prometheus.SummaryOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "http_server_request_duration_seconds",
Help: "total number of http server requests",
ConstLabels: constLabels,
},
[]string{"method", "path", "code"},
),
}
}

func (d DefaultServerSummaryMetrics) Measure(req *http.Request, statusCode int, duration time.Duration) {
code := strconv.Itoa(statusCode)
d.requests.WithLabelValues(req.Method, req.URL.Path, code).Inc()
d.duration.WithLabelValues(req.Method, req.URL.Path, code).Observe(duration.Seconds())
}

func (d DefaultServerSummaryMetrics) Describe(ch chan<- *prometheus.Desc) {
d.requests.Describe(ch)
d.duration.Describe(ch)
}

func (d DefaultServerSummaryMetrics) Collect(ch chan<- prometheus.Metric) {
d.requests.Collect(ch)
d.duration.Collect(ch)
}

/////////////////////////////////////////////////////////////////////////////////////////////////////////

var _ ServerMetrics = &DefaultServerHistogramMetrics{}

type DefaultServerHistogramMetrics struct {
requests *prometheus.CounterVec
duration *prometheus.HistogramVec
}

func NewDefaultServerHistogramMetrics(namespace, subsystem, application string, buckets ...float64) *DefaultServerHistogramMetrics {
var constLabels map[string]string
if application != "" {
constLabels = map[string]string{"application": application}
}
if len(buckets) == 0 {
buckets = prometheus.DefBuckets
}
return &DefaultServerHistogramMetrics{
requests: prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "http_server_requests_total",
Help: "total number of http server requests",
ConstLabels: constLabels,
},
[]string{"method", "path", "code"},
),
duration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "http_server_request_duration_seconds",
Help: "total number of http server requests",
ConstLabels: constLabels,
Buckets: buckets,
},
[]string{"method", "path", "code"},
),
}
}

func (d DefaultServerHistogramMetrics) Measure(req *http.Request, statusCode int, duration time.Duration) {
code := strconv.Itoa(statusCode)
d.requests.WithLabelValues(req.Method, req.URL.Path, code).Inc()
d.duration.WithLabelValues(req.Method, req.URL.Path, code).Observe(duration.Seconds())
}

func (d DefaultServerHistogramMetrics) Describe(ch chan<- *prometheus.Desc) {
d.requests.Describe(ch)
d.duration.Describe(ch)
}

func (d DefaultServerHistogramMetrics) Collect(ch chan<- prometheus.Metric) {
d.requests.Collect(ch)
d.duration.Collect(ch)
}

/////////////////////////////////////////////////////////////////////////////////////////////////////////

var _ ServerMetrics = NoMetrics{}

type NoMetrics struct{}

func (n NoMetrics) Measure(_ *http.Request, _ int, _ time.Duration) {
}

func (n NoMetrics) Describe(_ chan<- *prometheus.Desc) {
}

func (n NoMetrics) Collect(_ chan<- prometheus.Metric) {
}

////////////////////////////////////////////////////////////////////////////////////////////////////////

type loggingResponseWriter struct {
http.ResponseWriter
wroteHeader bool
statusCode int
}

// WriteHeader implements the http.ResponseWriter interface.
func (w *loggingResponseWriter) WriteHeader(code int) {
w.ResponseWriter.WriteHeader(code)
w.statusCode = code
w.wroteHeader = true
}

// Write implements the http.ResponseWriter interface.
func (w *loggingResponseWriter) Write(body []byte) (int, error) {
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
return w.ResponseWriter.Write(body)
}
Loading

0 comments on commit 99e2edd

Please sign in to comment.