diff --git a/Makefile b/Makefile index e1032e5525..f783524c4c 100644 --- a/Makefile +++ b/Makefile @@ -10,5 +10,8 @@ test: test-enterprise: pytest --verbose +test-asyncio: + pytest --verbose -k asyncio + test-cover: pytest --cov=hazelcast --cov-report=xml diff --git a/examples/asyncio/cloud-discovery/hazelcast_cloud_discovery_example.py b/examples/asyncio/cloud-discovery/hazelcast_cloud_discovery_example.py new file mode 100644 index 0000000000..e5098dfa84 --- /dev/null +++ b/examples/asyncio/cloud-discovery/hazelcast_cloud_discovery_example.py @@ -0,0 +1,28 @@ +import asyncio + +from hazelcast.asyncio import HazelcastClient + + +async def amain(): + client = await HazelcastClient.create_and_start( + # Set up cluster name for authentication + cluster_name="asyncio", + # Set the token of your cloud cluster + cloud_discovery_token="wE1w1USF6zOnaLVjLZwbZHxEoZJhw43yyViTbe6UBTvz4tZniA", + ssl_enabled=True, + ssl_cafile="/path/to/ca.pem", + ssl_certfile="/path/to/cert.pem", + ssl_keyfile="/path/to/key.pem", + ssl_password="05dd4498c3f", + ) + my_map = await client.get_map("map-on-the-cloud") + await my_map.put("key", "value") + + value = await my_map.get("key") + print(value) + + await client.shutdown() + + +if __name__ == "__main__": + asyncio.run(amain()) diff --git a/examples/cloud-discovery/hazelcast_cloud_discovery_example.py b/examples/cloud-discovery/hazelcast_cloud_discovery_example.py index de7eca0460..2f478fc000 100644 --- a/examples/cloud-discovery/hazelcast_cloud_discovery_example.py +++ b/examples/cloud-discovery/hazelcast_cloud_discovery_example.py @@ -5,8 +5,6 @@ cluster_name="YOUR_CLUSTER_NAME", # Set the token of your cloud cluster cloud_discovery_token="YOUR_CLUSTER_DISCOVERY_TOKEN", - # If you have enabled encryption for your cluster, also configure TLS/SSL for the client. - # Otherwise, skip options below. ssl_enabled=True, ssl_cafile="/path/to/ca.pem", ssl_certfile="/path/to/cert.pem", diff --git a/hazelcast/__init__.py b/hazelcast/__init__.py index d40f2659da..a60e982958 100644 --- a/hazelcast/__init__.py +++ b/hazelcast/__init__.py @@ -1,4 +1,4 @@ -__version__ = "6.0.0" +__version__ = "5.6.0" # Set the default handler to "hazelcast" loggers # to avoid "No handlers could be found" warnings. diff --git a/hazelcast/asyncio/client.py b/hazelcast/asyncio/client.py index 1e08bab92f..db7106cdfd 100644 --- a/hazelcast/asyncio/client.py +++ b/hazelcast/asyncio/client.py @@ -6,7 +6,7 @@ from hazelcast.internal.asyncio_cluster import ClusterService, _InternalClusterService from hazelcast.internal.asyncio_compact import CompactSchemaService from hazelcast.config import Config, IndexConfig -from hazelcast.internal.asyncio_connection import ConnectionManager, DefaultAddressProvider +from hazelcast.internal.asyncio_connection import ConnectionManager, DefaultAsyncioAddressProvider from hazelcast.core import DistributedObjectEvent, DistributedObjectInfo from hazelcast.discovery import HazelcastCloudAddressProvider from hazelcast.errors import IllegalStateError, InvalidConfigurationError @@ -313,7 +313,7 @@ def _create_address_provider(self): connection_timeout = self._get_connection_timeout(config) return HazelcastCloudAddressProvider(cloud_discovery_token, connection_timeout) - return DefaultAddressProvider(cluster_members) + return DefaultAsyncioAddressProvider(cluster_members) def _create_client_name(self, client_id): client_name = self._config.client_name diff --git a/hazelcast/internal/asyncio_connection.py b/hazelcast/internal/asyncio_connection.py index d008011477..d1d4aa9a90 100644 --- a/hazelcast/internal/asyncio_connection.py +++ b/hazelcast/internal/asyncio_connection.py @@ -183,7 +183,7 @@ def __init__( self._cluster_id = None self._load_balancer = None self._use_public_ip = ( - isinstance(address_provider, DefaultAddressProvider) and config.use_public_ip + isinstance(address_provider, DefaultAsyncioAddressProvider) and config.use_public_ip ) # asyncio tasks are weakly referenced # storing tasks here in order not to lose them midway @@ -385,9 +385,10 @@ async def _get_or_connect_to_address(self, address): for connection in list(self.active_connections.values()): if connection.remote_address == address: return connection - translated = self._translate(address) - connection = await self._create_connection(translated) - response = await self._authenticate(connection) + translated = await self._translate(address) + connection = self._create_connection(translated) + await connection._create_task + response = self._authenticate(connection) await self._on_auth(response, connection) return connection @@ -396,14 +397,15 @@ async def _get_or_connect_to_member(self, member): if connection: return connection - translated = self._translate_member_address(member) - connection = await self._create_connection(translated) - response = await self._authenticate(connection) + translated = await self._translate_member_address(member) + connection = self._create_connection(translated) + await connection._create_task + response = self._authenticate(connection) await self._on_auth(response, connection) return connection - async def _create_connection(self, address): - return await self._reactor.connection_factory( + def _create_connection(self, address): + return self._reactor.connection_factory( self, self._connection_id_generator.get_and_increment(), address, @@ -411,8 +413,8 @@ async def _create_connection(self, address): self._invocation_service.handle_client_message, ) - def _translate(self, address): - translated = self._address_provider.translate(address) + async def _translate(self, address): + translated = await self._address_provider.translate(address) if not translated: raise ValueError( "Address provider %s could not translate address %s" @@ -421,7 +423,7 @@ def _translate(self, address): return translated - def _translate_member_address(self, member): + async def _translate_member_address(self, member): if self._use_public_ip: public_address = member.address_map.get(_CLIENT_PUBLIC_ENDPOINT_QUALIFIER, None) if public_address: @@ -429,7 +431,7 @@ def _translate_member_address(self, member): return member.address - return self._translate(member.address) + return await self._translate(member.address) async def _trigger_cluster_reconnection(self): if self._reconnect_mode == ReconnectMode.OFF: @@ -529,7 +531,8 @@ async def _sync_connect_to_cluster(self): if connection: return - for address in self._get_possible_addresses(): + addresses = await self._get_possible_addresses() + for address in addresses: self._check_client_active() if address in tried_addresses_per_attempt: # We already tried this address on from the member list @@ -614,6 +617,7 @@ def _authenticate(self, connection) -> asyncio.Future: async def _on_auth(self, response, connection): try: + response = await response response = client_authentication_codec.decode_response(response) except Exception as e: await connection.close_connection("Failed to authenticate connection", e) @@ -790,8 +794,8 @@ def _check_client_active(self): if not self._lifecycle_service.running: raise HazelcastClientNotActiveError() - def _get_possible_addresses(self): - primaries, secondaries = self._address_provider.load_addresses() + async def _get_possible_addresses(self): + primaries, secondaries = await self._address_provider.load_addresses() if self._shuffle_member_list: # The relative order between primary and secondary addresses should # not be changed. So we shuffle the lists separately and then add @@ -1028,17 +1032,13 @@ def __hash__(self): return self._id -class DefaultAddressProvider: - """Provides initial addresses for client to find and connect to a node. - - It also provides a no-op translator. - """ - +class DefaultAsyncioAddressProvider: def __init__(self, addresses): self._addresses = addresses - def load_addresses(self): + async def load_addresses(self): """Returns the possible primary and secondary member addresses to connect to.""" + # NOTE: This method is marked with async since the caller assumes that. configured_addresses = self._addresses if not configured_addresses: @@ -1053,9 +1053,10 @@ def load_addresses(self): return primaries, secondaries - def translate(self, address): + async def translate(self, address): """No-op address translator. It is there to provide the same API with other address providers. """ + # NOTE: This method is marked with async since the caller assumes that. return address diff --git a/hazelcast/internal/asyncio_discovery.py b/hazelcast/internal/asyncio_discovery.py new file mode 100644 index 0000000000..dcd890e8fb --- /dev/null +++ b/hazelcast/internal/asyncio_discovery.py @@ -0,0 +1,58 @@ +import asyncio +import logging + +from hazelcast.discovery import HazelcastCloudDiscovery + +_logger = logging.getLogger(__name__) + + +class HazelcastCloudAddressProvider: + """Provides initial addresses for client to find and connect to a node + and resolves private IP addresses of Hazelcast Cloud service. + """ + + def __init__(self, token, connection_timeout): + self.cloud_discovery = HazelcastCloudDiscovery(token, connection_timeout) + self._private_to_public = dict() + + async def load_addresses(self): + """Loads member addresses from Hazelcast Cloud endpoint. + + Returns: + tuple[list[hazelcast.core.Address], list[hazelcast.core.Address]]: The possible member addresses + as primary addresses to connect to. + """ + try: + nodes = await asyncio.to_thread(self.cloud_discovery.discover_nodes) + # Every private address is primary + return list(nodes.keys()), [] + except Exception as e: + _logger.warning("Failed to load addresses from Hazelcast Cloud: %s", e) + return [], [] + + async def translate(self, address): + """Translates the given address to another address specific to network or service. + + Args: + address (hazelcast.core.Address): Private address to be translated + + Returns: + hazelcast.core.Address: New address if given address is known, otherwise returns None + """ + if address is None: + return None + + public_address = self._private_to_public.get(address, None) + if public_address: + return public_address + + await self.refresh() + + return self._private_to_public.get(address, None) + + async def refresh(self): + """Refreshes the internal lookup table if necessary.""" + try: + self._private_to_public = self.cloud_discovery.discover_nodes() + except Exception as e: + _logger.warning("Failed to load addresses from Hazelcast Cloud: %s", e) diff --git a/hazelcast/internal/asyncio_reactor.py b/hazelcast/internal/asyncio_reactor.py index 7311f10bfc..0802088e60 100644 --- a/hazelcast/internal/asyncio_reactor.py +++ b/hazelcast/internal/asyncio_reactor.py @@ -1,8 +1,12 @@ import asyncio +import errno import io import logging +import os +import socket import ssl import time +from errno import errorcode from asyncio import AbstractEventLoop, transports from hazelcast.config import Config, SSLProtocol @@ -24,10 +28,10 @@ def __init__(self, loop: AbstractEventLoop | None = None): def add_timer(self, delay, callback): return self._loop.call_later(delay, callback) - async def connection_factory( + def connection_factory( self, connection_manager, connection_id, address: Address, network_config, message_callback ): - return await AsyncioConnection.create_and_connect( + return AsyncioConnection.create_and_connect( self._loop, self, connection_manager, @@ -62,9 +66,15 @@ def __init__( self._config = config self._proto = None self.connected_address = address + self._preconn_buffers: list = [] + self._create_task: asyncio.Task | None = None + self._close_task: asyncio.Task | None = None + self._connected = False + self._receive_buffer_size = _BUFFER_SIZE + self._sock = None @classmethod - async def create_and_connect( + def create_and_connect( cls, loop, reactor: AsyncioReactor, @@ -77,26 +87,48 @@ async def create_and_connect( this = cls( loop, reactor, connection_manager, connection_id, address, config, message_callback ) - await this._create_connection(config, address) + this._create_task = asyncio.create_task(this._create_connection(config, address)) + if config.connection_timeout > 0: + this._close_task = asyncio.create_task(this._close_timer_cb(config.connection_timeout)) return this def _create_protocol(self): return HazelcastProtocol(self) async def _create_connection(self, config, address): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(False) + sock.settimeout(0) + self._set_socket_options(sock, config) + server_hostname = None ssl_context = None if config.ssl_enabled: - ssl_context = self._create_ssl_context(config) - server_hostname = None - if config.ssl_check_hostname: server_hostname = address.host + ssl_context = self._create_ssl_context(config) + + try: + self.connect(sock, (address.host, address.port)) + except socket.error as e: + self._inner_close() + raise e + + self._sock = sock + res = await self._loop.create_connection( self._create_protocol, - host=self._address.host, - port=self._address.port, ssl=ssl_context, server_hostname=server_hostname, + sock=sock, ) + try: + sock.getpeername() + except OSError as err: + if err.errno not in (errno.ENOTCONN, errno.EINVAL): + raise + self._connected = False + else: + self._connected = True + sock, self._proto = res if hasattr(sock, "_ssl_protocol"): sock = sock._ssl_protocol._transport._sock @@ -106,11 +138,62 @@ async def _create_connection(self, config, address): host, port = sockname[0], sockname[1] self.local_address = Address(host, port) + def connect(self, sock, address): + self._connected = False + err = sock.connect_ex(address) + if ( + err in (errno.EINPROGRESS, errno.EALREADY, errno.EWOULDBLOCK) + or err == errno.EINVAL + and os.name == "nt" + ): + return + if err in (0, errno.EISCONN): + self.handle_connect_event(sock) + else: + raise OSError(err, errorcode[err]) + + def handle_connect_event(self, sock): + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise OSError(err, _strerror(err)) + self.handle_connect() + + def handle_connect(self): + self._connected = True + # write any data that were buffered before the socket is available + if self._preconn_buffers: + for b in self._preconn_buffers: + self._proto.write(b) + self._preconn_buffers.clear() + if self._close_task: + self._close_task.cancel() + + self.start_time = time.time() + _logger.debug("Connected to %s", self.connected_address) + + async def _close_timer_cb(self, timeout): + await asyncio.sleep(timeout) + if not self._connected: + await self.close_connection(None, IOError("Connection timed out")) + def _write(self, buf): + if not self._proto: + self._preconn_buffers.append(buf) + return self._proto.write(buf) def _inner_close(self): - self._proto.close() + if self._close_task: + self._close_task.cancel() + if self._proto: + self._proto.close() + self._connected = False + if self._sock: + try: + self._sock.close() + except OSError as why: + if why.errno not in (errno.ENOTCONN, errno.EBADF): + raise def _update_read_time(self, time): self.last_read_time = time @@ -124,6 +207,16 @@ def _update_sent(self, sent): def _update_received(self, received): self._reactor.update_bytes_received(received) + def _set_socket_options(self, sock, config): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, _BUFFER_SIZE) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, _BUFFER_SIZE) + for level, option_name, value in config.socket_options: + if option_name is socket.SO_RCVBUF: + self._receive_buffer_size = value + + sock.setsockopt(level, option_name, value) + def _create_ssl_context(self, config: Config): ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) protocol = config.ssl_protocol @@ -202,7 +295,7 @@ def write(self, buf): def get_buffer(self, sizehint): if self._recv_buf is None: - buf_size = max(sizehint, _BUFFER_SIZE) + buf_size = max(sizehint, self._conn._receive_buffer_size) self._recv_buf = memoryview(bytearray(buf_size)) return self._recv_buf @@ -227,3 +320,12 @@ def _do_write(self): def _write_loop(self): self._do_write() return self._conn._loop.call_later(0.01, self._write_loop) + + +def _strerror(err): + try: + return os.strerror(err) + except (ValueError, OverflowError, NameError): + if err in errorcode: + return errorcode[err] + return "Unknown error %s" % err diff --git a/tests/integration/asyncio/cluster_test.py b/tests/integration/asyncio/cluster_test.py index 681a4b4f6f..206ace480c 100644 --- a/tests/integration/asyncio/cluster_test.py +++ b/tests/integration/asyncio/cluster_test.py @@ -1,3 +1,4 @@ +import asyncio import os import tempfile import unittest @@ -200,17 +201,19 @@ async def asyncTearDown(self): self.rc.exit() async def test_when_member_started_with_another_port_and_the_same_uuid(self): - member = self.cluster.start_member() + member = await asyncio.to_thread(self.cluster.start_member) self.client = await HazelcastClient.create_and_start(cluster_name=self.cluster.id) added_listener = event_collector() removed_listener = event_collector() self.client.cluster_service.add_listener( member_added=added_listener, member_removed=removed_listener ) - self.rc.shutdownCluster(self.cluster.id) + await asyncio.to_thread(self.rc.shutdownCluster, self.cluster.id) # now stop cluster, restart it with the same name and then start member with port 5702 - self.cluster = self.create_cluster_keep_cluster_name(self.rc, self.get_config(5702)) - self.cluster.start_member() + self.cluster = await asyncio.to_thread( + self.create_cluster_keep_cluster_name, self.rc, self.get_config(5702) + ) + await asyncio.to_thread(self.cluster.start_member) def assertion(): self.assertEqual(1, len(added_listener.events)) @@ -223,15 +226,15 @@ def assertion(): async def test_when_member_started_with_the_same_address(self): skip_if_client_version_older_than(self, "4.2") - old_member = self.cluster.start_member() + old_member = await asyncio.to_thread(self.cluster.start_member) self.client = await HazelcastClient.create_and_start(cluster_name=self.cluster.id) members_added = [] members_removed = [] self.client.cluster_service.add_listener( lambda m: members_added.append(m), lambda m: members_removed.append(m) ) - self.rc.shutdownMember(self.cluster.id, old_member.uuid) - new_member = self.cluster.start_member() + await asyncio.to_thread(self.rc.shutdownMember, self.cluster.id, old_member.uuid) + new_member = await asyncio.to_thread(self.cluster.start_member) def assertion(): self.assertEqual(1, len(members_added)) diff --git a/tests/integration/asyncio/connection_manager_test.py b/tests/integration/asyncio/connection_manager_test.py new file mode 100644 index 0000000000..fd89cf18c8 --- /dev/null +++ b/tests/integration/asyncio/connection_manager_test.py @@ -0,0 +1,238 @@ +import asyncio +import unittest +import uuid + +from mock import patch + +from hazelcast.asyncio import HazelcastClient +from hazelcast.core import Address, MemberInfo, MemberVersion, EndpointQualifier, ProtocolType +from hazelcast.errors import IllegalStateError, TargetDisconnectedError +from hazelcast.lifecycle import LifecycleState +from hazelcast.util import AtomicInteger +from tests.integration.asyncio.base import HazelcastTestCase, SingleMemberTestCase +from tests.util import random_string + +# 198.51.100.0/24 is assigned as TEST-NET-2 and should be unreachable +# See: https://en.wikipedia.org/wiki/Reserved_IP_addresses +_UNREACHABLE_ADDRESS = Address("198.51.100.1", 5701) +_MEMBER_VERSION = MemberVersion(5, 0, 0) +_CLIENT_PUBLIC_ENDPOINT_QUALIFIER = EndpointQualifier(ProtocolType.CLIENT, "public") + + +class ConnectionManagerTranslateTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + + rc = None + cluster = None + member = None + + @classmethod + def setUpClass(cls): + cls.rc = cls.create_rc() + cls.cluster = cls.create_cluster(cls.rc, None) + cls.member = cls.cluster.start_member() + + @classmethod + def tearDownClass(cls): + cls.rc.terminateCluster(cls.cluster.id) + cls.rc.exit() + + def setUp(self): + self.client = None + + async def asyncTearDown(self): + if self.client: + await self.client.shutdown() + + async def test_translate_is_used(self): + # It shouldn't be able to connect to cluster using unreachable + # public address. + with self.assertRaises(IllegalStateError): + with patch.object( + HazelcastClient, + "_create_address_provider", + return_value=StaticAddressProvider(True, self.member.address), + ): + self.client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, + cluster_connect_timeout=1.0, + connection_timeout=1.0, + ) + + async def test_translate_is_not_used_when_getting_existing_connection(self): + provider = StaticAddressProvider(False, self.member.address) + with patch.object( + HazelcastClient, + "_create_address_provider", + return_value=provider, + ): + self.client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, + ) + # If the translate is used for this, it would return + # the unreachable address and the connection attempt + # would fail. + provider.should_translate = True + conn_manager = self.client._connection_manager + conn = await conn_manager._get_or_connect_to_address(self.member.address) + self.assertIsNotNone(conn) + + async def test_translate_is_used_when_member_has_public_client_address(self): + self.client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, + use_public_ip=True, + ) + + member = MemberInfo( + _UNREACHABLE_ADDRESS, + uuid.uuid4(), + {}, + False, + _MEMBER_VERSION, + None, + { + _CLIENT_PUBLIC_ENDPOINT_QUALIFIER: self.member.address, + }, + ) + conn_manager = self.client._connection_manager + conn = await conn_manager._get_or_connect_to_member(member) + self.assertIsNotNone(conn) + + async def test_translate_is_not_used_when_member_has_public_client_address_but_option_is_disabled( + self, + ): + self.client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, + connection_timeout=1.0, + use_public_ip=False, + ) + + member = MemberInfo( + _UNREACHABLE_ADDRESS, + uuid.uuid4(), + {}, + False, + _MEMBER_VERSION, + None, + { + _CLIENT_PUBLIC_ENDPOINT_QUALIFIER: self.member.address, + }, + ) + conn_manager = self.client._connection_manager + + with self.assertRaises(TargetDisconnectedError): + await conn_manager._get_or_connect_to_member(member) + + +class ConnectionManagerOnClusterRestartTest(SingleMemberTestCase): + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + return config + + async def test_client_state_is_sent_once_if_send_operation_is_successful(self): + conn_manager = self.client._connection_manager + counter = AtomicInteger() + + async def send_state_to_cluster_fn(): + counter.add(1) + return None + + conn_manager._send_state_to_cluster_fn = send_state_to_cluster_fn + await self._restart_cluster() + self.assertEqual(1, counter.get()) + + async def test_sending_client_state_is_retried_if_send_operation_is_failed(self): + conn_manager = self.client._connection_manager + counter = AtomicInteger() + + async def send_state_to_cluster_fn(): + counter.add(1) + if counter.get() == 5: + # Let's pretend it succeeds at some point + return None + + raise RuntimeError("expected") + + conn_manager._send_state_to_cluster_fn = send_state_to_cluster_fn + await self._restart_cluster() + self.assertEqual(5, counter.get()) + + async def test_sending_client_state_is_retried_if_send_operation_is_failed_synchronously(self): + conn_manager = self.client._connection_manager + counter = AtomicInteger() + + async def send_state_to_cluster_fn(): + counter.add(1) + if counter.get() == 5: + # Let's pretend it succeeds at some point + return None + + raise RuntimeError("expected") + + conn_manager._send_state_to_cluster_fn = send_state_to_cluster_fn + await self._restart_cluster() + self.assertEqual(5, counter.get()) + + async def test_client_state_is_sent_on_reconnection_when_the_cluster_id_is_same(self): + disconnected = asyncio.Event() + reconnected = asyncio.Event() + + def listener(state): + if state == LifecycleState.DISCONNECTED: + disconnected.set() + elif disconnected.is_set() and state == LifecycleState.CONNECTED: + reconnected.set() + + self.client.lifecycle_service.add_listener(listener) + conn_manager = self.client._connection_manager + counter = AtomicInteger() + + async def send_state_to_cluster_fn(): + counter.add(1) + return None + + conn_manager._send_state_to_cluster_fn = send_state_to_cluster_fn + + # Keep the cluster alive, but close the connection + # to simulate re-connection to a cluster with + # the same cluster id. + connection = conn_manager.get_random_connection() + await connection.close_connection("expected", None) + + await disconnected.wait() + await reconnected.wait() + + await self._wait_until_state_is_sent() + + self.assertEqual(1, counter.get()) + + async def _restart_cluster(self): + await asyncio.to_thread(self.rc.terminateMember, self.cluster.id, self.member.uuid) + ConnectionManagerOnClusterRestartTest.member = await asyncio.to_thread( + self.cluster.start_member + ) + await self._wait_until_state_is_sent() + + async def _wait_until_state_is_sent(self): + # Perform an invocation to wait until the client state is sent + m = await self.client.get_map(random_string()) + await m.set(1, 1) + self.assertEqual(1, await m.get(1)) + + +class StaticAddressProvider: + def __init__(self, should_translate, member_address): + self.should_translate = should_translate + self.member_address = member_address + + async def load_addresses(self): + return [self.member_address], [] + + async def translate(self, address): + if not self.should_translate: + return address + + if address == self.member_address: + return _UNREACHABLE_ADDRESS + + return None diff --git a/tests/integration/asyncio/heartbeat_test.py b/tests/integration/asyncio/heartbeat_test.py index 5fbdd72b3f..8f4c6ad5c7 100644 --- a/tests/integration/asyncio/heartbeat_test.py +++ b/tests/integration/asyncio/heartbeat_test.py @@ -74,7 +74,7 @@ async def run(): await asyncio.sleep((i + 1) * 0.1) - asyncio.create_task(run()) + task = asyncio.create_task(run()) async def assert_heartbeat_stopped_and_restored(): nonlocal assertion_succeeded @@ -95,3 +95,4 @@ async def assert_heartbeat_stopped_and_restored(): assertion_succeeded = True await self.assertTrueEventually(assert_heartbeat_stopped_and_restored) + await task diff --git a/tests/integration/asyncio/listener_test.py b/tests/integration/asyncio/listener_test.py index 9d132a8ad4..626040adb5 100644 --- a/tests/integration/asyncio/listener_test.py +++ b/tests/integration/asyncio/listener_test.py @@ -83,7 +83,7 @@ async def test_add_member_smart(self): await self._add_member_test(True) async def test_add_member_unisocket(self): - await self._add_member_test(True) + await self._add_member_test(False) async def _add_member_test(self, is_smart): self.client_config["smart_routing"] = is_smart @@ -111,7 +111,7 @@ async def run(): await random_map.put(key_m2, f"value-{i}") await asyncio.sleep((i + 1) * 0.1) - asyncio.create_task(run()) + task = asyncio.create_task(run()) def assert_event(): nonlocal assertion_succeeded @@ -119,3 +119,4 @@ def assert_event(): assertion_succeeded = True await self.assertTrueEventually(assert_event) + await task diff --git a/tests/integration/asyncio/proxy/vector_collection_test.py b/tests/integration/asyncio/proxy/vector_collection_test.py index d6db835782..8cc4090f0e 100644 --- a/tests/integration/asyncio/proxy/vector_collection_test.py +++ b/tests/integration/asyncio/proxy/vector_collection_test.py @@ -178,7 +178,7 @@ async def test_size(self): self.assertEqual(await self.vector_collection.size(), 0) async def test_backup_count_valid_values_pass(self): - skip_if_client_version_older_than(self, "6.0") + skip_if_client_version_older_than(self, "5.6.0") name = random_string() await self.client.create_vector_collection_config( name, [IndexConfig("vector", Metric.COSINE, 3)], backup_count=2, async_backup_count=2 @@ -186,7 +186,7 @@ async def test_backup_count_valid_values_pass(self): await self.client.get_vector_collection(name) async def test_backup_count_max_value_pass(self): - skip_if_client_version_older_than(self, "6.0") + skip_if_client_version_older_than(self, "5.6.0") name = random_string() await self.client.create_vector_collection_config( name, [IndexConfig("vector", Metric.COSINE, 3)], backup_count=6 @@ -194,7 +194,7 @@ async def test_backup_count_max_value_pass(self): await self.client.get_vector_collection(name) async def test_backup_count_min_value_pass(self): - skip_if_client_version_older_than(self, "6.0") + skip_if_client_version_older_than(self, "5.6.0") name = random_string() await self.client.create_vector_collection_config( name, [IndexConfig("vector", Metric.COSINE, 3)], backup_count=0 @@ -202,7 +202,7 @@ async def test_backup_count_min_value_pass(self): await self.client.get_vector_collection(name) async def test_backup_count_more_than_max_value_fail(self): - skip_if_server_version_older_than(self, self.client, "6.0") + skip_if_server_version_older_than(self, self.client, "5.6.0") name = random_string() # check that the parameter is used by ensuring that it is validated on server side # there is no simple way to check number of backups @@ -215,7 +215,7 @@ async def test_backup_count_more_than_max_value_fail(self): ) async def test_backup_count_less_than_min_value_fail(self): - skip_if_server_version_older_than(self, self.client, "6.0") + skip_if_server_version_older_than(self, self.client, "5.6.0") name = random_string() with self.assertRaises(hazelcast.errors.IllegalArgumentError): await self.client.create_vector_collection_config( @@ -223,7 +223,7 @@ async def test_backup_count_less_than_min_value_fail(self): ) async def test_async_backup_count_max_value_pass(self): - skip_if_client_version_older_than(self, "6.0") + skip_if_client_version_older_than(self, "5.6.0") name = random_string() await self.client.create_vector_collection_config( name, @@ -234,7 +234,7 @@ async def test_async_backup_count_max_value_pass(self): await self.client.get_vector_collection(name) async def test_async_backup_count_min_value_pass(self): - skip_if_client_version_older_than(self, "6.0") + skip_if_client_version_older_than(self, "5.6.0") name = random_string() await self.client.create_vector_collection_config( name, [IndexConfig("vector", Metric.COSINE, 3)], async_backup_count=0 @@ -242,7 +242,7 @@ async def test_async_backup_count_min_value_pass(self): await self.client.get_vector_collection(name) async def test_async_backup_count_more_than_max_value_fail(self): - skip_if_server_version_older_than(self, self.client, "6.0") + skip_if_server_version_older_than(self, self.client, "5.6.0") name = random_string() # check that the parameter is used by ensuring that it is validated on server side # there is no simple way to check number of backups @@ -255,7 +255,7 @@ async def test_async_backup_count_more_than_max_value_fail(self): ) async def test_async_backup_count_less_than_min_value_fail(self): - skip_if_server_version_older_than(self, self.client, "6.0") + skip_if_server_version_older_than(self, self.client, "5.6.0") name = random_string() with self.assertRaises(hazelcast.errors.IllegalArgumentError): await self.client.create_vector_collection_config( @@ -265,7 +265,7 @@ async def test_async_backup_count_less_than_min_value_fail(self): ) async def test_sync_and_async_backup_count_more_than_max_value_fail(self): - skip_if_server_version_older_than(self, self.client, "6.0") + skip_if_server_version_older_than(self, self.client, "5.6.0") name = random_string() with self.assertRaises(hazelcast.errors.IllegalArgumentError): await self.client.create_vector_collection_config( @@ -276,7 +276,7 @@ async def test_sync_and_async_backup_count_more_than_max_value_fail(self): ) async def test_merge_policy_can_be_sent(self): - skip_if_client_version_older_than(self, "6.0") + skip_if_client_version_older_than(self, "5.6.0") name = random_string() await self.client.create_vector_collection_config( name, @@ -288,8 +288,8 @@ async def test_merge_policy_can_be_sent(self): await self.client.get_vector_collection(name) async def test_wrong_merge_policy_fails(self): - skip_if_client_version_older_than(self, "6.0") - skip_if_server_version_older_than(self, self.client, "6.0") + skip_if_client_version_older_than(self, "5.6.0") + skip_if_server_version_older_than(self, self.client, "5.6.0") name = random_string() with self.assertRaises(hazelcast.errors.InvalidConfigurationError): await self.client.create_vector_collection_config( @@ -299,7 +299,7 @@ async def test_wrong_merge_policy_fails(self): await self.client.get_vector_collection(name) async def test_split_brain_name_can_be_sent(self): - skip_if_client_version_older_than(self, "6.0") + skip_if_client_version_older_than(self, "5.6.0") name = random_string() await self.client.create_vector_collection_config( name, diff --git a/tests/integration/backward_compatible/proxy/vector_collection_test.py b/tests/integration/backward_compatible/proxy/vector_collection_test.py index 046919f139..ff82b512ad 100644 --- a/tests/integration/backward_compatible/proxy/vector_collection_test.py +++ b/tests/integration/backward_compatible/proxy/vector_collection_test.py @@ -175,7 +175,7 @@ def test_size(self): self.assertEqual(self.vector_collection.size(), 0) def test_backup_count_valid_values_pass(self): - skip_if_client_version_older_than(self, "6.0") + skip_if_client_version_older_than(self, "5.6.0") name = random_string() self.client.create_vector_collection_config( name, [IndexConfig("vector", Metric.COSINE, 3)], backup_count=2, async_backup_count=2 @@ -183,7 +183,7 @@ def test_backup_count_valid_values_pass(self): self.client.get_vector_collection(name).blocking() def test_backup_count_max_value_pass(self): - skip_if_client_version_older_than(self, "6.0") + skip_if_client_version_older_than(self, "5.6.0") name = random_string() self.client.create_vector_collection_config( name, [IndexConfig("vector", Metric.COSINE, 3)], backup_count=6 @@ -191,7 +191,7 @@ def test_backup_count_max_value_pass(self): self.client.get_vector_collection(name).blocking() def test_backup_count_min_value_pass(self): - skip_if_client_version_older_than(self, "6.0") + skip_if_client_version_older_than(self, "5.6.0") name = random_string() self.client.create_vector_collection_config( name, [IndexConfig("vector", Metric.COSINE, 3)], backup_count=0 @@ -199,7 +199,7 @@ def test_backup_count_min_value_pass(self): self.client.get_vector_collection(name).blocking() def test_backup_count_more_than_max_value_fail(self): - skip_if_server_version_older_than(self, self.client, "6.0") + skip_if_server_version_older_than(self, self.client, "5.6.0") name = random_string() # check that the parameter is used by ensuring that it is validated on server side # there is no simple way to check number of backups @@ -212,7 +212,7 @@ def test_backup_count_more_than_max_value_fail(self): ) def test_backup_count_less_than_min_value_fail(self): - skip_if_server_version_older_than(self, self.client, "6.0") + skip_if_server_version_older_than(self, self.client, "5.6.0") name = random_string() with self.assertRaises(hazelcast.errors.IllegalArgumentError): self.client.create_vector_collection_config( @@ -220,7 +220,7 @@ def test_backup_count_less_than_min_value_fail(self): ) def test_async_backup_count_max_value_pass(self): - skip_if_client_version_older_than(self, "6.0") + skip_if_client_version_older_than(self, "5.6.0") name = random_string() self.client.create_vector_collection_config( name, @@ -231,7 +231,7 @@ def test_async_backup_count_max_value_pass(self): self.client.get_vector_collection(name).blocking() def test_async_backup_count_min_value_pass(self): - skip_if_client_version_older_than(self, "6.0") + skip_if_client_version_older_than(self, "5.6.0") name = random_string() self.client.create_vector_collection_config( name, [IndexConfig("vector", Metric.COSINE, 3)], async_backup_count=0 @@ -239,7 +239,7 @@ def test_async_backup_count_min_value_pass(self): self.client.get_vector_collection(name).blocking() def test_async_backup_count_more_than_max_value_fail(self): - skip_if_server_version_older_than(self, self.client, "6.0") + skip_if_server_version_older_than(self, self.client, "5.6.0") name = random_string() # check that the parameter is used by ensuring that it is validated on server side # there is no simple way to check number of backups @@ -252,7 +252,7 @@ def test_async_backup_count_more_than_max_value_fail(self): ) def test_async_backup_count_less_than_min_value_fail(self): - skip_if_server_version_older_than(self, self.client, "6.0") + skip_if_server_version_older_than(self, self.client, "5.6.0") name = random_string() with self.assertRaises(hazelcast.errors.IllegalArgumentError): self.client.create_vector_collection_config( @@ -262,7 +262,7 @@ def test_async_backup_count_less_than_min_value_fail(self): ) def test_sync_and_async_backup_count_more_than_max_value_fail(self): - skip_if_server_version_older_than(self, self.client, "6.0") + skip_if_server_version_older_than(self, self.client, "5.6.0") name = random_string() with self.assertRaises(hazelcast.errors.IllegalArgumentError): self.client.create_vector_collection_config( @@ -273,7 +273,7 @@ def test_sync_and_async_backup_count_more_than_max_value_fail(self): ) def test_merge_policy_can_be_sent(self): - skip_if_client_version_older_than(self, "6.0") + skip_if_client_version_older_than(self, "5.6.0") name = random_string() self.client.create_vector_collection_config( name, @@ -285,8 +285,8 @@ def test_merge_policy_can_be_sent(self): self.client.get_vector_collection(name) def test_wrong_merge_policy_fails(self): - skip_if_client_version_older_than(self, "6.0") - skip_if_server_version_older_than(self, self.client, "6.0") + skip_if_client_version_older_than(self, "5.6.0") + skip_if_server_version_older_than(self, self.client, "5.6.0") name = random_string() with self.assertRaises(hazelcast.errors.InvalidConfigurationError): self.client.create_vector_collection_config( @@ -296,7 +296,7 @@ def test_wrong_merge_policy_fails(self): self.client.get_vector_collection(name) def test_split_brain_name_can_be_sent(self): - skip_if_client_version_older_than(self, "6.0") + skip_if_client_version_older_than(self, "5.6.0") name = random_string() self.client.create_vector_collection_config( name,