Skip to content
Permalink
Browse files Browse the repository at this point in the history
Merge pull request from GHSA-ch68-7cf4-35vr
* Validate audience restrictions when validating SAML auth reqs

* EntityID is usually the audience

* Add coverage for failures on audience conditions
  • Loading branch information
chiiph authored and zwass committed Feb 2, 2022
1 parent 6e706bf commit 35d5a7b
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 18 deletions.
1 change: 1 addition & 0 deletions changes/sso-check-audience
@@ -0,0 +1 @@
* Check for audience restrictions when validating a SAML request.
6 changes: 5 additions & 1 deletion server/service/service_sessions.go
Expand Up @@ -156,7 +156,11 @@ func (svc *Service) CallbackSSO(ctx context.Context, auth fleet.Auth) (*fleet.SS
}

// Validate response
validator, err := sso.NewValidator(*metadata)
validator, err := sso.NewValidator(*metadata, sso.WithExpectedAudience(
appConfig.SSOSettings.EntityID,
appConfig.ServerSettings.ServerURL,
appConfig.ServerSettings.ServerURL+svc.config.Server.URLPrefix+"/api/v1/fleet/sso/callback", // ACS
))
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "create validator from metadata")
}
Expand Down
11 changes: 8 additions & 3 deletions server/sso/types.go
Expand Up @@ -205,9 +205,14 @@ type Assertion struct {
}

type Conditions struct {
XMLName xml.Name
NotBefore string `xml:",attr"`
NotOnOrAfter string `xml:",attr"`
XMLName xml.Name
NotBefore string `xml:",attr"`
NotOnOrAfter string `xml:",attr"`
AudienceRestriction AudienceRestriction `xml:"AudienceRestriction"`
}

type AudienceRestriction struct {
Audience string `xml:"Audience"`
}

type Subject struct {
Expand Down
25 changes: 22 additions & 3 deletions server/sso/validate.go
Expand Up @@ -24,9 +24,10 @@ type Validator interface {
}

type validator struct {
context *dsig.ValidationContext
clock *dsig.Clock
metadata Metadata
context *dsig.ValidationContext
clock *dsig.Clock
metadata Metadata
expectedAudiences []string
}

func Clock(clock *dsig.Clock) func(v *validator) {
Expand All @@ -35,6 +36,12 @@ func Clock(clock *dsig.Clock) func(v *validator) {
}
}

func WithExpectedAudience(audiences ...string) func(v *validator) {
return func(v *validator) {
v.expectedAudiences = audiences
}
}

// NewValidator is used to validate the response to an auth request.
// metadata is from the IDP.
func NewValidator(metadata Metadata, opts ...func(v *validator)) (Validator, error) {
Expand Down Expand Up @@ -86,6 +93,18 @@ func (v *validator) ValidateResponse(auth fleet.Auth) error {
if currentTime.Before(notBefore) {
return errors.New("response too early")
}

verifiesAudience := false
for _, audience := range v.expectedAudiences {
if info.response.Assertion.Conditions.AudienceRestriction.Audience == audience {
verifiesAudience = true
break
}
}
if !verifiesAudience {
return errors.New("wrong audience:" + info.response.Assertion.Conditions.AudienceRestriction.Audience)
}

if auth.UserID() == "" {
return errors.New("missing user id")
}
Expand Down
38 changes: 27 additions & 11 deletions server/sso/validate_test.go
Expand Up @@ -49,19 +49,35 @@ func TestValidate(t *testing.T) {
require.Nil(t, err)

clock := dsig.NewFakeClockAt(tm)
validator, err := NewValidator(testMetadata(), Clock(clock))
require.Nil(t, err)
require.NotNil(t, validator)

auth, err := DecodeAuthResponse(testResponse)
require.Nil(t, err)
testCases := []struct {
audiences []string
shouldFail bool
}{
{audiences: []string{"kolide"}, shouldFail: false},
{audiences: []string{"someotheraudience"}, shouldFail: true},
{audiences: nil, shouldFail: true},
}

signed, err := validator.ValidateSignature(auth)
require.Nil(t, err)
require.NotNil(t, signed)
for _, tt := range testCases {
validator, err := NewValidator(testMetadata(), Clock(clock), WithExpectedAudience(tt.audiences...))
require.Nil(t, err)
require.NotNil(t, validator)

err = validator.ValidateResponse(auth)
assert.Nil(t, err)
auth, err := DecodeAuthResponse(testResponse)
require.Nil(t, err)

signed, err := validator.ValidateSignature(auth)
require.Nil(t, err)
require.NotNil(t, signed)

err = validator.ValidateResponse(auth)
if tt.shouldFail {
require.Error(t, err)
} else {
require.NoError(t, err)
}
}
}

func tamperedResponse(original string) (string, error) {
Expand Down Expand Up @@ -169,7 +185,7 @@ func TestVerifyValidGoogleResponse(t *testing.T) {
tm, err := time.Parse(time.RFC3339, "2017-07-18T14:47:08.035Z")
require.Nil(t, err)
clock := dsig.NewFakeClockAt(tm)
validator, err := NewValidator(testGoogleMetadata(), Clock(clock))
validator, err := NewValidator(testGoogleMetadata(), Clock(clock), WithExpectedAudience("kolide.edilok.net"))
require.Nil(t, err)
require.NotNil(t, validator)
auth, err := DecodeAuthResponse(samlResponse)
Expand Down

0 comments on commit 35d5a7b

Please sign in to comment.