forked from hashicorp/consul
-
Notifications
You must be signed in to change notification settings - Fork 0
/
listener.go
295 lines (253 loc) · 8.37 KB
/
listener.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package proxy
import (
"context"
"crypto/tls"
"errors"
"net"
"sync"
"sync/atomic"
"time"
metrics "github.com/armon/go-metrics"
"github.com/hashicorp/go-hclog"
"github.com/hernad/consul/api"
"github.com/hernad/consul/connect"
"github.com/hernad/consul/ipaddr"
)
const (
publicListenerPrefix = "inbound"
upstreamListenerPrefix = "upstream"
)
// Listener is the implementation of a specific proxy listener. It has pluggable
// Listen and Dial methods to suit public mTLS vs upstream semantics. It handles
// the lifecycle of the listener and all connections opened through it
type Listener struct {
// Service is the connect service instance to use.
Service *connect.Service
// listenFunc, dialFunc, and bindAddr are set by type-specific constructors.
listenFunc func() (net.Listener, error)
dialFunc func() (net.Conn, error)
bindAddr string
stopFlag int32
stopChan chan struct{}
// listeningChan is closed when listener is opened successfully. It's really
// only for use in tests where we need to coordinate wait for the Serve
// goroutine to be running before we proceed trying to connect. On my laptop
// this always works out anyway but on constrained VMs and especially docker
// containers (e.g. in CI) we often see the Dial routine win the race and get
// `connection refused`. Retry loops and sleeps are unpleasant workarounds and
// this is cheap and correct.
listeningChan chan struct{}
// listenerLock guards access to the listener field
listenerLock sync.Mutex
listener net.Listener
logger hclog.Logger
// Gauge to track current open connections
activeConns int32
connWG sync.WaitGroup
metricPrefix string
metricLabels []metrics.Label
}
// NewPublicListener returns a Listener setup to listen for public mTLS
// connections and proxy them to the configured local application over TCP.
func NewPublicListener(svc *connect.Service, cfg PublicListenerConfig,
logger hclog.Logger) *Listener {
bindAddr := ipaddr.FormatAddressPort(cfg.BindAddress, cfg.BindPort)
return &Listener{
Service: svc,
listenFunc: func() (net.Listener, error) {
return tls.Listen("tcp", bindAddr, svc.ServerTLSConfig())
},
dialFunc: func() (net.Conn, error) {
return net.DialTimeout("tcp", cfg.LocalServiceAddress,
time.Duration(cfg.LocalConnectTimeoutMs)*time.Millisecond)
},
bindAddr: bindAddr,
stopChan: make(chan struct{}),
listeningChan: make(chan struct{}),
logger: logger.Named(publicListenerPrefix),
metricPrefix: publicListenerPrefix,
// For now we only label ourselves as source - we could fetch the src
// service from cert on each connection and label metrics differently but it
// significaly complicates the active connection tracking here and it's not
// clear that it's very valuable - on aggregate looking at all _outbound_
// connections across all proxies gets you a full picture of src->dst
// traffic. We might expand this later for better debugging of which clients
// are abusing a particular service instance but we'll see how valuable that
// seems for the extra complication of tracking many gauges here.
metricLabels: []metrics.Label{{Name: "dst", Value: svc.Name()}},
}
}
// NewUpstreamListener returns a Listener setup to listen locally for TCP
// connections that are proxied to a discovered Connect service instance.
func NewUpstreamListener(svc *connect.Service, client *api.Client,
cfg UpstreamConfig, logger hclog.Logger) *Listener {
return newUpstreamListenerWithResolver(svc, cfg,
UpstreamResolverFuncFromClient(client), logger)
}
func newUpstreamListenerWithResolver(svc *connect.Service, cfg UpstreamConfig,
resolverFunc func(UpstreamConfig) (connect.Resolver, error),
logger hclog.Logger) *Listener {
bindAddr := ipaddr.FormatAddressPort(cfg.LocalBindAddress, cfg.LocalBindPort)
return &Listener{
Service: svc,
listenFunc: func() (net.Listener, error) {
return net.Listen("tcp", bindAddr)
},
dialFunc: func() (net.Conn, error) {
rf, err := resolverFunc(cfg)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(),
cfg.ConnectTimeout())
defer cancel()
return svc.Dial(ctx, rf)
},
bindAddr: bindAddr,
stopChan: make(chan struct{}),
listeningChan: make(chan struct{}),
logger: logger.Named(upstreamListenerPrefix),
metricPrefix: upstreamListenerPrefix,
metricLabels: []metrics.Label{
{Name: "src", Value: svc.Name()},
// TODO(banks): namespace support
{Name: "dst_type", Value: string(cfg.DestinationType)},
{Name: "dst", Value: cfg.DestinationName},
},
}
}
// Serve runs the listener until it is stopped. It is an error to call Serve
// more than once for any given Listener instance.
func (l *Listener) Serve() error {
// Ensure we mark state closed if we fail before Close is called externally.
defer l.Close()
if atomic.LoadInt32(&l.stopFlag) != 0 {
return errors.New("serve called on a closed listener")
}
listener, err := l.listenFunc()
if err != nil {
return err
}
l.setListener(listener)
close(l.listeningChan)
for {
conn, err := listener.Accept()
if err != nil {
if atomic.LoadInt32(&l.stopFlag) == 1 {
return nil
}
return err
}
l.connWG.Add(1)
go l.handleConn(conn)
}
}
// handleConn is the internal connection handler goroutine.
func (l *Listener) handleConn(src net.Conn) {
defer func() {
// Make sure Listener.Close waits for this conn to be cleaned up.
src.Close()
l.connWG.Done()
}()
dst, err := l.dialFunc()
if err != nil {
l.logger.Error("failed to dial", "error", err)
return
}
// Track active conn now (first function call) and defer un-counting it when
// it closes.
defer l.trackConn()()
// Note no need to defer dst.Close() since conn handles that for us.
conn := NewConn(src, dst)
defer conn.Close()
connStop := make(chan struct{})
// Run another goroutine to copy the bytes.
go func() {
err = conn.CopyBytes()
if err != nil {
l.logger.Error("connection failed", "error", err)
}
close(connStop)
}()
// Periodically copy stats from conn to metrics (to keep metrics calls out of
// the path of every single packet copy). 5 seconds is probably good enough
// resolution - statsd and most others tend to summarize with lower resolution
// anyway and this amortizes the cost more.
var tx, rx uint64
statsT := time.NewTicker(5 * time.Second)
defer statsT.Stop()
reportStats := func() {
newTx, newRx := conn.Stats()
if delta := newTx - tx; delta > 0 {
metrics.IncrCounterWithLabels([]string{l.metricPrefix, "tx_bytes"},
float32(newTx-tx), l.metricLabels)
}
if delta := newRx - rx; delta > 0 {
metrics.IncrCounterWithLabels([]string{l.metricPrefix, "rx_bytes"},
float32(newRx-rx), l.metricLabels)
}
tx, rx = newTx, newRx
}
// Always report final stats for the conn.
defer reportStats()
// Wait for conn to close
for {
select {
case <-connStop:
return
case <-l.stopChan:
return
case <-statsT.C:
reportStats()
}
}
}
// trackConn increments the count of active conns and returns a func() that can
// be deferred on to decrement the counter again on connection close.
func (l *Listener) trackConn() func() {
c := atomic.AddInt32(&l.activeConns, 1)
metrics.SetGaugeWithLabels([]string{l.metricPrefix, "conns"}, float32(c),
l.metricLabels)
return func() {
c := atomic.AddInt32(&l.activeConns, -1)
metrics.SetGaugeWithLabels([]string{l.metricPrefix, "conns"}, float32(c),
l.metricLabels)
}
}
// Close terminates the listener and all active connections.
func (l *Listener) Close() error {
// Prevent the listener from being started.
oldFlag := atomic.SwapInt32(&l.stopFlag, 1)
if oldFlag != 0 {
return nil
}
// Stop the current listener and stop accepting new requests.
if listener := l.getListener(); listener != nil {
listener.Close()
}
// Stop outstanding requests.
close(l.stopChan)
// Wait for all conns to close
l.connWG.Wait()
return nil
}
// Wait for the listener to be ready to accept connections.
func (l *Listener) Wait() {
<-l.listeningChan
}
// BindAddr returns the address the listen is bound to.
func (l *Listener) BindAddr() string {
return l.bindAddr
}
func (l *Listener) setListener(listener net.Listener) {
l.listenerLock.Lock()
l.listener = listener
l.listenerLock.Unlock()
}
func (l *Listener) getListener() net.Listener {
l.listenerLock.Lock()
defer l.listenerLock.Unlock()
return l.listener
}