-
Notifications
You must be signed in to change notification settings - Fork 7
/
middlewares.go
126 lines (104 loc) · 2.92 KB
/
middlewares.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
package server
import (
"fmt"
"net/http"
"strings"
"time"
gorillamux "github.com/gorilla/mux"
"github.com/rs/xid"
"go.opencensus.io/plugin/ochttp"
"go.opencensus.io/tag"
"go.opencensus.io/trace"
"go.uber.org/zap"
)
const (
grpcGatewayPrefix = "/api"
headerRequestID = "X-Request-Id"
)
type wrappedWriter struct {
http.ResponseWriter
Status int
}
func (wr *wrappedWriter) WriteHeader(statusCode int) {
wr.Status = statusCode
wr.ResponseWriter.WriteHeader(statusCode)
}
func withOpenCensus() gorillamux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
oc := &ochttp.Handler{
Handler: next,
FormatSpanName: formatSpanName,
IsPublicEndpoint: false,
}
return http.HandlerFunc(func(wr http.ResponseWriter, req *http.Request) {
route := gorillamux.CurrentRoute(req)
pathTpl := req.URL.Path
if route != nil {
pathTpl, _ = route.GetPathTemplate()
}
if strings.HasPrefix(pathTpl, grpcGatewayPrefix) {
// FIX: figure out a way to extract path-pattern from gateway requests.
pathTpl = "/api/"
}
ctx, _ := tag.New(req.Context(),
tag.Insert(ochttp.KeyServerRoute, pathTpl),
tag.Insert(ochttp.Method, req.Method),
)
oc.ServeHTTP(wr, req.WithContext(ctx))
})
}
}
func requestID() gorillamux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(wr http.ResponseWriter, req *http.Request) {
rid := strings.TrimSpace(req.Header.Get(headerRequestID))
if rid == "" {
rid = xid.New().String()
}
headers := req.Header.Clone()
headers.Set(headerRequestID, rid)
wr.Header().Set(headerRequestID, rid)
req.Header = headers
next.ServeHTTP(wr, req)
})
}
}
func requestLogger(lg *zap.Logger) gorillamux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(wr http.ResponseWriter, req *http.Request) {
t := time.Now()
span := trace.FromContext(req.Context())
clientID, _, _ := req.BasicAuth()
fields := []zap.Field{
zap.String("path", req.URL.Path),
zap.String("method", req.Method),
zap.String("request_id", req.Header.Get(headerRequestID)),
zap.String("client_id", clientID),
zap.String("trace_id", span.SpanContext().TraceID.String()),
}
wrapped := &wrappedWriter{ResponseWriter: wr}
next.ServeHTTP(wrapped, req)
fields = append(fields,
zap.Duration("response_time", time.Since(t)),
zap.Int("status", wrapped.Status),
)
if !is2xx(wrapped.Status) {
lg.Warn("request handled with non-2xx response", fields...)
} else {
lg.Info("request handled", fields...)
}
})
}
}
func formatSpanName(req *http.Request) string {
route := gorillamux.CurrentRoute(req)
pathTpl := req.URL.Path
if route != nil {
pathTpl, _ = route.GetPathTemplate()
}
return fmt.Sprintf("%s %s", req.Method, pathTpl)
}
func is2xx(status int) bool {
const max2xxCode = 299
return status >= http.StatusOK && status < max2xxCode
}