/
aws_test_helpers.go
128 lines (107 loc) · 3.17 KB
/
aws_test_helpers.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package aws
import (
"bytes"
"fmt"
"log"
"net/http"
"net/http/httptest"
"strings"
awsSDK "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/autoscaling"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/elb"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/sts"
)
type MockedAWSInput struct {
Region string
AutoscalingSess *session.Session
Ec2Sess *session.Session
ElbSess *session.Session
StsSess *session.Session
S3Sess *session.Session
}
func MockedAWS(input *MockedAWSInput) *AWS {
a := &AWS{
Region: awsSDK.String(input.Region),
}
if input.AutoscalingSess != nil {
a.autoscalingConn = autoscaling.New(input.AutoscalingSess)
}
if input.Ec2Sess != nil {
a.ec2Conn = ec2.New(input.Ec2Sess)
}
if input.ElbSess != nil {
a.elbConn = elb.New(input.ElbSess)
}
if input.StsSess != nil {
a.stsConn = sts.New(input.StsSess)
}
if input.S3Sess != nil {
a.s3Conn = s3.New(input.S3Sess)
}
return a
}
func GetMockedAwsSession(apiRoutes []*MockRoute, region string) (*session.Session, func()) {
url, closeFunc := mockServer(apiRoutes)
sess := session.New(&awsSDK.Config{
Region: awsSDK.String(region),
Credentials: credentials.NewStaticCredentials("dummy", "dummy", ""),
})
return sess.Copy(&awsSDK.Config{
Endpoint: awsSDK.String(url),
S3ForcePathStyle: awsSDK.Bool(true), // This makes mocking easier
}), closeFunc
}
func mockServer(routes []*MockRoute) (string, func()) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Some parts of AWS SDK (S3) make RequestURI full URL (strangely)
uri := strings.TrimPrefix(r.RequestURI, "http://"+r.Host)
log.Printf("[DEBUG] Mocked server received request to %q", uri)
buf := new(bytes.Buffer)
buf.ReadFrom(r.Body)
reqBody := buf.String()
log.Printf("[DEBUG] Mocked server received body: %q", reqBody)
mr, err := lookupMockRoute(routes, uri, reqBody, r)
if err != nil {
w.WriteHeader(400)
log.Printf("[DEBUG] Responding HTTP 400: %s, known: %#v", err, routes)
fmt.Fprintln(w, "<ErrorResponse><Error><Code>UnknownRequest</Code></Error></ErrorResponse>")
return
}
resp := mr.Response
w.WriteHeader(resp.Code)
for k, v := range resp.HeaderMap {
w.Header().Set(k, v)
}
if resp.Body != "" {
fmt.Fprintln(w, resp.Body)
}
}))
log.Printf("[DEBUG] Created new mock server: %s", ts.URL)
return ts.URL, ts.Close
}
func lookupMockRoute(routes []*MockRoute, uri, body string, req *http.Request) (*MockRoute, error) {
for _, route := range routes {
r := *route
if r.ExpectedURI == uri && r.ExpectedRequestBody == body &&
(r.ExpectedMethod == "" || r.ExpectedMethod == req.Method) {
log.Printf("[DEBUG] Mocked server matched...")
return route, nil
}
}
return nil, fmt.Errorf("Mock route not found")
}
type MockRoute struct {
ExpectedURI string
ExpectedMethod string
ExpectedRequestBody string
Response MockResponse
}
type MockResponse struct {
Code int
HeaderMap map[string]string
Body string
}