-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
saml.go
364 lines (320 loc) · 11.3 KB
/
saml.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
package auth
import (
"bytes"
"compress/flate"
"encoding/base64"
"io/ioutil"
"time"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
"github.com/beevik/etree"
"github.com/gravitational/trace"
saml2 "github.com/russellhaering/gosaml2"
)
func (s *AuthServer) UpsertSAMLConnector(connector services.SAMLConnector) error {
return s.Identity.UpsertSAMLConnector(connector)
}
func (s *AuthServer) DeleteSAMLConnector(connectorName string) error {
return s.Identity.DeleteSAMLConnector(connectorName)
}
func (s *AuthServer) CreateSAMLAuthRequest(req services.SAMLAuthRequest) (*services.SAMLAuthRequest, error) {
connector, err := s.Identity.GetSAMLConnector(req.ConnectorID, true)
if err != nil {
return nil, trace.Wrap(err)
}
provider, err := s.getSAMLProvider(connector)
if err != nil {
return nil, trace.Wrap(err)
}
doc, err := provider.BuildAuthRequestDocument()
if err != nil {
return nil, trace.Wrap(err)
}
attr := doc.Root().SelectAttr("ID")
if attr == nil || attr.Value == "" {
return nil, trace.BadParameter("missing auth request ID")
}
req.ID = attr.Value
req.RedirectURL, err = provider.BuildAuthURLFromDocument("", doc)
if err != nil {
return nil, trace.Wrap(err)
}
err = s.Identity.CreateSAMLAuthRequest(req, defaults.SAMLAuthRequestTTL)
if err != nil {
return nil, trace.Wrap(err)
}
return &req, nil
}
func (s *AuthServer) getSAMLProvider(conn services.SAMLConnector) (*saml2.SAMLServiceProvider, error) {
s.lock.Lock()
defer s.lock.Unlock()
providerPack, ok := s.samlProviders[conn.GetName()]
if ok && providerPack.connector.Equals(conn) {
return providerPack.provider, nil
}
delete(s.samlProviders, conn.GetName())
serviceProvider, err := conn.GetServiceProvider(s.clock)
if err != nil {
return nil, trace.Wrap(err)
}
s.samlProviders[conn.GetName()] = &samlProvider{connector: conn, provider: serviceProvider}
return serviceProvider, nil
}
// buildSAMLRoles takes a connector and claims and returns a slice of roles.
func (a *AuthServer) buildSAMLRoles(connector services.SAMLConnector, assertionInfo saml2.AssertionInfo) ([]string, error) {
roles := connector.MapAttributes(assertionInfo)
if len(roles) == 0 {
return nil, trace.AccessDenied("unable to map attributes to role for connector: %v", connector.GetName())
}
return roles, nil
}
// assertionsToTraitMap extracts all string assertions and creates a map of traits
// that can be used to populate role variables.
func assertionsToTraitMap(assertionInfo saml2.AssertionInfo) map[string][]string {
traits := make(map[string][]string)
for _, assr := range assertionInfo.Values {
var vals []string
for _, value := range assr.Values {
vals = append(vals, value.Value)
}
traits[assr.Name] = vals
}
return traits
}
func (a *AuthServer) createSAMLUser(connector services.SAMLConnector, assertionInfo saml2.AssertionInfo, expiresAt time.Time) error {
roles, err := a.buildSAMLRoles(connector, assertionInfo)
if err != nil {
return trace.Wrap(err)
}
traits := assertionsToTraitMap(assertionInfo)
log.Debugf("[SAML] Generating dynamic identity %v/%v with roles: %v", connector.GetName(), assertionInfo.NameID, roles)
user, err := services.GetUserMarshaler().GenerateUser(&services.UserV2{
Kind: services.KindUser,
Version: services.V2,
Metadata: services.Metadata{
Name: assertionInfo.NameID,
Namespace: defaults.Namespace,
},
Spec: services.UserSpecV2{
Roles: roles,
Traits: traits,
Expires: expiresAt,
SAMLIdentities: []services.ExternalIdentity{{ConnectorID: connector.GetName(), Username: assertionInfo.NameID}},
CreatedBy: services.CreatedBy{
User: services.UserRef{Name: "system"},
Time: time.Now().UTC(),
Connector: &services.ConnectorRef{
Type: teleport.ConnectorSAML,
ID: connector.GetName(),
Identity: assertionInfo.NameID,
},
},
},
})
if err != nil {
return trace.Wrap(err)
}
// check if a user exists already
existingUser, err := a.GetUser(assertionInfo.NameID)
if err != nil {
if !trace.IsNotFound(err) {
return trace.Wrap(err)
}
}
// check if exisiting user is a non-saml user, if so, return an error
if existingUser != nil {
connectorRef := existingUser.GetCreatedBy().Connector
if connectorRef == nil || connectorRef.Type != teleport.ConnectorSAML || connectorRef.ID != connector.GetName() {
return trace.AlreadyExists("user %q already exists and is not SAML user, remove local user and try again.",
existingUser.GetName())
}
}
// no non-saml user exists, create or update the exisiting saml user
err = a.UpsertUser(user)
if err != nil {
return trace.Wrap(err)
}
return nil
}
func parseSAMLInResponseTo(response string) (string, error) {
raw, _ := base64.StdEncoding.DecodeString(response)
doc := etree.NewDocument()
err := doc.ReadFromBytes(raw)
if err != nil {
// Attempt to inflate the response in case it happens to be compressed (as with one case at saml.oktadev.com)
buf, err := ioutil.ReadAll(flate.NewReader(bytes.NewReader(raw)))
if err != nil {
return "", trace.Wrap(err)
}
doc = etree.NewDocument()
err = doc.ReadFromBytes(buf)
if err != nil {
return "", trace.Wrap(err)
}
}
if doc.Root() == nil {
return "", trace.BadParameter("unable to parse response")
}
// teleport only supports sending party initiated flows (Teleport sends an
// AuthnRequest to the IdP and gets a SAMLResponse from the IdP). identity
// provider initiated flows (where Teleport gets an unsolicited SAMLResponse
// from the IdP) are not supported.
el := doc.Root()
responseTo := el.SelectAttr("InResponseTo")
if responseTo == nil {
log.Errorf("[SAML] Teleport does not support initiating login from an identity provider, login must be initiated from either the Teleport Web UI or CLI.")
return "", trace.BadParameter("identity provider initiated flows are not supported")
}
if responseTo.Value == "" {
return "", trace.BadParameter("InResponseTo can not be empty")
}
return responseTo.Value, nil
}
// SAMLAuthResponse is returned when auth server validated callback parameters
// returned from SAML identity provider
type SAMLAuthResponse struct {
// Username is an authenticated teleport username
Username string `json:"username"`
// Identity contains validated SAML identity
Identity services.ExternalIdentity `json:"identity"`
// Web session will be generated by auth server if requested in SAMLAuthRequest
Session services.WebSession `json:"session,omitempty"`
// Cert will be generated by certificate authority
Cert []byte `json:"cert,omitempty"`
// TLSCert is a PEM encoded TLS certificate
TLSCert []byte `json:"tls_cert,omitempty"`
// Req is an original SAML auth request
Req services.SAMLAuthRequest `json:"req"`
// HostSigners is a list of signing host public keys
// trusted by proxy, used in console login
HostSigners []services.CertAuthority `json:"host_signers"`
}
// ValidateSAMLResponse consumes attribute statements from SAML identity provider
func (a *AuthServer) ValidateSAMLResponse(samlResponse string) (*SAMLAuthResponse, error) {
re, err := a.validateSAMLResponse(samlResponse)
if err != nil {
a.EmitAuditEvent(events.UserLoginEvent, events.EventFields{
events.LoginMethod: events.LoginMethodSAML,
events.AuthAttemptSuccess: false,
events.AuthAttemptErr: err.Error(),
})
} else {
a.EmitAuditEvent(events.UserLoginEvent, events.EventFields{
events.EventUser: re.Username,
events.AuthAttemptSuccess: true,
events.LoginMethod: events.LoginMethodSAML,
})
}
return re, err
}
func (a *AuthServer) validateSAMLResponse(samlResponse string) (*SAMLAuthResponse, error) {
requestID, err := parseSAMLInResponseTo(samlResponse)
if err != nil {
return nil, trace.Wrap(err)
}
request, err := a.Identity.GetSAMLAuthRequest(requestID)
if err != nil {
return nil, trace.Wrap(err)
}
connector, err := a.Identity.GetSAMLConnector(request.ConnectorID, true)
if err != nil {
return nil, trace.Wrap(err)
}
provider, err := a.getSAMLProvider(connector)
if err != nil {
return nil, trace.Wrap(err)
}
assertionInfo, err := provider.RetrieveAssertionInfo(samlResponse)
if err != nil {
log.Warningf("SAML error: %v", err)
return nil, trace.AccessDenied("bad SAML response")
}
if assertionInfo.WarningInfo.InvalidTime {
log.Warningf("SAML error, invalid time")
return nil, trace.AccessDenied("bad SAML response")
}
if assertionInfo.WarningInfo.NotInAudience {
log.Warningf("SAML error, not in audience")
return nil, trace.AccessDenied("bad SAML response")
}
log.Debugf("[SAML] Obtained Assertions for %q", assertionInfo.NameID)
for key, val := range assertionInfo.Values {
var vals []string
for _, vv := range val.Values {
vals = append(vals, vv.Value)
}
log.Debugf("[SAML] Assertion: %q: %q", key, vals)
}
log.Debugf("[SAML] Assertion Warnings: %+v", assertionInfo.WarningInfo)
log.Debugf("[SAML] Applying %v claims to roles mappings", len(connector.GetAttributesToRoles()))
if len(connector.GetAttributesToRoles()) == 0 {
return nil, trace.BadParameter("SAML does not support binding to local users")
}
// TODO(klizhentas) use SessionNotOnOrAfter to calculate expiration time
expiresAt := a.clock.Now().Add(defaults.CertDuration)
if err := a.createSAMLUser(connector, *assertionInfo, expiresAt); err != nil {
return nil, trace.Wrap(err)
}
identity := services.ExternalIdentity{
ConnectorID: request.ConnectorID,
Username: assertionInfo.NameID,
}
user, err := a.Identity.GetUserBySAMLIdentity(identity)
if err != nil {
return nil, trace.Wrap(err)
}
response := &SAMLAuthResponse{
Req: *request,
Identity: identity,
Username: user.GetName(),
}
var roles services.RoleSet
roles, err = services.FetchRoles(user.GetRoles(), a.Access, user.GetTraits())
if err != nil {
return nil, trace.Wrap(err)
}
sessionTTL := roles.AdjustSessionTTL(utils.ToTTL(a.clock, expiresAt))
bearerTokenTTL := utils.MinTTL(BearerTokenTTL, sessionTTL)
if request.CreateWebSession {
sess, err := a.NewWebSession(user.GetName())
if err != nil {
return nil, trace.Wrap(err)
}
// session will expire based on identity TTL and allowed session TTL
sess.SetExpiryTime(a.clock.Now().UTC().Add(sessionTTL))
// bearer token will expire based on the expected session renewal
sess.SetBearerTokenExpiryTime(a.clock.Now().UTC().Add(bearerTokenTTL))
if err := a.UpsertWebSession(user.GetName(), sess); err != nil {
return nil, trace.Wrap(err)
}
response.Session = sess
}
if len(request.PublicKey) != 0 {
certTTL := utils.MinTTL(sessionTTL, request.CertTTL)
certs, err := a.generateUserCert(certRequest{
user: user,
roles: roles,
ttl: certTTL,
publicKey: request.PublicKey,
compatibility: request.Compatibility,
})
if err != nil {
return nil, trace.Wrap(err)
}
response.Cert = certs.ssh
response.TLSCert = certs.tls
// Return the host CA for this cluster only.
authority, err := a.GetCertAuthority(services.CertAuthID{
Type: services.HostCA,
DomainName: a.clusterName.GetClusterName(),
}, false)
if err != nil {
return nil, trace.Wrap(err)
}
response.HostSigners = append(response.HostSigners, authority)
}
return response, nil
}