-
Notifications
You must be signed in to change notification settings - Fork 13
/
clerktest.go
116 lines (103 loc) · 2.88 KB
/
clerktest.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
// Package clerktest provides utilities for testing.
package clerktest
import (
"bytes"
"crypto"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"io"
"net/http"
"net/url"
"sync"
"testing"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/stretchr/testify/require"
)
// RoundTripper can be used as a mock Transport for http.Clients.
// Set the RoundTripper's fields accordingly to determine the
// response or perform assertions on the http.Request properties.
type RoundTripper struct {
T *testing.T
// Status is the response Status code.
Status int
// Out is the response body.
Out json.RawMessage
// Set this field to assert on the request method.
Method string
// Set this field to assert that the request path matches.
Path string
// Set this field to assert that the request URL querystring matches.
Query *url.Values
// Set this field to assert that the request body matches.
In json.RawMessage
}
// RoundTrip returns an http.Response based on the RoundTripper's fields.
// It will also perform assertions on the http.Request.
func (rt *RoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
if rt.Status == 0 {
rt.Status = http.StatusOK
}
if rt.Method != "" {
require.Equal(rt.T, rt.Method, r.Method)
}
if rt.Path != "" {
require.Equal(rt.T, rt.Path, r.URL.Path)
}
if rt.Query != nil {
require.Equal(rt.T, rt.Query.Encode(), r.URL.Query().Encode())
}
if rt.In != nil {
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
defer r.Body.Close()
require.JSONEq(rt.T, string(rt.In), string(body))
}
return &http.Response{
StatusCode: rt.Status,
Body: io.NopCloser(bytes.NewReader(rt.Out)),
}, nil
}
// GenerateJWT creates a JSON web token with the provided claims
// and key ID.
func GenerateJWT(t *testing.T, claims any, kid string) (string, crypto.PublicKey) {
t.Helper()
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
signerOpts := &jose.SignerOptions{}
signerOpts.WithType("JWT")
if kid != "" {
signerOpts.WithHeader("kid", kid)
}
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: privKey}, signerOpts)
require.NoError(t, err)
builder := jwt.Signed(signer)
builder = builder.Claims(claims)
token, err := builder.CompactSerialize()
require.NoError(t, err)
return token, privKey.Public()
}
// Clock provides a test clock which can be manually advanced through time.
type Clock struct {
mu sync.RWMutex
// The current time of this test clock.
time time.Time
}
// NewClockAt returns a Clock initialized at the given time.
func NewClockAt(t time.Time) *Clock {
return &Clock{time: t}
}
// Now returns the clock's current time.
func (c *Clock) Now() time.Time {
return c.time
}
// Advance moves the test clock to a new point in time.
func (c *Clock) Advance(d time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.time = c.time.Add(d)
}