Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v13] Validate SAMLIdPServiceProviders ACS endpoints #32220

Merged
merged 1 commit into from Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 19 additions & 11 deletions lib/services/local/saml_idp_service_provider.go
Expand Up @@ -72,12 +72,12 @@ func (s *SAMLIdPServiceProviderService) GetSAMLIdPServiceProvider(ctx context.Co

// CreateSAMLIdPServiceProvider creates a new SAML IdP service provider resource.
func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error {
item, err := s.svc.MakeBackendItem(sp, sp.GetName())
if err != nil {
if err := validateSAMLIdPServiceProvider(sp); err != nil {
return trace.Wrap(err)
}

if err := s.ensureEntityDescriptorMatchesEntityID(sp); err != nil {
item, err := s.svc.MakeBackendItem(sp, sp.GetName())
if err != nil {
return trace.Wrap(err)
}

Expand All @@ -87,7 +87,7 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context
return trace.Wrap(err)
}

_, err = backend.Create(ctx, item)
_, err := backend.Create(ctx, item)
if trace.IsAlreadyExists(err) {
return trace.AlreadyExists("%s %q already exists", types.KindSAMLIdPServiceProvider, sp.GetName())
}
Expand All @@ -97,12 +97,12 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context

// UpdateSAMLIdPServiceProvider updates an existing SAML IdP service provider resource.
func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error {
item, err := s.svc.MakeBackendItem(sp, sp.GetName())
if err != nil {
if err := validateSAMLIdPServiceProvider(sp); err != nil {
return trace.Wrap(err)
}

if err := s.ensureEntityDescriptorMatchesEntityID(sp); err != nil {
item, err := s.svc.MakeBackendItem(sp, sp.GetName())
if err != nil {
return trace.Wrap(err)
}

Expand All @@ -112,7 +112,7 @@ func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context
return trace.Wrap(err)
}

_, err = backend.Update(ctx, item)
_, err := backend.Update(ctx, item)
if trace.IsNotFound(err) {
return trace.NotFound("%s %q doesn't exist", types.KindSAMLIdPServiceProvider, sp.GetName())
}
Expand Down Expand Up @@ -159,9 +159,9 @@ func (s *SAMLIdPServiceProviderService) ensureEntityIDIsUnique(ctx context.Conte
return nil
}

// ensureEntityDescriptorMatchesEntityID ensures that the entity ID in the entity descriptor is the same as the entity ID
// in the SAMLIdPServiceProvider object.
func (s *SAMLIdPServiceProviderService) ensureEntityDescriptorMatchesEntityID(sp types.SAMLIdPServiceProvider) error {
// validateSAMLIdPServiceProvider ensures that the entity ID in the entity descriptor is the same as the entity ID
// in the [types.SAMLIdPServiceProvider] and that all AssertionConsumerServices defined are valid HTTPS endpoints.
func validateSAMLIdPServiceProvider(sp types.SAMLIdPServiceProvider) error {
ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor()))
if err != nil {
return trace.Wrap(err)
Expand All @@ -171,5 +171,13 @@ func (s *SAMLIdPServiceProviderService) ensureEntityDescriptorMatchesEntityID(sp
return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object")
}

for _, descriptor := range ed.SPSSODescriptors {
for _, acs := range descriptor.AssertionConsumerServices {
if err := services.ValidateAssertionConsumerServicesEndpoint(acs.Location); err != nil {
return trace.Wrap(err)
}
}
}

return nil
}
81 changes: 81 additions & 0 deletions lib/services/local/saml_idp_service_provider_test.go
Expand Up @@ -23,6 +23,7 @@ import (

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -64,6 +65,20 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) {
})
require.NoError(t, err)

// Try to create an invalid service provider with an invalid acs.
sp3, err := types.NewSAMLIdPServiceProvider(
types.Metadata{
Name: "sp3",
},
types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: newInvalidACSEntityDescriptor("sp1"),
EntityID: "sp1",
})
require.NoError(t, err)
err = service.CreateSAMLIdPServiceProvider(ctx, sp3)
require.Error(t, err)
require.True(t, trace.IsBadParameter(err))

// Initially we expect no service providers.
out, nextToken, err := service.ListSAMLIdPServiceProviders(ctx, 200, "")
require.NoError(t, err)
Expand Down Expand Up @@ -163,6 +178,14 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) {
err = service.UpdateSAMLIdPServiceProvider(ctx, sp)
require.Error(t, err)

// Update a service provider with an invalid acs.
sp, err = service.GetSAMLIdPServiceProvider(ctx, sp1.GetName())
require.NoError(t, err)
sp.SetEntityDescriptor(newInvalidACSEntityDescriptor(sp1.GetEntityID()))
err = service.UpdateSAMLIdPServiceProvider(ctx, sp)
require.Error(t, err)
require.True(t, trace.IsBadParameter(err))

// Delete a service provider.
err = service.DeleteSAMLIdPServiceProvider(ctx, sp1.GetName())
require.NoError(t, err)
Expand All @@ -186,6 +209,49 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) {
require.Empty(t, out)
}

func TestValidateSAMLIdPServiceProvider(t *testing.T) {
descriptor := newEntityDescriptor("IAMShowcase")

cases := []struct {
name string
spec types.SAMLIdPServiceProviderSpecV1
errAssertion require.ErrorAssertionFunc
}{
{
name: "valid provider",
spec: types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: descriptor,
EntityID: "IAMShowcase",
},
errAssertion: require.NoError,
},
{
name: "invalid entity id",
spec: types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: descriptor,
EntityID: uuid.NewString(),
},
errAssertion: require.Error,
},
{
name: "invalid acs",
spec: types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: newInvalidACSEntityDescriptor("IAMShowcase"),
EntityID: "IAMShowcase",
},
errAssertion: require.Error,
},
}

