Skip to content

Commit

Permalink
Update the WebSocketTransport to detect upgrade to TLS (#443)
Browse files Browse the repository at this point in the history
Update the WebSocketTransport to detect the attempt to upgrade the connection to TLS.

If the connection is already using TLS then continue. Otherwise raise a ProtocolError.

Fixes #442
  • Loading branch information
allanbank committed May 5, 2023
1 parent f7bdf29 commit 990d30f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
10 changes: 10 additions & 0 deletions nats/aio/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
except ImportError:
aiohttp = None # type: ignore[assignment]

from nats.errors import ProtocolError


class Transport(abc.ABC):

Expand Down Expand Up @@ -197,6 +199,7 @@ def __init__(self):
self._client: aiohttp.ClientSession = aiohttp.ClientSession()
self._pending = asyncio.Queue()
self._close_task = asyncio.Future()
self._using_tls: bool | None = None

async def connect(
self, uri: ParseResult, buffer_size: int, connect_timeout: int
Expand All @@ -205,6 +208,7 @@ async def connect(
self._ws = await self._client.ws_connect(
uri.geturl(), timeout=connect_timeout
)
self._using_tls = False

async def connect_tls(
self,
Expand All @@ -213,11 +217,17 @@ async def connect_tls(
buffer_size: int,
connect_timeout: int,
):
if self._ws and not self._ws.closed:
if self._using_tls:
return
raise ProtocolError("ws: cannot upgrade to TLS")

self._ws = await self._client.ws_connect(
uri if isinstance(uri, str) else uri.geturl(),
ssl=ssl_context,
timeout=connect_timeout
)
self._using_tls = True

def write(self, payload):
self._pending.put_nowait(payload)
Expand Down
4 changes: 3 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,9 @@ def setUp(self):
)

server = NATSD(
port=4222, config_file=get_config_file("conf/ws_tls.conf")
port=4222,
tls=True,
config_file=get_config_file("conf/ws_tls.conf")
)
self.server_pool.append(server)
for natsd in self.server_pool:
Expand Down

0 comments on commit 990d30f

Please sign in to comment.