-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
middleware.go
362 lines (337 loc) · 12.2 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
/*
Copyright 2017 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package auth
import (
"context"
"crypto/tls"
"crypto/x509"
"net"
"net/http"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/limiter"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"golang.org/x/net/http2"
)
// TLSServerConfig is a configuration for TLS server
type TLSServerConfig struct {
// TLS is a base TLS configuration
TLS *tls.Config
// API is API server configuration
APIConfig
// LimiterConfig is limiter config
LimiterConfig limiter.LimiterConfig
// AccessPoint is a caching access point
AccessPoint AccessCache
// Component is used for debugging purposes
Component string
// AcceptedUsage restricts authentication
// to a subset of certificates based on the metadata
AcceptedUsage []string
}
// CheckAndSetDefaults checks and sets default values
func (c *TLSServerConfig) CheckAndSetDefaults() error {
if c.TLS == nil {
return trace.BadParameter("missing parameter TLS")
}
c.TLS.ClientAuth = tls.VerifyClientCertIfGiven
if c.TLS.ClientCAs == nil {
return trace.BadParameter("missing parameter TLS.ClientCAs")
}
if c.TLS.RootCAs == nil {
return trace.BadParameter("missing parameter TLS.RootCAs")
}
if len(c.TLS.Certificates) == 0 {
return trace.BadParameter("missing parameter TLS.Certificates")
}
if c.AccessPoint == nil {
return trace.BadParameter("missing parameter AccessPoint")
}
if c.Component == "" {
c.Component = teleport.ComponentAuth
}
return nil
}
// TLSServer is TLS auth server
type TLSServer struct {
*http.Server
// TLSServerConfig is TLS server configuration used for auth server
TLSServerConfig
// Entry is TLS server logging entry
*logrus.Entry
}
// NewTLSServer returns new unstarted TLS server
func NewTLSServer(cfg TLSServerConfig) (*TLSServer, error) {
if err := cfg.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
// limiter limits requests by frequency and amount of simultaneous
// connections per client
limiter, err := limiter.NewLimiter(cfg.LimiterConfig)
if err != nil {
return nil, trace.Wrap(err)
}
// authMiddleware authenticates request assuming TLS client authentication
// adds authentication information to the context
// and passes it to the API server
authMiddleware := &AuthMiddleware{
AccessPoint: cfg.AccessPoint,
AcceptedUsage: cfg.AcceptedUsage,
}
authMiddleware.Wrap(NewGRPCServer(cfg.APIConfig))
// Wrap sets the next middleware in chain to the authMiddleware
limiter.WrapHandle(authMiddleware)
// force client auth if given
cfg.TLS.ClientAuth = tls.VerifyClientCertIfGiven
cfg.TLS.NextProtos = []string{http2.NextProtoTLS}
server := &TLSServer{
TLSServerConfig: cfg,
Server: &http.Server{
Handler: limiter,
ReadHeaderTimeout: defaults.DefaultDialTimeout,
},
Entry: logrus.WithFields(logrus.Fields{
trace.Component: cfg.Component,
}),
}
server.TLS.GetConfigForClient = server.GetConfigForClient
return server, nil
}
// Serve takes TCP listener, upgrades to TLS using config and starts serving
func (t *TLSServer) Serve(listener net.Listener) error {
return t.Server.Serve(tls.NewListener(listener, t.TLS))
}
// GetConfigForClient is getting called on every connection
// and server's GetConfigForClient reloads the list of trusted
// local and remote certificate authorities
func (t *TLSServer) GetConfigForClient(info *tls.ClientHelloInfo) (*tls.Config, error) {
var clusterName string
var err error
if info.ServerName != "" {
clusterName, err = DecodeClusterName(info.ServerName)
if err != nil {
if !trace.IsNotFound(err) {
t.Warningf("Client sent unsupported cluster name %q, what resulted in error %v.", info.ServerName, err)
return nil, trace.AccessDenied("access is denied")
}
}
}
// update client certificate pool based on currently trusted TLS
// certificate authorities.
// TODO(klizhentas) drop connections of the TLS cert authorities
// that are not trusted
pool, err := ClientCertPool(t.AccessPoint, clusterName)
if err != nil {
var ourClusterName string
if clusterName, err := t.AccessPoint.GetClusterName(); err == nil {
ourClusterName = clusterName.GetClusterName()
}
t.Errorf("Failed to retrieve client pool. Client cluster %v, target cluster %v, error: %v.", clusterName, ourClusterName, trace.DebugReport(err))
// this falls back to the default config
return nil, nil
}
tlsCopy := t.TLS.Clone()
tlsCopy.ClientCAs = pool
for _, cert := range tlsCopy.Certificates {
t.Debugf("Server certificate %v.", TLSCertInfo(&cert))
}
return tlsCopy, nil
}
// AuthMiddleware is authentication middleware checking every request
type AuthMiddleware struct {
// AccessPoint is a caching access point for auth server
AccessPoint AccessCache
// Handler is HTTP handler called after the middleware checks requests
Handler http.Handler
// AcceptedUsage restricts authentication
// to a subset of certificates based on certificate metadata,
// for example middleware can reject certificates with mismatching usage.
// If empty, will only accept certificates with non-limited usage,
// if set, will accept certificates with non-limited usage,
// and usage exactly matching the specified values.
AcceptedUsage []string
}
// Wrap sets next handler in chain
func (a *AuthMiddleware) Wrap(h http.Handler) {
a.Handler = h
}
// GetUser returns authenticated user based on request metadata set by HTTP server
func (a *AuthMiddleware) GetUser(r *http.Request) (IdentityGetter, error) {
peers := r.TLS.PeerCertificates
if len(peers) > 1 {
// when turning intermediaries on, don't forget to verify
// https://github.com/kubernetes/kubernetes/pull/34524/files#diff-2b283dde198c92424df5355f39544aa4R59
return nil, trace.AccessDenied("access denied: intermediaries are not supported")
}
localClusterName, err := a.AccessPoint.GetClusterName()
if err != nil {
return nil, trace.Wrap(err)
}
// with no client authentication in place, middleware
// assumes not-privileged Nop role.
// it theoretically possible to use bearer token auth even
// for connections without auth, but this is not active use-case
// therefore it is not allowed to reduce scope
if len(peers) == 0 {
return BuiltinRole{
GetClusterConfig: a.AccessPoint.GetClusterConfig,
Role: teleport.RoleNop,
Username: string(teleport.RoleNop),
ClusterName: localClusterName.GetClusterName(),
Identity: tlsca.Identity{},
}, nil
}
clientCert := peers[0]
certClusterName, err := tlsca.ClusterName(clientCert.Issuer)
if err != nil {
log.Warnf("Failed to parse client certificate %v.", err)
return nil, trace.AccessDenied("access denied: invalid client certificate")
}
identity, err := tlsca.FromSubject(clientCert.Subject)
if err != nil {
return nil, trace.Wrap(err)
}
// If there is any restriction on the certificate usage
// reject the API server request. This is done so some classes
// of certificates issued for kubernetes usage by proxy, can not be used
// against auth server. Later on we can extend more
// advanced cert usage, but for now this is the safest option.
if len(identity.Usage) != 0 && !utils.StringSlicesEqual(a.AcceptedUsage, identity.Usage) {
log.Warningf("Restricted certificate of user %q with usage %v rejected while accessing the auth endpoint with acceptable usage %v.",
identity.Username, identity.Usage, a.AcceptedUsage)
return nil, trace.AccessDenied("access denied: invalid client certificate")
}
// this block assumes interactive user from remote cluster
// based on the remote certificate authority cluster name encoded in
// x509 organization name. This is a safe check because:
// 1. Trust and verification is established during TLS handshake
// by creating a cert pool constructed of trusted certificate authorities
// 2. Remote CAs are not allowed to have the same cluster name
// as the local certificate authority
if certClusterName != localClusterName.GetClusterName() {
// make sure that this user does not have system role
// the local auth server can not truste remote servers
// to issue certificates with system roles (e.g. Admin),
// to get unrestricted access to the local cluster
systemRole := findSystemRole(identity.Groups)
if systemRole != nil {
return RemoteBuiltinRole{
Role: *systemRole,
Username: identity.Username,
ClusterName: certClusterName,
Identity: *identity,
}, nil
}
return RemoteUser{
ClusterName: certClusterName,
Username: identity.Username,
Principals: identity.Principals,
KubernetesGroups: identity.KubernetesGroups,
RemoteRoles: identity.Groups,
Identity: *identity,
}, nil
}
// code below expects user or service from local cluster, to distinguish between
// interactive users and services (e.g. proxies), the code below
// checks for presence of system roles issued in certificate identity
systemRole := findSystemRole(identity.Groups)
// in case if the system role is present, assume this is a service
// agent, e.g. Proxy, connecting to the cluster
if systemRole != nil {
return BuiltinRole{
GetClusterConfig: a.AccessPoint.GetClusterConfig,
Role: *systemRole,
Username: identity.Username,
ClusterName: localClusterName.GetClusterName(),
Identity: *identity,
}, nil
}
// otherwise assume that is a local role, no need to pass the roles
// as it will be fetched from the local database
return LocalUser{
Username: identity.Username,
Identity: *identity,
}, nil
}
func findSystemRole(roles []string) *teleport.Role {
for _, role := range roles {
systemRole := teleport.Role(role)
err := systemRole.Check()
if err == nil {
return &systemRole
}
}
return nil
}
// ServeHTTP serves HTTP requests
func (a *AuthMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
baseContext := r.Context()
if baseContext == nil {
baseContext = context.TODO()
}
user, err := a.GetUser(r)
if err != nil {
trace.WriteError(w, err)
return
}
// determine authenticated user based on the request parameters
requestWithContext := r.WithContext(context.WithValue(baseContext, ContextUser, user))
a.Handler.ServeHTTP(w, requestWithContext)
}
// ClientCertPool returns trusted x509 cerificate authority pool
func ClientCertPool(client AccessCache, clusterName string) (*x509.CertPool, error) {
pool := x509.NewCertPool()
var authorities []services.CertAuthority
if clusterName == "" {
hostCAs, err := client.GetCertAuthorities(services.HostCA, false, services.SkipValidation())
if err != nil {
return nil, trace.Wrap(err)
}
userCAs, err := client.GetCertAuthorities(services.UserCA, false, services.SkipValidation())
if err != nil {
return nil, trace.Wrap(err)
}
authorities = append(authorities, hostCAs...)
authorities = append(authorities, userCAs...)
} else {
hostCA, err := client.GetCertAuthority(
services.CertAuthID{Type: services.HostCA, DomainName: clusterName},
false, services.SkipValidation())
if err != nil {
return nil, trace.Wrap(err)
}
userCA, err := client.GetCertAuthority(
services.CertAuthID{Type: services.UserCA, DomainName: clusterName},
false, services.SkipValidation())
if err != nil {
return nil, trace.Wrap(err)
}
authorities = append(authorities, hostCA)
authorities = append(authorities, userCA)
}
for _, auth := range authorities {
for _, keyPair := range auth.GetTLSKeyPairs() {
cert, err := tlsca.ParseCertificatePEM(keyPair.Cert)
if err != nil {
return nil, trace.Wrap(err)
}
log.Debugf("ClientCertPool -> %v", CertInfo(cert))
pool.AddCert(cert)
}
}
return pool, nil
}