-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
replay.go
217 lines (203 loc) · 5.13 KB
/
replay.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"log"
"net/http"
"reflect"
"sync"
"github.com/google/martian/martianlog"
)
// ForReplaying returns a Proxy configured to replay.
func ForReplaying(filename string, port int) (*Proxy, error) {
p, err := newProxy(filename)
if err != nil {
return nil, err
}
lg, err := readLog(filename)
if err != nil {
return nil, err
}
calls, err := constructCalls(lg)
if err != nil {
return nil, err
}
p.Initial = lg.Initial
p.mproxy.SetRoundTripper(&replayRoundTripper{
calls: calls,
ignoreHeaders: p.ignoreHeaders,
conv: lg.Converter,
})
// Debug logging.
// TODO(jba): factor out from here and ForRecording.
logger := martianlog.NewLogger()
logger.SetDecode(true)
p.mproxy.SetRequestModifier(logger)
p.mproxy.SetResponseModifier(logger)
if err := p.start(port); err != nil {
return nil, err
}
return p, nil
}
func readLog(filename string) (*Log, error) {
bytes, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
var lg Log
if err := json.Unmarshal(bytes, &lg); err != nil {
return nil, fmt.Errorf("%s: %v", filename, err)
}
if lg.Version != LogVersion {
return nil, fmt.Errorf(
"httpreplay: read log version %s but current version is %s; re-record the log",
lg.Version, LogVersion)
}
return &lg, nil
}
// A call is an HTTP request and its matching response.
type call struct {
req *Request
res *Response
}
func constructCalls(lg *Log) ([]*call, error) {
ignoreIDs := map[string]bool{} // IDs of requests to ignore
callsByID := map[string]*call{}
var calls []*call
for _, e := range lg.Entries {
if ignoreIDs[e.ID] {
continue
}
c, ok := callsByID[e.ID]
switch {
case !ok:
if e.Request == nil {
return nil, fmt.Errorf("first entry for ID %s does not have a request", e.ID)
}
if e.Request.Method == "CONNECT" {
// Ignore CONNECT methods.
ignoreIDs[e.ID] = true
} else {
c := &call{e.Request, e.Response}
calls = append(calls, c)
callsByID[e.ID] = c
}
case e.Request != nil:
if e.Response != nil {
return nil, errors.New("entry has both request and response")
}
c.req = e.Request
case e.Response != nil:
c.res = e.Response
default:
return nil, errors.New("entry has neither request nor response")
}
}
for _, c := range calls {
if c.req == nil || c.res == nil {
return nil, fmt.Errorf("missing request or response: %+v", c)
}
}
return calls, nil
}
type replayRoundTripper struct {
mu sync.Mutex
calls []*call
ignoreHeaders map[string]bool
conv *Converter
}
func (r *replayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if req.Body != nil {
defer req.Body.Close()
}
creq, err := r.conv.convertRequest(req)
if err != nil {
return nil, err
}
r.mu.Lock()
defer r.mu.Unlock()
for i, call := range r.calls {
if call == nil {
continue
}
if requestsMatch(creq, call.req, r.ignoreHeaders) {
r.calls[i] = nil // nil out this call so we don't reuse it
return toHTTPResponse(call.res, req), nil
}
}
return nil, fmt.Errorf("no matching request for %+v", req)
}
// Report whether the incoming request in matches the candidate request cand.
func requestsMatch(in, cand *Request, ignoreHeaders map[string]bool) bool {
if in.Method != cand.Method {
return false
}
if in.URL != cand.URL {
return false
}
if in.MediaType != cand.MediaType {
return false
}
if len(in.BodyParts) != len(cand.BodyParts) {
return false
}
for i, p1 := range in.BodyParts {
if !bytes.Equal(p1, cand.BodyParts[i]) {
return false
}
}
// Check headers last. See DebugHeaders.
return headersMatch(in.Header, cand.Header, ignoreHeaders)
}
// DebugHeaders helps to determine whether a header should be ignored.
// When true, if requests have the same method, URL and body but differ
// in a header, the first mismatched header is logged.
var DebugHeaders = false
func headersMatch(in, cand http.Header, ignores map[string]bool) bool {
for k1, v1 := range in {
if ignores[k1] {
continue
}
v2 := cand[k1]
if v2 == nil {
if DebugHeaders {
log.Printf("header %s: present in incoming request but not candidate", k1)
}
return false
}
if !reflect.DeepEqual(v1, v2) {
if DebugHeaders {
log.Printf("header %s: incoming %v, candidate %v", k1, v1, v2)
}
return false
}
}
for k2 := range cand {
if ignores[k2] {
continue
}
if in[k2] == nil {
if DebugHeaders {
log.Printf("header %s: not in incoming request but present in candidate", k2)
}
return false
}
}
return true
}