diff --git a/hazelcast/internal/asyncio_connection.py b/hazelcast/internal/asyncio_connection.py index f89bd39db7..f747a53110 100644 --- a/hazelcast/internal/asyncio_connection.py +++ b/hazelcast/internal/asyncio_connection.py @@ -340,7 +340,6 @@ async def on_connection_close(self, closed_connection): if removed: async with asyncio.TaskGroup() as tg: - # TODO: see on_connection_open for _, on_connection_closed in self._connection_listeners: if on_connection_closed: try: @@ -395,13 +394,12 @@ async def _get_or_connect_to_member(self, member): translated = self._translate_member_address(member) connection = await self._create_connection(translated) - response = await self._authenticate(connection) # .continue_with(self._on_auth, connection) + response = await self._authenticate(connection) await self._on_auth(response, connection) return connection async def _create_connection(self, address): - factory = self._reactor.connection_factory - return await factory( + return await self._reactor.connection_factory( self, self._connection_id_generator.get_and_increment(), address, @@ -473,7 +471,6 @@ async def run(): connecting_uuids.add(member_uuid) if not self._lifecycle_service.running: break - # TODO: ERROR:asyncio:Task was destroyed but it is pending! tg.create_task(self._get_or_connect_to_member(member)) member_uuids.append(member_uuid) @@ -706,8 +703,6 @@ async def _handle_successful_auth(self, response, connection): for on_connection_opened, _ in self._connection_listeners: if on_connection_opened: try: - # TODO: creating the task may not throw the exception - # TODO: protect the loop against exceptions, so all handlers run maybe_coro = on_connection_opened(connection) if isinstance(maybe_coro, Coroutine): tg.create_task(maybe_coro) diff --git a/hazelcast/internal/asyncio_reactor.py b/hazelcast/internal/asyncio_reactor.py index 92888c103e..a44d656449 100644 --- a/hazelcast/internal/asyncio_reactor.py +++ b/hazelcast/internal/asyncio_reactor.py @@ -1,9 +1,11 @@ import asyncio import io import logging +import ssl import time from asyncio import AbstractEventLoop, transports +from hazelcast.config import Config, SSLProtocol from hazelcast.internal.asyncio_connection import Connection from hazelcast.core import Address @@ -83,25 +85,28 @@ async def create_and_connect( this = cls( loop, reactor, connection_manager, connection_id, address, config, message_callback ) - if this._config.ssl_enabled: - await this._create_ssl_connection() - else: - await this._create_connection() + await this._create_connection(config, address) return this def _create_protocol(self): return HazelcastProtocol(self) - async def _create_connection(self): - loop = self._loop - res = await loop.create_connection( - self._create_protocol, host=self._address.host, port=self._address.port + async def _create_connection(self, config, address): + ssl_context = None + if config.ssl_enabled: + ssl_context = self._create_ssl_context(config) + server_hostname = None + if config.ssl_check_hostname: + server_hostname = address.host + res = await self._loop.create_connection( + self._create_protocol, + host=self._address.host, + port=self._address.port, + ssl=ssl_context, + server_hostname=server_hostname, ) _sock, self._proto = res - async def _create_ssl_connection(self): - raise NotImplementedError - def _write(self, buf): self._proto.write(buf) @@ -120,6 +125,42 @@ def _update_sent(self, sent): def _update_received(self, received): self._reactor.update_bytes_received(received) + def _create_ssl_context(self, config: Config): + ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + protocol = config.ssl_protocol + # Use only the configured protocol + try: + if protocol != SSLProtocol.SSLv2: + ssl_context.options |= ssl.OP_NO_SSLv2 + if protocol != SSLProtocol.SSLv3: + ssl_context.options |= ssl.OP_NO_SSLv3 + if protocol != SSLProtocol.TLSv1: + ssl_context.options |= ssl.OP_NO_TLSv1 + if protocol != SSLProtocol.TLSv1_1: + ssl_context.options |= ssl.OP_NO_TLSv1_1 + if protocol != SSLProtocol.TLSv1_2: + ssl_context.options |= ssl.OP_NO_TLSv1_2 + if protocol != SSLProtocol.TLSv1_3: + ssl_context.options |= ssl.OP_NO_TLSv1_3 + except AttributeError: + pass + + ssl_context.verify_mode = ssl.CERT_REQUIRED + if config.ssl_cafile: + ssl_context.load_verify_locations(config.ssl_cafile) + else: + ssl_context.load_default_certs() + if config.ssl_certfile: + ssl_context.load_cert_chain( + config.ssl_certfile, config.ssl_keyfile, config.ssl_password + ) + if config.ssl_ciphers: + ssl_context.set_ciphers(config.ssl_ciphers) + if config.ssl_check_hostname: + ssl_context.check_hostname = True + + return ssl_context + class HazelcastProtocol(asyncio.BufferedProtocol): diff --git a/tests/integration/asyncio/ssl_tests/__init__.py b/tests/integration/asyncio/ssl_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/asyncio/ssl_tests/hostname_verification/__init__.py b/tests/integration/asyncio/ssl_tests/hostname_verification/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/asyncio/ssl_tests/hostname_verification/ssl_hostname_verification_test.py b/tests/integration/asyncio/ssl_tests/hostname_verification/ssl_hostname_verification_test.py new file mode 100644 index 0000000000..87fff817ec --- /dev/null +++ b/tests/integration/asyncio/ssl_tests/hostname_verification/ssl_hostname_verification_test.py @@ -0,0 +1,135 @@ +import os +import sys +import unittest + +import pytest + +from hazelcast.asyncio.client import HazelcastClient +from hazelcast.config import SSLProtocol +from hazelcast.errors import IllegalStateError +from tests.integration.asyncio.base import HazelcastTestCase +from tests.util import compare_client_version, get_abs_path + +current_directory = os.path.abspath( + os.path.join( + os.path.dirname(__file__), "../../../backward_compatible/ssl_tests/hostname_verification" + ) +) + +MEMBER_CONFIG = """ + + + + + com.hazelcast.nio.ssl.BasicSSLContextFactory + + + %s + 123456 + PKCS12 + TLSv1.2 + + + + +""" + + +@unittest.skipIf( + sys.version_info < (3, 7), + "Hostname verification feature requires Python 3.7+", +) +@unittest.skipIf( + compare_client_version("5.1") < 0, + "Tests the features added in 5.1 version of the client", +) +@pytest.mark.enterprise +class SslHostnameVerificationTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + def setUp(self): + self.rc = self.create_rc() + self.cluster = None + + async def asyncTearDown(self): + await self.shutdown_all_clients() + self.rc.terminateCluster(self.cluster.id) + self.rc.exit() + + async def test_hostname_verification_with_loopback_san(self): + # SAN entry is present with different possible values + file_name = "tls-host-loopback-san" + self.start_member_with(f"{file_name}.p12") + await self.start_client_with(f"{file_name}.pem", "127.0.0.1:5701") + await self.start_client_with(f"{file_name}.pem", "localhost:5701") + + async def test_hostname_verification_with_loopback_dns_san(self): + # SAN entry is present, but only with `dns:localhost` + file_name = "tls-host-loopback-san-dns" + self.start_member_with(f"{file_name}.p12") + await self.start_client_with(f"{file_name}.pem", "localhost:5701") + with self.assertRaisesRegex(IllegalStateError, "Unable to connect to any cluster"): + await self.start_client_with(f"{file_name}.pem", "127.0.0.1:5701") + + async def test_hostname_verification_with_different_san(self): + # There is a valid entry, but it does not match with the address of the member. + file_name = "tls-host-not-our-san" + self.start_member_with(f"{file_name}.p12") + with self.assertRaisesRegex(IllegalStateError, "Unable to connect to any cluster"): + await self.start_client_with(f"{file_name}.pem", "localhost:5701") + with self.assertRaisesRegex(IllegalStateError, "Unable to connect to any cluster"): + await self.start_client_with(f"{file_name}.pem", "127.0.0.1:5701") + + async def test_hostname_verification_with_loopback_cn(self): + # No entry in SAN but an entry in CN which checked as a fallback + # when no entry in SAN is present. + file_name = "tls-host-loopback-cn" + self.start_member_with(f"{file_name}.p12") + await self.start_client_with(f"{file_name}.pem", "localhost:5701") + # See https://stackoverflow.com/a/8444863/12394291. IP addresses in CN + # are not supported. So, we don't have a test for it. + with self.assertRaisesRegex(IllegalStateError, "Unable to connect to any cluster"): + await self.start_client_with(f"{file_name}.pem", "127.0.0.1:5701") + + async def test_hostname_verification_with_no_entry(self): + # No entry either in the SAN or CN. No way to verify hostname. + file_name = "tls-host-no-entry" + self.start_member_with(f"{file_name}.p12") + with self.assertRaisesRegex(IllegalStateError, "Unable to connect to any cluster"): + await self.start_client_with(f"{file_name}.pem", "localhost:5701") + with self.assertRaisesRegex(IllegalStateError, "Unable to connect to any cluster"): + await self.start_client_with(f"{file_name}.pem", "127.0.0.1:5701") + + async def test_hostname_verification_disabled(self): + # When hostname verification is disabled, the scenarious that + # would fail in `test_hostname_verification_with_no_entry` will + # no longer fail, showing that it is working as expected. + file_name = "tls-host-no-entry" + self.start_member_with(f"{file_name}.p12") + await self.start_client_with(f"{file_name}.pem", "localhost:5701", check_hostname=False) + await self.start_client_with(f"{file_name}.pem", "127.0.0.1:5701", check_hostname=False) + + async def start_client_with( + self, + truststore_name: str, + member_address: str, + *, + check_hostname=True, + ) -> HazelcastClient: + return await self.create_client( + { + "cluster_name": self.cluster.id, + "cluster_members": [member_address], + "ssl_enabled": True, + "ssl_protocol": SSLProtocol.TLSv1_2, + "ssl_cafile": get_abs_path(current_directory, truststore_name), + "ssl_check_hostname": check_hostname, + "cluster_connect_timeout": 0, + } + ) + + def start_member_with(self, keystore_name: str) -> None: + config = MEMBER_CONFIG % get_abs_path(current_directory, keystore_name) + self.cluster = self.create_cluster(self.rc, config) + self.cluster.start_member() diff --git a/tests/integration/asyncio/ssl_tests/mutual_authentication_test.py b/tests/integration/asyncio/ssl_tests/mutual_authentication_test.py new file mode 100644 index 0000000000..2d392278d2 --- /dev/null +++ b/tests/integration/asyncio/ssl_tests/mutual_authentication_test.py @@ -0,0 +1,169 @@ +import os +import unittest + +import pytest + +from tests.integration.asyncio.base import HazelcastTestCase +from hazelcast.asyncio.client import HazelcastClient +from hazelcast.errors import HazelcastError +from tests.util import get_ssl_config, get_abs_path + + +@pytest.mark.enterprise +class MutualAuthenticationTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + current_directory = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../backward_compatible/ssl_tests") + ) + rc = None + mutual_auth = True + ma_req_xml = get_abs_path(current_directory, "hazelcast-ma-required.xml") + ma_opt_xml = get_abs_path(current_directory, "hazelcast-ma-optional.xml") + + def setUp(self): + self.rc = self.create_rc() + + def tearDown(self): + self.rc.exit() + + async def test_ma_required_client_and_server_authenticated(self): + cluster = self.create_cluster(self.rc, self.read_config(True)) + cluster.start_member() + client = await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server1-cert.pem"), + get_abs_path(self.current_directory, "client1-cert.pem"), + get_abs_path(self.current_directory, "client1-key.pem"), + ) + ) + self.assertTrue(client.lifecycle_service.is_running()) + await client.shutdown() + + async def test_ma_required_server_not_authenticated(self): + cluster = self.create_cluster(self.rc, self.read_config(True)) + cluster.start_member() + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server2-cert.pem"), + get_abs_path(self.current_directory, "client1-cert.pem"), + get_abs_path(self.current_directory, "client1-key.pem"), + ) + ) + + async def test_ma_required_client_not_authenticated(self): + cluster = self.create_cluster(self.rc, self.read_config(True)) + cluster.start_member() + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server1-cert.pem"), + get_abs_path(self.current_directory, "client2-cert.pem"), + get_abs_path(self.current_directory, "client2-key.pem"), + ) + ) + + async def test_ma_required_client_and_server_not_authenticated(self): + cluster = self.create_cluster(self.rc, self.read_config(True)) + cluster.start_member() + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server2-cert.pem"), + get_abs_path(self.current_directory, "client2-cert.pem"), + get_abs_path(self.current_directory, "client2-key.pem"), + ) + ) + + async def test_ma_optional_client_and_server_authenticated(self): + cluster = self.create_cluster(self.rc, self.read_config(False)) + cluster.start_member() + client = await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server1-cert.pem"), + get_abs_path(self.current_directory, "client1-cert.pem"), + get_abs_path(self.current_directory, "client1-key.pem"), + ) + ) + self.assertTrue(client.lifecycle_service.is_running()) + await client.shutdown() + + async def test_ma_optional_server_not_authenticated(self): + cluster = self.create_cluster(self.rc, self.read_config(False)) + cluster.start_member() + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server2-cert.pem"), + get_abs_path(self.current_directory, "client1-cert.pem"), + get_abs_path(self.current_directory, "client1-key.pem"), + ) + ) + + async def test_ma_optional_client_not_authenticated(self): + cluster = self.create_cluster(self.rc, self.read_config(False)) + cluster.start_member() + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server1-cert.pem"), + get_abs_path(self.current_directory, "client2-cert.pem"), + get_abs_path(self.current_directory, "client2-key.pem"), + ) + ) + + async def test_ma_optional_client_and_server_not_authenticated(self): + cluster = self.create_cluster(self.rc, self.read_config(False)) + cluster.start_member() + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server2-cert.pem"), + get_abs_path(self.current_directory, "client2-cert.pem"), + get_abs_path(self.current_directory, "client2-key.pem"), + ) + ) + + async def test_ma_required_with_no_cert_file(self): + cluster = self.create_cluster(self.rc, self.read_config(True)) + cluster.start_member() + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, True, get_abs_path(self.current_directory, "server1-cert.pem") + ) + ) + + async def test_ma_optional_with_no_cert_file(self): + cluster = self.create_cluster(self.rc, self.read_config(False)) + cluster.start_member() + client = await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, True, get_abs_path(self.current_directory, "server1-cert.pem") + ) + ) + self.assertTrue(client.lifecycle_service.is_running()) + await client.shutdown() + + def read_config(self, is_ma_required): + file_path = self.ma_req_xml if is_ma_required else self.ma_opt_xml + with open(file_path, "r") as f: + xml_config = f.read() + keystore_path = get_abs_path(self.current_directory, "server1.keystore") + truststore_path = get_abs_path(self.current_directory, "server1.truststore") + return xml_config % (keystore_path, truststore_path) diff --git a/tests/integration/asyncio/ssl_tests/ssl_test.py b/tests/integration/asyncio/ssl_tests/ssl_test.py new file mode 100644 index 0000000000..6190c7571a --- /dev/null +++ b/tests/integration/asyncio/ssl_tests/ssl_test.py @@ -0,0 +1,134 @@ +import os +import unittest + +import pytest + +from tests.integration.asyncio.base import HazelcastTestCase +from hazelcast.asyncio.client import HazelcastClient +from hazelcast.errors import HazelcastError +from hazelcast.config import SSLProtocol +from tests.util import get_ssl_config, get_abs_path +from tests.integration.asyncio.util import fill_map + + +@pytest.mark.enterprise +class SSLTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + current_directory = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../backward_compatible/ssl_tests") + ) + rc = None + hazelcast_ssl_xml = get_abs_path(current_directory, "hazelcast-ssl.xml") + default_ca_xml = get_abs_path(current_directory, "hazelcast-default-ca.xml") + + def setUp(self): + self.rc = self.create_rc() + + def tearDown(self): + self.rc.exit() + + async def test_ssl_disabled(self): + cluster = self.create_cluster(self.rc, self.read_ssl_config()) + cluster.start_member() + + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start(**get_ssl_config(cluster.id, False)) + + async def test_ssl_enabled_is_client_live(self): + cluster = self.create_cluster(self.rc, self.read_ssl_config()) + cluster.start_member() + + client = await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, True, get_abs_path(self.current_directory, "server1-cert.pem") + ) + ) + self.assertTrue(client.lifecycle_service.is_running()) + await client.shutdown() + + async def test_ssl_enabled_trust_default_certificates(self): + cluster = self.create_cluster(self.rc, self.read_default_ca_config()) + cluster.start_member() + + client = await HazelcastClient.create_and_start(**get_ssl_config(cluster.id, True)) + self.assertTrue(client.lifecycle_service.is_running()) + await client.shutdown() + + async def test_ssl_enabled_dont_trust_self_signed_certificates(self): + # Member started with self-signed certificate + cluster = self.create_cluster(self.rc, self.read_ssl_config()) + cluster.start_member() + + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start(**get_ssl_config(cluster.id, True)) + + async def test_ssl_enabled_map_size(self): + cluster = self.create_cluster(self.rc, self.read_ssl_config()) + cluster.start_member() + + client = await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, True, get_abs_path(self.current_directory, "server1-cert.pem") + ) + ) + test_map = await client.get_map("test_map") + await fill_map(test_map, 10) + self.assertEqual(await test_map.size(), 10) + await client.shutdown() + + async def test_ssl_enabled_with_custom_ciphers(self): + cluster = self.create_cluster(self.rc, self.read_ssl_config()) + cluster.start_member() + + client = await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server1-cert.pem"), + ciphers="ECDHE-RSA-AES128-SHA256:ECDHE-RSA-AES256-GCM-SHA384", + ) + ) + self.assertTrue(client.lifecycle_service.is_running()) + await client.shutdown() + + async def test_ssl_enabled_with_invalid_ciphers(self): + cluster = self.create_cluster(self.rc, self.read_ssl_config()) + cluster.start_member() + + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server1-cert.pem"), + ciphers="INVALID-CIPHER1:INVALID_CIPHER2", + ) + ) + + async def test_ssl_enabled_with_protocol_mismatch(self): + cluster = self.create_cluster(self.rc, self.read_ssl_config()) + cluster.start_member() + + # Member configured with TLSv1 + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server1-cert.pem"), + protocol=SSLProtocol.SSLv3, + ) + ) + + def read_default_ca_config(self): + with open(self.default_ca_xml, "r") as f: + xml_config = f.read() + + keystore_path = get_abs_path(self.current_directory, "keystore.jks") + return xml_config % (keystore_path, keystore_path) + + def read_ssl_config(self): + with open(self.hazelcast_ssl_xml, "r") as f: + xml_config = f.read() + + keystore_path = get_abs_path(self.current_directory, "server1.keystore") + return xml_config % keystore_path