-
Notifications
You must be signed in to change notification settings - Fork 6
/
middleware.go
146 lines (128 loc) · 4.48 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package server
import (
"context"
"fmt"
"net/http"
"runtime/debug"
opentracing "github.com/opentracing/opentracing-go"
tags "github.com/opentracing/opentracing-go/ext"
opentracinglog "github.com/opentracing/opentracing-go/log"
"gopkg.in/Clever/kayvee-go.v6/logger"
)
// PanicMiddleware logs any panics. For now, we're continue throwing the panic up
// the stack so this may crash the process.
func PanicMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
panicErr := recover()
if panicErr == nil {
return
}
var err error
switch panicErr := panicErr.(type) {
case string:
err = fmt.Errorf(panicErr)
case error:
err = panicErr
default:
err = fmt.Errorf("unknown panic %#v of type %T", panicErr, panicErr)
}
logger.FromContext(r.Context()).ErrorD("panic",
logger.M{"err": err, "stacktrace": string(debug.Stack())})
panic(panicErr)
}()
h.ServeHTTP(w, r)
})
}
// statusResponseWriter wraps a response writer
type statusResponseWriter struct {
http.ResponseWriter
status int
}
func (s *statusResponseWriter) WriteHeader(code int) {
s.status = code
s.ResponseWriter.WriteHeader(code)
}
type tracingOpName struct{}
// WithTracingOpName adds the op name to a context for use by the tracing library. It uses
// a pointer because it's called below in the stack and the only way to pass the info up
// is to have it a set a pointer. Even though it doesn't change the context we still have
// this return a context to maintain the illusion.
func WithTracingOpName(ctx context.Context, opName string) context.Context {
strPtr := ctx.Value(tracingOpName{}).(*string)
if strPtr != nil {
*strPtr = opName
}
return ctx
}
// TracingMiddleware creates a new span named after the URL path of the request.
// It places this span in the request context, for use by other handlers via opentracing.SpanFromContext()
// If a span exists in request headers, the span created by this middleware will be a child of that span.
func TracingMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Attempt to join a span by getting trace info from the headers.
// To start with use the URL as the opName since we haven't gotten to the router yet and
// the router knows about opNames
opName := r.URL.Path
var sp opentracing.Span
if sc, err := opentracing.GlobalTracer().
Extract(opentracing.HTTPHeaders,
opentracing.HTTPHeadersCarrier(r.Header)); err != nil {
sp = opentracing.StartSpan(opName)
} else {
sp = opentracing.StartSpan(opName, opentracing.ChildOf(sc))
}
defer sp.Finish()
// inject span ID into logs to aid in request debugging
t := make(map[string]string)
if err := sp.Tracer().Inject(sp.Context(), opentracing.TextMap,
opentracing.TextMapCarrier(t)); err == nil {
if spanid, ok := t["ot-tracer-spanid"]; ok {
logger.FromContext(r.Context()).AddContext("ot-tracer-spanid", spanid)
}
}
sp.LogEvent("request_received")
defer func() {
sp.LogEvent("request_finished")
}()
newCtx := opentracing.ContextWithSpan(r.Context(), sp)
// Use a string pointer so layers below can modify it
strPtr := ""
newCtx = context.WithValue(newCtx, tracingOpName{}, &strPtr)
srw := &statusResponseWriter{
status: 200,
ResponseWriter: w,
}
tags.HTTPMethod.Set(sp, r.Method)
tags.SpanKind.Set(sp, tags.SpanKindRPCServerEnum)
tags.HTTPUrl.Set(sp, r.URL.Path)
sp.LogFields(opentracinglog.String("url-query", r.URL.RawQuery))
defer func() {
tags.HTTPStatusCode.Set(sp, uint16(srw.status))
if srw.status >= 500 {
tags.Error.Set(sp, true)
}
// Now that we have the opName let's try setting it
opName, ok := newCtx.Value(tracingOpName{}).(*string)
if ok && opName != nil {
sp.SetOperationName(*opName)
}
}()
h.ServeHTTP(srw, r.WithContext(newCtx))
})
}
// VersionRange decides whether to accept a version.
type VersionRange func(version string) bool
// ClientVersionCheckMiddleware checks the client version.
func ClientVersionCheckMiddleware(h http.Handler, rng VersionRange) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
version := r.Header.Get("X-Client-Version")
logger.FromContext(r.Context()).AddContext("client-version", version)
if !rng(version) {
w.WriteHeader(400)
w.Write([]byte(fmt.Sprintf(`{"message": "client version '%s' not accepted, please upgrade"}`, version)))
return
}
h.ServeHTTP(w, r)
})
}