Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAML IdP attribute mapping types and config handler #35584

Merged
merged 11 commits into from
Dec 14, 2023
14 changes: 14 additions & 0 deletions api/proto/teleport/legacy/types/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -5764,6 +5764,20 @@ message SAMLIdPServiceProviderSpecV1 {
string EntityID = 2 [(gogoproto.jsontag) = "entity_id"];
// ACSURL is the endpoint where SAML authentication response will be redirected.
string ACSURL = 3 [(gogoproto.jsontag) = "acs_url"];
// AttributeMapping is used to map Service Provider requested attributes to
// username, role and traits in Teleport.
repeated SAMLAttributeMapping AttributeMapping = 4 [(gogoproto.jsontag) = "attribute_mapping"];
}

// SAMLAttributeMapping represents SAML Service Provider requested attribute
// name, format and its values.
message SAMLAttributeMapping {
// name is an attribute name.
string name = 1 [(gogoproto.jsontag) = "name"];
// name_format is an attribute name format.
string name_format = 2 [(gogoproto.jsontag) = "name_format"];
// value is an attribute value definable with predicate expression.
string value = 3 [(gogoproto.jsontag) = "value"];
}

// IdPOptions specify options related to access Teleport IdPs.
Expand Down
61 changes: 57 additions & 4 deletions api/types/saml_idp_service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,20 @@ import (
"github.com/gravitational/teleport/api/utils"
)

const (
unspecifiedNameFormat = "urn:oasis:names:tc:SAML:2.0:attrname-format:unspecified"
flyinghermit marked this conversation as resolved.
Show resolved Hide resolved
uriNameFormat = "urn:oasis:names:tc:SAML:2.0:attrname-format:uri"
basicNameFormat = "urn:oasis:names:tc:SAML:2.0:attrname-format:basic"
)

var (
// ErrMissingEntityDescriptorAndEntityID is returned when both entity descriptor and entity ID is empty.
ErrEmptyEntityDescriptorAndEntityID = trace.BadParameter("either entity_descriptor or entity_id must be provided")
ErrEmptyEntityDescriptorAndEntityID = &trace.BadParameterError{Message: "either entity_descriptor or entity_id must be provided"}
flyinghermit marked this conversation as resolved.
Show resolved Hide resolved
// ErrMissingEntityDescriptorAndACSURL is returned when both entity descriptor and ACS URL is empty.
ErrEmptyEntityDescriptorAndACSURL = trace.BadParameter("either entity_descriptor or acs_url must be provided")
ErrEmptyEntityDescriptorAndACSURL = &trace.BadParameterError{Message: "either entity_descriptor or acs_url must be provided"}
// ErrDuplicateAttributeName is returned when attribute mapping declares two or more
// attributes with the same name.
ErrDuplicateAttributeName = &trace.BadParameterError{Message: "duplicate attribute name not allowed"}
)

