diff --git a/hazelcast/connection.py b/hazelcast/connection.py index 5a0ce48da4..be289030a0 100644 --- a/hazelcast/connection.py +++ b/hazelcast/connection.py @@ -367,7 +367,8 @@ def _get_or_connect(self, address): translated, self._client.config, self._invocation_service.handle_client_message) except IOError: - return ImmediateExceptionFuture(sys.exc_info()[1], sys.exc_info()[2]) + error = sys.exc_info() + return ImmediateExceptionFuture(error[1], error[2]) future = self._authenticate(connection).continue_with(self._on_auth, connection, address) self._pending_connections[address] = future diff --git a/hazelcast/discovery.py b/hazelcast/discovery.py index ba6082ee9c..1a928094b2 100644 --- a/hazelcast/discovery.py +++ b/hazelcast/discovery.py @@ -1,15 +1,11 @@ import json import logging +import ssl from hazelcast.errors import HazelcastCertificationError from hazelcast.core import AddressHelper from hazelcast.six.moves import http_client -try: - import ssl -except ImportError: - ssl = None - _logger = logging.getLogger(__name__) @@ -93,8 +89,8 @@ def discover_nodes(self): context=self._ctx) https_connection.request(method="GET", url=self._url, headers={"Accept-Charset": "UTF-8"}) https_response = https_connection.getresponse() - except ssl.SSLError as ex: - raise HazelcastCertificationError(str(ex)) + except ssl.SSLError as err: + raise HazelcastCertificationError(str(err)) self._check_error(https_response) return self._parse_response(https_response) diff --git a/hazelcast/reactor.py b/hazelcast/reactor.py index c995533dbe..6e59561296 100644 --- a/hazelcast/reactor.py +++ b/hazelcast/reactor.py @@ -5,6 +5,7 @@ import os import select import socket +import ssl import sys import threading import time @@ -13,18 +14,12 @@ from functools import total_ordering from heapq import heappush, heappop -from hazelcast import six from hazelcast.config import SSLProtocol from hazelcast.connection import Connection from hazelcast.core import Address from hazelcast.errors import HazelcastError from hazelcast.future import Future -try: - import ssl -except ImportError: - ssl = None - try: import fcntl except ImportError: @@ -38,6 +33,23 @@ _logger = logging.getLogger(__name__) +# We should retry receiving/sending the message in case of these errors +# EAGAIN: Resource temporarily unavailable +# EWOULDBLOCK: The read/write would block +# EDEADLK: Was added before, retrying it just to make sure that +# client behaves the same on some edge cases. +# SSL_ERROR_WANT_READ/WRITE: The socket could not satisfy the +# needs of the SSL_read/write. During the negotiation process +# SSL_read/write may also want to write/read data, hence may also +# raise SSL_ERROR_WANT_WRITE/READ. +_RETRYABLE_ERROR_CODES = ( + errno.EAGAIN, + errno.EWOULDBLOCK, + errno.EDEADLK, + ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_WANT_READ +) + def _set_nonblocking(fd): if not fcntl: @@ -354,6 +366,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): sent_protocol_bytes = False receive_buffer_size = _BUFFER_SIZE send_buffer_size = _BUFFER_SIZE + _close_timer = None def __init__(self, reactor, connection_manager, connection_id, address, config, message_callback): @@ -364,87 +377,44 @@ def __init__(self, reactor, connection_manager, connection_id, address, self.connected_address = address self._write_queue = deque() self._write_buf = io.BytesIO() - self.create_socket(socket.AF_INET, socket.SOCK_STREAM) - - timeout = config.connection_timeout - if not timeout: - timeout = six.MAXSIZE - - self.socket.settimeout(timeout) - - # set tcp no delay - self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - # set socket buffer - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, _BUFFER_SIZE) - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, _BUFFER_SIZE) - - for level, option_name, value in config.socket_options: - if option_name is socket.SO_RCVBUF: - self.receive_buffer_size = value - elif option_name is socket.SO_SNDBUF: - self.send_buffer_size = value - self.socket.setsockopt(level, option_name, value) + self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + # set the socket timeout to 0 explicitly + self.socket.settimeout(0) + self._set_socket_options(config) + if config.ssl_enabled: + self._wrap_as_ssl_socket(config) self.connect((address.host, address.port)) - if ssl and config.ssl_enabled: - ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - - protocol = config.ssl_protocol - - # Use only the configured protocol - try: - if protocol != SSLProtocol.SSLv2: - ssl_context.options |= ssl.OP_NO_SSLv2 - if protocol != SSLProtocol.SSLv3: - ssl_context.options |= ssl.OP_NO_SSLv3 - if protocol != SSLProtocol.TLSv1: - ssl_context.options |= ssl.OP_NO_TLSv1 - if protocol != SSLProtocol.TLSv1_1: - ssl_context.options |= ssl.OP_NO_TLSv1_1 - if protocol != SSLProtocol.TLSv1_2: - ssl_context.options |= ssl.OP_NO_TLSv1_2 - if protocol != SSLProtocol.TLSv1_3: - ssl_context.options |= ssl.OP_NO_TLSv1_3 - except AttributeError: - pass - - ssl_context.verify_mode = ssl.CERT_REQUIRED - - if config.ssl_cafile: - ssl_context.load_verify_locations(config.ssl_cafile) - else: - ssl_context.load_default_certs() - - if config.ssl_certfile: - ssl_context.load_cert_chain(config.ssl_certfile, config.ssl_keyfile, config.ssl_password) - - if config.ssl_ciphers: - ssl_context.set_ciphers(config.ssl_ciphers) - - self.socket = ssl_context.wrap_socket(self.socket) - - # the socket should be non-blocking from now on - self.socket.settimeout(0) + timeout = config.connection_timeout + if timeout > 0: + self._close_timer = reactor.add_timer(timeout, self._close_timer_cb) self.local_address = Address(*self.socket.getsockname()) - self._write_queue.append(b"CP2") def handle_connect(self): + if self._close_timer: + self._close_timer.cancel() + self.start_time = time.time() _logger.debug("Connected to %s", self.connected_address) def handle_read(self): reader = self._reader receive_buffer_size = self.receive_buffer_size - while True: - data = self.recv(receive_buffer_size) - reader.read(data) - self.last_read_time = time.time() - if len(data) < receive_buffer_size: - break + try: + while True: + data = self.recv(receive_buffer_size) + reader.read(data) + self.last_read_time = time.time() + if len(data) < receive_buffer_size: + break + except socket.error as err: + if err.args[0] not in _RETRYABLE_ERROR_CODES: + # Other error codes are fatal, should close the connection + self.close(None, err) if reader.length: reader.process() @@ -476,25 +446,34 @@ def handle_write(self): bytes_ = buf.getvalue() buf.truncate(0) - sent = self.send(bytes_) - self.last_write_time = time.time() - self.sent_protocol_bytes = True + try: + sent = self.send(bytes_) + except socket.error as err: + if err.args[0] in _RETRYABLE_ERROR_CODES: + # Couldn't write the bytes but we should + # retry it. + self._write_queue.appendleft(bytes_) + else: + # Other error codes are fatal, should close the connection + self.close(None, err) + else: + # No exception is thrown during the send + self.last_write_time = time.time() + self.sent_protocol_bytes = True - if sent < len(bytes_): - write_queue.appendleft(bytes_[sent:]) + if sent < len(bytes_): + write_queue.appendleft(bytes_[sent:]) def handle_close(self): _logger.warning("Connection closed by server") self.close(None, IOError("Connection closed by server")) def handle_error(self): + # We handle retryable error codes inside the + # handle_read/write. Anything else should be fatal. error = sys.exc_info()[1] - if sys.exc_info()[0] is socket.error: - if error.errno != errno.EAGAIN and error.errno != errno.EDEADLK: - _logger.exception("Received error") - self.close(None, IOError(error)) - else: - _logger.exception("Received unexpected error: %s", error) + _logger.exception("Received error") + self.close(None, error) def readable(self): return self.live and self.sent_protocol_bytes @@ -507,9 +486,72 @@ def writable(self): return len(self._write_queue) > 0 def _inner_close(self): + if self._close_timer: + # It might be the case that connection + # is closed before the timer. If we are + # closing via the timer, this call has + # no effects. + self._close_timer.cancel() + asyncore.dispatcher.close(self) self._write_buf.close() + def _close_timer_cb(self): + if not self.connected: + self.close(None, IOError("Connection timed out")) + + def _set_socket_options(self, config): + # set tcp no delay + self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + # set socket buffer + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, _BUFFER_SIZE) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, _BUFFER_SIZE) + + for level, option_name, value in config.socket_options: + if option_name is socket.SO_RCVBUF: + self.receive_buffer_size = value + elif option_name is socket.SO_SNDBUF: + self.send_buffer_size = value + + self.socket.setsockopt(level, option_name, value) + + def _wrap_as_ssl_socket(self, config): + ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + + protocol = config.ssl_protocol + + # Use only the configured protocol + try: + if protocol != SSLProtocol.SSLv2: + ssl_context.options |= ssl.OP_NO_SSLv2 + if protocol != SSLProtocol.SSLv3: + ssl_context.options |= ssl.OP_NO_SSLv3 + if protocol != SSLProtocol.TLSv1: + ssl_context.options |= ssl.OP_NO_TLSv1 + if protocol != SSLProtocol.TLSv1_1: + ssl_context.options |= ssl.OP_NO_TLSv1_1 + if protocol != SSLProtocol.TLSv1_2: + ssl_context.options |= ssl.OP_NO_TLSv1_2 + if protocol != SSLProtocol.TLSv1_3: + ssl_context.options |= ssl.OP_NO_TLSv1_3 + except AttributeError: + pass + + ssl_context.verify_mode = ssl.CERT_REQUIRED + + if config.ssl_cafile: + ssl_context.load_verify_locations(config.ssl_cafile) + else: + ssl_context.load_default_certs() + + if config.ssl_certfile: + ssl_context.load_cert_chain(config.ssl_certfile, config.ssl_keyfile, config.ssl_password) + + if config.ssl_ciphers: + ssl_context.set_ciphers(config.ssl_ciphers) + + self.socket = ssl_context.wrap_socket(self.socket) + def __repr__(self): return "Connection(id=%s, live=%s, remote_address=%s)" % (self._id, self.live, self.remote_address) diff --git a/tests/reactor_test.py b/tests/reactor_test.py index ce7fdb8025..869b4348e0 100644 --- a/tests/reactor_test.py +++ b/tests/reactor_test.py @@ -9,6 +9,7 @@ from hazelcast import six from hazelcast.config import _Config +from hazelcast.core import Address from hazelcast.reactor import AsyncoreReactor, _WakeableLoop, _SocketedWaker, _PipedWaker, _BasicLoop, \ AsyncoreConnection from hazelcast.util import AtomicInteger @@ -319,3 +320,15 @@ def test_send_buffer_size(self): self.assertEqual(size, conn.send_buffer_size) finally: conn._inner_close() + + def test_constructor_with_unreachable_addresses(self): + addr = Address("192.168.0.1", 5701) + config = _Config() + start = time.time() + conn = AsyncoreConnection(MagicMock(map=dict()), MagicMock(), None, addr, config, None) + try: + # Server is unreachable, but this call should return + # before connection timeout + self.assertLess(time.time() - start, config.connection_timeout) + finally: + conn.close(None, None) diff --git a/tests/ssl/README.md b/tests/ssl_tests/README.md similarity index 100% rename from tests/ssl/README.md rename to tests/ssl_tests/README.md diff --git a/tests/ssl/__init__.py b/tests/ssl_tests/__init__.py similarity index 100% rename from tests/ssl/__init__.py rename to tests/ssl_tests/__init__.py diff --git a/tests/ssl/client1-cert.pem b/tests/ssl_tests/client1-cert.pem similarity index 100% rename from tests/ssl/client1-cert.pem rename to tests/ssl_tests/client1-cert.pem diff --git a/tests/ssl/client1-key.pem b/tests/ssl_tests/client1-key.pem similarity index 100% rename from tests/ssl/client1-key.pem rename to tests/ssl_tests/client1-key.pem diff --git a/tests/ssl/client2-cert.pem b/tests/ssl_tests/client2-cert.pem similarity index 100% rename from tests/ssl/client2-cert.pem rename to tests/ssl_tests/client2-cert.pem diff --git a/tests/ssl/client2-key.pem b/tests/ssl_tests/client2-key.pem similarity index 100% rename from tests/ssl/client2-key.pem rename to tests/ssl_tests/client2-key.pem diff --git a/tests/ssl/hazelcast-default-ca.xml b/tests/ssl_tests/hazelcast-default-ca.xml similarity index 100% rename from tests/ssl/hazelcast-default-ca.xml rename to tests/ssl_tests/hazelcast-default-ca.xml diff --git a/tests/ssl/hazelcast-ma-optional.xml b/tests/ssl_tests/hazelcast-ma-optional.xml similarity index 100% rename from tests/ssl/hazelcast-ma-optional.xml rename to tests/ssl_tests/hazelcast-ma-optional.xml diff --git a/tests/ssl/hazelcast-ma-required.xml b/tests/ssl_tests/hazelcast-ma-required.xml similarity index 100% rename from tests/ssl/hazelcast-ma-required.xml rename to tests/ssl_tests/hazelcast-ma-required.xml diff --git a/tests/ssl/hazelcast-ssl.xml b/tests/ssl_tests/hazelcast-ssl.xml similarity index 100% rename from tests/ssl/hazelcast-ssl.xml rename to tests/ssl_tests/hazelcast-ssl.xml diff --git a/tests/ssl/mutual_authentication_test.py b/tests/ssl_tests/mutual_authentication_test.py similarity index 100% rename from tests/ssl/mutual_authentication_test.py rename to tests/ssl_tests/mutual_authentication_test.py diff --git a/tests/ssl/server1-cert.pem b/tests/ssl_tests/server1-cert.pem similarity index 100% rename from tests/ssl/server1-cert.pem rename to tests/ssl_tests/server1-cert.pem diff --git a/tests/ssl/server2-cert.pem b/tests/ssl_tests/server2-cert.pem similarity index 100% rename from tests/ssl/server2-cert.pem rename to tests/ssl_tests/server2-cert.pem diff --git a/tests/ssl/ssl_test.py b/tests/ssl_tests/ssl_test.py similarity index 100% rename from tests/ssl/ssl_test.py rename to tests/ssl_tests/ssl_test.py