forked from earthboundkid/requests
/
recorder.go
123 lines (116 loc) · 3.4 KB
/
recorder.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
package requests
import (
"bufio"
"bytes"
"crypto/md5"
"encoding/base64"
"errors"
"fmt"
"io/fs"
"net/http"
"net/http/httputil"
"os"
"path/filepath"
)
// Record returns an http.RoundTripper that writes out its
// requests and their responses to text files in basepath.
// Requests are named according to a hash of their contents.
// Responses are named according to the request that made them.
func Record(rt http.RoundTripper, basepath string) Transport {
if rt == nil {
rt = http.DefaultTransport
}
return RoundTripFunc(func(req *http.Request) (res *http.Response, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("problem while recording transport: %w", err)
}
}()
_ = os.MkdirAll(basepath, 0755)
b, err := httputil.DumpRequest(req, true)
if err != nil {
return nil, err
}
reqname, resname := buildName(b)
name := filepath.Join(basepath, reqname)
if err = os.WriteFile(name, b, 0644); err != nil {
return nil, err
}
if res, err = rt.RoundTrip(req); err != nil {
return
}
b, err = httputil.DumpResponse(res, true)
if err != nil {
return nil, err
}
name = filepath.Join(basepath, resname)
if err = os.WriteFile(name, b, 0644); err != nil {
return nil, err
}
return
})
}
// Replay returns an http.RoundTripper that reads its
// responses from text files in basepath.
// Responses are looked up according to a hash of the request.
func Replay(basepath string) Transport {
return ReplayFS(os.DirFS(basepath))
}
var errNotFound = errors.New("response not found")
// ReplayFS returns an http.RoundTripper that reads its
// responses from text files in the fs.FS.
// Responses are looked up according to a hash of the request.
// Response file names may optionally be prefixed with comments for better human organization.
func ReplayFS(fsys fs.FS) Transport {
return RoundTripFunc(func(req *http.Request) (res *http.Response, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("problem while replaying transport: %w", err)
}
}()
b, err := httputil.DumpRequest(req, true)
if err != nil {
return nil, err
}
_, name := buildName(b)
glob := "*" + name
matches, err := fs.Glob(fsys, glob)
if err != nil {
return nil, err
}
if len(matches) == 0 {
return nil, fmt.Errorf("%w: no replay file matches %q", errNotFound, glob)
}
if len(matches) > 1 {
return nil, fmt.Errorf("ambiguous response: multiple replay files match %q", glob)
}
b, err = fs.ReadFile(fsys, matches[0])
if err != nil {
return nil, err
}
r := bufio.NewReader(bytes.NewReader(b))
return http.ReadResponse(r, req)
})
}
func buildName(b []byte) (reqname, resname string) {
h := md5.New()
h.Write(b)
s := base64.URLEncoding.EncodeToString(h.Sum(nil))
return s[:8] + ".req.txt", s[:8] + ".res.txt"
}
// Caching returns an http.RoundTripper that attempts to read its
// responses from text files in basepath. If the response is absent,
// it caches the result of issuing the request with rt in basepath.
// Requests are named according to a hash of their contents.
// Responses are named according to the request that made them.
func Caching(rt http.RoundTripper, basepath string) Transport {
replay := Replay(basepath).RoundTrip
record := Record(rt, basepath).RoundTrip
return RoundTripFunc(func(req *http.Request) (res *http.Response, err error) {
res, err = replay(req)
if errors.Is(err, errNotFound) {
res, err = record(req)
}
return
})
}