From 913933e2a73b88e9ff858548b059fd4bfd43c944 Mon Sep 17 00:00:00 2001 From: Waldemar Quevedo Date: Mon, 23 Oct 2023 02:55:46 -0700 Subject: [PATCH 1/4] Add test for tls handshake option Signed-off-by: Waldemar Quevedo --- tests/conf/tls_handshake_first.conf | 6 +++ tests/test_client.py | 73 +++++++++++++++++++++++++++++ tests/utils.py | 24 ++++++++++ 3 files changed, 103 insertions(+) create mode 100644 tests/conf/tls_handshake_first.conf diff --git a/tests/conf/tls_handshake_first.conf b/tests/conf/tls_handshake_first.conf new file mode 100644 index 00000000..f6214181 --- /dev/null +++ b/tests/conf/tls_handshake_first.conf @@ -0,0 +1,6 @@ +tls { + cert_file: "./tests/certs/server-cert.pem" + key_file: "./tests/certs/server-key.pem" + ca_file: "./tests/certs/ca.pem" + handshake_first: true +} diff --git a/tests/test_client.py b/tests/test_client.py index 96833021..a61d7749 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -21,6 +21,7 @@ MultiTLSServerAuthTestCase, SingleServerTestCase, TLSServerTestCase, + TLSServerHandshakeFirstTestCase, NoAuthUserServerTestCase, async_test, ) @@ -1796,6 +1797,78 @@ async def worker_handler(msg): self.assertEqual(1, reconnected_count) self.assertEqual(1, err_count) +class ClientTLSHandshakeFirstTest(TLSServerHandshakeFirstTestCase): + + @async_test + async def test_connect(self): + nc = await nats.connect('nats://127.0.0.1:4224', tls=self.ssl_ctx) + self.assertEqual(nc._server_info['max_payload'], nc.max_payload) + self.assertTrue(nc._server_info['tls_required']) + self.assertTrue(nc._server_info['tls_verify']) + self.assertTrue(nc.max_payload > 0) + self.assertTrue(nc.is_connected) + await nc.close() + self.assertTrue(nc.is_closed) + self.assertFalse(nc.is_connected) + + @async_test + async def test_default_connect_using_tls_scheme(self): + nc = NATS() + + # Will attempt to connect using TLS with default certs. + with self.assertRaises(ssl.SSLError): + await nc.connect( + servers=['tls://127.0.0.1:4224'], allow_reconnect=False + ) + + @async_test + async def test_default_connect_using_tls_scheme_in_url(self): + nc = NATS() + + # Will attempt to connect using TLS with default certs. + with self.assertRaises(ssl.SSLError): + await nc.connect('tls://127.0.0.1:4224', allow_reconnect=False) + + @async_test + async def test_connect_tls_with_custom_hostname(self): + nc = NATS() + + # Will attempt to connect using TLS with an invalid hostname. + with self.assertRaises(ssl.SSLError): + await nc.connect( + servers=['nats://127.0.0.1:4224'], + tls=self.ssl_ctx, + tls_hostname="nats.example", + allow_reconnect=False, + ) + + @async_test + async def test_subscribe(self): + nc = NATS() + msgs = [] + + async def subscription_handler(msg): + msgs.append(msg) + + payload = b'hello world' + await nc.connect(servers=['nats://127.0.0.1:4224'], tls=self.ssl_ctx) + sub = await nc.subscribe("foo", cb=subscription_handler) + await nc.publish("foo", payload) + await nc.publish("bar", payload) + + with self.assertRaises(nats.errors.BadSubjectError): + await nc.publish("", b'') + + # Wait a bit for message to be received. + await asyncio.sleep(0.2) + + self.assertEqual(1, len(msgs)) + msg = msgs[0] + self.assertEqual('foo', msg.subject) + self.assertEqual('', msg.reply) + self.assertEqual(payload, msg.data) + self.assertEqual(1, sub._received) + await nc.close() class ClusterDiscoveryTest(ClusteringTestCase): diff --git a/tests/utils.py b/tests/utils.py index f026b40b..2fb56d0a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -253,6 +253,30 @@ def tearDown(self): self.loop.close() +class TLSServerHandshakeFirstTestCase(unittest.TestCase): + + def setUp(self): + super().setUp() + self.loop = asyncio.new_event_loop() + + self.natsd = NATSD(port=4224, config_file=get_config_file('conf/tls_handshake_first.conf')) + start_natsd(self.natsd) + + self.ssl_ctx = ssl.create_default_context( + purpose=ssl.Purpose.SERVER_AUTH + ) + # self.ssl_ctx.protocol = ssl.PROTOCOL_TLSv1_2 + self.ssl_ctx.load_verify_locations(get_config_file('certs/ca.pem')) + self.ssl_ctx.load_cert_chain( + certfile=get_config_file('certs/client-cert.pem'), + keyfile=get_config_file('certs/client-key.pem') + ) + + def tearDown(self): + self.natsd.stop() + self.loop.close() + + class MultiTLSServerAuthTestCase(unittest.TestCase): def setUp(self): From 41252eeba47585298b2cde4fa9523f6eec5accb9 Mon Sep 17 00:00:00 2001 From: Waldemar Quevedo Date: Mon, 23 Oct 2023 04:35:51 -0700 Subject: [PATCH 2/4] Add tls_handshake_option for connect Signed-off-by: Waldemar Quevedo --- nats/aio/client.py | 48 ++++++++++++++++++----------- tests/conf/tls_handshake_first.conf | 2 ++ tests/test_client.py | 25 ++++++++++++--- 3 files changed, 53 insertions(+), 22 deletions(-) diff --git a/nats/aio/client.py b/nats/aio/client.py index 9ef0d3f0..b092bf6f 100644 --- a/nats/aio/client.py +++ b/nats/aio/client.py @@ -305,6 +305,7 @@ async def connect( no_echo: bool = False, tls: Optional[ssl.SSLContext] = None, tls_hostname: Optional[str] = None, + tls_handshake_first: bool = False, user: Optional[str] = None, password: Optional[str] = None, token: Optional[str] = None, @@ -448,6 +449,7 @@ async def subscribe_handler(msg): self.options["token"] = token self.options["connect_timeout"] = connect_timeout self.options["drain_timeout"] = drain_timeout + self.options['tls_handshake_first'] = tls_handshake_first if tls: self.options['tls'] = tls @@ -1886,6 +1888,24 @@ async def _process_connect_init(self) -> None: assert self._current_server, "must be called only from Client.connect" self._status = Client.CONNECTING + # Check whether to reuse the original hostname for an implicit route. + hostname = None + if "tls_hostname" in self.options: + hostname = self.options["tls_hostname"] + elif self._current_server.tls_name is not None: + hostname = self._current_server.tls_name + else: + hostname = self._current_server.uri.hostname + + handshake_first = self.options['tls_handshake_first'] + if handshake_first: + await self._transport.connect_tls( + hostname, + self.ssl_context, + DEFAULT_BUFFER_SIZE, + self.options['connect_timeout'], + ) + connection_completed = self._transport.readline() info_line = await asyncio.wait_for( connection_completed, self.options["connect_timeout"] @@ -1921,24 +1941,16 @@ async def _process_connect_init(self) -> None: if 'tls_required' in self._server_info and self._server_info[ 'tls_required'] and self._current_server.uri.scheme != "ws": - # Check whether to reuse the original hostname for an implicit route. - hostname = None - if "tls_hostname" in self.options: - hostname = self.options["tls_hostname"] - elif self._current_server.tls_name is not None: - hostname = self._current_server.tls_name - else: - hostname = self._current_server.uri.hostname - - await self._transport.drain() # just in case something is left - - # connect to transport via tls - await self._transport.connect_tls( - hostname, - self.ssl_context, - DEFAULT_BUFFER_SIZE, - self.options['connect_timeout'], - ) + if not handshake_first: + await self._transport.drain() # just in case something is left + + # connect to transport via tls + await self._transport.connect_tls( + hostname, + self.ssl_context, + DEFAULT_BUFFER_SIZE, + self.options['connect_timeout'], + ) # Refresh state of parser upon reconnect. if self.is_reconnecting: diff --git a/tests/conf/tls_handshake_first.conf b/tests/conf/tls_handshake_first.conf index f6214181..5afc4b28 100644 --- a/tests/conf/tls_handshake_first.conf +++ b/tests/conf/tls_handshake_first.conf @@ -1,6 +1,8 @@ +port: 4224 tls { cert_file: "./tests/certs/server-cert.pem" key_file: "./tests/certs/server-key.pem" ca_file: "./tests/certs/ca.pem" handshake_first: true + verify: true } diff --git a/tests/test_client.py b/tests/test_client.py index a61d7749..52c7bd22 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,6 +2,7 @@ import http.client import json import ssl +import os import time import unittest import urllib @@ -1801,7 +1802,10 @@ class ClientTLSHandshakeFirstTest(TLSServerHandshakeFirstTestCase): @async_test async def test_connect(self): - nc = await nats.connect('nats://127.0.0.1:4224', tls=self.ssl_ctx) + if os.environ['NATS_SERVER_VERSION'] != 'main': + pytest.skip("test requires nats-server@main") + + nc = await nats.connect('nats://127.0.0.1:4224', tls=self.ssl_ctx, tls_handshake_first=True) self.assertEqual(nc._server_info['max_payload'], nc.max_payload) self.assertTrue(nc._server_info['tls_required']) self.assertTrue(nc._server_info['tls_verify']) @@ -1813,24 +1817,33 @@ async def test_connect(self): @async_test async def test_default_connect_using_tls_scheme(self): + if os.environ['NATS_SERVER_VERSION'] != 'main': + pytest.skip("test requires nats-server@main") + nc = NATS() # Will attempt to connect using TLS with default certs. with self.assertRaises(ssl.SSLError): await nc.connect( - servers=['tls://127.0.0.1:4224'], allow_reconnect=False + servers=['tls://127.0.0.1:4224'], allow_reconnect=False, tls_handshake_first=True, ) @async_test async def test_default_connect_using_tls_scheme_in_url(self): + if os.environ['NATS_SERVER_VERSION'] != 'main': + pytest.skip("test requires nats-server@main") + nc = NATS() # Will attempt to connect using TLS with default certs. with self.assertRaises(ssl.SSLError): - await nc.connect('tls://127.0.0.1:4224', allow_reconnect=False) + await nc.connect('tls://127.0.0.1:4224', allow_reconnect=False, tls_handshake_first=True) @async_test async def test_connect_tls_with_custom_hostname(self): + if os.environ['NATS_SERVER_VERSION'] != 'main': + pytest.skip("test requires nats-server@main") + nc = NATS() # Will attempt to connect using TLS with an invalid hostname. @@ -1839,11 +1852,15 @@ async def test_connect_tls_with_custom_hostname(self): servers=['nats://127.0.0.1:4224'], tls=self.ssl_ctx, tls_hostname="nats.example", + tls_handshake_first=True, allow_reconnect=False, ) @async_test async def test_subscribe(self): + if os.environ['NATS_SERVER_VERSION'] != 'main': + pytest.skip("test requires nats-server@main") + nc = NATS() msgs = [] @@ -1851,7 +1868,7 @@ async def subscription_handler(msg): msgs.append(msg) payload = b'hello world' - await nc.connect(servers=['nats://127.0.0.1:4224'], tls=self.ssl_ctx) + await nc.connect(servers=['nats://127.0.0.1:4224'], tls=self.ssl_ctx, tls_handshake_first=True) sub = await nc.subscribe("foo", cb=subscription_handler) await nc.publish("foo", payload) await nc.publish("bar", payload) From 5a324c1d2334194c41ca1028f87e2a5fa173a642 Mon Sep 17 00:00:00 2001 From: Waldemar Quevedo Date: Mon, 23 Oct 2023 04:36:09 -0700 Subject: [PATCH 3/4] Bump version Signed-off-by: Waldemar Quevedo --- nats/aio/client.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nats/aio/client.py b/nats/aio/client.py index b092bf6f..dfc1a1e1 100644 --- a/nats/aio/client.py +++ b/nats/aio/client.py @@ -64,7 +64,7 @@ ) from .transport import TcpTransport, Transport, WebSocketTransport -__version__ = '2.4.0' +__version__ = '2.5.0' __lang__ = 'python3' _logger = logging.getLogger(__name__) PROTOCOL = 1 diff --git a/setup.py b/setup.py index bf1d8e8c..4ec0d413 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ # These are here for GitHub's dependency graph and help with setuptools support in some environments. setup( name="nats-py", - version='2.4.0', + version='2.5.0', license='Apache 2 License', extras_require={ 'nkeys': ['nkeys'], From 6127b18f832bb6651ec4dffbe9a502339868cd4a Mon Sep 17 00:00:00 2001 From: Waldemar Quevedo Date: Mon, 23 Oct 2023 04:53:14 -0700 Subject: [PATCH 4/4] Formatting Signed-off-by: Waldemar Quevedo --- tests/test_client.py | 34 +++++++++++++++++++++++++--------- tests/utils.py | 5 ++++- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index 52c7bd22..905effb8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1798,14 +1798,19 @@ async def worker_handler(msg): self.assertEqual(1, reconnected_count) self.assertEqual(1, err_count) + class ClientTLSHandshakeFirstTest(TLSServerHandshakeFirstTestCase): @async_test async def test_connect(self): - if os.environ['NATS_SERVER_VERSION'] != 'main': + if os.environ.get('NATS_SERVER_VERSION') != 'main': pytest.skip("test requires nats-server@main") - nc = await nats.connect('nats://127.0.0.1:4224', tls=self.ssl_ctx, tls_handshake_first=True) + nc = await nats.connect( + 'nats://127.0.0.1:4224', + tls=self.ssl_ctx, + tls_handshake_first=True + ) self.assertEqual(nc._server_info['max_payload'], nc.max_payload) self.assertTrue(nc._server_info['tls_required']) self.assertTrue(nc._server_info['tls_verify']) @@ -1817,7 +1822,7 @@ async def test_connect(self): @async_test async def test_default_connect_using_tls_scheme(self): - if os.environ['NATS_SERVER_VERSION'] != 'main': + if os.environ.get('NATS_SERVER_VERSION') != 'main': pytest.skip("test requires nats-server@main") nc = NATS() @@ -1825,23 +1830,29 @@ async def test_default_connect_using_tls_scheme(self): # Will attempt to connect using TLS with default certs. with self.assertRaises(ssl.SSLError): await nc.connect( - servers=['tls://127.0.0.1:4224'], allow_reconnect=False, tls_handshake_first=True, + servers=['tls://127.0.0.1:4224'], + allow_reconnect=False, + tls_handshake_first=True, ) @async_test async def test_default_connect_using_tls_scheme_in_url(self): - if os.environ['NATS_SERVER_VERSION'] != 'main': + if os.environ.get('NATS_SERVER_VERSION') != 'main': pytest.skip("test requires nats-server@main") nc = NATS() # Will attempt to connect using TLS with default certs. with self.assertRaises(ssl.SSLError): - await nc.connect('tls://127.0.0.1:4224', allow_reconnect=False, tls_handshake_first=True) + await nc.connect( + 'tls://127.0.0.1:4224', + allow_reconnect=False, + tls_handshake_first=True + ) @async_test async def test_connect_tls_with_custom_hostname(self): - if os.environ['NATS_SERVER_VERSION'] != 'main': + if os.environ.get('NATS_SERVER_VERSION') != 'main': pytest.skip("test requires nats-server@main") nc = NATS() @@ -1858,7 +1869,7 @@ async def test_connect_tls_with_custom_hostname(self): @async_test async def test_subscribe(self): - if os.environ['NATS_SERVER_VERSION'] != 'main': + if os.environ.get('NATS_SERVER_VERSION') != 'main': pytest.skip("test requires nats-server@main") nc = NATS() @@ -1868,7 +1879,11 @@ async def subscription_handler(msg): msgs.append(msg) payload = b'hello world' - await nc.connect(servers=['nats://127.0.0.1:4224'], tls=self.ssl_ctx, tls_handshake_first=True) + await nc.connect( + servers=['nats://127.0.0.1:4224'], + tls=self.ssl_ctx, + tls_handshake_first=True + ) sub = await nc.subscribe("foo", cb=subscription_handler) await nc.publish("foo", payload) await nc.publish("bar", payload) @@ -1887,6 +1902,7 @@ async def subscription_handler(msg): self.assertEqual(1, sub._received) await nc.close() + class ClusterDiscoveryTest(ClusteringTestCase): @async_test diff --git a/tests/utils.py b/tests/utils.py index 2fb56d0a..e900ecaf 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -259,7 +259,10 @@ def setUp(self): super().setUp() self.loop = asyncio.new_event_loop() - self.natsd = NATSD(port=4224, config_file=get_config_file('conf/tls_handshake_first.conf')) + self.natsd = NATSD( + port=4224, + config_file=get_config_file('conf/tls_handshake_first.conf') + ) start_natsd(self.natsd) self.ssl_ctx = ssl.create_default_context(