Skip to content

Commit 91b0005

Browse files
authored
test: simplify test helpers (#70)
Simplify our internal `assert` package, now lightly adapted from this blog post[1]. Primary benefits: - Unified `assert.Equal()` and `assert.DeepEqual()` implementation - Better `assert.Error()` implementation, which handles error types using `errors.As()` - Plus, just deleting a bunch of unused test helper code [1]: https://antonz.org/do-not-testify/
1 parent d008872 commit 91b0005

File tree

6 files changed

+126
-212
lines changed

6 files changed

+126
-212
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
COVERAGE_PATH ?= coverage.out
33
COVERAGE_ARGS ?= -covermode=atomic -coverprofile=$(COVERAGE_PATH)
44
TEST_ARGS ?= -race -count=1 -timeout=5s
5-
CI_TEST_ARGS ?= -timeout=60s
5+
CI_TEST_ARGS ?= -timeout=120s
66
AUTOBAHN_ARGS ?= -race -count=1 -timeout=120s
77
BENCH_COUNT ?= 10
88
BENCH_ARGS ?= -bench=. -benchmem -count=$(BENCH_COUNT) -run=^$$

internal/testing/assert/assert.go

Lines changed: 86 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1,165 +1,124 @@
1-
// Package assert implements common assertions used in go-httbin's unit tests.
1+
// Package assert implements a set of basic test helpers.
2+
//
3+
// Lightly adapted from this blog post: https://antonz.org/do-not-testify/
24
package assert
35

46
import (
7+
"bytes"
58
"errors"
69
"fmt"
7-
"net/http"
810
"reflect"
911
"strings"
1012
"testing"
11-
"time"
12-
13-
"github.com/mccutchen/websocket/internal/testing/must"
1413
)
1514

16-
// Equal asserts that two values are equal.
17-
func Equal[T comparable](t testing.TB, got, want T, msg string, arg ...any) {
18-
t.Helper()
19-
if got != want {
20-
if msg == "" {
21-
msg = "expected values to match"
22-
}
23-
msg = fmt.Sprintf(msg, arg...)
24-
t.Fatalf("%s:\nwant: %v\n got: %v", msg, want, got)
15+
// Equal asserts that got is equal to want.
16+
func Equal[T any](tb testing.TB, got T, want T, customMsg ...any) {
17+
tb.Helper()
18+
if areEqual(got, want) {
19+
return
2520
}
21+
msg := formatMsg("expected values to be equal", customMsg)
22+
tb.Errorf("%s:\ngot: %#v\nwant: %#v", msg, got, want)
2623
}
2724

28-
// DeepEqual asserts that two values are deeply equal.
29-
func DeepEqual[T any](t testing.TB, got, want T, msg string, arg ...any) {
30-
t.Helper()
31-
if !reflect.DeepEqual(got, want) {
32-
if msg == "" {
33-
msg = "expected values to match"
34-
}
35-
msg = fmt.Sprintf(msg, arg...)
36-
t.Fatalf("%s:\nwant: %#v\n got: %#v", msg, want, got)
25+
// True asserts that got is true.
26+
func True(tb testing.TB, got bool, customMsg ...any) {
27+
tb.Helper()
28+
if !got {
29+
tb.Error(formatMsg("expected value to be true", customMsg))
3730
}
3831
}
3932

40-
// NilError asserts that an error is nil.
41-
func NilError(t testing.TB, err error) {
42-
t.Helper()
43-
if err != nil {
44-
t.Fatalf("expected nil error, got %q (%T)", err, err)
33+
// Error asserts that got matches want, which may be an error, an error type,
34+
// an error string, or nil. If want is a string, it is considered a match if
35+
// it is ia substring of got's string value.
36+
func Error(tb testing.TB, got error, want any) {
37+
tb.Helper()
38+
39+
if want != nil && got == nil {
40+
tb.Errorf("errors do not match:\ngot: <nil>\nwant: %v", want)
41+
return
4542
}
46-
}
4743

48-
// Error asserts that an error matches an expected error or any one of a list
49-
// of expected errors.
50-
func Error(t testing.TB, got, expected error, alternates ...error) {
51-
t.Helper()
52-
matched := false
53-
wantAny := append([]error{expected}, alternates...)
54-
for _, want := range wantAny {
55-
if errorsMatch(t, got, want) {
56-
matched = true
57-
break
44+
switch w := want.(type) {
45+
case nil:
46+
NilError(tb, got)
47+
case error:
48+
if !errors.Is(got, w) {
49+
tb.Errorf("errors do not match:\ngot: %T(%v)\nwant: %T(%v)", got, got, w, w)
5850
}
59-
}
60-
if !matched {
61-
if len(wantAny) == 1 {
62-
t.Fatalf("expected error %q, got %q (%T vs %T)", expected, got, expected, got)
63-
} else {
64-
t.Fatalf("expected one of %v, got %q (%T)", wantAny, got, got)
51+
case string:
52+
if !strings.Contains(got.Error(), w) {
53+
tb.Errorf("error string does not match:\ngot: %q\nwant: %q", got.Error(), w)
54+
}
55+
case reflect.Type:
56+
target := reflect.New(w).Interface()
57+
if !errors.As(got, target) {
58+
tb.Errorf("error type does not match:\ngot: %T\nwant: %s", got, w)
6559
}
66-
}
67-
}
68-
69-
func errorsMatch(t testing.TB, got, expected error) bool {
70-
t.Helper()
71-
switch {
72-
case got == expected:
73-
return true
74-
case errors.Is(got, expected):
75-
return true
76-
case got != nil && expected != nil:
77-
return got.Error() == expected.Error()
7860
default:
79-
return false
61+
tb.Errorf("unsupported want type: %T", want)
8062
}
8163
}
8264

83-
// StatusCode asserts that a response has a specific status code.
84-
func StatusCode(t testing.TB, resp *http.Response, code int) {
85-
t.Helper()
86-
if resp.StatusCode != code {
87-
t.Fatalf("expected status code %d, got %d", code, resp.StatusCode)
88-
}
89-
if resp.StatusCode >= 400 {
90-
// Ensure our error responses are never served as HTML, so that we do
91-
// not need to worry about XSS or other attacks in error responses.
92-
if ct := resp.Header.Get("Content-Type"); !isSafeContentType(ct) {
93-
t.Errorf("HTTP %s error served with dangerous content type: %s", resp.Status, ct)
94-
}
65+
// NilError asserts that got is nil.
66+
func NilError(tb testing.TB, got error) {
67+
tb.Helper()
68+
if got != nil {
69+
tb.Fatalf("expected nil error, got %q (%T)", got, got)
9570
}
9671
}
9772

98-
func isSafeContentType(ct string) bool {
99-
return strings.HasPrefix(ct, "application/json") || strings.HasPrefix(ct, "text/plain") || strings.HasPrefix(ct, "application/octet-stream")
73+
type equaler[T any] interface {
74+
Equal(T) bool
10075
}
10176

102-
// Header asserts that a header key has a specific value in a response.
103-
func Header(t testing.TB, resp *http.Response, key, want string) {
104-
t.Helper()
105-
got := resp.Header.Get(key)
106-
if want != got {
107-
t.Fatalf("expected header %s=%#v, got %#v", key, want, got)
77+
func areEqual[T any](a, b T) bool {
78+
if isNil(a) && isNil(b) {
79+
return true
10880
}
109-
}
110-
111-
// ContentType asserts that a response has a specific Content-Type header
112-
// value.
113-
func ContentType(t testing.TB, resp *http.Response, contentType string) {
114-
t.Helper()
115-
Header(t, resp, "Content-Type", contentType)
116-
}
117-
118-
// Contains asserts that needle is found in the given string.
119-
func Contains(t testing.TB, s string, needle string, description string) {
120-
t.Helper()
121-
if !strings.Contains(s, needle) {
122-
t.Fatalf("expected string %q in %s %q", needle, description, s)
81+
// special case types with an Equal method
82+
if eq, ok := any(a).(equaler[T]); ok {
83+
return eq.Equal(b)
12384
}
124-
}
125-
126-
// BodyContains asserts that a response body contains a specific substring.
127-
func BodyContains(t testing.TB, resp *http.Response, needle string) {
128-
t.Helper()
129-
body := must.ReadAll(t, resp.Body)
130-
Contains(t, body, needle, "body")
131-
}
132-
133-
// BodyEquals asserts that a response body is equal to a specific string.
134-
func BodyEquals(t testing.TB, resp *http.Response, want string) {
135-
t.Helper()
136-
got := must.ReadAll(t, resp.Body)
137-
Equal(t, got, want, "incorrect response body")
138-
}
139-
140-
// BodySize asserts that a response body is a specific size.
141-
func BodySize(t testing.TB, resp *http.Response, want int) {
142-
t.Helper()
143-
got := must.ReadAll(t, resp.Body)
144-
Equal(t, len(got), want, "incorrect response body size")
145-
}
146-
147-
// DurationRange asserts that a duration is within a specific range.
148-
func DurationRange(t testing.TB, got, minVal, maxVal time.Duration) {
149-
t.Helper()
150-
if got < minVal || got > maxVal {
151-
t.Fatalf("expected duration between %s and %s, got %s", minVal, maxVal, got)
85+
// special case byte slices
86+
if aBytes, ok := any(a).([]byte); ok {
87+
bBytes := any(b).([]byte)
88+
return bytes.Equal(aBytes, bBytes)
15289
}
90+
return reflect.DeepEqual(a, b)
15391
}
15492

155-
type number interface {
156-
~int64 | ~float64
93+
func isNil(v any) bool {
94+
if v == nil {
95+
return true
96+
}
97+
// A non-nil interface can still hold a nil value, so we check the
98+
// underlying value.
99+
rv := reflect.ValueOf(v)
100+
switch rv.Kind() {
101+
case reflect.Chan,
102+
reflect.Func,
103+
reflect.Interface,
104+
reflect.Map,
105+
reflect.Pointer,
106+
reflect.Slice,
107+
reflect.UnsafePointer:
108+
return rv.IsNil()
109+
default:
110+
return false
111+
}
157112
}
158113

159-
// RoughlyEqual asserts that a numeric value is within a certain tolerance.
160-
func RoughlyEqual[T number](t testing.TB, got, want T, epsilon T) {
161-
t.Helper()
162-
if got < want-epsilon || got > want+epsilon {
163-
t.Fatalf("expected value between %v and %v, got %v", want-epsilon, want+epsilon, got)
114+
func formatMsg(defaultMsg string, customMsg []any) string {
115+
msg := defaultMsg
116+
if len(customMsg) > 0 {
117+
tmpl, ok := customMsg[0].(string)
118+
if !ok {
119+
tmpl = fmt.Sprintf("%v", customMsg[0])
120+
}
121+
msg = fmt.Sprintf(tmpl, customMsg[1:]...)
164122
}
123+
return msg
165124
}

internal/testing/must/must.go

Lines changed: 0 additions & 50 deletions
This file was deleted.

proto_test.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package websocket_test
22

33
import (
44
"bytes"
5-
"errors"
65
"fmt"
76
"testing"
87

@@ -25,7 +24,7 @@ func TestFrameRoundTrip(t *testing.T) {
2524
assert.NilError(t, err)
2625

2726
// ensure client and server frame match
28-
assert.DeepEqual(t, serverFrame, clientFrame, "server and client frame mismatch")
27+
assert.Equal(t, serverFrame, clientFrame, "server and client frame mismatch")
2928
}
3029

3130
func TestMaxFrameSize(t *testing.T) {
@@ -148,31 +147,31 @@ func TestExampleFramesFromRFC(t *testing.T) {
148147
t.Parallel()
149148
buf := bytes.NewReader(tc.rawBytes)
150149
got := mustReadFrame(t, buf, len(tc.rawBytes))
151-
assert.DeepEqual(t, got, tc.wantFrame, "frames do not match")
150+
assert.Equal(t, got, tc.wantFrame, "frames do not match")
152151
})
153152
}
154153
}
155154

156155
func TestIncompleteFrames(t *testing.T) {
157156
testCases := map[string]struct {
158157
rawBytes []byte
159-
wantErr error
158+
wantErr string
160159
}{
161160
"2-byte extended payload can't be read": {
162161
rawBytes: []byte{0x82, 0x7E},
163-
wantErr: errors.New("error reading 2-byte extended payload length: EOF"),
162+
wantErr: "error reading 2-byte extended payload length: EOF",
164163
},
165164
"8-byte extended payload can't be read": {
166165
rawBytes: []byte{0x82, 0x7F},
167-
wantErr: errors.New("error reading 8-byte extended payload length: EOF"),
166+
wantErr: "error reading 8-byte extended payload length: EOF",
168167
},
169168
"mask can't be read": {
170169
rawBytes: []byte{0x81, 0x85},
171-
wantErr: errors.New("error reading mask key: EOF"),
170+
wantErr: "error reading mask key: EOF",
172171
},
173172
"payload can't be read": {
174173
rawBytes: []byte{0x81, 0x05},
175-
wantErr: errors.New("error reading 5 byte payload: EOF"),
174+
wantErr: "error reading 5 byte payload: EOF",
176175
},
177176
}
178177

websocket_internal_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ func TestDefaults(t *testing.T) {
4444
assert.Equal(t, ws.mode, ServerMode, "incorrect mode value")
4545
assert.Equal(t, ws.hooks.OnCloseHandshakeStart != nil, true, "OnCloseHandshakeStart hook is nil")
4646
assert.Equal(t, ws.hooks.OnCloseHandshakeDone != nil, true, "OnCloseHandshakeDone hook is nil")
47-
assert.Equal(t, ws.hooks.OnReadError != nil, true, "OnReadError hook is nil")
48-
assert.Equal(t, ws.hooks.OnReadFrame != nil, true, "OnReadFrame hook is nil")
49-
assert.Equal(t, ws.hooks.OnReadMessage != nil, true, "OnReadMessage hook is nil")
50-
assert.Equal(t, ws.hooks.OnWriteError != nil, true, "OnWriteError hook is nil")
51-
assert.Equal(t, ws.hooks.OnWriteFrame != nil, true, "OnWriteFrame hook is nil")
52-
assert.Equal(t, ws.hooks.OnWriteMessage != nil, true, "OnWriteMessage hook is nil")
47+
assert.True(t, ws.hooks.OnReadError != nil, "OnReadError hook is nil")
48+
assert.True(t, ws.hooks.OnReadFrame != nil, "OnReadFrame hook is nil")
49+
assert.True(t, ws.hooks.OnReadMessage != nil, "OnReadMessage hook is nil")
50+
assert.True(t, ws.hooks.OnWriteError != nil, "OnWriteError hook is nil")
51+
assert.True(t, ws.hooks.OnWriteFrame != nil, "OnWriteFrame hook is nil")
52+
assert.True(t, ws.hooks.OnWriteMessage != nil, "OnWriteMessage hook is nil")
5353

5454
t.Run("CloseTimeout defaults to ReadTimeout if set", func(t *testing.T) {
5555
var (

0 commit comments

Comments
 (0)