-
Notifications
You must be signed in to change notification settings - Fork 86
/
middleware.go
148 lines (125 loc) · 4.12 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package route
import (
"context"
"errors"
"fmt"
"math/rand"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/honeycombio/refinery/types"
)
// for generating request IDs
func init() {
rand.Seed(time.Now().UnixNano())
}
func (r *Router) queryTokenChecker(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
requiredToken := r.Config.GetQueryAuthToken()
if requiredToken == "" {
err := fmt.Errorf("/query endpoint is not authorized for use (specify QueryAuthToken in config)")
r.handlerReturnWithError(w, ErrAuthNeeded, err)
}
token := req.Header.Get(types.QueryTokenHeader)
if token == requiredToken {
// if they're equal (including both blank) we're good
next.ServeHTTP(w, req)
return
}
err := fmt.Errorf("token %s found in %s not authorized for query", token, types.QueryTokenHeader)
r.handlerReturnWithError(w, ErrAuthNeeded, err)
})
}
func (r *Router) apiKeyChecker(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
apiKey := req.Header.Get(types.APIKeyHeader)
if apiKey == "" {
apiKey = req.Header.Get(types.APIKeyHeaderShort)
}
if apiKey == "" {
err := errors.New("no " + types.APIKeyHeader + " header found from within authing middleware")
r.handlerReturnWithError(w, ErrAuthNeeded, err)
return
}
allowedKeys, err := r.Config.GetAPIKeys()
if err != nil {
r.handlerReturnWithError(w, ErrConfigReadFailed, err)
return
}
for _, key := range allowedKeys {
if key == "*" {
// all keys are allowed, it's all good
next.ServeHTTP(w, req)
return
}
if apiKey == key {
// we're in the allowlist, it's all good
next.ServeHTTP(w, req)
return
}
}
err = fmt.Errorf("api key %s not found in list of authed keys", apiKey)
r.handlerReturnWithError(w, ErrAuthNeeded, err)
})
}
type statusRecorder struct {
http.ResponseWriter
status int
}
func (rec *statusRecorder) WriteHeader(code int) {
rec.status = code
rec.ResponseWriter.WriteHeader(code)
}
// panicCatcher recovers any panics, sets a 500, and returns an obvious error
func (r *Router) panicCatcher(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
defer func() {
if rcvr := recover(); rcvr != nil {
err, ok := rcvr.(error)
if !ok {
err = fmt.Errorf("caught panic: %v", rcvr)
}
r.handlerReturnWithError(w, ErrCaughtPanic, err)
}
}()
next.ServeHTTP(w, req)
})
}
// requestLogger logs one line debug per request that comes through Refinery
func (r *Router) requestLogger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
arrivalTime := time.Now()
remoteIP := req.RemoteAddr
url := req.URL.String()
method := req.Method
route := mux.CurrentRoute(req)
// generate a request ID and put it in the context for logging
reqID := randStringBytes(8)
req = req.WithContext(context.WithValue(req.Context(), types.RequestIDContextKey{}, reqID))
// go ahead and process the request
wrapped := statusRecorder{w, 200}
next.ServeHTTP(&wrapped, req)
// calculate duration
dur := float64(time.Since(arrivalTime)) / float64(time.Millisecond)
// log that we did so TODO better formatted http log line
r.Logger.Debug().Logf("handled %s request %s %s %s %s %f %d", route.GetName(), reqID, remoteIP, method, url, dur, wrapped.status)
})
}
func (r *Router) setResponseHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// Set content type header early so it's before any calls to WriteHeader
w.Header().Set("Content-Type", "application/json")
// Allow cross-origin API operation from browser js
w.Header().Set("Access-Control-Allow-Origin", "*")
next.ServeHTTP(w, req)
})
}
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
// randStringBytes makes us a request ID for logging.
func randStringBytes(n int) string {
b := make([]byte, n)
for i := range b {
b[i] = letterBytes[rand.Int63()%int64(len(letterBytes))]
}
return string(b)
}