generated from morningconsult/.github
/
wait.go
203 lines (175 loc) · 4.81 KB
/
wait.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
package grace
import (
"context"
"log/slog"
"net"
"net/http"
"regexp"
"time"
"github.com/hashicorp/go-cleanhttp"
"github.com/morningconsult/serrors"
"golang.org/x/sync/errgroup"
)
// Waiter is something that waits for a thing to be "ready".
type Waiter interface {
Wait(ctx context.Context) error
}
// WaiterFunc is a function that can be used as a Waiter.
type WaiterFunc func(context.Context) error
// Wait waits for a resource using the WaiterFunc.
func (w WaiterFunc) Wait(ctx context.Context) error {
return w(ctx)
}
// Wait waits for all the provided checker pings to be successful until
// the specified timeout is exceeded. It will block until all of the pings are
// successful and return nil, or return an error if any checker is failing by
// the time the timeout elapses.
//
// Wait can be used to wait for dependent services like sidecar upstreams to
// be available before proceeding with other parts of an application startup.
func Wait(ctx context.Context, timeout time.Duration, opts ...WaitOption) error {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
cfg := waitConfig{
logger: slog.Default(),
}
for _, opt := range opts {
cfg = opt(cfg)
}
g, ctx := errgroup.WithContext(ctx)
for _, waiter := range cfg.waiters {
waiter := waiter
g.Go(func() error {
return waiter.Wait(ctx)
})
}
return serrors.WithStack(g.Wait())
}
// WaitOption is a configurable option for [Wait].
type WaitOption func(cfg waitConfig) waitConfig
type waitConfig struct {
logger *slog.Logger
waiters []Waiter
}
// WithWaitLogger configures the logger to use when calling [Wait].
func WithWaitLogger(logger *slog.Logger) WaitOption {
return func(cfg waitConfig) waitConfig {
cfg.logger = logger
for i, waiter := range cfg.waiters {
switch waiter := waiter.(type) {
case httpWaiter:
waiter.logger = logger
cfg.waiters[i] = waiter
case tcpWaiter:
waiter.logger = logger
cfg.waiters[i] = waiter
}
}
return cfg
}
}
// WithWaiter adds a waiter for use with [Wait].
func WithWaiter(w Waiter) WaitOption {
return func(cfg waitConfig) waitConfig {
cfg.waiters = append(cfg.waiters, w)
return cfg
}
}
// WithWaiterFunc adds a waiter for use with [Wait].
func WithWaiterFunc(w WaiterFunc) WaitOption {
return func(cfg waitConfig) waitConfig {
cfg.waiters = append(cfg.waiters, w)
return cfg
}
}
// urlRegexp is used to remove any protocol or path that might be present
// when creating a tcp waiter.
var urlRegexp = regexp.MustCompile("^(https?://)?(?P<host>.+):(?P<port>[0-9]+)(.*)?")
// WithWaitForTCP makes a new TCP waiter that will ping an address and return
// once it is reachable.
func WithWaitForTCP(addr string) WaitOption {
return func(cfg waitConfig) waitConfig {
cfg.waiters = append(cfg.waiters,
tcpWaiter{
addr: urlRegexp.ReplaceAllString(addr, "$host:$port"),
logger: cfg.logger,
},
)
return cfg
}
}
type tcpWaiter struct {
addr string
logger *slog.Logger
}
// Wait waits for something to be listening on the given TCP address.
func (w tcpWaiter) Wait(ctx context.Context) error {
for {
if err := checkContextDone(ctx, w.logger, w.addr); err != nil {
return err
}
d := net.Dialer{
Timeout: 300 * time.Millisecond,
}
conn, _ := d.DialContext(ctx, "tcp", w.addr)
if conn != nil {
w.logger.DebugContext(ctx, "established connection to address",
"address", w.addr,
)
defer conn.Close() //nolint:errcheck
return nil
}
}
}
// WithWaitForHTTP makes a new HTTP waiter that will make GET requests to a URL
// until it returns a non-500 error code. All statuses below 500 mean the dependency
// is accepting requests, even if the check is unauthorized or invalid.
func WithWaitForHTTP(url string) WaitOption {
return func(cfg waitConfig) waitConfig {
cfg.waiters = append(cfg.waiters,
httpWaiter{
client: cleanhttp.DefaultClient(),
logger: cfg.logger,
url: url,
},
)
return cfg
}
}
type httpWaiter struct {
client *http.Client
logger *slog.Logger
url string
}
// Wait waits for something to be accepting HTTP requests.
func (w httpWaiter) Wait(ctx context.Context) error {
for {
if err := checkContextDone(ctx, w.logger, w.url); err != nil {
return err
}
res, _ := w.client.Get(w.url)
if res == nil {
continue
}
res.Body.Close()
if res.StatusCode < http.StatusInternalServerError {
w.logger.DebugContext(ctx, "established connection to address",
"address", w.url,
)
return nil
}
}
}
// checkContextDone checks if the provided context is done, and returns
// an error if it is.
func checkContextDone(ctx context.Context, logger *slog.Logger, addr string) error {
select {
case <-ctx.Done():
logger.DebugContext(ctx, "failed to establish connection to address",
"address", addr,
)
return serrors.Errorf("timed out connecting to %q", addr)
default:
return nil
}
}