Skip to content

Commit

Permalink
threadsafe sync strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
cannatag committed Jul 14, 2020
1 parent d4aaff1 commit 72edc5e
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 259 deletions.
1 change: 1 addition & 0 deletions ldap3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@

# client strategies
SYNC = 'SYNC'
SAFE_SYNC = 'SAFE_SYNC'
ASYNC = 'ASYNC'
LDIF = 'LDIF'
RESTARTABLE = 'RESTARTABLE'
Expand Down
14 changes: 10 additions & 4 deletions ldap3/abstract/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,11 @@ def _execute_query(self, query_scope, attributes):
if not self.connection.strategy.sync:
response, result, request = self.connection.get_response(result, get_request=True)
else:
response = self.connection.response
result = self.connection.result
if self.connection.strategy.thread_safe:
_, result, response = result
else:
response = self.connection.response
result = self.connection.result
request = self.connection.request

self._store_operation_in_history(request, result, response)
Expand Down Expand Up @@ -820,8 +823,11 @@ def _refresh_object(self, entry_dn, attributes=None, tries=4, seconds=2, control
if not self.connection.strategy.sync:
response, result, request = self.connection.get_response(result, get_request=True)
else:
response = self.connection.response
result = self.connection.result
if self.connection.strategy.thread_safe:
_, result, response = result
else:
response = self.connection.response
result = self.connection.result
request = self.connection.request

if result['result'] in [RESULT_SUCCESS]:
Expand Down
47 changes: 28 additions & 19 deletions ldap3/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from .. import ANONYMOUS, SIMPLE, SASL, MODIFY_ADD, MODIFY_DELETE, MODIFY_REPLACE, get_config_parameter, DEREF_ALWAYS, \
SUBTREE, ASYNC, SYNC, NO_ATTRIBUTES, ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES, MODIFY_INCREMENT, LDIF, ASYNC_STREAM, \
RESTARTABLE, ROUND_ROBIN, REUSABLE, AUTO_BIND_DEFAULT, AUTO_BIND_NONE, AUTO_BIND_TLS_BEFORE_BIND,\
RESTARTABLE, ROUND_ROBIN, REUSABLE, AUTO_BIND_DEFAULT, AUTO_BIND_NONE, AUTO_BIND_TLS_BEFORE_BIND, SAFE_SYNC, \
AUTO_BIND_TLS_AFTER_BIND, AUTO_BIND_NO_TLS, STRING_TYPES, SEQUENCE_TYPES, MOCK_SYNC, MOCK_ASYNC, NTLM, EXTERNAL,\
DIGEST_MD5, GSSAPI, PLAIN, DSA, SCHEMA, ALL

Expand All @@ -52,6 +52,7 @@
from ..protocol.sasl.external import sasl_external
from ..protocol.sasl.plain import sasl_plain
from ..strategy.sync import SyncStrategy
from ..strategy.safeSync import SafeSyncStrategy
from ..strategy.mockAsync import MockAsyncStrategy
from ..strategy.asynchronous import AsyncStrategy
from ..strategy.reusable import ReusableStrategy
Expand Down Expand Up @@ -80,6 +81,7 @@
PLAIN]

