Skip to content

Commit

Permalink
Update the WebSocketTransport to detect the attempt to upgrade the co…
Browse files Browse the repository at this point in the history
…nnection to TLS.

If the connection is already using TLS then continue. Otherwise raise a ProtocolError.

Fixes #442
  • Loading branch information
allanbank committed Apr 6, 2023
1 parent 0ecfb22 commit 9ea5d76
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions nats/aio/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
except ImportError:
aiohttp = None # type: ignore[assignment]

from nats.errors import ProtocolError

class Transport(abc.ABC):

Expand Down Expand Up @@ -197,6 +198,7 @@ def __init__(self):
self._client: aiohttp.ClientSession = aiohttp.ClientSession()
self._pending = asyncio.Queue()
self._close_task = asyncio.Future()
self._using_tls: bool | None = None

async def connect(
self, uri: ParseResult, buffer_size: int, connect_timeout: int
Expand All @@ -205,6 +207,7 @@ async def connect(
self._ws = await self._client.ws_connect(
uri.geturl(), timeout=connect_timeout
)
self._using_tls = False

async def connect_tls(
self,
Expand All @@ -213,11 +216,17 @@ async def connect_tls(
buffer_size: int,
connect_timeout: int,
):
if self._ws and not self._ws.closed:
if self._using_tls:
return
raise ProtocolError("ws: cannot upgrade to TLS")

self._ws = await self._client.ws_connect(
uri if isinstance(uri, str) else uri.geturl(),
ssl=ssl_context,
timeout=connect_timeout
)
self._using_tls = True

def write(self, payload):
self._pending.put_nowait(payload)
Expand Down

0 comments on commit 9ea5d76

Please sign in to comment.