Skip to content

Commit

Permalink
tests adapted to threadsafe strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
cannatag committed Jul 18, 2020
1 parent 1d09d90 commit ba8933e
Show file tree
Hide file tree
Showing 17 changed files with 155 additions and 287 deletions.
3 changes: 2 additions & 1 deletion _changelog.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# 2.7.1 - not yet released
# 2.8 - not yet released
- new feature: SafeSync strategy (SAFE_SYNC) for using a synchronous Connection object in a multi-threading program
- fixed requirements for pyasn1
- fixed issue with lazy connection requesting server info on every operation
- fixed searching by objectGUID in hex format (thanks Matt)
Expand Down
4 changes: 2 additions & 2 deletions ldap3/abstract/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,11 +824,11 @@ def _refresh_object(self, entry_dn, attributes=None, tries=4, seconds=2, control
response, result, request = self.connection.get_response(result, get_request=True)
else:
if self.connection.strategy.thread_safe:
_, result, response, _ = result
_, result, response, request = result
else:
response = self.connection.response
result = self.connection.result
request = self.connection.request
request = self.connection.request

if result['result'] in [RESULT_SUCCESS]:
break
Expand Down
36 changes: 24 additions & 12 deletions ldap3/abstract/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,12 @@ def entry_commit_changes(self, refresh=True, controls=None, clear_history=True):
if not self.entry_cursor.connection.strategy.sync:
response, result, request = self.entry_cursor.connection.get_response(result, get_request=True)
else:
response = self.entry_cursor.connection.response
result = self.entry_cursor.connection.result
request = self.entry_cursor.connection.request
if self.entry_cursor.connection.strategy.thread_safe:
_, result, response, request = result
else:
response = self.entry_cursor.connection.response
result = self.entry_cursor.connection.result
request = self.entry_cursor.connection.request
self.entry_cursor._store_operation_in_history(request, result, response)
if result['result'] == RESULT_SUCCESS:
dn = self.entry_dn
Expand All @@ -557,9 +560,12 @@ def entry_commit_changes(self, refresh=True, controls=None, clear_history=True):
if not self.entry_cursor.connection.strategy.sync:
response, result, request = self.entry_cursor.connection.get_response(result, get_request=True)
else:
response = self.entry_cursor.connection.response
result = self.entry_cursor.connection.result
request = self.entry_cursor.connection.request
if self.entry_cursor.connection.strategy.thread_safe:
_, result, response, request = result
else:
response = self.entry_cursor.connection.response
result = self.entry_cursor.connection.result
request = self.entry_cursor.connection.request
self.entry_cursor._store_operation_in_history(request, result, response)
if result['result'] == RESULT_SUCCESS:
self._state.dn = safe_dn('+'.join(safe_rdn(self.entry_dn)) + ',' + self._state._to)
Expand All @@ -577,9 +583,12 @@ def entry_commit_changes(self, refresh=True, controls=None, clear_history=True):
if not self.entry_cursor.connection.strategy.sync:
response, result, request = self.entry_cursor.connection.get_response(result, get_request=True)
else:
response = self.entry_cursor.connection.response
result = self.entry_cursor.connection.result
request = self.entry_cursor.connection.request
if self.entry_cursor.connection.strategy.thread_safe:
_, result, response, request = result
else:
response = self.entry_cursor.connection.response
result = self.entry_cursor.connection.result
request = self.entry_cursor.connection.request
self.entry_cursor._store_operation_in_history(request, result, response)
if result['result'] == RESULT_SUCCESS:
self._state.dn = rdn + ',' + ','.join(to_dn(self.entry_dn)[1:])
Expand Down Expand Up @@ -628,9 +637,12 @@ def entry_commit_changes(self, refresh=True, controls=None, clear_history=True):
if not self.entry_cursor.connection.strategy.sync: # asynchronous request
response, result, request = self.entry_cursor.connection.get_response(result, get_request=True)
else:
response = self.entry_cursor.connection.response
result = self.entry_cursor.connection.result
request = self.entry_cursor.connection.request
if self.entry_cursor.connection.strategy.thread_safe:
_, result, response, request = result
else:
response = self.entry_cursor.connection.response
result = self.entry_cursor.connection.result
request = self.entry_cursor.connection.request
self.entry_cursor._store_operation_in_history(request, result, response)

if result['result'] == RESULT_SUCCESS:
Expand Down
6 changes: 3 additions & 3 deletions ldap3/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,10 +1513,10 @@ def _fire_deferred(self, read_info=None):
def entries(self):
if self.response:
if not self._entries:
self._entries = self._get_entries(self.response)
self._entries = self._get_entries(self.response, self.request)
return self._entries

def _get_entries(self, search_response):
def _get_entries(self, search_response, search_request):
with self.connection_lock:
from .. import ObjectDef, Reader

Expand All @@ -1541,7 +1541,7 @@ def _get_entries(self, search_response):
object_def += list(attr_set) # converts the set in a list to be added to the object definition
object_defs.append((attr_set,
object_def,
Reader(self, object_def, self.request['base'], self.request['filter'], attributes=attr_set) if self.strategy.sync else Reader(self, object_def, '', '', attributes=attr_set))
Reader(self, object_def, search_request['base'], search_request['filter'], attributes=attr_set) if self.strategy.sync else Reader(self, object_def, '', '', attributes=attr_set))
) # objects_defs contains a tuple with the set, the ObjectDef and a cursor

