Skip to content

Commit

Permalink
Merge pull request #345 from dvonthenen/fix-issue-344
Browse files Browse the repository at this point in the history
Fix Starting KeepAlive Always, Switch for Exceptions
  • Loading branch information
dvonthenen committed Mar 20, 2024
2 parents 2d0ddf1 + 163f46d commit 4fa26b3
Show file tree
Hide file tree
Showing 12 changed files with 318 additions and 199 deletions.
195 changes: 132 additions & 63 deletions deepgram/clients/live/v1/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def __init__(self, config: DeepgramClientOptions):
self.config = config
self.endpoint = "v1/listen"
self._socket = None
self._exit_event = None
self._event_handlers = {event: [] for event in LiveTranscriptionEvents}
self.websocket_url = convert_to_websocket_url(self.config.url, self.endpoint)
self.exit_event = None

# starts the WebSocket connection for live transcription
async def start(
Expand All @@ -61,6 +61,9 @@ async def start(
members: Optional[Dict] = None,
**kwargs,
) -> bool:
"""
Starts the WebSocket connection for live transcription.
"""
self.logger.debug("AsyncLiveClient.start ENTER")
self.logger.info("options: %s", options)
self.logger.info("addons: %s", addons)
Expand Down Expand Up @@ -102,13 +105,25 @@ async def start(
self.logger.debug("combined_options: %s", combined_options)

url_with_params = append_query_params(self.websocket_url, combined_options)
self.exit_event = asyncio.Event()

try:
self._socket = await _socket_connect(url_with_params, self.config.headers)
self._socket = await websockets.connect(
url_with_params,
extra_headers=self.config.headers,
ping_interval=PING_INTERVAL,
)
self._exit_event = asyncio.Event()

# listen thread
self._listen_thread = asyncio.create_task(self._listening())
self._keep_alive_thread = asyncio.create_task(self._keep_alive())

# keepalive thread
if self.config.options.get("keepalive") == "true":
self.logger.notice("keepalive is disabled")
self._keep_alive_thread = asyncio.create_task(self._keep_alive())
else:
self.logger.notice("keepalive is disabled")
self._keep_alive_thread = None

# push open event
await self._emit(
Expand All @@ -120,12 +135,30 @@ async def start(
self.logger.debug("AsyncLiveClient.start LEAVE")
return True
except websockets.ConnectionClosed as e:
await self._emit(LiveTranscriptionEvents.Close, e.code)
self.logger.notice("exception: websockets.ConnectionClosed")
self.logger.error("exception: websockets.ConnectionClosed")
self.logger.debug("AsyncLiveClient.start LEAVE")
if self.config.options.get("termination_exception_connect") == "true":
raise
return False
except websockets.exceptions.WebSocketException as e:
self.logger.error("WebSocketException in AsyncLiveClient.start: %s", e)
self.logger.debug("AsyncLiveClient.start LEAVE")
if self.config.options.get("termination_exception_connect") == "true":
raise
return False
except Exception as e:
self.logger.error("WebSocketException in AsyncLiveClient.start: %s", e)
self.logger.debug("AsyncLiveClient.start LEAVE")
if self.config.options.get("termination_exception_connect") == "true":
raise
return False

# registers event handlers for specific events
def on(self, event: LiveTranscriptionEvents, handler) -> None:
"""
Registers event handlers for specific events.
"""
self.logger.info("event fired: %s", event)
if event in LiveTranscriptionEvents and callable(handler):
self._event_handlers[event].append(handler)

Expand All @@ -140,7 +173,7 @@ async def _listening(self) -> None:

while True:
try:
if self.exit_event.is_set():
if self._exit_event.is_set():
self.logger.notice("_listening exiting gracefully")
self.logger.debug("AsyncLiveClient._listening LEAVE")
return
Expand Down Expand Up @@ -218,6 +251,11 @@ async def _listening(self) -> None:
**dict(self.kwargs),
)
case _:
self.logger.warning(
"Unknown Message: response_type: %s, data: %s",
response_type,
data,
)
error = ErrorResponse(
type="UnhandledMessage",
description="Unknown message type",
Expand All @@ -231,6 +269,9 @@ async def _listening(self) -> None:
return

except websockets.exceptions.WebSocketException as e:
self.logger.error(
"WebSocketException in AsyncLiveClient._listening: %s", e
)
error: ErrorResponse = {
"type": "Exception",
"description": "WebSocketException in AsyncLiveClient._listening",
Expand All @@ -242,15 +283,17 @@ async def _listening(self) -> None:
)
await self._emit(LiveTranscriptionEvents.Error, error)

# signal exit and close
await self._signal_exit()

self.logger.debug("AsyncLiveClient._listening LEAVE")

if (
"termination_exception" in self.options
and self.options["termination_exception"] == "true"
):
if self.config.options.get("termination_exception") == "true":
raise
return

except Exception as e:
self.logger.error("Exception in AsyncLiveClient._listening: %s", e)
error: ErrorResponse = {
"type": "Exception",
"description": "Exception in AsyncLiveClient._listening",
Expand All @@ -260,13 +303,14 @@ async def _listening(self) -> None:
self.logger.error("Exception in AsyncLiveClient._listening: %s", str(e))
await self._emit(LiveTranscriptionEvents.Error, error)

# signal exit and close
await self._signal_exit()

self.logger.debug("AsyncLiveClient._listening LEAVE")

if (
"termination_exception" in self.options
and self.options["termination_exception"] == "true"
):
if self.config.options.get("termination_exception") == "true":
raise
return

# keep the connection alive by sending keepalive messages
async def _keep_alive(self) -> None:
Expand All @@ -278,21 +322,18 @@ async def _keep_alive(self) -> None:
counter += 1
await asyncio.sleep(ONE_SECOND)

if self.exit_event.is_set():
if self._exit_event.is_set():
self.logger.notice("_keep_alive exiting gracefully")
self.logger.debug("AsyncLiveClient._keep_alive LEAVE")
return

if self._socket is None:
self.logger.notice("socket is None, exiting keep_alive")
self.logger.debug("AsyncLiveClient._keep_alive LEAVE")
break
return

# deepgram keepalive
if (
counter % DEEPGRAM_INTERVAL == 0
and self.config.options.get("keepalive") == "true"
):
if counter % DEEPGRAM_INTERVAL == 0:
self.logger.verbose("Sending KeepAlive...")
await self.send(json.dumps({"type": "KeepAlive"}))

Expand All @@ -302,6 +343,9 @@ async def _keep_alive(self) -> None:
return

except websockets.exceptions.WebSocketException as e:
self.logger.error(
"WebSocketException in AsyncLiveClient._keep_alive: %s", e
)
error: ErrorResponse = {
"type": "Exception",
"description": "WebSocketException in AsyncLiveClient._keep_alive",
Expand All @@ -313,16 +357,17 @@ async def _keep_alive(self) -> None:
)
await self._emit(LiveTranscriptionEvents.Error, error)

# signal exit and close
await self._signal_exit()

self.logger.debug("AsyncLiveClient._keep_alive LEAVE")

if (
"termination_exception" in self.options
and self.options["termination_exception"] == "true"
):
if self.config.options.get("termination_exception") == "true":
raise
return

except Exception as e:
self.logger.error("Exception in AsyncLiveClient._keep_alive: %s", e)
error: ErrorResponse = {
"type": "Exception",
"description": "Exception in _keep_alive",
Expand All @@ -334,54 +379,99 @@ async def _keep_alive(self) -> None:
)
await self._emit(LiveTranscriptionEvents.Error, error)

# signal exit and close
await self._signal_exit()

self.logger.debug("AsyncLiveClient._keep_alive LEAVE")

if (
"termination_exception" in self.options
and self.options["termination_exception"] == "true"
):
if self.config.options.get("termination_exception") == "true":
raise
return

self.logger.debug("AsyncLiveClient._keep_alive LEAVE")

# sends data over the WebSocket connection
async def send(self, data: Union[str, bytes]) -> bool:
"""
Sends data over the WebSocket connection.
"""
self.logger.spam("AsyncLiveClient.send ENTER")

if self._exit_event.is_set():
self.logger.notice("send exiting gracefully")
self.logger.debug("AsyncLiveClient.send LEAVE")
return False

if self._socket is not None:
try:
await self._socket.send(data)
except websockets.exceptions.ConnectionClosedOK as e:
self.logger.notice(f"send() exiting gracefully: {e.code}")
self.logger.debug("AsyncLiveClient._keep_alive LEAVE")
if self.config.options.get("termination_exception_send") == "true":
raise
return True
except websockets.exceptions.WebSocketException as e:
self.logger.error("send() failed - WebSocketException: %s", str(e))
self.logger.spam("AsyncLiveClient.send LEAVE")
if self.config.options.get("termination_exception_send") == "true":
raise
return False
except Exception as e:
self.logger.error("send() failed - Exception: %s", str(e))
self.logger.spam("AsyncLiveClient.send LEAVE")
if self.config.options.get("termination_exception_send") == "true":
raise
return False

self.logger.spam(f"send() succeeded")
self.logger.spam("AsyncLiveClient.send LEAVE")
return True

self.logger.error("send() failed. socket is None")
self.logger.spam("send() failed. socket is None")
self.logger.spam("AsyncLiveClient.send LEAVE")
return False

# closes the WebSocket connection gracefully
async def finish(self) -> bool:
"""
Closes the WebSocket connection gracefully.
"""
self.logger.debug("AsyncLiveClient.finish ENTER")

# signal exit
self.exit_event.set()
await self._signal_exit()

# close the stream
# stop the threads
self.logger.verbose("cancelling tasks...")
try:
# Before cancelling, check if the tasks were created
tasks = []
if self._keep_alive_thread is not None:
self._keep_alive_thread.cancel()
tasks.append(self._keep_alive_thread)
if self._listen_thread is not None:
self._listen_thread.cancel()
tasks.append(self._listen_thread)

# Use asyncio.gather to wait for tasks to be cancelled
await asyncio.gather(*filter(None, tasks), return_exceptions=True)
self.logger.notice("threads joined")
self._listen_thread = None
self._keep_alive_thread = None

self._socket = None

self.logger.notice("finish succeeded")
self.logger.spam("AsyncLiveClient.finish LEAVE")
return True

except asyncio.CancelledError as e:
self.logger.error("tasks cancelled error: %s", e)
self.logger.debug("AsyncLiveClient.finish LEAVE")
return False

# signals the WebSocket connection to exit
async def _signal_exit(self) -> None:
# send close event
self.logger.verbose("closing socket...")
if self._socket is not None:
self.logger.verbose("send CloseStream...")
Expand All @@ -395,36 +485,15 @@ async def finish(self) -> bool:
CloseResponse(type=LiveTranscriptionEvents.Close.value),
)

# signal exit
self._exit_event.set()

# closes the WebSocket connection gracefully
self.logger.verbose("clean up socket...")
if self._socket is not None:
self.logger.verbose("socket.wait_closed...")
try:
await self._socket.wait_closed()
await self._socket.close()
self._socket = None
except websockets.exceptions.WebSocketException as e:
self.logger.error("socket.wait_closed failed: %s", e)
self._socket = None

self.logger.verbose("cancelling tasks...")
try:
# Before cancelling, check if the tasks were created
if self._listen_thread is not None:
self._listen_thread.cancel()
if self._keep_alive_thread is not None:
self._keep_alive_thread.cancel()

# Use asyncio.gather to wait for tasks to be cancelled
tasks = [self._listen_thread, self._keep_alive_thread]
await asyncio.gather(*filter(None, tasks), return_exceptions=True)

except asyncio.CancelledError as e:
self.logger.error("tasks cancelled error: %s", e)

self.logger.info("finish succeeded")
self.logger.debug("AsyncLiveClient.finish LEAVE")
return True


async def _socket_connect(websocket_url, headers) -> websockets.WebSocketClientProtocol:
destination = websocket_url
updated_headers = headers
return await websockets.connect(
destination, extra_headers=updated_headers, ping_interval=PING_INTERVAL
)
Loading

0 comments on commit 4fa26b3

Please sign in to comment.