Skip to content

Commit

Permalink
Validate SAMLIdPServiceProviders ACS endpoints
Browse files Browse the repository at this point in the history
Enforces that all ACS endpoints are HTTPS to prevent any
XSS attacks. To allow admins to interogate any existing resources
which may be impacted validation only happens on create and update
but not get. All usages of SAMLIdPServiceProviders within teleport
follow all internal retrievals with a call to
services.ValidateAssertionConsumerServicesEndpoint in order to
subvert invalid ACS endpoints.
  • Loading branch information
rosstimothy committed Sep 20, 2023
1 parent fd2171c commit 8894dd4
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 12 deletions.
30 changes: 19 additions & 11 deletions lib/services/local/saml_idp_service_provider.go
Expand Up @@ -72,12 +72,12 @@ func (s *SAMLIdPServiceProviderService) GetSAMLIdPServiceProvider(ctx context.Co

// CreateSAMLIdPServiceProvider creates a new SAML IdP service provider resource.
func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error {
item, err := s.svc.MakeBackendItem(sp, sp.GetName())
if err != nil {
if err := validateSAMLIdPServiceProvider(sp); err != nil {
return trace.Wrap(err)
}

if err := s.ensureEntityDescriptorMatchesEntityID(sp); err != nil {
item, err := s.svc.MakeBackendItem(sp, sp.GetName())
if err != nil {
return trace.Wrap(err)
}

Expand All @@ -87,7 +87,7 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context
return trace.Wrap(err)
}

_, err = backend.Create(ctx, item)
_, err := backend.Create(ctx, item)
if trace.IsAlreadyExists(err) {
return trace.AlreadyExists("%s %q already exists", types.KindSAMLIdPServiceProvider, sp.GetName())
}
Expand All @@ -97,12 +97,12 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context

// UpdateSAMLIdPServiceProvider updates an existing SAML IdP service provider resource.
func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error {
item, err := s.svc.MakeBackendItem(sp, sp.GetName())
if err != nil {
if err := validateSAMLIdPServiceProvider(sp); err != nil {
return trace.Wrap(err)
}

if err := s.ensureEntityDescriptorMatchesEntityID(sp); err != nil {
item, err := s.svc.MakeBackendItem(sp, sp.GetName())
if err != nil {
return trace.Wrap(err)
}

Expand All @@ -112,7 +112,7 @@ func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context
return trace.Wrap(err)
}

_, err = backend.Update(ctx, item)
_, err := backend.Update(ctx, item)
if trace.IsNotFound(err) {
return trace.NotFound("%s %q doesn't exist", types.KindSAMLIdPServiceProvider, sp.GetName())
}
Expand Down Expand Up @@ -159,9 +159,9 @@ func (s *SAMLIdPServiceProviderService) ensureEntityIDIsUnique(ctx context.Conte
return nil
}

// ensureEntityDescriptorMatchesEntityID ensures that the entity ID in the entity descriptor is the same as the entity ID
// in the SAMLIdPServiceProvider object.
func (s *SAMLIdPServiceProviderService) ensureEntityDescriptorMatchesEntityID(sp types.SAMLIdPServiceProvider) error {
// validateSAMLIdPServiceProvider ensures that the entity ID in the entity descriptor is the same as the entity ID
// in the [types.SAMLIdPServiceProvider] and that all AssertionConsumerServices defined are valid HTTPS endpoints.
func validateSAMLIdPServiceProvider(sp types.SAMLIdPServiceProvider) error {
ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor()))
if err != nil {
return trace.Wrap(err)
Expand All @@ -171,5 +171,13 @@ func (s *SAMLIdPServiceProviderService) ensureEntityDescriptorMatchesEntityID(sp
return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object")
}

for _, descriptor := range ed.SPSSODescriptors {
for _, acs := range descriptor.AssertionConsumerServices {
if err := services.ValidateAssertionConsumerServicesEndpoint(acs.Location); err != nil {
return trace.Wrap(err)
}
}
}

return nil
}
81 changes: 81 additions & 0 deletions lib/services/local/saml_idp_service_provider_test.go
Expand Up @@ -23,6 +23,7 @@ import (

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -64,6 +65,20 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) {
})
require.NoError(t, err)

// Try to create an invalid service provider with an invalid acs.
sp3, err := types.NewSAMLIdPServiceProvider(
types.Metadata{
Name: "sp3",
},
types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: newInvalidACSEntityDescriptor("sp1"),
EntityID: "sp1",
})
require.NoError(t, err)
err = service.CreateSAMLIdPServiceProvider(ctx, sp3)
require.Error(t, err)
require.True(t, trace.IsBadParameter(err))

// Initially we expect no service providers.
out, nextToken, err := service.ListSAMLIdPServiceProviders(ctx, 200, "")
require.NoError(t, err)
Expand Down Expand Up @@ -163,6 +178,14 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) {
err = service.UpdateSAMLIdPServiceProvider(ctx, sp)
require.Error(t, err)

// Update a service provider with an invalid acs.
sp, err = service.GetSAMLIdPServiceProvider(ctx, sp1.GetName())
require.NoError(t, err)
sp.SetEntityDescriptor(newInvalidACSEntityDescriptor(sp1.GetEntityID()))
err = service.UpdateSAMLIdPServiceProvider(ctx, sp)
require.Error(t, err)
require.True(t, trace.IsBadParameter(err))

// Delete a service provider.
err = service.DeleteSAMLIdPServiceProvider(ctx, sp1.GetName())
require.NoError(t, err)
Expand All @@ -186,6 +209,49 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) {
require.Empty(t, out)
}

func TestValidateSAMLIdPServiceProvider(t *testing.T) {
descriptor := newEntityDescriptor("IAMShowcase")

cases := []struct {
name string
spec types.SAMLIdPServiceProviderSpecV1
errAssertion require.ErrorAssertionFunc
}{
{
name: "valid provider",
spec: types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: descriptor,
EntityID: "IAMShowcase",
},
errAssertion: require.NoError,
},
{
name: "invalid entity id",
spec: types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: descriptor,
EntityID: uuid.NewString(),
},
errAssertion: require.Error,
},
{
name: "invalid acs",
spec: types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: newInvalidACSEntityDescriptor("IAMShowcase"),
EntityID: "IAMShowcase",
},
errAssertion: require.Error,
},
}

