forked from xmidt-org/webpa-common
-
Notifications
You must be signed in to change notification settings - Fork 0
/
transactor.go
200 lines (167 loc) · 5.99 KB
/
transactor.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
package xhttptest
import (
"bytes"
"context"
"fmt"
"io/ioutil"
"net/http"
"net/textproto"
"net/url"
"strings"
"github.com/stretchr/testify/mock"
)
// ExpectedResponse is a tuple of the expected return values from transactor.Do. This struct provides
// a simple unit to build table-driven tests from.
type ExpectedResponse struct {
StatusCode int
Body []byte
Header http.Header
Err error
}
// TransactCall is a stretchr mock Call with some extra behavior to make mocking out HTTP client behavior easier
type TransactCall struct {
*mock.Call
}
// RespondWith creates an (*http.Response, error) tuple from an ExpectedResponse. If the Err field is nil,
// an *http.Response is created from the other fields. If the Err field is not nil, a nil *http.Response is used.
func (dc *TransactCall) RespondWith(er ExpectedResponse) *TransactCall {
var response *http.Response
if er.Err == nil {
response = NewResponse(er.StatusCode, er.Body)
for key, values := range er.Header {
response.Header[key] = values
}
}
return dc.Respond(response, er.Err)
}
// Respond is a convenience for setting a Return(response, err)
func (dc *TransactCall) Respond(response *http.Response, err error) *TransactCall {
dc.Return(response, err)
return dc
}
// MockTransactor is a stretchr mock for the Do method of an HTTP client or round tripper.
// This mock extends the behavior of a stretchr mock in a few ways that make clientside
// HTTP behavior easier to mock.
//
// This type implements the http.RoundTripper interface, and provides a Do method that can
// implement a subset interface of http.Client.
type MockTransactor struct {
mock.Mock
}
// Do is a mocked HTTP transaction call. Use On or OnRequest to setup behaviors for this method.
func (mt *MockTransactor) Do(request *http.Request) (*http.Response, error) {
// HACK: Because of the way Called works, there is a race condition involving the http.Request's Context object.
// Called performs a printf, which bypasses the context's mutex to produce the string. We have to replace
// the context with a known, immutable value so that no race conditions occur.
arguments := mt.Called(request.WithContext(context.Background()))
response, _ := arguments.Get(0).(*http.Response)
return response, arguments.Error(1)
}
// RoundTrip is a mocked HTTP transaction call. Use On or OnRoundTrip to setup behaviors for this method.
func (mt *MockTransactor) RoundTrip(request *http.Request) (*http.Response, error) {
// HACK: Because of the way Called works, there is a race condition involving the http.Request's Context object.
// Called performs a printf, which bypasses the context's mutex to produce the string. We have to replace
// the context with a known, immutable value so that no race conditions occur.
arguments := mt.Called(request.WithContext(context.Background()))
response, _ := arguments.Get(0).(*http.Response)
return response, arguments.Error(1)
}
// OnDo sets an On("Do", ...) with the given matchers for a request. The returned Call has some
// augmented behavior for setting responses.
func (mt *MockTransactor) OnDo(matchers ...func(*http.Request) bool) *TransactCall {
call := mt.On("Do", mock.MatchedBy(func(candidate *http.Request) bool {
for _, matcher := range matchers {
if !matcher(candidate) {
return false
}
}
return true
}))
return &TransactCall{call}
}
// OnRoundTrip sets an On("Do", ...) with the given matchers for a request. The returned Call has some
// augmented behavior for setting responses.
func (mt *MockTransactor) OnRoundTrip(matchers ...func(*http.Request) bool) *TransactCall {
call := mt.On("RoundTrip", mock.MatchedBy(func(candidate *http.Request) bool {
for _, matcher := range matchers {
if !matcher(candidate) {
return false
}
}
return true
}))
return &TransactCall{call}
}
// MatchMethod returns a request matcher that verifies each request has a specific method
func MatchMethod(expected string) func(*http.Request) bool {
return func(r *http.Request) bool {
return strings.EqualFold(expected, r.Method)
}
}
// MatchURL returns a request matcher that verifies each request has an exact URL.
func MatchURL(expected *url.URL) func(*http.Request) bool {
return func(r *http.Request) bool {
if expected == r.URL {
return true
}
if expected == nil || r.URL == nil {
return false
}
return *expected == *r.URL
}
}
// MatchURLString returns a request matcher that verifies the request's URL translates to the given string.
func MatchURLString(expected string) func(*http.Request) bool {
return func(r *http.Request) bool {
if r.URL == nil {
return len(expected) == 0
}
return expected == r.URL.String()
}
}
// MatchBody returns a request matcher that verifies each request has an exact body.
// The body is consumed, but then replaced so that downstream code can still access the body.
func MatchBody(expected []byte) func(*http.Request) bool {
return func(r *http.Request) bool {
if r.Body == nil {
return len(expected) == 0
}
actual, err := ioutil.ReadAll(r.Body)
if err != nil {
panic(fmt.Errorf("Error while read request body for matching: %s", err))
}
// replace the body so other test code can reread it
r.Body = ioutil.NopCloser(bytes.NewReader(actual))
if len(actual) != len(expected) {
return false
}
for i := 0; i < len(actual); i++ {
if actual[i] != expected[i] {
return false
}
}
return true
}
}
func MatchBodyString(expected string) func(*http.Request) bool {
return MatchBody([]byte(expected))
}
// MatchHeader returns a request matcher that matches against a request header
func MatchHeader(name, expected string) func(*http.Request) bool {
return func(r *http.Request) bool {
// allow for requests created by test code that instantiates the request directly
if r.Header == nil {
return false
}
values := r.Header[textproto.CanonicalMIMEHeaderKey(name)]
if len(values) == 0 {
return len(expected) == 0
}
for _, actual := range values {
if actual == expected {
return true
}
}
return false
}
}