|
1 | | -// Package assert implements common assertions used in go-httbin's unit tests. |
| 1 | +// Package assert implements a set of basic test helpers. |
| 2 | +// |
| 3 | +// Lightly adapted from this blog post: https://antonz.org/do-not-testify/ |
2 | 4 | package assert |
3 | 5 |
|
4 | 6 | import ( |
| 7 | + "bytes" |
5 | 8 | "errors" |
6 | 9 | "fmt" |
7 | | - "net/http" |
8 | 10 | "reflect" |
9 | 11 | "strings" |
10 | 12 | "testing" |
11 | | - "time" |
12 | | - |
13 | | - "github.com/mccutchen/websocket/internal/testing/must" |
14 | 13 | ) |
15 | 14 |
|
16 | | -// Equal asserts that two values are equal. |
17 | | -func Equal[T comparable](t testing.TB, got, want T, msg string, arg ...any) { |
18 | | - t.Helper() |
19 | | - if got != want { |
20 | | - if msg == "" { |
21 | | - msg = "expected values to match" |
22 | | - } |
23 | | - msg = fmt.Sprintf(msg, arg...) |
24 | | - t.Fatalf("%s:\nwant: %v\n got: %v", msg, want, got) |
| 15 | +// Equal asserts that got is equal to want. |
| 16 | +func Equal[T any](tb testing.TB, got T, want T, customMsg ...any) { |
| 17 | + tb.Helper() |
| 18 | + if areEqual(got, want) { |
| 19 | + return |
25 | 20 | } |
| 21 | + msg := formatMsg("expected values to be equal", customMsg) |
| 22 | + tb.Errorf("%s:\ngot: %#v\nwant: %#v", msg, got, want) |
26 | 23 | } |
27 | 24 |
|
28 | | -// DeepEqual asserts that two values are deeply equal. |
29 | | -func DeepEqual[T any](t testing.TB, got, want T, msg string, arg ...any) { |
30 | | - t.Helper() |
31 | | - if !reflect.DeepEqual(got, want) { |
32 | | - if msg == "" { |
33 | | - msg = "expected values to match" |
34 | | - } |
35 | | - msg = fmt.Sprintf(msg, arg...) |
36 | | - t.Fatalf("%s:\nwant: %#v\n got: %#v", msg, want, got) |
| 25 | +// True asserts that got is true. |
| 26 | +func True(tb testing.TB, got bool, customMsg ...any) { |
| 27 | + tb.Helper() |
| 28 | + if !got { |
| 29 | + tb.Error(formatMsg("expected value to be true", customMsg)) |
37 | 30 | } |
38 | 31 | } |
39 | 32 |
|
40 | | -// NilError asserts that an error is nil. |
41 | | -func NilError(t testing.TB, err error) { |
42 | | - t.Helper() |
43 | | - if err != nil { |
44 | | - t.Fatalf("expected nil error, got %q (%T)", err, err) |
| 33 | +// Error asserts that got matches want, which may be an error, an error type, |
| 34 | +// an error string, or nil. If want is a string, it is considered a match if |
| 35 | +// it is ia substring of got's string value. |
| 36 | +func Error(tb testing.TB, got error, want any) { |
| 37 | + tb.Helper() |
| 38 | + |
| 39 | + if want != nil && got == nil { |
| 40 | + tb.Errorf("errors do not match:\ngot: <nil>\nwant: %v", want) |
| 41 | + return |
45 | 42 | } |
46 | | -} |
47 | 43 |
|
48 | | -// Error asserts that an error matches an expected error or any one of a list |
49 | | -// of expected errors. |
50 | | -func Error(t testing.TB, got, expected error, alternates ...error) { |
51 | | - t.Helper() |
52 | | - matched := false |
53 | | - wantAny := append([]error{expected}, alternates...) |
54 | | - for _, want := range wantAny { |
55 | | - if errorsMatch(t, got, want) { |
56 | | - matched = true |
57 | | - break |
| 44 | + switch w := want.(type) { |
| 45 | + case nil: |
| 46 | + NilError(tb, got) |
| 47 | + case error: |
| 48 | + if !errors.Is(got, w) { |
| 49 | + tb.Errorf("errors do not match:\ngot: %T(%v)\nwant: %T(%v)", got, got, w, w) |
58 | 50 | } |
59 | | - } |
60 | | - if !matched { |
61 | | - if len(wantAny) == 1 { |
62 | | - t.Fatalf("expected error %q, got %q (%T vs %T)", expected, got, expected, got) |
63 | | - } else { |
64 | | - t.Fatalf("expected one of %v, got %q (%T)", wantAny, got, got) |
| 51 | + case string: |
| 52 | + if !strings.Contains(got.Error(), w) { |
| 53 | + tb.Errorf("error string does not match:\ngot: %q\nwant: %q", got.Error(), w) |
| 54 | + } |
| 55 | + case reflect.Type: |
| 56 | + target := reflect.New(w).Interface() |
| 57 | + if !errors.As(got, target) { |
| 58 | + tb.Errorf("error type does not match:\ngot: %T\nwant: %s", got, w) |
65 | 59 | } |
66 | | - } |
67 | | -} |
68 | | - |
69 | | -func errorsMatch(t testing.TB, got, expected error) bool { |
70 | | - t.Helper() |
71 | | - switch { |
72 | | - case got == expected: |
73 | | - return true |
74 | | - case errors.Is(got, expected): |
75 | | - return true |
76 | | - case got != nil && expected != nil: |
77 | | - return got.Error() == expected.Error() |
78 | 60 | default: |
79 | | - return false |
| 61 | + tb.Errorf("unsupported want type: %T", want) |
80 | 62 | } |
81 | 63 | } |
82 | 64 |
|
83 | | -// StatusCode asserts that a response has a specific status code. |
84 | | -func StatusCode(t testing.TB, resp *http.Response, code int) { |
85 | | - t.Helper() |
86 | | - if resp.StatusCode != code { |
87 | | - t.Fatalf("expected status code %d, got %d", code, resp.StatusCode) |
88 | | - } |
89 | | - if resp.StatusCode >= 400 { |
90 | | - // Ensure our error responses are never served as HTML, so that we do |
91 | | - // not need to worry about XSS or other attacks in error responses. |
92 | | - if ct := resp.Header.Get("Content-Type"); !isSafeContentType(ct) { |
93 | | - t.Errorf("HTTP %s error served with dangerous content type: %s", resp.Status, ct) |
94 | | - } |
| 65 | +// NilError asserts that got is nil. |
| 66 | +func NilError(tb testing.TB, got error) { |
| 67 | + tb.Helper() |
| 68 | + if got != nil { |
| 69 | + tb.Fatalf("expected nil error, got %q (%T)", got, got) |
95 | 70 | } |
96 | 71 | } |
97 | 72 |
|
98 | | -func isSafeContentType(ct string) bool { |
99 | | - return strings.HasPrefix(ct, "application/json") || strings.HasPrefix(ct, "text/plain") || strings.HasPrefix(ct, "application/octet-stream") |
| 73 | +type equaler[T any] interface { |
| 74 | + Equal(T) bool |
100 | 75 | } |
101 | 76 |
|
102 | | -// Header asserts that a header key has a specific value in a response. |
103 | | -func Header(t testing.TB, resp *http.Response, key, want string) { |
104 | | - t.Helper() |
105 | | - got := resp.Header.Get(key) |
106 | | - if want != got { |
107 | | - t.Fatalf("expected header %s=%#v, got %#v", key, want, got) |
| 77 | +func areEqual[T any](a, b T) bool { |
| 78 | + if isNil(a) && isNil(b) { |
| 79 | + return true |
108 | 80 | } |
109 | | -} |
110 | | - |
111 | | -// ContentType asserts that a response has a specific Content-Type header |
112 | | -// value. |
113 | | -func ContentType(t testing.TB, resp *http.Response, contentType string) { |
114 | | - t.Helper() |
115 | | - Header(t, resp, "Content-Type", contentType) |
116 | | -} |
117 | | - |
118 | | -// Contains asserts that needle is found in the given string. |
119 | | -func Contains(t testing.TB, s string, needle string, description string) { |
120 | | - t.Helper() |
121 | | - if !strings.Contains(s, needle) { |
122 | | - t.Fatalf("expected string %q in %s %q", needle, description, s) |
| 81 | + // special case types with an Equal method |
| 82 | + if eq, ok := any(a).(equaler[T]); ok { |
| 83 | + return eq.Equal(b) |
123 | 84 | } |
124 | | -} |
125 | | - |
126 | | -// BodyContains asserts that a response body contains a specific substring. |
127 | | -func BodyContains(t testing.TB, resp *http.Response, needle string) { |
128 | | - t.Helper() |
129 | | - body := must.ReadAll(t, resp.Body) |
130 | | - Contains(t, body, needle, "body") |
131 | | -} |
132 | | - |
133 | | -// BodyEquals asserts that a response body is equal to a specific string. |
134 | | -func BodyEquals(t testing.TB, resp *http.Response, want string) { |
135 | | - t.Helper() |
136 | | - got := must.ReadAll(t, resp.Body) |
137 | | - Equal(t, got, want, "incorrect response body") |
138 | | -} |
139 | | - |
140 | | -// BodySize asserts that a response body is a specific size. |
141 | | -func BodySize(t testing.TB, resp *http.Response, want int) { |
142 | | - t.Helper() |
143 | | - got := must.ReadAll(t, resp.Body) |
144 | | - Equal(t, len(got), want, "incorrect response body size") |
145 | | -} |
146 | | - |
147 | | -// DurationRange asserts that a duration is within a specific range. |
148 | | -func DurationRange(t testing.TB, got, minVal, maxVal time.Duration) { |
149 | | - t.Helper() |
150 | | - if got < minVal || got > maxVal { |
151 | | - t.Fatalf("expected duration between %s and %s, got %s", minVal, maxVal, got) |
| 85 | + // special case byte slices |
| 86 | + if aBytes, ok := any(a).([]byte); ok { |
| 87 | + bBytes := any(b).([]byte) |
| 88 | + return bytes.Equal(aBytes, bBytes) |
152 | 89 | } |
| 90 | + return reflect.DeepEqual(a, b) |
153 | 91 | } |
154 | 92 |
|
155 | | -type number interface { |
156 | | - ~int64 | ~float64 |
| 93 | +func isNil(v any) bool { |
| 94 | + if v == nil { |
| 95 | + return true |
| 96 | + } |
| 97 | + // A non-nil interface can still hold a nil value, so we check the |
| 98 | + // underlying value. |
| 99 | + rv := reflect.ValueOf(v) |
| 100 | + switch rv.Kind() { |
| 101 | + case reflect.Chan, |
| 102 | + reflect.Func, |
| 103 | + reflect.Interface, |
| 104 | + reflect.Map, |
| 105 | + reflect.Pointer, |
| 106 | + reflect.Slice, |
| 107 | + reflect.UnsafePointer: |
| 108 | + return rv.IsNil() |
| 109 | + default: |
| 110 | + return false |
| 111 | + } |
157 | 112 | } |
158 | 113 |
|
159 | | -// RoughlyEqual asserts that a numeric value is within a certain tolerance. |
160 | | -func RoughlyEqual[T number](t testing.TB, got, want T, epsilon T) { |
161 | | - t.Helper() |
162 | | - if got < want-epsilon || got > want+epsilon { |
163 | | - t.Fatalf("expected value between %v and %v, got %v", want-epsilon, want+epsilon, got) |
| 114 | +func formatMsg(defaultMsg string, customMsg []any) string { |
| 115 | + msg := defaultMsg |
| 116 | + if len(customMsg) > 0 { |
| 117 | + tmpl, ok := customMsg[0].(string) |
| 118 | + if !ok { |
| 119 | + tmpl = fmt.Sprintf("%v", customMsg[0]) |
| 120 | + } |
| 121 | + msg = fmt.Sprintf(tmpl, customMsg[1:]...) |
164 | 122 | } |
| 123 | + return msg |
165 | 124 | } |
0 commit comments