diff --git a/changelog/17640.txt b/changelog/17640.txt new file mode 100644 index 0000000000000..6db136a0638f2 --- /dev/null +++ b/changelog/17640.txt @@ -0,0 +1,3 @@ +```release-note:improvement +sdk/ldap: Added support for paging when searching for groups using group filters +``` \ No newline at end of file diff --git a/sdk/helper/ldaputil/client.go b/sdk/helper/ldaputil/client.go index f3946c8269e2c..8a7ac4822c34f 100644 --- a/sdk/helper/ldaputil/client.go +++ b/sdk/helper/ldaputil/client.go @@ -14,7 +14,6 @@ import ( "time" "github.com/go-ldap/ldap/v3" - "github.com/hashicorp/errwrap" hclog "github.com/hashicorp/go-hclog" multierror "github.com/hashicorp/go-multierror" "github.com/hashicorp/go-secure-stdlib/tlsutil" @@ -32,7 +31,7 @@ func (c *Client) DialLDAP(cfg *ConfigEntry) (Connection, error) { for _, uut := range urls { u, err := url.Parse(uut) if err != nil { - retErr = multierror.Append(retErr, errwrap.Wrapf(fmt.Sprintf("error parsing url %q: {{err}}", uut), err)) + retErr = multierror.Append(retErr, fmt.Errorf(fmt.Sprintf("error parsing url %q: {{err}}", uut), err)) continue } host, port, err := net.SplitHostPort(u.Host) @@ -83,7 +82,7 @@ func (c *Client) DialLDAP(cfg *ConfigEntry) (Connection, error) { retErr = nil break } - retErr = multierror.Append(retErr, errwrap.Wrapf(fmt.Sprintf("error connecting to host %q: {{err}}", uut), err)) + retErr = multierror.Append(retErr, fmt.Errorf(fmt.Sprintf("error connecting to host %q: {{err}}", uut), err)) } if retErr != nil { return nil, retErr @@ -158,7 +157,7 @@ func (c *Client) GetUserBindDN(cfg *ConfigEntry, conn Connection, username strin result, err := c.makeLdapSearchRequest(cfg, conn, username) if err != nil { - return bindDN, errwrap.Wrapf("LDAP search for binddn failed: {{err}}", err) + return bindDN, fmt.Errorf("LDAP search for binddn failed %w", err) } if len(result.Entries) != 1 { return bindDN, fmt.Errorf("LDAP search for binddn 0 or not unique") @@ -194,7 +193,7 @@ func (c *Client) RenderUserSearchFilter(cfg *ConfigEntry, username string) (stri // Example template "({{.UserAttr}}={{.Username}})" t, err := template.New("queryTemplate").Parse(cfg.UserFilter) if err != nil { - return "", errwrap.Wrapf("LDAP search failed due to template compilation error: {{err}}", err) + return "", fmt.Errorf("LDAP search failed due to template compilation error: %w", err) } // Build context to pass to template - we will be exposing UserDn and Username. @@ -212,7 +211,7 @@ func (c *Client) RenderUserSearchFilter(cfg *ConfigEntry, username string) (stri var renderedFilter bytes.Buffer if err := t.Execute(&renderedFilter, context); err != nil { - return "", errwrap.Wrapf("LDAP search failed due to template parsing error: {{err}}", err) + return "", fmt.Errorf("LDAP search failed due to template parsing error: %w", err) } return renderedFilter.String(), nil @@ -237,7 +236,7 @@ func (c *Client) GetUserAliasAttributeValue(cfg *ConfigEntry, conn Connection, u result, err := c.makeLdapSearchRequest(cfg, conn, username) if err != nil { - return aliasAttributeValue, errwrap.Wrapf("LDAP search for entity alias attribute failed: {{err}}", err) + return aliasAttributeValue, fmt.Errorf("LDAP search for entity alias attribute failed: %w", err) } if len(result.Entries) != 1 { return aliasAttributeValue, fmt.Errorf("LDAP search for entity alias attribute 0 or not unique") @@ -281,7 +280,7 @@ func (c *Client) GetUserDN(cfg *ConfigEntry, conn Connection, bindDN, username s SizeLimit: math.MaxInt32, }) if err != nil { - return userDN, errwrap.Wrapf("LDAP search failed for detecting user: {{err}}", err) + return userDN, fmt.Errorf("LDAP search failed for detecting user: %w", err) } for _, e := range result.Entries { userDN = e.DN @@ -314,7 +313,7 @@ func (c *Client) performLdapFilterGroupsSearch(cfg *ConfigEntry, conn Connection // Example template "(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))" t, err := template.New("queryTemplate").Parse(cfg.GroupFilter) if err != nil { - return nil, errwrap.Wrapf("LDAP search failed due to template compilation error: {{err}}", err) + return nil, fmt.Errorf("LDAP search failed due to template compilation error: %w", err) } // Build context to pass to template - we will be exposing UserDn and Username. @@ -328,7 +327,7 @@ func (c *Client) performLdapFilterGroupsSearch(cfg *ConfigEntry, conn Connection var renderedQuery bytes.Buffer if err := t.Execute(&renderedQuery, context); err != nil { - return nil, errwrap.Wrapf("LDAP search failed due to template parsing error: {{err}}", err) + return nil, fmt.Errorf("LDAP search failed due to template parsing error: %w", err) } if c.Logger.IsDebug() { @@ -345,7 +344,65 @@ func (c *Client) performLdapFilterGroupsSearch(cfg *ConfigEntry, conn Connection SizeLimit: math.MaxInt32, }) if err != nil { - return nil, errwrap.Wrapf("LDAP search failed: {{err}}", err) + return nil, fmt.Errorf("LDAP search failed: %w", err) + } + + return result.Entries, nil +} + +func (c *Client) performLdapFilterGroupsSearchPaging(cfg *ConfigEntry, conn PagingConnection, userDN string, username string) ([]*ldap.Entry, error) { + if cfg.GroupFilter == "" { + c.Logger.Warn("groupfilter is empty, will not query server") + return make([]*ldap.Entry, 0), nil + } + + if cfg.GroupDN == "" { + c.Logger.Warn("groupdn is empty, will not query server") + return make([]*ldap.Entry, 0), nil + } + + // If groupfilter was defined, resolve it as a Go template and use the query for + // returning the user's groups + if c.Logger.IsDebug() { + c.Logger.Debug("compiling group filter", "group_filter", cfg.GroupFilter) + } + + // Parse the configuration as a template. + // Example template "(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))" + t, err := template.New("queryTemplate").Parse(cfg.GroupFilter) + if err != nil { + return nil, fmt.Errorf("LDAP search failed due to template compilation error: %w", err) + } + + // Build context to pass to template - we will be exposing UserDn and Username. + context := struct { + UserDN string + Username string + }{ + ldap.EscapeFilter(userDN), + ldap.EscapeFilter(username), + } + + var renderedQuery bytes.Buffer + if err := t.Execute(&renderedQuery, context); err != nil { + return nil, fmt.Errorf("LDAP search failed due to template parsing error: %w", err) + } + + if c.Logger.IsDebug() { + c.Logger.Debug("searching", "groupdn", cfg.GroupDN, "rendered_query", renderedQuery.String()) + } + + result, err := conn.SearchWithPaging(&ldap.SearchRequest{ + BaseDN: cfg.GroupDN, + Scope: ldap.ScopeWholeSubtree, + Filter: renderedQuery.String(), + Attributes: []string{ + cfg.GroupAttr, + }, + SizeLimit: math.MaxInt32, + }, math.MaxInt32) + if err != nil { + return nil, fmt.Errorf("LDAP search failed: %w", err) } return result.Entries, nil @@ -358,21 +415,21 @@ func sidBytesToString(b []byte) (string, error) { var identifierAuthorityParts [3]uint16 if err := binary.Read(reader, binary.LittleEndian, &revision); err != nil { - return "", errwrap.Wrapf(fmt.Sprintf("SID %#v convert failed reading Revision: {{err}}", b), err) + return "", fmt.Errorf(fmt.Sprintf("SID %#v convert failed reading Revision: {{err}}", b), err) } if err := binary.Read(reader, binary.LittleEndian, &subAuthorityCount); err != nil { - return "", errwrap.Wrapf(fmt.Sprintf("SID %#v convert failed reading SubAuthorityCount: {{err}}", b), err) + return "", fmt.Errorf(fmt.Sprintf("SID %#v convert failed reading SubAuthorityCount: {{err}}", b), err) } if err := binary.Read(reader, binary.BigEndian, &identifierAuthorityParts); err != nil { - return "", errwrap.Wrapf(fmt.Sprintf("SID %#v convert failed reading IdentifierAuthority: {{err}}", b), err) + return "", fmt.Errorf(fmt.Sprintf("SID %#v convert failed reading IdentifierAuthority: {{err}}", b), err) } identifierAuthority := (uint64(identifierAuthorityParts[0]) << 32) + (uint64(identifierAuthorityParts[1]) << 16) + uint64(identifierAuthorityParts[2]) subAuthority := make([]uint32, subAuthorityCount) if err := binary.Read(reader, binary.LittleEndian, &subAuthority); err != nil { - return "", errwrap.Wrapf(fmt.Sprintf("SID %#v convert failed reading SubAuthority: {{err}}", b), err) + return "", fmt.Errorf(fmt.Sprintf("SID %#v convert failed reading SubAuthority: {{err}}", b), err) } result := fmt.Sprintf("S-%d-%d", revision, identifierAuthority) @@ -394,7 +451,7 @@ func (c *Client) performLdapTokenGroupsSearch(cfg *ConfigEntry, conn Connection, SizeLimit: 1, }) if err != nil { - return nil, errwrap.Wrapf("LDAP search failed: {{err}}", err) + return nil, fmt.Errorf("LDAP search failed: %w", err) } if len(result.Entries) == 0 { c.Logger.Warn("unable to read object for group attributes", "userdn", userDN, "groupattr", cfg.GroupAttr) @@ -462,7 +519,11 @@ func (c *Client) GetLdapGroups(cfg *ConfigEntry, conn Connection, userDN string, if cfg.UseTokenGroups { entries, err = c.performLdapTokenGroupsSearch(cfg, conn, userDN) } else { - entries, err = c.performLdapFilterGroupsSearch(cfg, conn, userDN, username) + if paging, ok := conn.(PagingConnection); ok { + entries, err = c.performLdapFilterGroupsSearchPaging(cfg, paging, userDN, username) + } else { + entries, err = c.performLdapFilterGroupsSearch(cfg, conn, userDN, username) + } } if err != nil { return nil, err @@ -603,7 +664,7 @@ func getTLSConfig(cfg *ConfigEntry, host string) (*tls.Config, error) { if cfg.ClientTLSCert != "" && cfg.ClientTLSKey != "" { certificate, err := tls.X509KeyPair([]byte(cfg.ClientTLSCert), []byte(cfg.ClientTLSKey)) if err != nil { - return nil, errwrap.Wrapf("failed to parse client X509 key pair: {{err}}", err) + return nil, fmt.Errorf("failed to parse client X509 key pair: %w", err) } tlsConfig.Certificates = append(tlsConfig.Certificates, certificate) } else if cfg.ClientTLSCert != "" || cfg.ClientTLSKey != "" { diff --git a/sdk/helper/ldaputil/connection.go b/sdk/helper/ldaputil/connection.go index ba984e052eaa0..71c83f2f9b3a9 100644 --- a/sdk/helper/ldaputil/connection.go +++ b/sdk/helper/ldaputil/connection.go @@ -20,3 +20,8 @@ type Connection interface { SetTimeout(timeout time.Duration) UnauthenticatedBind(username string) error } + +type PagingConnection interface { + Connection + SearchWithPaging(searchRequest *ldap.SearchRequest, pagingSize uint32) (*ldap.SearchResult, error) +}