diff --git a/hazelcast/connection.py b/hazelcast/connection.py index 6822fb1598..4f408167bb 100644 --- a/hazelcast/connection.py +++ b/hazelcast/connection.py @@ -75,7 +75,7 @@ def __init__(self, client, reactor, address_provider, lifecycle_service, partition_service, cluster_service, invocation_service, near_cache_manager): self.live = False - self.active_connections = dict() + self.active_connections = dict() # uuid to connection, must be modified under the _lock self.client_uuid = uuid.uuid4() self._client = client @@ -95,7 +95,8 @@ def __init__(self, client, reactor, address_provider, lifecycle_service, self._connect_all_members_timer = None self._async_start = config.async_start self._connect_to_cluster_thread_running = False - self._pending_connections = dict() + self._pending_connections = dict() # must be modified under the _lock + self._addresses_to_connections = dict() # address to connection, must be modified under the _lock self._shuffle_member_list = config.shuffle_member_list self._lock = threading.RLock() self._connection_id_generator = AtomicInteger() @@ -118,10 +119,7 @@ def get_connection(self, member_uuid): return self.active_connections.get(member_uuid, None) def get_connection_from_address(self, address): - for connection in six.itervalues(self.active_connections): - if address == connection.remote_address: - return connection - return None + return self._addresses_to_connections.get(address, None) def get_random_connection(self): if self._smart_routing_enabled: @@ -131,7 +129,9 @@ def get_random_connection(self): if connection: return connection - for connection in six.itervalues(self.active_connections): + # We should not get to this point under normal circumstances. + # Therefore, copying the list should be OK. + for connection in list(six.itervalues(self.active_connections)): return connection return None @@ -156,16 +156,20 @@ def shutdown(self): self._connect_all_members_timer.cancel() self._heartbeat_manager.shutdown() - for connection_future in six.itervalues(self._pending_connections): - connection_future.set_exception(HazelcastClientNotActiveError("Hazelcast client is shutting down")) - # Need to create copy of connection values to avoid modification errors on runtime - for connection in list(six.itervalues(self.active_connections)): - connection.close("Hazelcast client is shutting down", None) + with self._lock: + for connection_future in six.itervalues(self._pending_connections): + connection_future.set_exception(HazelcastClientNotActiveError("Hazelcast client is shutting down")) - self._connection_listeners = [] - self.active_connections.clear() - self._pending_connections.clear() + # Need to create copy of connection values to avoid modification errors on runtime + for connection in list(six.itervalues(self.active_connections)): + connection.close("Hazelcast client is shutting down", None) + + self.active_connections.clear() + self._addresses_to_connections.clear() + self._pending_connections.clear() + + del self._connection_listeners[:] def connect_to_all_cluster_members(self): if not self._smart_routing_enabled: @@ -180,6 +184,7 @@ def connect_to_all_cluster_members(self): def on_connection_close(self, closed_connection, cause): connected_address = closed_connection.connected_address remote_uuid = closed_connection.remote_uuid + remote_address = closed_connection.remote_address if not connected_address: _logger.debug("Destroying %s, but it has no remote address, hence nothing is " @@ -188,6 +193,7 @@ def on_connection_close(self, closed_connection, cause): with self._lock: pending = self._pending_connections.pop(connected_address, None) connection = self.active_connections.pop(remote_uuid, None) + self._addresses_to_connections.pop(remote_address, None) if pending: pending.set_exception(cause) @@ -395,8 +401,8 @@ def _on_auth(self, response, connection, address): raise err else: e = response.exception() + # This will set the exception for the pending connection future connection.close("Failed to authenticate connection", e) - self._pending_connections.pop(address, None) six.reraise(e.__class__, e, response.traceback()) def _handle_successful_auth(self, response, connection, address): @@ -420,7 +426,8 @@ def _handle_successful_auth(self, response, connection, address): self._on_cluster_restart() with self._lock: - self.active_connections[response["member_uuid"]] = connection + self.active_connections[remote_uuid] = connection + self._addresses_to_connections[remote_address] = connection self._pending_connections.pop(address, None) if is_initial_connection: @@ -494,11 +501,12 @@ def start(self): """Starts sending periodic HeartBeat operations.""" def _heartbeat(): - if not self._connection_manager.live: + conn_manager = self._connection_manager + if not conn_manager.live: return now = time.time() - for connection in list(self._connection_manager.active_connections.values()): + for connection in list(six.itervalues(conn_manager.active_connections)): self._check_connection(now, connection) self._heartbeat_timer = self._reactor.add_timer(self._heartbeat_interval, _heartbeat) diff --git a/hazelcast/listener.py b/hazelcast/listener.py index 6e10418d5a..7575ff573f 100644 --- a/hazelcast/listener.py +++ b/hazelcast/listener.py @@ -53,7 +53,7 @@ def register_listener(self, registration_request, decode_register_response, enco self._active_registrations[registration_id] = registration futures = [] - for connection in six.itervalues(self._connection_manager.active_connections): + for connection in list(six.itervalues(self._connection_manager.active_connections)): future = self._register_on_connection_async(registration_id, registration, connection) futures.append(future) diff --git a/hazelcast/reactor.py b/hazelcast/reactor.py index b2faf652c7..04189d3484 100644 --- a/hazelcast/reactor.py +++ b/hazelcast/reactor.py @@ -1,6 +1,7 @@ import asyncore import errno import logging +import os import select import socket import sys @@ -23,32 +24,151 @@ except ImportError: ssl = None +try: + import fcntl +except ImportError: + fcntl = None + +try: + from _thread import get_ident +except ImportError: + # Python2 + from thread import get_ident + _logger = logging.getLogger(__name__) -class AsyncoreReactor(object): - _thread = None - _is_live = False +def _set_nonblocking(fd): + if not fcntl: + return - def __init__(self): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + + +class _SocketAdapter(object): + def __init__(self, fd): + self._fd = fd + + def fileno(self): + return self._fd + + def close(self): + os.close(self._fd) + + def getsockopt(self, level, optname, buflen=None): + if level == socket.SOL_SOCKET and optname == socket.SO_ERROR and not buflen: + return 0 + raise NotImplementedError("Only asyncore specific behaviour is implemented.") + + +class _AbstractWaker(asyncore.dispatcher): + def __init__(self, map): + asyncore.dispatcher.__init__(self, map=map) + self.awake = False + + def writable(self): + return False + + def wake(self): + raise NotImplementedError("wake") + + +class _PipedWaker(_AbstractWaker): + def __init__(self, map): + _AbstractWaker.__init__(self, map) + self._read_fd, self._write_fd = os.pipe() + self.set_socket(_SocketAdapter(self._read_fd)) + _set_nonblocking(self._read_fd) + _set_nonblocking(self._write_fd) + + def wake(self): + if not self.awake: + self.awake = True + try: + os.write(self._write_fd, b"x") + except (IOError, ValueError): + pass + + def handle_read(self): + self.awake = False + try: + while len(os.read(self._read_fd, 4096)) == 4096: + pass + except (IOError, OSError): + pass + + def close(self): + _AbstractWaker.close(self) # Will close the reader + os.close(self._write_fd) + + +class _SocketedWaker(_AbstractWaker): + def __init__(self, map): + _AbstractWaker.__init__(self, map) + self._writer = socket.socket() + self._writer.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + a = socket.socket() + a.bind(("127.0.0.1", 0)) + a.listen(1) + addr = a.getsockname() + + try: + self._writer.connect(addr) + self._reader, _ = a.accept() + finally: + a.close() + + self.set_socket(self._reader) + self._writer.settimeout(0) + self._reader.settimeout(0) + + def wake(self): + if not self.awake: + self.awake = True + try: + self._writer.send(b"x") + except (IOError, socket.error, ValueError): + pass + + def handle_read(self): + self.awake = False + try: + while len(self._reader.recv(4096)) == 4096: + pass + except (IOError, socket.error): + pass + + def close(self): + _AbstractWaker.close(self) # Will close the reader + self._writer.close() + + +class _AbstractLoop(object): + def __init__(self, map): + self._map = map self._timers = [] # Accessed only from the reactor thread self._new_timers = deque() # Popped only from the reactor thread - self._map = {} + self._is_live = False + self._thread = None + self._ident = -1 def start(self): self._is_live = True self._thread = threading.Thread(target=self._loop, name="hazelcast-reactor") self._thread.daemon = True self._thread.start() + self._ident = self._thread.ident def _loop(self): _logger.debug("Starting Reactor Thread") Future._threading_locals.is_reactor_thread = True while self._is_live: try: - asyncore.loop(count=1, timeout=0.01, map=self._map) + self.run_loop() self._check_timers() - except select.error as err: + except select.error: # TODO: parse error type to catch only error "9" _logger.warning("Connection closed by server") pass @@ -59,6 +179,11 @@ def _loop(self): _logger.debug("Reactor Thread exited") self._cleanup_all_timers() + def add_timer(self, delay, callback): + timer = Timer(delay + time.time(), callback) + self._new_timers.append((timer.end, timer)) + return timer + def _check_timers(self): timers = self._timers @@ -83,13 +208,54 @@ def _check_timers(self): # timers in the heap. return - def add_timer_absolute(self, timeout, callback): - timer = Timer(timeout, callback) - self._new_timers.append((timer.end, timer)) - return timer + def _cleanup_all_timers(self): + timers = self._timers + new_timers = self._new_timers - def add_timer(self, delay, callback): - return self.add_timer_absolute(delay + time.time(), callback) + while timers: + _, timer = timers.pop() + timer.timer_ended_cb() + + # Although it is not the case with the current code base, + # the timers ended above may add new timers. So, the order + # is important. + while new_timers: + _, timer = new_timers.popleft() + timer.timer_ended_cb() + + def check_loop(self): + raise NotImplementedError("check_loop") + + def run_loop(self): + raise NotImplementedError("run_loop") + + def wake_loop(self): + raise NotImplementedError("wake_loop") + + def shutdown(self): + raise NotImplementedError("shutdown") + + +class _WakeableLoop(_AbstractLoop): + _waker_class = _PipedWaker if os.name != 'nt' else _SocketedWaker + + def __init__(self, map): + _AbstractLoop.__init__(self, map) + self.waker = self._waker_class(map) + + def check_loop(self): + assert not self.waker.awake + self.wake_loop() + assert self.waker.awake + self.run_loop() + assert not self.waker.awake + + def run_loop(self): + asyncore.loop(timeout=0.01, use_poll=True, map=self._map, count=1) + + def wake_loop(self): + if self._ident != get_ident(): + self.waker.wake() def shutdown(self): if not self._is_live: @@ -97,10 +263,13 @@ def shutdown(self): self._is_live = False - if self._thread is not threading.current_thread(): + if self._ident != get_ident(): self._thread.join() for connection in list(self._map.values()): + if connection is self.waker: + continue + try: connection.close(None, HazelcastError("Client is shutting down")) except OSError as connection: @@ -108,26 +277,73 @@ def shutdown(self): pass else: raise + + self.waker.close() self._map.clear() - def connection_factory(self, connection_manager, connection_id, address, network_config, message_callback): - return AsyncoreConnection(self._map, connection_manager, connection_id, - address, network_config, message_callback) - def _cleanup_all_timers(self): - timers = self._timers - new_timers = self._new_timers +class _BasicLoop(_AbstractLoop): + def check_loop(self): + pass - while timers: - _, timer = timers.pop() - timer.timer_ended_cb() + def run_loop(self): + asyncore.loop(timeout=0.001, use_poll=True, map=self._map, count=1) - # Although it is not the case with the current code base, - # the timers ended above may add new timers. So, the order - # is important. - while new_timers: - _, timer = new_timers.popleft() - timer.timer_ended_cb() + def wake_loop(self): + pass + + def shutdown(self): + if not self._is_live: + return + + self._is_live = False + + if self._ident != get_ident(): + self._thread.join() + + for connection in list(self._map.values()): + try: + connection.close(None, HazelcastError("Client is shutting down")) + except OSError as connection: + if connection.args[0] == socket.EBADF: + pass + else: + raise + + self._map.clear() + + +class AsyncoreReactor(object): + def __init__(self): + self.map = {} + loop = None + try: + loop = _WakeableLoop(self.map) + loop.check_loop() + except: + _logger.exception("Failed to initialize the wakeable loop. " + "Using the basic loop instead. " + "When used in the blocking mode, client" + "may have sub-optimal performance.") + if loop: + loop.shutdown() + loop = _BasicLoop(self.map) + self._loop = loop + + def start(self): + self._loop.start() + + def add_timer(self, delay, callback): + return self._loop.add_timer(delay, callback) + + def wake_loop(self): + self._loop.wake_loop() + + def shutdown(self): + self._loop.shutdown() + + def connection_factory(self, connection_manager, connection_id, address, network_config, message_callback): + return AsyncoreConnection(self, connection_manager, connection_id, address, network_config, message_callback) _BUFFER_SIZE = 128000 @@ -137,13 +353,13 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): sent_protocol_bytes = False read_buffer_size = _BUFFER_SIZE - def __init__(self, dispatcher_map, connection_manager, connection_id, address, + def __init__(self, reactor, connection_manager, connection_id, address, config, message_callback): - asyncore.dispatcher.__init__(self, map=dispatcher_map) + asyncore.dispatcher.__init__(self, map=reactor.map) Connection.__init__(self, connection_manager, connection_id, message_callback) - self.connected_address = address - self._write_lock = threading.Lock() + self._reactor = reactor + self.connected_address = address self._write_queue = deque() self.create_socket(socket.AF_INET, socket.SOCK_STREAM) @@ -228,17 +444,21 @@ def handle_read(self): reader.process() def handle_write(self): - with self._write_lock: + while True: try: data = self._write_queue.popleft() except IndexError: return + sent = self.send(data) self.last_write_time = time.time() self.sent_protocol_bytes = True if sent < len(data): self._write_queue.appendleft(data[sent:]) + if sent == 0: + return + def handle_close(self): _logger.warning("Connection closed by server") self.close(None, IOError("Connection closed by server")) @@ -256,18 +476,8 @@ def readable(self): return self.live and self.sent_protocol_bytes def _write(self, buf): - # if write queue is empty, send the data right away, otherwise add to queue - if len(self._write_queue) == 0 and self._write_lock.acquire(False): - try: - sent = self.send(buf) - self.last_write_time = time.time() - if sent < len(buf): - _logger.info("Adding to queue") - self._write_queue.appendleft(buf[sent:]) - finally: - self._write_lock.release() - else: - self._write_queue.append(buf) + self._write_queue.append(buf) + self._reactor.wake_loop() def writable(self): return len(self._write_queue) > 0 diff --git a/run-tests.ps1 b/run-tests.ps1 index 5b796d6b57..c8cac5380a 100644 --- a/run-tests.ps1 +++ b/run-tests.ps1 @@ -1,4 +1,4 @@ -$serverVersion = "4.1-SNAPSHOT" +$serverVersion = "4.1" $hazelcastTestVersion=$serverVersion $hazelcastEnterpriseTestVersion=$serverVersion diff --git a/run-tests.sh b/run-tests.sh index 244e6f4e85..1ad58b34e6 100644 --- a/run-tests.sh +++ b/run-tests.sh @@ -20,7 +20,7 @@ else USER="" fi -HZ_VERSION="4.1-SNAPSHOT" +HZ_VERSION="4.1" HAZELCAST_TEST_VERSION=${HZ_VERSION} HAZELCAST_ENTERPRISE_TEST_VERSION=${HZ_VERSION} diff --git a/start-rc.sh b/start-rc.sh index 02998f7654..368b0116a6 100644 --- a/start-rc.sh +++ b/start-rc.sh @@ -20,7 +20,7 @@ else USER="" fi -HZ_VERSION="4.1-SNAPSHOT" +HZ_VERSION="4.1" HAZELCAST_TEST_VERSION=${HZ_VERSION} HAZELCAST_ENTERPRISE_TEST_VERSION=${HZ_VERSION} diff --git a/tests/reactor_test.py b/tests/reactor_test.py new file mode 100644 index 0000000000..f9ceebb17f --- /dev/null +++ b/tests/reactor_test.py @@ -0,0 +1,261 @@ +import os +import socket +import threading +import time +from collections import OrderedDict + +from mock import MagicMock +from parameterized import parameterized + +from hazelcast import six +from hazelcast.reactor import AsyncoreReactor, _WakeableLoop, _SocketedWaker, _PipedWaker, _BasicLoop +from hazelcast.util import AtomicInteger +from tests.base import HazelcastTestCase + + +class ReactorTest(HazelcastTestCase): + def test_default_loop_is_wakeable(self): + reactor = AsyncoreReactor() + self.assertIsInstance(reactor._loop, _WakeableLoop) + + def test_reactor_lifetime(self): + t_count = threading.active_count() + reactor = AsyncoreReactor() + reactor.start() + try: + self.assertEqual(t_count + 1, threading.active_count()) # reactor thread + finally: + reactor.shutdown() + self.assertEqual(t_count, threading.active_count()) + + +LOOP_CLASSES = [ + ("wakeable", _WakeableLoop,), + ("basic", _BasicLoop,), +] + + +class LoopTest(HazelcastTestCase): + def test_wakeable_loop_default_waker(self): + loop = _WakeableLoop({}) + try: + if os.name == "nt": + self.assertIsInstance(loop.waker, _SocketedWaker) + else: + self.assertIsInstance(loop.waker, _PipedWaker) + finally: + loop.waker.close() + + def test_wakeable_loop_waker_closes_last(self): + dispatchers = OrderedDict() + loop = _WakeableLoop(dispatchers) # Waker comes first in the dict + + mock_dispatcher = MagicMock(readable=lambda: False, writeable=lambda: False) + dispatchers[loop.waker._fileno + 1] = mock_dispatcher + + original_close = loop.waker.close + + def assertion(): + mock_dispatcher.close.assert_called() + original_close() + + loop.waker.close = assertion + + loop.shutdown() + + @parameterized.expand(LOOP_CLASSES) + def test_check_loop(self, _, cls): + loop = cls({}) + # For the WakeableLoop, we are checking that + # the loop can be waken up, and once the reactor + # handles the written bytes, it is not awake + # anymore. Assertions are in the method + # implementation. For, the BasicLoop, this should + # be no-op, just checking it is not raising any + # error. + loop.check_loop() + + @parameterized.expand(LOOP_CLASSES) + def test_add_timer(self, _, cls): + call_count = AtomicInteger() + + def callback(): + call_count.add(1) + + loop = cls({}) + loop.start() + loop.add_timer(0, callback) # already expired, should be run immediately + + def assertion(): + self.assertEqual(1, call_count.get()) + + try: + self.assertTrueEventually(assertion) + finally: + loop.shutdown() + + @parameterized.expand(LOOP_CLASSES) + def test_timer_cleanup(self, _, cls): + call_count = AtomicInteger() + + def callback(): + call_count.add(1) + + loop = cls({}) + loop.start() + loop.add_timer(float('inf'), callback) # never expired, must be cleaned up + time.sleep(1) + try: + self.assertEqual(0, call_count.get()) + finally: + loop.shutdown() + + def assertion(): + self.assertEqual(1, call_count.get()) + + self.assertTrueEventually(assertion) + + @parameterized.expand(LOOP_CLASSES) + def test_timer_that_adds_another_timer(self, _, cls): + loop = cls({}) + loop.start() + + call_count = AtomicInteger() + + def callback(): + if call_count.get() == 0: + loop.add_timer(0, callback) + call_count.add(1) + + loop.add_timer(float('inf'), callback) + + loop.shutdown() + + def assertion(): + self.assertEqual(2, call_count.get()) # newly added timer must also be cleaned up + + self.assertTrueEventually(assertion) + + @parameterized.expand(LOOP_CLASSES) + def test_timer_that_shuts_down_loop(self, _, cls): + # It may be the case that, we want to shutdown the client(hence, the loop) in timers + loop = cls({}) + loop.start() + + loop.add_timer(0, lambda: loop.shutdown()) + + def assertion(): + self.assertFalse(loop._is_live) + + try: + self.assertTrueEventually(assertion) + finally: + loop.shutdown() # Should be no op + + +class SocketedWakerTest(HazelcastTestCase): + def setUp(self): + self.waker = _SocketedWaker({}) + + def tearDown(self): + try: + self.waker.close() + except: + pass + + def test_wake(self): + waker = self.waker + self.assertFalse(waker.awake) + waker.wake() + self.assertTrue(waker.awake) + self.assertEqual(b"x", waker._reader.recv(1)) + + def test_wake_while_awake(self): + waker = self.waker + waker.wake() + waker.wake() + self.assertTrue(waker.awake) + self.assertEqual(b"x", waker._reader.recv(2)) # only the first one should write + + def test_handle_read(self): + waker = self.waker + waker.wake() + self.assertTrue(waker.awake) + waker.handle_read() + self.assertFalse(waker.awake) + + with self.assertRaises((IOError, socket.error)): # BlockingIOError on Py3, socket.error on Py2 + waker._reader.recv(1) # handle_read should consume the socket, there should be nothing + + def test_close(self): + waker = self.waker + writer = waker._writer + reader = waker._reader + self.assertNotEqual(-1, writer.fileno()) + self.assertNotEqual(-1, reader.fileno()) + + waker.close() + + if six.PY3: + self.assertEqual(-1, writer.fileno()) + self.assertEqual(-1, reader.fileno()) + else: + # Closed sockets raise socket.error with EBADF error code in Python2 + with self.assertRaises(socket.error): + writer.fileno() + + with self.assertRaises(socket.error): + reader.fileno() + + +class PipedWakerTest(HazelcastTestCase): + def setUp(self): + self.waker = _PipedWaker({}) + + def tearDown(self): + try: + self.waker.close() + except: + pass + + def test_wake(self): + waker = self.waker + self.assertFalse(waker.awake) + waker.wake() + self.assertTrue(waker.awake) + self.assertEqual(b"x", os.read(waker._read_fd, 1)) + + def test_wake_while_awake(self): + waker = self.waker + waker.wake() + waker.wake() + self.assertTrue(waker.awake) + self.assertEqual(b"x", os.read(waker._read_fd, 2)) # only the first one should write + + def test_handle_read(self): + waker = self.waker + waker.wake() + self.assertTrue(waker.awake) + waker.handle_read() + self.assertFalse(waker.awake) + + if os.name == "nt": + return # pipes are not non-blocking on Windows, assertion below blocks forever on Windows + + with self.assertRaises((IOError, OSError)): # BlockingIOError on Py3, OSError on Py2 + os.read(waker._read_fd, 1) # handle_read should consume the pipe, there should be nothing + + def test_close(self): + waker = self.waker + w_fd = waker._write_fd + r_fd = waker._read_fd + self.assertEqual(1, os.write(w_fd, b"x")) + self.assertEqual(b"x", os.read(r_fd, 1)) + + waker.close() + + with self.assertRaises(OSError): + os.write(w_fd, b"x") + + with self.assertRaises(OSError): + os.read(r_fd, 1) diff --git a/tests/reconnect_test.py b/tests/reconnect_test.py index 08b25031f2..4a73157b0e 100644 --- a/tests/reconnect_test.py +++ b/tests/reconnect_test.py @@ -1,6 +1,5 @@ import time from threading import Thread -from time import sleep from hazelcast.errors import HazelcastError, TargetDisconnectedError from hazelcast.lifecycle import LifecycleState @@ -76,7 +75,6 @@ def test_listener_re_register(self): reg_id = map.add_entry_listener(added_func=collector) self.logger.info("Registered listener with id %s", reg_id) member.shutdown() - sleep(3) self.cluster.start_member() count = AtomicInteger()