-
Notifications
You must be signed in to change notification settings - Fork 2.8k
/
mutual_authhandler.go
316 lines (258 loc) · 9.21 KB
/
mutual_authhandler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
// SPDX-License-Identifier: Apache-2.0
// Copyright Authors of Cilium
package auth
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"strconv"
"github.com/sirupsen/logrus"
"github.com/spf13/pflag"
"github.com/cilium/cilium/api/v1/models"
"github.com/cilium/cilium/pkg/auth/certs"
"github.com/cilium/cilium/pkg/endpoint"
"github.com/cilium/cilium/pkg/endpointmanager"
"github.com/cilium/cilium/pkg/hive/cell"
"github.com/cilium/cilium/pkg/identity"
"github.com/cilium/cilium/pkg/logging/logfields"
"github.com/cilium/cilium/pkg/policy"
"github.com/cilium/cilium/pkg/time"
)
type endpointGetter interface {
GetEndpoints() []*endpoint.Endpoint
}
type mutualAuthParams struct {
cell.In
CertificateProvider certs.CertificateProvider
EndpointManager endpointmanager.EndpointManager
}
func newMutualAuthHandler(logger logrus.FieldLogger, lc cell.Lifecycle, cfg MutualAuthConfig, params mutualAuthParams) authHandlerResult {
if cfg.MutualAuthListenerPort == 0 {
logger.Info("Mutual authentication handler is disabled as no port is configured")
return authHandlerResult{}
}
if params.CertificateProvider == nil {
logger.Fatal("No certificate provider configured, but one is required. Please check if the spire flags are configured.")
}
mAuthHandler := &mutualAuthHandler{
cfg: cfg,
log: logger,
cert: params.CertificateProvider,
endpointManager: params.EndpointManager,
}
lc.Append(cell.Hook{OnStart: mAuthHandler.onStart, OnStop: mAuthHandler.onStop})
return authHandlerResult{
AuthHandler: mAuthHandler,
}
}
type MutualAuthConfig struct {
MutualAuthListenerPort int `mapstructure:"mesh-auth-mutual-listener-port"`
MutualAuthConnectTimeout time.Duration `mapstructure:"mesh-auth-mutual-connect-timeout"`
}
func (cfg MutualAuthConfig) Flags(flags *pflag.FlagSet) {
flags.IntVar(&cfg.MutualAuthListenerPort, "mesh-auth-mutual-listener-port", 0,
"Port on which the Cilium Agent will perform mutual authentication handshakes between other Agents")
flags.DurationVar(&cfg.MutualAuthConnectTimeout, "mesh-auth-mutual-connect-timeout", 5*time.Second,
"Timeout for connecting to the remote node TCP socket")
}
type mutualAuthHandler struct {
cell.In
cfg MutualAuthConfig
log logrus.FieldLogger
cert certs.CertificateProvider
cancelSocketListen context.CancelFunc
endpointManager endpointGetter
}
func (m *mutualAuthHandler) authenticate(ar *authRequest) (*authResponse, error) {
if ar == nil {
return nil, errors.New("authRequest is nil")
}
clientCert, err := m.cert.GetCertificateForIdentity(ar.localIdentity)
if err != nil {
return nil, fmt.Errorf("failed to get certificate for local identity %s: %w", ar.localIdentity.String(), err)
}
caBundle, err := m.cert.GetTrustBundle()
if err != nil {
return nil, fmt.Errorf("failed to get CA bundle: %w", err)
}
// set up TCP connection
conn, err := net.DialTimeout("tcp",
net.JoinHostPort(ar.remoteNodeIP, strconv.Itoa(m.cfg.MutualAuthListenerPort)),
m.cfg.MutualAuthConnectTimeout)
if err != nil {
return nil, fmt.Errorf("failed to dial %s:%d: %w", ar.remoteNodeIP, m.cfg.MutualAuthListenerPort, err)
}
defer conn.Close()
var expirationTime *time.Time = &clientCert.Leaf.NotAfter
// set up TLS socket
//nolint:gosec // InsecureSkipVerify is not insecure as we do the verification in VerifyPeerCertificate
tlsConn := tls.Client(conn, &tls.Config{
ServerName: m.cert.NumericIdentityToSNI(ar.remoteIdentity),
GetClientCertificate: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return clientCert, nil
},
MinVersion: tls.VersionTLS13,
InsecureSkipVerify: true, // not insecure as we do the verification in VerifyPeerCertificate
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
// verifiedChains will be nil as we set InsecureSkipVerify to true
chain := make([]*x509.Certificate, len(rawCerts))
for i, rawCert := range rawCerts {
cert, err := x509.ParseCertificate(rawCert)
if err != nil {
return fmt.Errorf("failed to parse certificate: %w", err)
}
chain[i] = cert
}
peerExpirationTime, err := m.verifyPeerCertificate(&ar.remoteIdentity, caBundle, [][]*x509.Certificate{chain})
if peerExpirationTime != nil && peerExpirationTime.Before(*expirationTime) {
expirationTime = peerExpirationTime // send down the lowest expiration time of the two certificates
}
return err
},
ClientCAs: caBundle,
RootCAs: caBundle,
})
defer tlsConn.Close()
if err := tlsConn.Handshake(); err != nil {
return nil, fmt.Errorf("failed to perform TLS handshake: %w", err)
}
if expirationTime == nil {
return nil, fmt.Errorf("failed to get expiration time of peer certificate")
}
return &authResponse{
expirationTime: *expirationTime,
}, nil
}
func (m *mutualAuthHandler) authType() policy.AuthType {
return policy.AuthTypeSpire
}
func (m *mutualAuthHandler) listenForConnections(upstreamCtx context.Context, ready chan<- struct{}) {
// set up TCP listener
ctx, cancel := context.WithCancel(upstreamCtx)
defer cancel()
var lc net.ListenConfig
l, err := lc.Listen(ctx, "tcp", fmt.Sprintf(":%d", m.cfg.MutualAuthListenerPort))
if err != nil {
m.log.WithError(err).Fatal("Failed to start mutual auth listener")
}
go func() { // shutdown socket goroutine
<-ctx.Done()
l.Close()
}()
m.log.WithField(logfields.Port, m.cfg.MutualAuthListenerPort).Info("Started mutual auth listener")
ready <- struct{}{} // signal to hive that we are ready to accept connections
for {
conn, err := l.Accept()
if err != nil {
m.log.WithError(err).Error("Failed to accept connection")
if errors.Is(err, net.ErrClosed) {
m.log.Info("Mutual auth listener socket got closed")
return
}
continue
}
go m.handleConnection(ctx, conn)
}
}
func (m *mutualAuthHandler) handleConnection(ctx context.Context, conn net.Conn) {
defer conn.Close()
caBundle, err := m.cert.GetTrustBundle()
if err != nil {
m.log.WithError(err).Error("failed to get CA bundle")
return
}
tlsConn := tls.Server(conn, &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
GetCertificate: m.GetCertificateForIncomingConnection,
MinVersion: tls.VersionTLS13,
ClientCAs: caBundle,
})
defer tlsConn.Close()
if err := tlsConn.HandshakeContext(ctx); err != nil {
m.log.WithError(err).Error("failed to perform TLS handshake")
}
}
func (m *mutualAuthHandler) GetCertificateForIncomingConnection(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
m.log.WithField("SNI", info.ServerName).Debug("Got new TLS connection")
id, err := m.cert.SNIToNumericIdentity(info.ServerName)
if err != nil {
return nil, fmt.Errorf("failed to get identity for SNI %s: %w", info.ServerName, err)
}
// this checks if the requested Security ID is present on the local node
if m.endpointManager == nil {
return nil, errors.New("endpoint manager is not loaded")
}
localEPs := m.endpointManager.GetEndpoints()
matched := false
for _, ep := range localEPs {
if ep.SecurityIdentity != nil && ep.SecurityIdentity.ID == id {
matched = true
break
}
}
if !matched {
return nil, fmt.Errorf("no local endpoint present for identity %s", id.String())
}
return m.cert.GetCertificateForIdentity(id)
}
func (m *mutualAuthHandler) onStart(ctx cell.HookContext) error {
m.log.Info("Starting mutual auth handler")
listenCtx, cancel := context.WithCancel(context.Background())
m.cancelSocketListen = cancel
ready := make(chan struct{})
go m.listenForConnections(listenCtx, ready)
<-ready // wait for the socket to be ready
return nil
}
func (m *mutualAuthHandler) onStop(ctx cell.HookContext) error {
m.log.Info("Stopping mutual auth handler")
m.cancelSocketListen()
return nil
}
// verifyPeerCertificate is used for Go's TLS library to verify certificates
func (m *mutualAuthHandler) verifyPeerCertificate(id *identity.NumericIdentity, caBundle *x509.CertPool, certChains [][]*x509.Certificate) (*time.Time, error) {
if len(certChains) == 0 {
return nil, fmt.Errorf("no certificate chains found")
}
var expirationTime *time.Time
for _, chain := range certChains {
opts := x509.VerifyOptions{
Roots: caBundle,
Intermediates: x509.NewCertPool(),
}
var leaf *x509.Certificate
for _, cert := range chain {
if cert.IsCA {
opts.Intermediates.AddCert(cert)
} else {
leaf = cert
}
}
if leaf == nil {
return nil, fmt.Errorf("no leaf certificate found")
}
if _, err := leaf.Verify(opts); err != nil {
return nil, fmt.Errorf("failed to verify certificate: %w", err)
}
if id != nil { // this will be empty in the peer connection
m.log.WithField("SNI ID", id.String()).Debug("Validating Server SNI")
if valid, err := m.cert.ValidateIdentity(*id, leaf); err != nil {
return nil, fmt.Errorf("failed to validate SAN: %w", err)
} else if !valid {
return nil, fmt.Errorf("unable to validate SAN")
}
}
expirationTime = &leaf.NotAfter
m.log.WithField("uri-san", leaf.URIs).Debug("Validated certificate")
}
return expirationTime, nil
}
func (m *mutualAuthHandler) subscribeToRotatedIdentities() <-chan certs.CertificateRotationEvent {
return m.cert.SubscribeToRotatedIdentities()
}
func (m *mutualAuthHandler) certProviderStatus() *models.Status {
return m.cert.Status()
}