Skip to content

Commit

Permalink
Add additional typing hints
Browse files Browse the repository at this point in the history
  • Loading branch information
fkantelberg committed Mar 30, 2024
1 parent 8812d3c commit 332568e
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 35 deletions.
6 changes: 3 additions & 3 deletions src/socket_proxy/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

try:
from aiohttp import web
from aiohttp.web import AppRunner, Request, Response, TCPSite
from aiohttp.web import Application, AppRunner, Request, Response, TCPSite
except ImportError:
web = AppRunner = Request = Response = TCPSite = None # type: ignore
web = Application = AppRunner = Request = Response = TCPSite = None # type: ignore

_logger = logging.getLogger(__name__)

Expand All @@ -21,7 +21,7 @@ class APIType(enum.IntEnum):


async def run_app(
api,
api: Application,
host: Optional[str] = None,
port: Optional[int] = None,
ssl_context: Optional[ssl.SSLContext] = None,
Expand Down
31 changes: 16 additions & 15 deletions src/socket_proxy/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ class PingPackage(Package):

TIMESTAMP = PackageStruct("!d")

def __init__(self, timestamp, *args, **kwargs):
def __init__(self, timestamp: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.time = timestamp
self.time: int = timestamp

def to_bytes(self) -> bytes:
return super().to_bytes() + self.TIMESTAMP.pack(self.time)
Expand All @@ -179,8 +179,8 @@ class AuthPackage(Package):

def __init__(self, token: str, token_type: base.AuthType, *args, **kwargs):
super().__init__(*args, **kwargs)
self.token = token
self.token_type = token_type
self.token: str = token
self.token_type: base.AuthType = token_type

def to_bytes(self) -> bytes:
return (
Expand Down Expand Up @@ -221,9 +221,9 @@ def __init__(
**kwargs,
):
super().__init__(*args, **kwargs)
self.token = token
self.addresses = addresses[:255]
self.domain = domain
self.token: bytes = token
self.addresses: base.IPvXPorts = addresses[:255]
self.domain: str = domain

def to_bytes(self) -> bytes:
data = super().to_bytes() + self.INIT.pack(self.token, len(self.addresses))
Expand Down Expand Up @@ -281,11 +281,11 @@ def __init__(
**kwargs,
):
super().__init__(*args, **kwargs)
self.bantime = bantime
self.clients = clients
self.connects = connects
self.idle_timeout = idle_timeout
self.networks = networks
self.bantime: int = bantime
self.clients: int = clients
self.connects: int = connects
self.idle_timeout: int = idle_timeout
self.networks: base.IPvXNetworks = networks

def to_bytes(self) -> bytes:
config = self.CONFIG.pack(
Expand Down Expand Up @@ -323,7 +323,7 @@ class ClientPackage(Package):

def __init__(self, token: bytes, *args, **kwargs):
super().__init__(*args, **kwargs)
self.token = token
self.token: bytes = token

def to_bytes(self) -> bytes:
return super().to_bytes() + self.TOKEN.pack(self.token)
Expand All @@ -348,7 +348,8 @@ class ClientInitPackage(ClientPackage):

def __init__(self, ip: base.IPvXAddress, port: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ip, self.port = ip, port
self.ip: base.IPvXAddress = ip
self.port: int = port

def to_bytes(self) -> bytes:
ip_type: base.InternetType = base.InternetType.from_ip(self.ip)
Expand Down Expand Up @@ -389,7 +390,7 @@ class ClientDataPackage(ClientPackage):

def __init__(self, data: bytes, *args, **kwargs):
super().__init__(*args, **kwargs)
self.data = data
self.data: bytes = data

def to_bytes(self) -> bytes:
data = b""
Expand Down
21 changes: 11 additions & 10 deletions src/socket_proxy/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import re
import ssl
import uuid
from asyncio import StreamReader, StreamWriter
from collections import defaultdict
Expand Down Expand Up @@ -34,12 +35,13 @@ def __init__(
**kwargs: Any,
):
super().__init__(api_type=api.APIType.Server)
self.kwargs = kwargs
self.host, self.port = host, port
self.max_tunnels = base.config.max_tunnels
self.http_ssl = base.config.http_ssl
self.kwargs: Dict[str, Any] = kwargs
self.host: Union[str, List[str]] = host
self.port: int = port
self.max_tunnels: int = base.config.max_tunnels
self.http_ssl: bool = base.config.http_ssl
self.tunnels: Dict[str, TunnelServer] = {}
self.sc = utils.generate_ssl_context(
self.sc: ssl.SSLContext = utils.generate_ssl_context(
cert=cert,
key=key,
ca=ca,
Expand All @@ -48,16 +50,17 @@ def __init__(
)

# Authentication
self.authentication = authentication
self.authentication: bool = authentication
self.tokens: Dict[base.AuthType, dict] = defaultdict(dict)
self.auth_timeout = auth_timeout
self.auth_timeout: int = auth_timeout

self.event = event.EventSystem(
self.event: event.EventSystem = event.EventSystem(
event.EventType.Server,
url=base.config.hook_url,
token=base.config.hook_token,
)

self.http_domain: str = ""
self.http_host: Optional[str] = None
self.http_port: Optional[str] = None
self.http_domain_regex: Optional[re.Pattern] = None
Expand All @@ -67,8 +70,6 @@ def __init__(
self.http_domain_regex = re.compile(
rb"^(.*)\.%s$" % self.http_domain.replace(".", r"\.").encode()
)
else:
self.http_domain = ""

self._load_persisted_state()

Expand Down
4 changes: 2 additions & 2 deletions src/socket_proxy/tunnel_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def idle(self) -> None:

# Send a ping regularly
self.last_ping = time.time()
await self.tunnel.tun_write(package.PingPackage(self.last_ping))
await self.tunnel.tun_write(package.PingPackage(int(self.last_ping)))

def _check_alive(self) -> bool:
"""Check if the connection is alive using the last ping"""
Expand Down Expand Up @@ -153,7 +153,7 @@ async def _handle(self) -> bool:

# Handle a ping package and reply
if isinstance(pkg, package.PingPackage):
self.last_pong = time.time()
self.last_pong = int(time.time())
return True

# A new client connected on the other side of the tunnel
Expand Down
4 changes: 2 additions & 2 deletions src/socket_proxy/tunnel_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def __init__(
self.connections: Dict[base.IPvXAddress, utils.Ban] = collections.defaultdict(
utils.Ban
)
self.protocols = protocols or utils.protocols()
self.event = event
self.protocols: List[base.ProtocolType] = protocols or utils.protocols()
self.event: EventSystem = event

def block(self, ip: base.IPvXAddress) -> bool:
"""Decide whether the ip should be blocked"""
Expand Down
6 changes: 3 additions & 3 deletions src/socket_proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sys
from datetime import datetime, timedelta
from random import shuffle
from typing import Any, List, Optional, Sequence, Set, Tuple, Union
from typing import Any, List, Optional, Sequence, Tuple, Union
from urllib.parse import urlsplit

from . import base
Expand Down Expand Up @@ -306,7 +306,7 @@ def parse_networks(network: str) -> base.IPvXNetworks:
raise argparse.ArgumentTypeError("Invalid network format") from e


def protocols() -> Set[base.ProtocolType]:
def protocols() -> List[base.ProtocolType]:
result = set()
for protocol in base.ProtocolType:
name = f"no-{protocol.name.lower()}".replace("-", "_")
Expand All @@ -315,7 +315,7 @@ def protocols() -> Set[base.ProtocolType]:

if not base.config.http_domain:
result.discard(base.ProtocolType.HTTP)
return result
return list(result)


def to_bool(val: Any) -> bool:
Expand Down

0 comments on commit 332568e

Please sign in to comment.