/
middleware.go
84 lines (67 loc) · 2.03 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package main
import (
"context"
"log/slog"
"net/http"
"time"
"github.com/google/uuid"
"github.com/lstoll/cookiesession"
)
type (
requestIDCtxKey struct{}
)
// baseMiddleware should wrap all requests to the service
func baseMiddleware(wrapped http.Handler, wnsessmgr *cookiesession.Manager[webSession, *webSession]) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
st := time.Now()
rid := r.Header.Get("Fly-Request-ID")
if rid == "" {
rid = uuid.NewString()
}
r = r.WithContext(context.WithValue(r.Context(), requestIDCtxKey{}, rid))
logger := slog.With(slog.String("request_id", rid))
ww := &wrapResponseWriter{
ResponseWriter: w,
}
wnsessmgr.Wrap(
wrapped,
).ServeHTTP(ww, r)
if ww.st == 0 {
// WriteHeader is not guaranteed to be called, so we need to set a
// default.
ww.st = http.StatusOK
}
remoteAddr := r.RemoteAddr
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// append the real remote address to the xff, in case it is spoofed
// etc.
remoteAddr = xff + ", " + remoteAddr
}
logger.Info("http request",
slog.String("method", r.Method),
slog.String("path", r.URL.Path),
slog.Int("status", ww.st),
slog.String("remote-address", remoteAddr),
slog.Duration("duration", time.Since(st)))
})
}
// httpErrHandler renders out nicer errors
type httpErrHandler struct{}
func (h *httpErrHandler) Error(w http.ResponseWriter, r *http.Request, err error) {
slog.ErrorContext(r.Context(), "http error", logErr(err))
http.Error(w, "Internal Error", http.StatusInternalServerError)
}
func (h *httpErrHandler) BadRequest(w http.ResponseWriter, _ *http.Request, message string) {
http.Error(w, message, http.StatusBadRequest)
}
func (h *httpErrHandler) Forbidden(w http.ResponseWriter, _ *http.Request, message string) {
http.Error(w, message, http.StatusForbidden)
}
type wrapResponseWriter struct {
http.ResponseWriter
st int
}
func (w *wrapResponseWriter) WriteHeader(code int) {
w.st = code
w.ResponseWriter.WriteHeader(code)
}