From 8894dd417594fc9684f209d951a2de2160e5495a Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Thu, 7 Sep 2023 15:08:30 -0400 Subject: [PATCH] Validate SAMLIdPServiceProviders ACS endpoints Enforces that all ACS endpoints are HTTPS to prevent any XSS attacks. To allow admins to interogate any existing resources which may be impacted validation only happens on create and update but not get. All usages of SAMLIdPServiceProviders within teleport follow all internal retrievals with a call to services.ValidateAssertionConsumerServicesEndpoint in order to subvert invalid ACS endpoints. --- .../local/saml_idp_service_provider.go | 30 ++++--- .../local/saml_idp_service_provider_test.go | 81 +++++++++++++++++++ lib/services/saml_idp_service_provider.go | 17 +++- .../saml_idp_service_provider_test.go | 26 ++++++ 4 files changed, 142 insertions(+), 12 deletions(-) diff --git a/lib/services/local/saml_idp_service_provider.go b/lib/services/local/saml_idp_service_provider.go index cbe090ce7202d..c6045e0e8e2a7 100644 --- a/lib/services/local/saml_idp_service_provider.go +++ b/lib/services/local/saml_idp_service_provider.go @@ -72,12 +72,12 @@ func (s *SAMLIdPServiceProviderService) GetSAMLIdPServiceProvider(ctx context.Co // CreateSAMLIdPServiceProvider creates a new SAML IdP service provider resource. func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error { - item, err := s.svc.MakeBackendItem(sp, sp.GetName()) - if err != nil { + if err := validateSAMLIdPServiceProvider(sp); err != nil { return trace.Wrap(err) } - if err := s.ensureEntityDescriptorMatchesEntityID(sp); err != nil { + item, err := s.svc.MakeBackendItem(sp, sp.GetName()) + if err != nil { return trace.Wrap(err) } @@ -87,7 +87,7 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context return trace.Wrap(err) } - _, err = backend.Create(ctx, item) + _, err := backend.Create(ctx, item) if trace.IsAlreadyExists(err) { return trace.AlreadyExists("%s %q already exists", types.KindSAMLIdPServiceProvider, sp.GetName()) } @@ -97,12 +97,12 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context // UpdateSAMLIdPServiceProvider updates an existing SAML IdP service provider resource. func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error { - item, err := s.svc.MakeBackendItem(sp, sp.GetName()) - if err != nil { + if err := validateSAMLIdPServiceProvider(sp); err != nil { return trace.Wrap(err) } - if err := s.ensureEntityDescriptorMatchesEntityID(sp); err != nil { + item, err := s.svc.MakeBackendItem(sp, sp.GetName()) + if err != nil { return trace.Wrap(err) } @@ -112,7 +112,7 @@ func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context return trace.Wrap(err) } - _, err = backend.Update(ctx, item) + _, err := backend.Update(ctx, item) if trace.IsNotFound(err) { return trace.NotFound("%s %q doesn't exist", types.KindSAMLIdPServiceProvider, sp.GetName()) } @@ -159,9 +159,9 @@ func (s *SAMLIdPServiceProviderService) ensureEntityIDIsUnique(ctx context.Conte return nil } -// ensureEntityDescriptorMatchesEntityID ensures that the entity ID in the entity descriptor is the same as the entity ID -// in the SAMLIdPServiceProvider object. -func (s *SAMLIdPServiceProviderService) ensureEntityDescriptorMatchesEntityID(sp types.SAMLIdPServiceProvider) error { +// validateSAMLIdPServiceProvider ensures that the entity ID in the entity descriptor is the same as the entity ID +// in the [types.SAMLIdPServiceProvider] and that all AssertionConsumerServices defined are valid HTTPS endpoints. +func validateSAMLIdPServiceProvider(sp types.SAMLIdPServiceProvider) error { ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor())) if err != nil { return trace.Wrap(err) @@ -171,5 +171,13 @@ func (s *SAMLIdPServiceProviderService) ensureEntityDescriptorMatchesEntityID(sp return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object") } + for _, descriptor := range ed.SPSSODescriptors { + for _, acs := range descriptor.AssertionConsumerServices { + if err := services.ValidateAssertionConsumerServicesEndpoint(acs.Location); err != nil { + return trace.Wrap(err) + } + } + } + return nil } diff --git a/lib/services/local/saml_idp_service_provider_test.go b/lib/services/local/saml_idp_service_provider_test.go index 86a66cbdda2c9..4a90318e8bd77 100644 --- a/lib/services/local/saml_idp_service_provider_test.go +++ b/lib/services/local/saml_idp_service_provider_test.go @@ -23,6 +23,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" @@ -64,6 +65,20 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) { }) require.NoError(t, err) + // Try to create an invalid service provider with an invalid acs. + sp3, err := types.NewSAMLIdPServiceProvider( + types.Metadata{ + Name: "sp3", + }, + types.SAMLIdPServiceProviderSpecV1{ + EntityDescriptor: newInvalidACSEntityDescriptor("sp1"), + EntityID: "sp1", + }) + require.NoError(t, err) + err = service.CreateSAMLIdPServiceProvider(ctx, sp3) + require.Error(t, err) + require.True(t, trace.IsBadParameter(err)) + // Initially we expect no service providers. out, nextToken, err := service.ListSAMLIdPServiceProviders(ctx, 200, "") require.NoError(t, err) @@ -163,6 +178,14 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) { err = service.UpdateSAMLIdPServiceProvider(ctx, sp) require.Error(t, err) + // Update a service provider with an invalid acs. + sp, err = service.GetSAMLIdPServiceProvider(ctx, sp1.GetName()) + require.NoError(t, err) + sp.SetEntityDescriptor(newInvalidACSEntityDescriptor(sp1.GetEntityID())) + err = service.UpdateSAMLIdPServiceProvider(ctx, sp) + require.Error(t, err) + require.True(t, trace.IsBadParameter(err)) + // Delete a service provider. err = service.DeleteSAMLIdPServiceProvider(ctx, sp1.GetName()) require.NoError(t, err) @@ -186,6 +209,49 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) { require.Empty(t, out) } +func TestValidateSAMLIdPServiceProvider(t *testing.T) { + descriptor := newEntityDescriptor("IAMShowcase") + + cases := []struct { + name string + spec types.SAMLIdPServiceProviderSpecV1 + errAssertion require.ErrorAssertionFunc + }{ + { + name: "valid provider", + spec: types.SAMLIdPServiceProviderSpecV1{ + EntityDescriptor: descriptor, + EntityID: "IAMShowcase", + }, + errAssertion: require.NoError, + }, + { + name: "invalid entity id", + spec: types.SAMLIdPServiceProviderSpecV1{ + EntityDescriptor: descriptor, + EntityID: uuid.NewString(), + }, + errAssertion: require.Error, + }, + { + name: "invalid acs", + spec: types.SAMLIdPServiceProviderSpecV1{ + EntityDescriptor: newInvalidACSEntityDescriptor("IAMShowcase"), + EntityID: "IAMShowcase", + }, + errAssertion: require.Error, + }, + } + + for _, test := range cases { + t.Run(test.name, func(t *testing.T) { + sp, err := types.NewSAMLIdPServiceProvider(types.Metadata{Name: "sp"}, test.spec) + require.NoError(t, err) + test.errAssertion(t, validateSAMLIdPServiceProvider(sp)) + }) + } +} + func newEntityDescriptor(entityID string) string { return fmt.Sprintf(testEntityDescriptor, entityID) } @@ -200,3 +266,18 @@ const testEntityDescriptor = ` ` + +func newInvalidACSEntityDescriptor(entityID string) string { + return fmt.Sprintf(invalidEntityDescriptor, entityID) +} + +// A test entity descriptor from https://sptest.iamshowcase.com/testsp_metadata.xml with invalid ACS locations. +const invalidEntityDescriptor = ` + + + urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified + urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress + + + +` diff --git a/lib/services/saml_idp_service_provider.go b/lib/services/saml_idp_service_provider.go index 31f76648688b0..dcd750a5008f6 100644 --- a/lib/services/saml_idp_service_provider.go +++ b/lib/services/saml_idp_service_provider.go @@ -18,6 +18,7 @@ package services import ( "context" + "net/url" "github.com/gravitational/trace" @@ -25,7 +26,7 @@ import ( "github.com/gravitational/teleport/lib/utils" ) -// SAMLIdPServiceProvider defines an interface for managing SAML IdP service providers. +// SAMLIdPServiceProviders defines an interface for managing SAML IdP service providers. type SAMLIdPServiceProviders interface { // ListSAMLIdPServiceProviders returns a paginated list of all SAML IdP service provider resources. ListSAMLIdPServiceProviders(context.Context, int, string) ([]types.SAMLIdPServiceProvider, string, error) @@ -115,3 +116,17 @@ func GenerateIdPServiceProviderFromFields(name string, entityDescriptor string) } return &s, nil } + +// ValidateAssertionConsumerServicesEndpoint ensures that the Assertion Consumer Service location +// is a valid HTTPS endpoint. +func ValidateAssertionConsumerServicesEndpoint(acs string) error { + endpoint, err := url.Parse(acs) + switch { + case err != nil: + return trace.Wrap(err) + case endpoint.Scheme != "https": + return trace.BadParameter("the assertion consumer services location must be an https endpoint") + } + + return nil +} diff --git a/lib/services/saml_idp_service_provider_test.go b/lib/services/saml_idp_service_provider_test.go index 25f83c60d95ce..51987205e1c3d 100644 --- a/lib/services/saml_idp_service_provider_test.go +++ b/lib/services/saml_idp_service_provider_test.go @@ -61,6 +61,32 @@ func TestSAMLIdPServiceProviderMarshal(t *testing.T) { require.Equal(t, expected, actual) } +func TestValidateAssertionConsumerServicesEndpoint(t *testing.T) { + cases := []struct { + location string + assertion require.ErrorAssertionFunc + }{ + { + location: "https://sptest.iamshowcase.com/acs", + assertion: require.NoError, + }, + { + location: "http://sptest.iamshowcase.com/acs", + assertion: require.Error, + }, + { + location: "javascript://sptest.iamshowcase.com/acs", + assertion: require.Error, + }, + } + + for _, test := range cases { + t.Run(test.location, func(t *testing.T) { + test.assertion(t, ValidateAssertionConsumerServicesEndpoint(test.location)) + }) + } +} + var samlIDPServiceProviderYAML = `--- kind: saml_idp_service_provider version: v1