diff --git a/factorio_rcon/_factorio_rcon.py b/factorio_rcon/_factorio_rcon.py index 5820b5b..9f6545b 100644 --- a/factorio_rcon/_factorio_rcon.py +++ b/factorio_rcon/_factorio_rcon.py @@ -3,6 +3,7 @@ import functools import socket import struct +from types import TracebackType from typing import Any, Callable, Dict, NamedTuple, Optional, TypeVar, cast try: @@ -403,6 +404,20 @@ def send_commands(self, commands: Dict[T, str]) -> Dict[T, Optional[str]]: results[id_map[response.id]] = response.body.rstrip() return results + @handle_socket_errors(alive_socket_required=False) + def __enter__(self) -> "RCONClient": + if self.rcon_socket is None or self.rcon_failure: + self.connect() + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.close() + class AsyncRCONClient(RCONSharedBase): """Asynchronous RCON client for Factorio servers @@ -617,6 +632,20 @@ async def send_commands(self, commands: Dict[T, str]) -> Dict[T, Optional[str]]: results[id_map[response.id]] = response.body.rstrip() return results + @async_handle_socket_errors(alive_socket_required=False) + async def __aenter__(self) -> "AsyncRCONClient": + if self.rcon_socket is None or self.rcon_failure: + await self.connect() + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + await self.close() + INVALID_PASS = "The RCON password is incorrect" INVALID_ID = (