// SAMLIdPServiceProvider specifies configuration for service providers for Teleport's built in SAML IdP.
Expand All @@ -51,6 +60,10 @@ type SAMLIdPServiceProvider interface {
GetACSURL() string
// SetACSURL sets the ACS URL.
SetACSURL(string)
// GetAttributeMapping returns Attribute Mapping.
GetAttributeMapping() []*SAMLAttributeMapping
// SetAttributeMapping sets Attribute Mapping.
SetAttributeMapping([]*SAMLAttributeMapping)
// Copy returns a copy of this saml idp service provider object.
Copy() SAMLIdPServiceProvider
// CloneResource returns a copy of the SAMLIdPServiceProvider as a ResourceWithLabels
Expand Down Expand Up @@ -103,6 +116,16 @@ func (s *SAMLIdPServiceProviderV1) SetACSURL(acsURL string) {
s.Spec.ACSURL = acsURL
}

// GetAttributeMapping returns the Attribute Mapping.
func (s *SAMLIdPServiceProviderV1) GetAttributeMapping() []*SAMLAttributeMapping {
return s.Spec.AttributeMapping
}

// SetAttributeMapping sets Attribute Mapping.
func (s *SAMLIdPServiceProviderV1) SetAttributeMapping(attrMaps []*SAMLAttributeMapping) {
s.Spec.AttributeMapping = attrMaps
}

// String returns the SAML IdP service provider string representation.
func (s *SAMLIdPServiceProviderV1) String() string {
return fmt.Sprintf("SAMLIdPServiceProviderV1(Name=%v)",
Expand Down Expand Up @@ -139,11 +162,11 @@ func (s *SAMLIdPServiceProviderV1) CheckAndSetDefaults() error {

if s.Spec.EntityDescriptor == "" {
if s.Spec.EntityID == "" {
return ErrEmptyEntityDescriptorAndEntityID
return trace.Wrap(ErrEmptyEntityDescriptorAndEntityID)
}

if s.Spec.ACSURL == "" {
return ErrEmptyEntityDescriptorAndACSURL
return trace.Wrap(ErrEmptyEntityDescriptorAndACSURL)
}

}
Expand All @@ -161,6 +184,18 @@ func (s *SAMLIdPServiceProviderV1) CheckAndSetDefaults() error {
s.Spec.EntityID = ed.EntityID
}

attrNames := make(map[string]struct{})
for _, attributeMap := range s.GetAttributeMapping() {
if err := attributeMap.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}
// check for duplicate attribute names
if _, ok := attrNames[attributeMap.Name]; ok {
return trace.Wrap(ErrDuplicateAttributeName)
}
attrNames[attributeMap.Name] = struct{}{}
}

return nil
}

Expand All @@ -184,3 +219,21 @@ func (s SAMLIdPServiceProviders) Less(i, j int) bool { return s[i].GetName() < s

// Swap swaps two service providers.
func (s SAMLIdPServiceProviders) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

// CheckAndSetDefaults check and sets SAMLAttributeMapping default values
func (am *SAMLAttributeMapping) CheckAndSetDefaults() error {
// verify name format is one of the supported
// formats - unspecifiedNameFormat, basicNameFormat or uriNameFormat
// and assign it with the URN value of that format.
switch am.NameFormat {
case "", "unspecified", unspecifiedNameFormat:
am.NameFormat = unspecifiedNameFormat
case "basic", basicNameFormat:
am.NameFormat = basicNameFormat
case "uri", uriNameFormat:
am.NameFormat = uriNameFormat
default:
return trace.BadParameter("invalid name format: %s", am.NameFormat)
}
return nil
}
63 changes: 63 additions & 0 deletions api/types/saml_idp_service_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func TestNewSAMLIdPServiceProvider(t *testing.T) {
acsURL string
errAssertion require.ErrorAssertionFunc
expectedEntityID string
attributeMapping []*SAMLAttributeMapping
}{
{
name: "valid entity descriptor",
Expand Down Expand Up @@ -82,6 +83,64 @@ func TestNewSAMLIdPServiceProvider(t *testing.T) {
errAssertion: require.NoError,
expectedEntityID: "IAMShowcase",
},
{
name: "duplicate attribute mapping",
entityDescriptor: testEntityDescriptor,
attributeMapping: []*SAMLAttributeMapping{
{
Name: "username",
Value: "user.traits.name",
},
{
Name: "user1",
Value: "user.traits.firstname",
},
{
Name: "username",
Value: "user.traits.givenname",
},
},
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorIs(t, err, ErrDuplicateAttributeName)
},
},
{
name: "valid attribute mapping",
entityDescriptor: testEntityDescriptor,
entityID: "IAMShowcase",
expectedEntityID: "IAMShowcase",
attributeMapping: []*SAMLAttributeMapping{
{
Name: "username",
Value: "user.traits.name",
},
{
Name: "user1",
Value: "user.traits.givenname",
},
},
errAssertion: require.NoError,
},
{
name: "invalid attribute mapping name format",
entityDescriptor: testEntityDescriptor,
entityID: "IAMShowcase",
expectedEntityID: "IAMShowcase",
attributeMapping: []*SAMLAttributeMapping{
{
Name: "username",
Value: "user.traits.name",
NameFormat: "emailAddress",
},
{
Name: "user1",
Value: "user.traits.givenname",
},
},
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "invalid name format")
},
},
}

for _, test := range tests {
Expand All @@ -92,11 +151,15 @@ func TestNewSAMLIdPServiceProvider(t *testing.T) {
EntityDescriptor: test.entityDescriptor,
EntityID: test.entityID,
ACSURL: test.acsURL,
AttributeMapping: test.attributeMapping,
})

test.errAssertion(t, err)
if sp != nil {
require.Equal(t, test.expectedEntityID, sp.GetEntityID())
if len(sp.GetAttributeMapping()) > 0 {
require.Equal(t, test.attributeMapping, sp.GetAttributeMapping())
}
}
})
}
Expand Down