Skip to content

Commit

Permalink
Merge pull request #511 from nats-io/tls-handshake-first
Browse files Browse the repository at this point in the history
Add tls_handshake_first option.
  • Loading branch information
wallyqs committed Oct 23, 2023
2 parents cf86a09 + 6127b18 commit 12fe022
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 20 deletions.
50 changes: 31 additions & 19 deletions nats/aio/client.py
Expand Up @@ -64,7 +64,7 @@
)
from .transport import TcpTransport, Transport, WebSocketTransport

__version__ = '2.4.0'
__version__ = '2.5.0'
__lang__ = 'python3'
_logger = logging.getLogger(__name__)
PROTOCOL = 1
Expand Down Expand Up @@ -305,6 +305,7 @@ async def connect(
no_echo: bool = False,
tls: Optional[ssl.SSLContext] = None,
tls_hostname: Optional[str] = None,
tls_handshake_first: bool = False,
user: Optional[str] = None,
password: Optional[str] = None,
token: Optional[str] = None,
Expand Down Expand Up @@ -448,6 +449,7 @@ async def subscribe_handler(msg):
self.options["token"] = token
self.options["connect_timeout"] = connect_timeout
self.options["drain_timeout"] = drain_timeout
self.options['tls_handshake_first'] = tls_handshake_first

if tls:
self.options['tls'] = tls
Expand Down Expand Up @@ -1886,6 +1888,24 @@ async def _process_connect_init(self) -> None:
assert self._current_server, "must be called only from Client.connect"
self._status = Client.CONNECTING

# Check whether to reuse the original hostname for an implicit route.
hostname = None
if "tls_hostname" in self.options:
hostname = self.options["tls_hostname"]
elif self._current_server.tls_name is not None:
hostname = self._current_server.tls_name
else:
hostname = self._current_server.uri.hostname

handshake_first = self.options['tls_handshake_first']
if handshake_first:
await self._transport.connect_tls(
hostname,
self.ssl_context,
DEFAULT_BUFFER_SIZE,
self.options['connect_timeout'],
)

connection_completed = self._transport.readline()
info_line = await asyncio.wait_for(
connection_completed, self.options["connect_timeout"]
Expand Down Expand Up @@ -1921,24 +1941,16 @@ async def _process_connect_init(self) -> None:

if 'tls_required' in self._server_info and self._server_info[
'tls_required'] and self._current_server.uri.scheme != "ws":
# Check whether to reuse the original hostname for an implicit route.
hostname = None
if "tls_hostname" in self.options:
hostname = self.options["tls_hostname"]
elif self._current_server.tls_name is not None:
hostname = self._current_server.tls_name
else:
hostname = self._current_server.uri.hostname

await self._transport.drain() # just in case something is left

# connect to transport via tls
await self._transport.connect_tls(
hostname,
self.ssl_context,
DEFAULT_BUFFER_SIZE,
self.options['connect_timeout'],
)
if not handshake_first:
await self._transport.drain() # just in case something is left

# connect to transport via tls
await self._transport.connect_tls(
hostname,
self.ssl_context,
DEFAULT_BUFFER_SIZE,
self.options['connect_timeout'],
)

# Refresh state of parser upon reconnect.
if self.is_reconnecting:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -4,7 +4,7 @@
# These are here for GitHub's dependency graph and help with setuptools support in some environments.
setup(
name="nats-py",
version='2.4.0',
version='2.5.0',
license='Apache 2 License',
extras_require={
'nkeys': ['nkeys'],
Expand Down
8 changes: 8 additions & 0 deletions tests/conf/tls_handshake_first.conf
@@ -0,0 +1,8 @@
port: 4224
tls {
cert_file: "./tests/certs/server-cert.pem"
key_file: "./tests/certs/server-key.pem"
ca_file: "./tests/certs/ca.pem"
handshake_first: true
verify: true
}
106 changes: 106 additions & 0 deletions tests/test_client.py
Expand Up @@ -2,6 +2,7 @@
import http.client
import json
import ssl
import os
import time
import unittest
import urllib
Expand All @@ -21,6 +22,7 @@
MultiTLSServerAuthTestCase,
SingleServerTestCase,
TLSServerTestCase,
TLSServerHandshakeFirstTestCase,
NoAuthUserServerTestCase,
async_test,
)
Expand Down Expand Up @@ -1797,6 +1799,110 @@ async def worker_handler(msg):
self.assertEqual(1, err_count)


class ClientTLSHandshakeFirstTest(TLSServerHandshakeFirstTestCase):

@async_test
async def test_connect(self):
if os.environ.get('NATS_SERVER_VERSION') != 'main':
pytest.skip("test requires nats-server@main")

nc = await nats.connect(
'nats://127.0.0.1:4224',
tls=self.ssl_ctx,
tls_handshake_first=True
)
self.assertEqual(nc._server_info['max_payload'], nc.max_payload)
self.assertTrue(nc._server_info['tls_required'])
self.assertTrue(nc._server_info['tls_verify'])
self.assertTrue(nc.max_payload > 0)
self.assertTrue(nc.is_connected)
await nc.close()
self.assertTrue(nc.is_closed)
self.assertFalse(nc.is_connected)

@async_test
async def test_default_connect_using_tls_scheme(self):
if os.environ.get('NATS_SERVER_VERSION') != 'main':
pytest.skip("test requires nats-server@main")

nc = NATS()

# Will attempt to connect using TLS with default certs.
with self.assertRaises(ssl.SSLError):
await nc.connect(
servers=['tls://127.0.0.1:4224'],
allow_reconnect=False,
tls_handshake_first=True,
)

@async_test
async def test_default_connect_using_tls_scheme_in_url(self):
if os.environ.get('NATS_SERVER_VERSION') != 'main':
pytest.skip("test requires nats-server@main")

nc = NATS()

# Will attempt to connect using TLS with default certs.
with self.assertRaises(ssl.SSLError):
await nc.connect(
'tls://127.0.0.1:4224',
allow_reconnect=False,
tls_handshake_first=True
)

@async_test
async def test_connect_tls_with_custom_hostname(self):
if os.environ.get('NATS_SERVER_VERSION') != 'main':
pytest.skip("test requires nats-server@main")

nc = NATS()

# Will attempt to connect using TLS with an invalid hostname.
with self.assertRaises(ssl.SSLError):
await nc.connect(
servers=['nats://127.0.0.1:4224'],
tls=self.ssl_ctx,
tls_hostname="nats.example",
tls_handshake_first=True,
allow_reconnect=False,
)

@async_test
async def test_subscribe(self):
if os.environ.get('NATS_SERVER_VERSION') != 'main':
pytest.skip("test requires nats-server@main")

nc = NATS()
msgs = []

async def subscription_handler(msg):
msgs.append(msg)

payload = b'hello world'
await nc.connect(
servers=['nats://127.0.0.1:4224'],
tls=self.ssl_ctx,
tls_handshake_first=True
)
sub = await nc.subscribe("foo", cb=subscription_handler)
await nc.publish("foo", payload)
await nc.publish("bar", payload)

with self.assertRaises(nats.errors.BadSubjectError):
await nc.publish("", b'')

# Wait a bit for message to be received.
await asyncio.sleep(0.2)

self.assertEqual(1, len(msgs))
msg = msgs[0]
self.assertEqual('foo', msg.subject)
self.assertEqual('', msg.reply)
self.assertEqual(payload, msg.data)
self.assertEqual(1, sub._received)
await nc.close()


class ClusterDiscoveryTest(ClusteringTestCase):

@async_test
Expand Down
27 changes: 27 additions & 0 deletions tests/utils.py
Expand Up @@ -253,6 +253,33 @@ def tearDown(self):
self.loop.close()


class TLSServerHandshakeFirstTestCase(unittest.TestCase):

def setUp(self):
super().setUp()
self.loop = asyncio.new_event_loop()

self.natsd = NATSD(
port=4224,
config_file=get_config_file('conf/tls_handshake_first.conf')
)
start_natsd(self.natsd)

self.ssl_ctx = ssl.create_default_context(
purpose=ssl.Purpose.SERVER_AUTH
)
# self.ssl_ctx.protocol = ssl.PROTOCOL_TLSv1_2
self.ssl_ctx.load_verify_locations(get_config_file('certs/ca.pem'))
self.ssl_ctx.load_cert_chain(
certfile=get_config_file('certs/client-cert.pem'),
keyfile=get_config_file('certs/client-key.pem')
)

def tearDown(self):
self.natsd.stop()
self.loop.close()


class MultiTLSServerAuthTestCase(unittest.TestCase):

def setUp(self):
Expand Down

0 comments on commit 12fe022

Please sign in to comment.