CLIENT_STRATEGIES = [SYNC,
SAFE_SYNC,
ASYNC,
LDIF,
RESTARTABLE,
Expand Down Expand Up @@ -116,7 +118,6 @@ def _format_socket_endpoints(sock):
return '<no socket>'


# noinspection PyProtectedMember
class Connection(object):
"""Main ldap connection class.
Expand Down Expand Up @@ -181,7 +182,6 @@ class Connection(object):
:param source_port_list: a list of source ports to choose from when opening the connection to the server. Cannot be specified with source_port
:type source_port_list: list
"""

def __init__(self,
server,
user=None,
Expand Down Expand Up @@ -323,6 +323,8 @@ def __init__(self,

if self.strategy_type == SYNC:
self.strategy = SyncStrategy(self)
elif self.strategy_type == SAFE_SYNC:
self.strategy = SafeSyncStrategy(self)
elif self.strategy_type == ASYNC:
self.strategy = AsyncStrategy(self)
elif self.strategy_type == LDIF:
Expand Down Expand Up @@ -352,7 +354,7 @@ def __init__(self,
self.post_send_search = self.strategy.post_send_search

if not self.strategy.no_real_dsa:
self.do_auto_bind()
self._do_auto_bind()
# else: # for strategies with a fake server set get_info to NONE if server hasn't a schema
# if self.server and not self.server.schema:
# self.server.get_info = NONE
Expand All @@ -362,7 +364,14 @@ def __init__(self,
else:
log(BASIC, 'instantiated Connection: <%r>', self)

def do_auto_bind(self):
def _prepare_return_value(self, status, response=False):
if self.strategy.thread_safe:
temp_response = self.response
self.response = None
return status, deepcopy(self.result), deepcopy(temp_response) if response else None
return status

def _do_auto_bind(self):
if self.auto_bind and self.auto_bind not in [AUTO_BIND_NONE, AUTO_BIND_DEFAULT]:
if log_enabled(BASIC):
log(BASIC, 'performing automatic bind for <%s>', self)
Expand Down Expand Up @@ -513,7 +522,6 @@ def __enter__(self):

return self

# noinspection PyUnusedLocal
def __exit__(self, exc_type, exc_val, exc_tb):
with self.connection_lock:
context_bound, context_closed = self._context_state.pop()
Expand Down Expand Up @@ -646,7 +654,7 @@ def bind(self,
if log_enabled(BASIC):
log(BASIC, 'done BIND operation, result <%s>', self.bound)

return self.bound
return self._prepare_return_value(self.bound, self.result)

def rebind(self,
user=None,
Expand Down Expand Up @@ -692,7 +700,7 @@ def rebind(self,
raise LDAPBindError('Unable to rebind as a different user, furthermore the server abruptly closed the connection')
else:
self.strategy.pool.rebind_pool()
return True
return self._prepare_return_value(True, self.result)

def unbind(self,
controls=None):
Expand Down Expand Up @@ -724,7 +732,7 @@ def unbind(self,
if log_enabled(BASIC):
log(BASIC, 'done UNBIND operation, result <%s>', True)

return True
return self._prepare_return_value(True)

def search(self,
search_base,
Expand Down Expand Up @@ -838,7 +846,7 @@ def search(self,
if log_enabled(BASIC):
log(BASIC, 'done SEARCH operation, result <%s>', return_value)

return return_value
return self._prepare_return_value(return_value, response=True)

def compare(self,
dn,
Expand Down Expand Up @@ -892,7 +900,7 @@ def compare(self,
if log_enabled(BASIC):
log(BASIC, 'done COMPARE operation, result <%s>', return_value)

return return_value
return self._prepare_return_value(return_value)

def add(self,
dn,
Expand Down Expand Up @@ -981,7 +989,7 @@ def add(self,
if log_enabled(BASIC):
log(BASIC, 'done ADD operation, result <%s>', return_value)

return return_value
return self._prepare_return_value(return_value)

def delete(self,
dn,
Expand Down Expand Up @@ -1025,7 +1033,7 @@ def delete(self,
if log_enabled(BASIC):
log(BASIC, 'done DELETE operation, result <%s>', return_value)

return return_value
return self._prepare_return_value(return_value)

def modify(self,
dn,
Expand Down Expand Up @@ -1115,7 +1123,7 @@ def modify(self,
if log_enabled(BASIC):
log(BASIC, 'done MODIFY operation, result <%s>', return_value)

return return_value
return self._prepare_return_value(return_value)

def modify_dn(self,
dn,
Expand Down Expand Up @@ -1172,7 +1180,7 @@ def modify_dn(self,
if log_enabled(BASIC):
log(BASIC, 'done MODIFY DN operation, result <%s>', return_value)

return return_value
return self._prepare_return_value(return_value)

def abandon(self,
message_id,
Expand Down Expand Up @@ -1205,7 +1213,7 @@ def abandon(self,
if log_enabled(BASIC):
log(BASIC, 'done ABANDON operation, result <%s>', return_value)

return return_value
return self._prepare_return_value(return_value)

def extended(self,
request_name,
Expand Down Expand Up @@ -1239,15 +1247,16 @@ def extended(self,
if log_enabled(BASIC):
log(BASIC, 'done EXTENDED operation, result <%s>', return_value)

return return_value
return self._prepare_return_value(return_value, response=True)

def start_tls(self, read_server_info=True): # as per RFC4511. Removal of TLS is defined as MAY in RFC4511 so the client can't implement a generic stop_tls method0

if log_enabled(BASIC):
log(BASIC, 'start START TLS operation via <%s>', self)

with self.connection_lock:
return_value = False
self.result = None

if not self.server.tls:
self.server.tls = Tls()

Expand All @@ -1271,7 +1280,7 @@ def start_tls(self, read_server_info=True): # as per RFC4511. Removal of TLS is
if log_enabled(BASIC):
log(BASIC, 'done START TLS operation, result <%s>', return_value)

return return_value
return self._prepare_return_value(return_value)

def do_sasl_bind(self,
controls):
Expand Down
45 changes: 32 additions & 13 deletions ldap3/core/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,18 @@ def _get_dsa_info(self, connection):
'+'], # requests all remaining attributes (other),
get_operational_attributes=True)

if connection.strategy.thread_safe:
status, result, response = result
else:
status = result
result = connection.result
response = connection.response

with self.dit_lock:
if isinstance(result, bool): # sync request
self._dsa_info = DsaInfo(connection.response[0]['attributes'], connection.response[0]['raw_attributes']) if result else self._dsa_info
elif result: # asynchronous request, must check if attributes in response
results, _ = connection.get_response(result)
if connection.strategy.sync: # sync request
self._dsa_info = DsaInfo(response[0]['attributes'], response[0]['raw_attributes']) if status else self._dsa_info
elif status: # asynchronous request, must check if attributes in response
results, _ = connection.get_response(status)
if len(results) == 1 and 'attributes' in results[0] and 'raw_attributes' in results[0]:
self._dsa_info = DsaInfo(results[0]['attributes'], results[0]['raw_attributes'])

Expand All @@ -446,12 +453,18 @@ def _get_schema_info(self, connection, entry=''):
schema_entry = self._dsa_info.schema_entry if self._dsa_info.schema_entry else None
else:
result = connection.search(entry, '(objectClass=*)', BASE, attributes=['subschemaSubentry'], get_operational_attributes=True)
if isinstance(result, bool): # sync request
if result and 'subschemaSubentry' in connection.response[0]['raw_attributes']:
if len(connection.response[0]['raw_attributes']['subschemaSubentry']) > 0:
schema_entry = connection.response[0]['raw_attributes']['subschemaSubentry'][0]
if connection.strategy.thread_safe:
status, result, response = result
else:
status = result
result = connection.result
response = connection.response
if connection.strategy.sync: # sync request
if status and 'subschemaSubentry' in response[0]['raw_attributes']:
if len(response[0]['raw_attributes']['subschemaSubentry']) > 0:
schema_entry = response[0]['raw_attributes']['subschemaSubentry'][0]
else: # asynchronous request, must check if subschemaSubentry in attributes
results, _ = connection.get_response(result)
results, _ = connection.get_response(status)
if len(results) == 1 and 'raw_attributes' in results[0] and 'subschemaSubentry' in results[0]['attributes']:
if len(results[0]['raw_attributes']['subschemaSubentry']) > 0:
schema_entry = results[0]['raw_attributes']['subschemaSubentry'][0]
Expand All @@ -475,13 +488,19 @@ def _get_schema_info(self, connection, entry=''):
'*'], # requests all remaining attributes (other)
get_operational_attributes=True
)
if connection.strategy.thread_safe:
status, result, response = result
else:
status = result
result = connection.result
response = connection.response
with self.dit_lock:
self._schema_info = None
if result:
if isinstance(result, bool): # sync request
self._schema_info = SchemaInfo(schema_entry, connection.response[0]['attributes'], connection.response[0]['raw_attributes']) if result else None
if status:
if connection.strategy.sync: # sync request
self._schema_info = SchemaInfo(schema_entry, response[0]['attributes'], response[0]['raw_attributes'])
else: # asynchronous request, must check if attributes in response
results, result = connection.get_response(result)
results, result = connection.get_response(status)
if len(results) == 1 and 'attributes' in results[0] and 'raw_attributes' in results[0]:
self._schema_info = SchemaInfo(schema_entry, results[0]['attributes'], results[0]['raw_attributes'])
if self._schema_info and not self._schema_info.is_valid(): # flaky servers can return an empty schema, checks if it is so and set schema to None
Expand Down
19 changes: 12 additions & 7 deletions ldap3/strategy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self, ldap_connection):
self.pooled = None # Indicates a connection with a connection pool
self.can_stream = None # indicates if a strategy keeps a stream of responses (i.e. LdifProducer can accumulate responses with a single header). Stream must be initialized and closed in _start_listen() and _stop_listen()
self.referral_cache = {}
self.thread_safe = False # Indicates that connection can be used in a multithread application
if log_enabled(BASIC):
log(BASIC, 'instantiated <%s>: <%s>', self.__class__.__name__, self)

Expand Down Expand Up @@ -141,9 +142,6 @@ def open(self, reset_usage=True, read_server_info=True):
if log_enabled(ERROR):
log(ERROR, 'unable to open socket for <%s>', self.connection)
raise LDAPSocketOpenError('unable to open socket', exception_history)
if log_enabled(ERROR):
log(ERROR, 'unable to open socket for <%s>', self.connection)
raise LDAPSocketOpenError('unable to open socket', exception_history)
elif not self.connection.server.current_address:
if log_enabled(ERROR):
log(ERROR, 'invalid server address for <%s>', self.connection)
Expand Down Expand Up @@ -693,13 +691,20 @@ def do_next_range_search(self, request, response, attr_name):
search_scope=BASE,
dereference_aliases=request['dereferenceAlias'],
attributes=[attr_type + ';range=' + str(int(high_range) + 1) + '-*'])
if isinstance(result, bool):
if result:
current_response = self.connection.response[0]
if self.connection.strategy.thread_safe:
status, result, response = result
else:
status = result
result = self.connection.result
response = self.connection.response

if self.connection.strategy.sync:
if status:
current_response = response[0]
else:
done = True
else:
current_response, _ = self.get_response(result)
current_response, _ = self.get_response(status)
current_response = current_response[0]

if not done:
Expand Down
32 changes: 32 additions & 0 deletions ldap3/strategy/safeSync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
"""

# Created on 2020.07.12
#
# Author: Giovanni Cannata
#
# Copyright 2013 - 2020 Giovanni Cannata
#
# This file is part of ldap3.
#
# ldap3 is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ldap3 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with ldap3 in the COPYING and COPYING.LESSER files.
# If not, see <http://www.gnu.org/licenses/>.

from .sync import SyncStrategy


class SafeSyncStrategy(SyncStrategy):
def __init__(self, ldap_connection):
SyncStrategy.__init__(self, ldap_connection)
self.thread_safe = True

0 comments on commit 72edc5e

Please sign in to comment.