for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
sp, err := types.NewSAMLIdPServiceProvider(types.Metadata{Name: "sp"}, test.spec)
require.NoError(t, err)
test.errAssertion(t, validateSAMLIdPServiceProvider(sp))
})
}
}

func newEntityDescriptor(entityID string) string {
return fmt.Sprintf(testEntityDescriptor, entityID)
}
Expand All @@ -200,3 +266,18 @@ const testEntityDescriptor = `<?xml version="1.0" encoding="UTF-8"?>
</md:SPSSODescriptor>
</md:EntityDescriptor>
`

func newInvalidACSEntityDescriptor(entityID string) string {
return fmt.Sprintf(invalidEntityDescriptor, entityID)
}

// A test entity descriptor from https://sptest.iamshowcase.com/testsp_metadata.xml with invalid ACS locations.
const invalidEntityDescriptor = `<?xml version="1.0" encoding="UTF-8"?>
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" xmlns:ds="http://www.w3.org/2000/09/xmldsig#" entityID="%s" validUntil="2025-12-09T09:13:31.006Z">
<md:SPSSODescriptor AuthnRequestsSigned="false" WantAssertionsSigned="true" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</md:NameIDFormat>
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="javascript://sptest.iamshowcase.com/acs" index="0" isDefault="true"/>
</md:SPSSODescriptor>
</md:EntityDescriptor>
`
17 changes: 16 additions & 1 deletion lib/services/saml_idp_service_provider.go
Expand Up @@ -18,14 +18,15 @@ package services

import (
"context"
"net/url"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/utils"
)

// SAMLIdPServiceProvider defines an interface for managing SAML IdP service providers.
// SAMLIdPServiceProviders defines an interface for managing SAML IdP service providers.
type SAMLIdPServiceProviders interface {
// ListSAMLIdPServiceProviders returns a paginated list of all SAML IdP service provider resources.
ListSAMLIdPServiceProviders(context.Context, int, string) ([]types.SAMLIdPServiceProvider, string, error)
Expand Down Expand Up @@ -115,3 +116,17 @@ func GenerateIdPServiceProviderFromFields(name string, entityDescriptor string)
}
return &s, nil
}

// ValidateAssertionConsumerServicesEndpoint ensures that the Assertion Consumer Service location
// is a valid HTTPS endpoint.
func ValidateAssertionConsumerServicesEndpoint(acs string) error {
endpoint, err := url.Parse(acs)
switch {
case err != nil:
return trace.Wrap(err)
case endpoint.Scheme != "https":
return trace.BadParameter("the assertion consumer services location must be an https endpoint")
}

return nil
}
26 changes: 26 additions & 0 deletions lib/services/saml_idp_service_provider_test.go
Expand Up @@ -61,6 +61,32 @@ func TestSAMLIdPServiceProviderMarshal(t *testing.T) {
require.Equal(t, expected, actual)
}

func TestValidateAssertionConsumerServicesEndpoint(t *testing.T) {
cases := []struct {
location string
assertion require.ErrorAssertionFunc
}{
{
location: "https://sptest.iamshowcase.com/acs",
assertion: require.NoError,
},
{
location: "http://sptest.iamshowcase.com/acs",
assertion: require.Error,
},
{
location: "javascript://sptest.iamshowcase.com/acs",
assertion: require.Error,
},
}

for _, test := range cases {
t.Run(test.location, func(t *testing.T) {
test.assertion(t, ValidateAssertionConsumerServicesEndpoint(test.location))
})
}
}

var samlIDPServiceProviderYAML = `---
kind: saml_idp_service_provider
version: v1
Expand Down

0 comments on commit 8894dd4

Please sign in to comment.