Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions factorio_rcon/_factorio_rcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
import socket
import struct
from types import TracebackType
from typing import Any, Callable, Dict, NamedTuple, Optional, TypeVar, cast

try:
Expand Down Expand Up @@ -403,6 +404,20 @@ def send_commands(self, commands: Dict[T, str]) -> Dict[T, Optional[str]]:
results[id_map[response.id]] = response.body.rstrip()
return results

@handle_socket_errors(alive_socket_required=False)
def __enter__(self) -> "RCONClient":
if self.rcon_socket is None or self.rcon_failure:
self.connect()
return self

def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.close()


class AsyncRCONClient(RCONSharedBase):
"""Asynchronous RCON client for Factorio servers
Expand Down Expand Up @@ -617,6 +632,20 @@ async def send_commands(self, commands: Dict[T, str]) -> Dict[T, Optional[str]]:
results[id_map[response.id]] = response.body.rstrip()
return results

@async_handle_socket_errors(alive_socket_required=False)
async def __aenter__(self) -> "AsyncRCONClient":
if self.rcon_socket is None or self.rcon_failure:
await self.connect()
return self

async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
await self.close()


INVALID_PASS = "The RCON password is incorrect"
INVALID_ID = (
Expand Down