for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
sp, err := types.NewSAMLIdPServiceProvider(types.Metadata{Name: "sp"}, test.spec)
require.NoError(t, err)
test.errAssertion(t, validateSAMLIdPServiceProvider(sp))
})
}
}

func newEntityDescriptor(entityID string) string {
return fmt.Sprintf(testEntityDescriptor, entityID)
}
Expand All @@ -200,3 +266,18 @@ const testEntityDescriptor = `<?xml version="1.0" encoding="UTF-8"?>
</md:SPSSODescriptor>
</md:EntityDescriptor>
`

func newInvalidACSEntityDescriptor(entityID string) string {
return fmt.Sprintf(invalidEntityDescriptor, entityID)
}

// A test entity descriptor from https://sptest.iamshowcase.com/testsp_metadata.xml with invalid ACS locations.
const invalidEntityDescriptor = `<?xml version="1.0" encoding="UTF-8"?>
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" xmlns:ds="http://www.w3.org/2000/09/xmldsig#" entityID="%s" validUntil="2025-12-09T09:13:31.006Z">
<md:SPSSODescriptor AuthnRequestsSigned="false" WantAssertionsSigned="true" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</md:NameIDFormat>
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="javascript://sptest.iamshowcase.com/acs" index="0" isDefault="true"/>
</md:SPSSODescriptor>
</md:EntityDescriptor>
`
17 changes: 16 additions & 1 deletion lib/services/saml_idp_service_provider.go
Expand Up @@ -18,14 +18,15 @@ package services

import (
"context"
"net/url"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/utils"
)

// SAMLIdPServiceProvider defines an interface for managing SAML IdP service providers.
// SAMLIdPServiceProviders defines an interface for managing SAML IdP service providers.
type SAMLIdPServiceProviders interface {
// ListSAMLIdPServiceProviders returns a paginated list of all SAML IdP service provider resources.
ListSAMLIdPServiceProviders(context.Context, int, string) ([]types.SAMLIdPServiceProvider, string, error)
Expand Down Expand Up @@ -115,3 +116,17 @@ func GenerateIdPServiceProviderFromFields(name string, entityDescriptor string)
}
return &s, nil
}

// ValidateAssertionConsumerServicesEndpoint ensures that the Assertion Consumer Service location
// is a valid HTTPS endpoint.
func ValidateAssertionConsumerServicesEndpoint(acs string) error {
endpoint, err := url.Parse(acs)
switch {
case err != nil:
return trace.Wrap(err)
case endpoint.Scheme != "https":
return trace.BadParameter("the assertion consumer services location must be an https endpoint")
}

return nil
}
26 changes: 26 additions & 0 deletions lib/services/saml_idp_service_provider_test.go
Expand Up @@ -61,6 +61,32 @@ func TestSAMLIdPServiceProviderMarshal(t *testing.T) {
require.Equal(t, expected, actual)
}

func TestValidateAssertionConsumerServicesEndpoint(t *testing.T) {
cases := []struct {
location string
assertion require.ErrorAssertionFunc
}{
{
location: "https://sptest.iamshowcase.com/acs",
assertion: require.NoError,
},
{
location: "http://sptest.iamshowcase.com/acs",
assertion: require.Error,
},
{
location: "javascript://sptest.iamshowcase.com/acs",
assertion: require.Error,
},
}

for _, test := range cases {
t.Run(test.location, func(t *testing.T) {
test.assertion(t, ValidateAssertionConsumerServicesEndpoint(test.location))
})
}
}

var samlIDPServiceProviderYAML = `---
kind: saml_idp_service_provider
version: v1
Expand Down