diff --git a/.changeset/curly-pumpkins-kick.md b/.changeset/curly-pumpkins-kick.md new file mode 100644 index 00000000..f8b47d56 --- /dev/null +++ b/.changeset/curly-pumpkins-kick.md @@ -0,0 +1,5 @@ +--- +'@e2b/code-interpreter-template': patch +--- + +Add retry diff --git a/.changeset/wicked-mirrors-punch.md b/.changeset/wicked-mirrors-punch.md new file mode 100644 index 00000000..2050d45d --- /dev/null +++ b/.changeset/wicked-mirrors-punch.md @@ -0,0 +1,5 @@ +--- +'@e2b/code-interpreter-python': patch +--- + +Fix issue with secure False diff --git a/python/e2b_code_interpreter/code_interpreter_async.py b/python/e2b_code_interpreter/code_interpreter_async.py index 02c3e99a..b8129354 100644 --- a/python/e2b_code_interpreter/code_interpreter_async.py +++ b/python/e2b_code_interpreter/code_interpreter_async.py @@ -191,6 +191,10 @@ async def run_code( request_timeout = request_timeout or self.connection_config.request_timeout context_id = context.id if context else None + headers: Dict[str, str] = {} + if self._envd_access_token: + headers = {"X-Access-Token": self._envd_access_token} + try: async with self._client.stream( "POST", @@ -201,7 +205,7 @@ async def run_code( "language": language, "env_vars": envs, }, - headers={"X-Access-Token": self._envd_access_token}, + headers=headers, timeout=(request_timeout, timeout, request_timeout, request_timeout), ) as response: err = await aextract_exception(response) @@ -249,10 +253,14 @@ async def create_code_context( if cwd: data["cwd"] = cwd + headers: Dict[str, str] = {} + if self._envd_access_token: + headers = {"X-Access-Token": self._envd_access_token} + try: response = await self._client.post( f"{self._jupyter_url}/contexts", - headers={"X-Access-Token": self._envd_access_token}, + headers=headers, json=data, timeout=request_timeout or self.connection_config.request_timeout, ) diff --git a/python/e2b_code_interpreter/code_interpreter_sync.py b/python/e2b_code_interpreter/code_interpreter_sync.py index 978c6dc5..6cf56c11 100644 --- a/python/e2b_code_interpreter/code_interpreter_sync.py +++ b/python/e2b_code_interpreter/code_interpreter_sync.py @@ -188,6 +188,10 @@ def run_code( request_timeout = request_timeout or self.connection_config.request_timeout context_id = context.id if context else None + headers: Dict[str, str] = {} + if self._envd_access_token: + headers = {"X-Access-Token": self._envd_access_token} + try: with self._client.stream( "POST", @@ -198,7 +202,7 @@ def run_code( "language": language, "env_vars": envs, }, - headers={"X-Access-Token": self._envd_access_token}, + headers=headers, timeout=(request_timeout, timeout, request_timeout, request_timeout), ) as response: err = extract_exception(response) @@ -246,11 +250,15 @@ def create_code_context( if cwd: data["cwd"] = cwd + headers: Dict[str, str] = {} + if self._envd_access_token: + headers = {"X-Access-Token": self._envd_access_token} + try: response = self._client.post( f"{self._jupyter_url}/contexts", json=data, - headers={"X-Access-Token": self._envd_access_token}, + headers=headers, timeout=request_timeout or self.connection_config.request_timeout, ) diff --git a/template/server/messaging.py b/template/server/messaging.py index e541351f..0b151f8e 100644 --- a/template/server/messaging.py +++ b/template/server/messaging.py @@ -12,6 +12,10 @@ ) from pydantic import StrictStr from websockets.client import WebSocketClientProtocol, connect +from websockets.exceptions import ( + ConnectionClosedError, + WebSocketException, +) from api.models.error import Error from api.models.logs import Stdout, Stderr @@ -27,6 +31,9 @@ logger = logging.getLogger(__name__) +MAX_RECONNECT_RETRIES = 3 +PING_TIMEOUT = 30 + class Execution: def __init__(self, in_background: bool = False): @@ -61,6 +68,15 @@ def __init__(self, context_id: str, session_id: str, language: str, cwd: str): self._executions: Dict[str, Execution] = {} self._lock = asyncio.Lock() + async def reconnect(self): + if self._ws is not None: + await self._ws.close(reason="Reconnecting") + + if self._receive_task is not None: + await self._receive_task + + await self.connect() + async def connect(self): logger.debug(f"WebSocket connecting to {self.url}") @@ -69,6 +85,7 @@ async def connect(self): self._ws = await connect( self.url, + ping_timeout=PING_TIMEOUT, max_size=None, max_queue=None, logger=ws_logger, @@ -274,9 +291,6 @@ async def execute( env_vars: Dict[StrictStr, str], access_token: str, ): - message_id = str(uuid.uuid4()) - self._executions[message_id] = Execution() - if self._ws is None: raise Exception("WebSocket not connected") @@ -313,13 +327,40 @@ async def execute( ) complete_code = f"{indented_env_code}\n{complete_code}" - logger.info( - f"Sending code for the execution ({message_id}): {complete_code}" - ) - request = self._get_execute_request(message_id, complete_code, False) + message_id = str(uuid.uuid4()) + execution = Execution() + self._executions[message_id] = execution # Send the code for execution - await self._ws.send(request) + # Initial request and retries + for i in range(1 + MAX_RECONNECT_RETRIES): + try: + logger.info( + f"Sending code for the execution ({message_id}): {complete_code}" + ) + request = self._get_execute_request( + message_id, complete_code, False + ) + await self._ws.send(request) + break + except (ConnectionClosedError, WebSocketException) as e: + # Keep the last result, even if error + if i < MAX_RECONNECT_RETRIES: + logger.warning( + f"WebSocket connection lost while sending execution request, {i + 1}. reconnecting...: {str(e)}" + ) + await self.reconnect() + else: + # The retry didn't help, request wasn't sent successfully + logger.error("Failed to send execution request") + await execution.queue.put( + Error( + name="WebSocketError", + value="Failed to send execution request", + traceback="", + ) + ) + await execution.queue.put(UnexpectedEndOfExecution()) # Stream the results async for item in self._wait_for_result(message_id): @@ -343,6 +384,18 @@ async def _receive_message(self): await self._process_message(json.loads(message)) except Exception as e: logger.error(f"WebSocket received error while receiving messages: {str(e)}") + finally: + # To prevent infinite hang, we need to cancel all ongoing execution as we could lost results during the reconnect + # Thanks to the locking, there can be either no ongoing execution or just one. + for key, execution in self._executions.items(): + await execution.queue.put( + Error( + name="WebSocketError", + value="The connections was lost, rerun the code to get the results", + traceback="", + ) + ) + await execution.queue.put(UnexpectedEndOfExecution()) async def _process_message(self, data: dict): """