/
expectm.go
219 lines (188 loc) · 5.35 KB
/
expectm.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
package expectm
import (
"encoding/json"
"errors"
"fmt"
"reflect"
"strconv"
"time"
"github.com/tidwall/gjson"
mtime "github.com/monstercat/golib/time"
)
// This custom type is so you can't accidentally pass in the wrong
// map to a function that compares two mars
// EG: checkJSON(sentJSON, sentJSON) won't compile if the function
// expects ExpectedM as the 2nd parameter
type ExpectedM map[string]interface{}
func CheckJSONBytes(js []byte, expected *ExpectedM) error {
res := gjson.ParseBytes(js)
return CheckExpectedM(res, expected)
}
func CheckJSONString(js string, expected *ExpectedM) error {
res := gjson.Parse(js)
return CheckExpectedM(res, expected)
}
func CheckJSON(obj interface{}, expected *ExpectedM) error {
bytes, err := json.Marshal(obj)
if err != nil {
return err
}
return CheckJSONBytes(bytes, expected)
}
func MustInt(val interface{}) (int, error) {
switch v := val.(type) {
case float64:
return int(v), nil
case float32:
return int(v), nil
case int32:
return int(v), nil
case int64:
return int(v), nil
case int:
return v, nil
default:
return 0, errors.New("Length must be of type int")
}
}
func CheckGJSONLength(k string, expectedValue, actualValue interface{}) error {
var test int
var comparator rune
if str, ok := expectedValue.(string); ok {
var err error
test, err = strconv.Atoi(str[1:])
if err != nil {
return err
}
comparator = rune(str[0])
} else {
v, err := MustInt(expectedValue)
if err != nil {
return err
}
test = v
}
if actualValue == nil {
var hasError bool
if test > 0 {
// if test is greater than zero, there is an error if comparator is not <
hasError = comparator != '<'
} else if test < 0 {
hasError = comparator == '<'
}
if hasError {
return errors.New(fmt.Sprintf("for test `%s`, expect length: %s%d; got 0", string(comparator), k, test))
}
return nil
}
actVal := int(actualValue.(float64))
var hasError bool
switch comparator {
case '<':
if test <= actVal {
hasError = true
}
case '>':
if test >= actVal {
hasError = true
}
case '!':
if test == actVal {
hasError = true
}
default:
if test != actVal {
hasError = true
}
}
if hasError {
return errors.New(fmt.Sprintf(
"for test `%s`, expect length: %s%d; got %d\n",
k,
string(comparator),
test,
actVal,
))
}
return nil
}
func CheckExpectedM(result gjson.Result, expected *ExpectedM) error {
for k, expectedValue := range *expected {
actualValue := result.Get(k).Value()
// Special for "field.#" when checking length of array that was returned
if k[len(k)-1] == '#' {
if err := CheckGJSONLength(k, expectedValue, actualValue); err != nil {
return err
}
// Special for "field(#)" where field itself is a number value
// EG "total(#)": ">30"
} else if len(k) > 3 && k[len(k)-3:] == "(#)" {
key := k[:len(k)-3]
actualValue := result.Get(key).Value()
if err := CheckGJSONLength(key, expectedValue, actualValue); err != nil {
return err
}
} else if f, ok := expectedValue.(func(json interface{}) error); ok {
if err := f(actualValue); err != nil {
return errors.New(fmt.Sprintf("Error at \"%v\" %v", k, err))
}
} else {
if !reflect.DeepEqual(actualValue, expectedValue) {
msg := fmt.Sprintf("unexpected JSON value at \"%v\". \n Expected: \"%v\" \n Found: \"%v\"\n JSON: %v", k, expectedValue, actualValue, result)
// We don't really care about Go types for checking JSON
// This lets us do {"numberField": 0} without having to do float64(0)
if fmt.Sprintf("%v", actualValue) == fmt.Sprintf("%v", expectedValue) {
continue
}
return errors.New(msg)
}
}
}
return nil
}
func CheckDate(expectedStr string, format string) func(json interface{}) error {
return func(json interface{}) error {
if json == nil {
return errors.New(fmt.Sprintf("expected date %s but it was nil", expectedStr))
}
loc, _ := time.LoadLocation("Europe/London")
expectedDate, err := time.Parse(format, expectedStr)
if err != nil {
return err
}
expectedDate = expectedDate.In(loc)
foundDate, err := time.Parse(format, json.(string))
if err != nil {
return err
}
foundDate = foundDate.In(loc)
expected := expectedDate.Format(format)
found := foundDate.Format(format)
if expected == found {
return nil
}
return errors.New(fmt.Sprintf("expected date %s but got %s", expected, found))
}
}
// Returns a handler function that can be used in an ExpectedM object to compare a date
// value, represented as a string such as in JSON, against an actual date passed in
// The leeway duration allows for the dates to be off by that much time
// This is useful when you are comparing one time to one that will happen after some code runs
// For example comparing a PostedDate of a blog post that you are creating versus the one
// actually stored in the database. Without the leeway then if they were off by milliseconds
// they would not be equal.
// Example:
// test := {
// bodyShouldHave: ExpectedM{
// "created": CheckDateClose(time.Now(), time.Second),
// }
func CheckDateClose(target time.Time, leeway time.Duration) func(json interface{}) error {
return func(val interface{}) error {
if val == nil {
return errors.New(fmt.Sprintf("expected date %s +/- %s", target, leeway))
}
valS := val.(string)
_, err := mtime.ParseTimeCheckNear(valS, time.RFC3339, target, leeway)
return err
}
}