Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for paged requests #74

Merged
merged 3 commits into from
Oct 2, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 85 additions & 7 deletions nss_cache/sources/ldapsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import ldap.sasl
import urllib
import re
from distutils.version import StrictVersion

from nss_cache import error
from nss_cache.maps import automount
Expand All @@ -36,9 +37,35 @@
from nss_cache.maps import sshkey
from nss_cache.sources import source

IS_LDAP24_OR_NEWER = StrictVersion(ldap.__version__) >= StrictVersion('2.4')

# ldap.LDAP_CONTROL_PAGE_OID is unavailable on some systems, so we define it here
LDAP_CONTROL_PAGE_OID = '1.2.840.113556.1.4.319'

def RegisterImplementation(registration_callback):
registration_callback(LdapSource)

def makeSimplePagedResultsControl(page_size):
# The API for this is different on older versions of python-ldap, so we need
# to handle this case.
if IS_LDAP24_OR_NEWER:
return ldap.controls.SimplePagedResultsControl(True, size=page_size, cookie='')
else:
return ldap.controls.SimplePagedResultsControl(LDAP_CONTROL_PAGE_OID, True, (page_size, ''))

def getCookieFromControl(pctrl):
if IS_LDAP24_OR_NEWER:
return pctrl.cookie
else:
return pctrl.controlValue[1]

def setCookieOnControl(control, cookie, page_size):
if IS_LDAP24_OR_NEWER:
control.cookie = cookie
else:
control.controlValue = (page_size, cookie)

return cookie

class LdapSource(source.Source):
"""Source for data in LDAP.
Expand All @@ -65,6 +92,10 @@ class LdapSource(source.Source):
# for registration
name = 'ldap'

# Page size for paged LDAP requests
# Value chosen based on default Active Directory MaxPageSize
PAGE_SIZE = 1000

def __init__(self, conf, conn=None):
"""Initialise the LDAP Data Source.

Expand All @@ -77,6 +108,10 @@ def __init__(self, conf, conn=None):

self._SetDefaults(conf)
self._conf = conf
self.ldap_controls = makeSimplePagedResultsControl(self.PAGE_SIZE)

# Used by _ReSearch:
self._last_search_params = None

if conn is None:
# ReconnectLDAPObject should handle interrupted ldap transactions.
Expand Down Expand Up @@ -156,6 +191,9 @@ def _SetDefaults(self, configuration):
ldap.set_option(ldap.OPT_X_TLS_CACERTFILE, configuration['tls_cacertfile'])
ldap.version = ldap.VERSION3 # this is hard-coded, we only support V3

def _SetCookie(self, cookie):
return setCookieOnControl(self.ldap_controls, cookie, self.PAGE_SIZE)

def Bind(self, configuration):
"""Bind to LDAP, retrying if necessary."""
# If the server is unavailable, we are going to find out now, as this
Expand Down Expand Up @@ -189,6 +227,15 @@ def Bind(self, configuration):
self.log.debug('sleeping %d seconds', configuration['retry_delay'])
time.sleep(configuration['retry_delay'])

def _ReSearch(self):
"""
Performs self.Search again with the previously used parameters.

Returns:
self.Search result.
"""
self.Search(*self._last_search_params)

def Search(self, search_base, search_filter, search_scope, attrs):
"""Search the data source.

Expand All @@ -204,12 +251,16 @@ def Search(self, search_base, search_filter, search_scope, attrs):
Returns:
nothing.
"""
self._last_search_params = (search_base, search_filter, search_scope, attrs)

self.log.debug('searching for base=%r, filter=%r, scope=%r, attrs=%r',
search_base, search_filter, search_scope, attrs)
if 'dn' in attrs: self._dn_requested = True # special cased attribute
self.message_id = self.conn.search(base=search_base,
filterstr=search_filter,
scope=search_scope, attrlist=attrs)
self.message_id = self.conn.search_ext(base=search_base,
filterstr=search_filter,
scope=search_scope,
attrlist=attrs,
serverctrls=[self.ldap_controls])

def __iter__(self):
"""Iterate over the data from the last search.
Expand All @@ -219,15 +270,42 @@ def __iter__(self):
Yields:
Search results from the prior call to self.Search()
"""
# Acquire data to yield:
while True:
result_type, data = None, None

timeout_retries = 0
while timeout_retries < self._conf['retry_max']:
try:
result_type, data = self.conn.result(self.message_id, all=0,
timeout=self.conf['timelimit'])
result_type, data, _, serverctrls = self.conn.result3(
self.message_id, all=0, timeout=self.conf['timelimit'])

# Paged requests return a new cookie in serverctrls at the end of a page,
# so we search for the cookie and perform another search if needed.
if len(serverctrls) > 0:
# Search for appropriate control
simple_paged_results_controls = [
control
for control in serverctrls
if control.controlType == LDAP_CONTROL_PAGE_OID
]
if simple_paged_results_controls:
# We only expect one control; just take the first in the list.
cookie = getCookieFromControl(simple_paged_results_controls[0])

if len(cookie) > 0:
# If cookie is non-empty, call search_ext and result3 again
self._SetCookie(cookie)
self._ReSearch()
result_type, data, _, serverctrls = self.conn.result3(
self.message_id, all=0, timeout=self.conf['timelimit'])
# else: An empty cookie means we are done.

# break loop once result3 doesn't time out
break
except ldap.SIZELIMIT_EXCEEDED:
self.log.warning('LDAP server size limit exceeded; using page size {0}.'.format(self.PAGE_SIZE))
return
except ldap.NO_SUCH_OBJECT:
self.log.debug('Returning due to ldap.NO_SUCH_OBJECT')
return
Expand Down Expand Up @@ -530,7 +608,7 @@ def GetUpdates(self, source, search_base, search_filter,
data_map.SetModifyTimestamp(max_ts)

return data_map

def PostProcess(self, data_map, source, search_filter, search_scope):
"""Perform some post-process of the data."""
pass
Expand Down Expand Up @@ -645,7 +723,7 @@ def Transform(self, obj):
gr.members = members

return gr

def PostProcess(self, data_map, source, search_filter, search_scope):
"""Perform some post-process of the data."""
if 'uniqueMember' in self.attrs:
Expand Down
Loading