/
wrap_response_writer.go
153 lines (140 loc) · 5.56 KB
/
wrap_response_writer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package middleware
import (
"io"
"net/http"
"sync"
"github.com/felixge/httpsnoop"
)
var (
// CollectHeaderHook capture the response code into a WriterMetricsCollector and forward the execution to the main
// WriteHeader method. It's the default hook for WriteHeader when WrapWriter is used.
CollectHeaderHook = func(collector *WriterMetricsCollector) func(next httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
return func(next httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
return func(code int) {
next(code)
collector.locker.Lock()
defer collector.locker.Unlock()
if !collector.wroteHeader {
collector.Code = code
collector.wroteHeader = true
}
}
}
}
// HijackWriteHeaderHook capture the response code into a WriterMetricsCollector. Warning it'll not forward to the
// main WriteHeader method execution.
HijackWriteHeaderHook = func(collector *WriterMetricsCollector) func(next httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
return func(next httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
return func(code int) {
collector.locker.Lock()
defer collector.locker.Unlock()
if !collector.wroteHeader {
collector.Code = code
collector.wroteHeader = true
}
}
}
}
// CollectBytesHook capture the amount of bytes into a WriterMetricsCollector and forward the execution to the main
// Write method. It's the default hook for Write when WrapWriter is used.
CollectBytesHook = func(collector *WriterMetricsCollector) func(next httpsnoop.WriteFunc) httpsnoop.WriteFunc {
return func(next httpsnoop.WriteFunc) httpsnoop.WriteFunc {
return func(p []byte) (int, error) {
n, err := next(p)
collector.locker.Lock()
defer collector.locker.Unlock()
collector.Bytes += int64(n)
collector.wroteHeader = true
return n, err
}
}
}
)
// CopyWriterHook makes a copy of the bytes into a io.Writer and forward the execution to the main Write method.
// It'll save the amount of bytes into a WriterMetricsCollector.
func CopyWriterHook(w io.Writer) func(collector *WriterMetricsCollector) func(next httpsnoop.WriteFunc) httpsnoop.WriteFunc {
return func(collector *WriterMetricsCollector) func(next httpsnoop.WriteFunc) httpsnoop.WriteFunc {
return func(next httpsnoop.WriteFunc) httpsnoop.WriteFunc {
return func(p []byte) (int, error) {
n, err := next(p)
collector.locker.Lock()
defer collector.locker.Unlock()
w.Write(p)
collector.Bytes += int64(n)
collector.wroteHeader = true
return n, err
}
}
}
}
// HijackWriteHook write the response bytes into a io.Writer. It'll save the amount of bytes into
// a WriterMetricsCollector. Warning it'll not forward to the main Write method execution.
func HijackWriteHook(w io.Writer) func(collector *WriterMetricsCollector) func(next httpsnoop.WriteFunc) httpsnoop.WriteFunc {
return func(collector *WriterMetricsCollector) func(next httpsnoop.WriteFunc) httpsnoop.WriteFunc {
return func(next httpsnoop.WriteFunc) httpsnoop.WriteFunc {
return func(p []byte) (int, error) {
collector.locker.Lock()
defer collector.locker.Unlock()
n, err := w.Write(p)
collector.Bytes += int64(n)
collector.wroteHeader = true
return n, err
}
}
}
}
// WriterMetricsCollector holds metrics captured from writer.
type WriterMetricsCollector struct {
// Code is the first http response code passed to the WriteHeader func of
// the ResponseWriter. If no such call is made, a default code of 200 is
// assumed instead.
Code int
// bytes is the number of bytes successfully written by the Write or
// ReadFrom function of the ResponseWriter. ResponseWriters may also write
// data to their underlying connection directly (e.g. headers), but those
// are not tracked. Therefor the number of Written bytes will usually match
// the size of the response body.
Bytes int64
wroteHeader bool
locker sync.Mutex
}
// WriteHeaderHook define the method interceptor when WriteHeader is called.
func WriteHeaderHook(hook func(*WriterMetricsCollector) func(next httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc) func(*wrapWriterOpts) {
return func(opts *wrapWriterOpts) {
opts.writeHeaderHook = hook
}
}
// WriteHook define the method interceptor when Write is called.
func WriteHook(hook func(*WriterMetricsCollector) func(next httpsnoop.WriteFunc) httpsnoop.WriteFunc) func(*wrapWriterOpts) {
return func(opts *wrapWriterOpts) {
opts.writeHook = hook
}
}
type wrapWriterOpts struct {
writeHeaderHook func(*WriterMetricsCollector) func(next httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc
writeHook func(*WriterMetricsCollector) func(next httpsnoop.WriteFunc) httpsnoop.WriteFunc
}
func wrapWriterSetupCfg(opts ...func(*wrapWriterOpts)) *wrapWriterOpts {
cfg := &wrapWriterOpts{
writeHeaderHook: CollectHeaderHook,
writeHook: CollectBytesHook,
}
for _, opt := range opts {
opt(cfg)
}
return cfg
}
// WrapResponseWriter defines a set of method interceptors for methods included in
// http.ResponseWriter as well as some others. You can think of them as
// middleware for the function calls they target.
// It response with a wrapped ResponseWriter and WriterMetricsCollector with some useful metrics.
func WrapResponseWriter(w http.ResponseWriter, opts ...func(*wrapWriterOpts)) (*WriterMetricsCollector, http.ResponseWriter) {
cfg := wrapWriterSetupCfg(opts...)
collector := &WriterMetricsCollector{Code: http.StatusOK}
hooks := httpsnoop.Hooks{
WriteHeader: cfg.writeHeaderHook(collector),
Write: cfg.writeHook(collector),
}
snoop := httpsnoop.Wrap(w, hooks)
return collector, snoop
}