/
middleware.go
120 lines (106 loc) · 3.11 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
package middleware
import (
"fmt"
"net/http"
"runtime"
"github.com/felixge/httpsnoop"
"github.com/honeycombio/beeline-go"
"github.com/honeycombio/beeline-go/wrappers/hnynethttp"
ctxdata "github.com/peterbourgon/ctxdata/v4"
"github.com/rs/zerolog"
"golang.org/x/time/rate"
)
// Wrap wraps an http handler with middleware to add instrumentation, error
// handling, authentication, and rate limiting.
func Wrap(h http.Handler, authMap AuthMap, rl *rate.Limiter, l zerolog.Logger) http.Handler {
h = corsHandler(h)
h = rateLimitHandler(h, rl)
h = authHandler(h, authMap)
h = panicHandler(h)
h = observeHandler(h, l)
h = hnynethttp.WrapHandler(h)
return h
}
func observeHandler(next http.Handler, l zerolog.Logger) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, d := ctxdata.New(r.Context())
m := httpsnoop.CaptureMetrics(next, w, r.WithContext(ctx))
rec := logRecord{
ClientID: d.GetString("client_id"),
DurationMS: m.Duration.Milliseconds(),
Method: r.Method,
RemoteAddr: getRemoteAddr(r),
Size: m.Written,
Status: m.Code,
URL: r.URL.String(),
UserAgent: r.Header.Get("User-Agent"),
}
if err := d.GetError("error"); err != nil {
rec.Error = err.Error()
beeline.AddField(ctx, "error", err)
// stack might be added by panicHandler
if stack := d.GetString("stack"); stack != "" {
rec.Stack = stack
beeline.AddField(ctx, "stack", stack)
}
}
evt := l.Info()
if rec.Status != http.StatusOK || rec.Error != "" {
evt = l.Error()
}
evt.EmbedObject(rec).Send()
})
}
func panicHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if p := recover(); p != nil {
buf := make([]byte, 2048)
n := runtime.Stack(buf, false)
stack := string(buf[:n])
d := ctxdata.From(r.Context())
_ = d.Set("error", fmt.Errorf("panic: %s", p))
_ = d.Set("stack", stack)
w.WriteHeader(http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}
func getRemoteAddr(r *http.Request) string {
remoteAddr := r.Header.Get("Fly-Client-IP")
if remoteAddr == "" {
remoteAddr = r.RemoteAddr
}
return remoteAddr
}
type logRecord struct {
ClientID string `json:"client_id"`
DurationMS int64 `json:"duration_ms"`
Method string `json:"method"`
RemoteAddr string `json:"remote_addr"`
Size int64 `json:"size"`
Status int `json:"status"`
URL string `json:"url"`
UserAgent string `json:"user_agent"`
// Only added when an error and/or a panic occurs
Error string `json:"error,omitempty"`
Stack string `json:"stack,omitempty"`
}
func (rec logRecord) MarshalZerologObject(e *zerolog.Event) {
e.Int("status", rec.Status)
e.Int64("duration_ms", rec.DurationMS)
e.Int64("size", rec.Size)
e.Str("method", rec.Method)
e.Str("remote_addr", rec.RemoteAddr)
e.Str("url", rec.URL)
e.Str("user_agent", rec.UserAgent)
e.Str("client_id", rec.ClientID)
if rec.Error != "" {
e.Str("error", rec.Error)
if rec.Stack != "" {
e.Str("stack", rec.Stack)
}
}
}
var _ zerolog.LogObjectMarshaler = logRecord{}