entries = []
Expand Down
15 changes: 10 additions & 5 deletions ldap3/extend/microsoft/addMembersToGroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,15 @@ def ad_add_members_to_groups(connection,
error = False
for group in groups_dn:
if fix: # checks for existance of group and for already assigned members
result = connection.search(group, '(objectclass=*)', BASE, dereference_aliases=DEREF_NEVER,
attributes=['member'])

result = connection.search(group, '(objectclass=*)', BASE, dereference_aliases=DEREF_NEVER, attributes=['member'])
if not connection.strategy.sync:
response, result = connection.get_response(result)
else:
response, result = connection.response, connection.result
if connection.strategy.thread_safe:
_, result, response, _ = result
else:
response = connection.response
result = connection.result

if not result['description'] == 'success':
raise LDAPInvalidDnError(group + ' not found')
Expand All @@ -82,7 +84,10 @@ def ad_add_members_to_groups(connection,
if not connection.strategy.sync:
_, result = connection.get_response(result)
else:
result = connection.result
if connection.strategy.thread_safe:
_, result, _, _ = result
else:
result = connection.result
if result['description'] != 'success':
error = True
result_error_params = ['result', 'description', 'dn', 'message']
Expand Down
7 changes: 5 additions & 2 deletions ldap3/extend/microsoft/dirSync.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,11 @@ def loop(self):
if not self.connection.strategy.sync:
response, result = self.connection.get_response(result)
else:
response = self.connection.response
result = self.connection.result
if self.connection.strategy.thread_safe:
_, result, response, _ = result
else:
response = self.connection.response
result = self.connection.result

if result['description'] == 'success' and 'controls' in result and '1.2.840.113556.1.4.841' in result['controls']:
self.more_results = result['controls']['1.2.840.113556.1.4.841']['value']['more_results']
Expand Down
5 changes: 4 additions & 1 deletion ldap3/extend/microsoft/modifyPassword.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ def ad_modify_password(connection, user_dn, new_password, old_password, controls
if not connection.strategy.sync:
_, result = connection.get_response(result)
else:
result = connection.result
if connection.strategy.thread_safe:
_, result, _, _ = result
else:
result = connection.result

# change successful, returns True
if result['result'] == RESULT_SUCCESS:
Expand Down
11 changes: 9 additions & 2 deletions ldap3/extend/microsoft/removeMembersFromGroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ def ad_remove_members_from_groups(connection,
if not connection.strategy.sync:
response, result = connection.get_response(result)
else:
response, result = connection.response, connection.result
if connection.strategy.thread_safe:
_, result, response, _ = result
else:
response = connection.response
result = connection.result

if not result['description'] == 'success':
raise LDAPInvalidDnError(group + ' not found')
Expand All @@ -81,7 +85,10 @@ def ad_remove_members_from_groups(connection,
if not connection.strategy.sync:
_, result = connection.get_response(result)
else:
result = connection.result
if connection.strategy.thread_safe:
_, result, _, _ = result
else:
result = connection.result
if result['description'] != 'success':
error = True
result_error_params = ['result', 'description', 'dn', 'message']
Expand Down
9 changes: 5 additions & 4 deletions ldap3/extend/microsoft/unlockAccount.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@
def ad_unlock_account(connection, user_dn, controls=None):
if connection.check_names:
user_dn = safe_dn(user_dn)
result = connection.modify(user_dn,
{'lockoutTime': [(MODIFY_REPLACE, ['0'])]},
controls)
result = connection.modify(user_dn, {'lockoutTime': [(MODIFY_REPLACE, ['0'])]}, controls)

if not connection.strategy.sync:
_, result = connection.get_response(result)
else:
result = connection.result
if connection.strategy.thread_safe:
_, result, _, _ = result
else:
result = connection.result

# change successful, returns True
if result['result'] == RESULT_SUCCESS:
Expand Down
31 changes: 19 additions & 12 deletions ldap3/extend/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,40 +52,47 @@ def send(self):

resp = self.connection.extended(self.request_name, self.request_value, self.controls)
if not self.connection.strategy.sync:
_, self.result = self.connection.get_response(resp)
_, result = self.connection.get_response(resp)
else:
self.result = self.connection.result
self.decode_response()
if self.connection.strategy.thread_safe:
_, result, _, _ = resp
else:
result = self.connection.result
self.result = result
self.decode_response(result)
self.populate_result()
self.set_response()
return self.response_value

def populate_result(self):
pass

def decode_response(self):
if not self.result:
def decode_response(self, response=None):
if not response:
response = self.result
if not response:
return None
if self.result['result'] not in [RESULT_SUCCESS]:
if response['result'] not in [RESULT_SUCCESS]:
if self.connection.raise_exceptions:
raise LDAPExtensionError('extended operation error: ' + self.result['description'] + ' - ' + self.result['message'])
raise LDAPExtensionError('extended operation error: ' + response['description'] + ' - ' + response['message'])
else:
return None
if not self.response_name or self.result['responseName'] == self.response_name:
if self.result['responseValue']:
if not self.response_name or response['responseName'] == self.response_name:
if response['responseValue']:
if self.asn1_spec is not None:
decoded, unprocessed = decoder.decode(self.result['responseValue'], asn1Spec=self.asn1_spec)
decoded, unprocessed = decoder.decode(response['responseValue'], asn1Spec=self.asn1_spec)
if unprocessed:
raise LDAPExtensionError('error decoding extended response value')
self.decoded_response = decoded
else:
self.decoded_response = self.result['responseValue']
self.decoded_response = response['responseValue']
else:
raise LDAPExtensionError('invalid response name received')

def set_response(self):
self.response_value = self.result[self.response_attribute] if self.result and self.response_attribute in self.result else None
self.connection.response = self.response_value
if not self.connection.strategy.thread_safe:
self.connection.response = self.response_value

def config(self):
pass
2 changes: 1 addition & 1 deletion ldap3/extend/standard/PagedSearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def paged_search_generator(connection,
response, result = connection.get_response(result)
else:
if connection.strategy.thread_safe:
status, result, response, _ = result
_, result, response, _ = result
else:
response = connection.response
result = connection.result
Expand Down
6 changes: 4 additions & 2 deletions test/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,11 @@ def get_response_values(result, connection):
if isinstance(result, tuple):
return result # result already contains a tuple with status, result, response, request, request (for thread safe connection)
if not connection.strategy.sync:
if isinstance(result, bool): # abandon returns a boolean even with async strategy
return result, None, None, None
status = result
response, result, request = connection.get_response(result, get_request=True)
return status, result, response, request, request
response, result, request = connection.get_response(status, get_request=True)
return status, result, response, request
return result, connection.result, connection.response, connection.request


Expand Down
20 changes: 5 additions & 15 deletions test/testSearchAndModifyEntries.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ldap3.abstract import STATUS_WRITABLE, STATUS_COMMITTED, STATUS_DELETED, STATUS_INIT, STATUS_MANDATORY_MISSING, STATUS_VIRTUAL, STATUS_PENDING_CHANGES, STATUS_READ, STATUS_READY_FOR_DELETION
from ldap3.core.results import RESULT_CONSTRAINT_VIOLATION, RESULT_ATTRIBUTE_OR_VALUE_EXISTS
from test.config import test_base, test_name_attr, random_id, get_connection, add_user, drop_connection, test_server_type, test_int_attr, test_strategy,\
test_multivalued_attribute, test_singlevalued_attribute
test_multivalued_attribute, test_singlevalued_attribute, get_response_values


testcase_id = ''
Expand All @@ -58,13 +58,8 @@ def tearDown(self):
self.assertFalse(self.connection.bound)

def get_entry(self, entry_name):
result = self.connection.search(search_base=test_base, search_filter='(' + test_name_attr + '=' + testcase_id + entry_name + ')', attributes=[test_name_attr, 'givenName', test_multivalued_attribute, test_singlevalued_attribute])
if not self.connection.strategy.sync:
response, result = self.connection.get_response(result)
entries = self.connection._get_entries(response)
else:
result = self.connection.result
entries = self.connection.entries
status, result, response, request = get_response_values(self.connection.search(search_base=test_base, search_filter='(' + test_name_attr + '=' + testcase_id + entry_name + ')', attributes=[test_name_attr, 'givenName', test_multivalued_attribute, test_singlevalued_attribute]), self.connection)
entries = self.connection._get_entries(response, request)
self.assertEqual(result['description'], 'success')
self.assertEqual(len(entries), 1)
return entries[0]
Expand Down Expand Up @@ -120,13 +115,8 @@ def test_search_and_delete_entry(self):
self.assertTrue(result)
counter = 20
while counter > 0: # waits for at maximum 20 times - delete operation can take some time to complete
result = self.connection.search(search_base=test_base, search_filter='(' + test_name_attr + '=' + testcase_id + 'del1)', attributes=[test_name_attr, 'givenName'])
if not self.connection.strategy.sync:
response, result = self.connection.get_response(result)
entries = self.connection._get_entries(response)
else:
result = self.connection.result
entries = self.connection.entries
status, result, response, request = get_response_values(self.connection.search(search_base=test_base, search_filter='(' + test_name_attr + '=' + testcase_id + 'del1)', attributes=[test_name_attr, 'givenName']), self.connection)
entries = self.connection._get_entries(response, request)
if len(entries) == 0:
break
sleep(3)
Expand Down

0 comments on commit ba8933e

Please sign in to comment.