Skip to content

Commit

Permalink
Add missing types and format function. Switch to f-strings for logging
Browse files Browse the repository at this point in the history
  • Loading branch information
fkantelberg committed May 4, 2023
1 parent fd6a990 commit 5f0b1bd
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 36 deletions.
4 changes: 2 additions & 2 deletions src/socket_proxy/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def parse_args(args: Tuple[str] = None) -> None:
def run_client(no_curses: bool) -> None:
for arg in ["ca", "connect", "dst"]:
if not getattr(base.config, arg, False):
_logger.critical("Missing --%s argument", arg)
_logger.critical(f"Missing --{arg} argument")
sys.exit(1)

cls = TunnelClient if no_curses else GUIClient
Expand All @@ -430,7 +430,7 @@ def run_client(no_curses: bool) -> None:
def run_server() -> None:
for arg in ["cert", "key"]:
if not getattr(base.config, arg, False):
_logger.critical("Missing --%s argument", arg)
_logger.critical(f"Missing --{arg} argument")
sys.exit(1)

server = ProxyServer(
Expand Down
2 changes: 1 addition & 1 deletion src/socket_proxy/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def start_api(self) -> None:
extras = sorted(filter(None, extras))
extras = f"[{','.join(extras)}]" if extras else ""

_logger.info("Starting API on %s:%s %s", self.api_host, self.api_port, extras)
_logger.info(f"Starting API on {self.api_host}:{self.api_port} {extras}")
self.api = web.Application()
self.api.add_routes(
[
Expand Down
2 changes: 1 addition & 1 deletion src/socket_proxy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class ProtocolType(enum.IntEnum):
TCP = 0x01
HTTP = 0x02

def __str__(self):
def __str__(self) -> str:
return {
ProtocolType.TCP: "TCP",
ProtocolType.HTTP: "HTTP",
Expand Down
7 changes: 2 additions & 5 deletions src/socket_proxy/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,8 @@ def generate_token(self, hotp: bool = False) -> str:

token = str(uuid.uuid4())
self.tokens[token] = None if hotp else datetime.now()
_logger.info(
"Generated authentication token %s [%s]",
token,
"hotp" if hotp else "totp",
)
ttype = "hotp" if hotp else "totp"
_logger.info(f"Generated authentication token {token} [{ttype}]")
self.event.send_nowait(msg="token_generate", token=token, hotp=bool(hotp))
self._persist_state()
return token
Expand Down
8 changes: 4 additions & 4 deletions src/socket_proxy/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def token(self) -> str:
def uuid(self) -> str:
return self.tunnel.uuid

def get_config_dict(self):
def get_config_dict(self) -> dict:
"""Return the configuration as a dictionary used in the API or client"""
return {
"bantime": self.bantime or None,
Expand Down Expand Up @@ -104,10 +104,10 @@ def get_state_dict(self) -> dict:
}

def info(self, msg: str, *args) -> None:
_logger.info("Tunnel %s " + msg, self.uuid, *args)
_logger.info(f"Tunnel {self.uuid} {msg}", *args)

def error(self, msg: str, *args) -> None:
_logger.error("Tunnel %s " + msg, self.uuid, *args)
_logger.error(f"Tunnel {self.uuid} {msg}", *args)

def add(self, client: Connection) -> None:
if client.token in self.clients:
Expand Down Expand Up @@ -147,7 +147,7 @@ async def _disconnect_client(self, token: bytes) -> None:
# Store the traffic information from the disconnecting clients
self.bytes_in += client.bytes_in
self.bytes_out += client.bytes_out
_logger.info("Client %s disconnected", token.hex())
_logger.info(f"Client {token.hex()} disconnected")
await client.close()

async def idle(self) -> None:
Expand Down
12 changes: 6 additions & 6 deletions src/socket_proxy/tunnel_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def idle(self) -> None:
self.last_ping = time.time()
await self.tunnel.tun_write(package.PingPackage(self.last_ping))

def _check_alive(self):
def _check_alive(self) -> bool:
"""Check if the connection is alive using the last ping"""

if self.last_ping is None or self.last_pong is None:
Expand All @@ -76,7 +76,7 @@ def _check_alive(self):

async def _client_loop(self, client: Connection) -> None:
"""This is the main client loop"""
_logger.info("Client %s connected", client.token.hex())
_logger.info(f"Client {client.token.hex()} connected")
while True:
data = await client.read(self.chunk_size)
if not data:
Expand Down Expand Up @@ -168,7 +168,7 @@ async def _handle(self) -> bool:

# Something unexpected happened
if pkg is not None:
self.error("invalid package: %s", pkg)
self.error(f"invalid package: {pkg}")
return await super()._handle()

return await super()._handle()
Expand All @@ -190,8 +190,8 @@ async def loop(self) -> None:
self.tunnel = await Connection.connect(self.host, self.port, ssl=self.sc)
ssl_obj = self.tunnel.writer.get_extra_info("ssl_object")
extra = f" [{ssl_obj.version()}]" if ssl_obj else ""
_logger.info("Tunnel %s:%s connected%s", self.host, self.port, extra)
_logger.info("Forwarding to %s:%s", self.dst_host, self.dst_port)
_logger.info(f"Tunnel {self.host}:{self.port} connected{extra}")
_logger.info(f"Forwarding to {self.dst_host}:{self.dst_port}")

if self.api_port:
asyncio.create_task(self.start_api())
Expand All @@ -214,7 +214,7 @@ async def loop(self) -> None:
finally:
self.running = False
await self.stop()
_logger.info("Tunnel %s:%s closed", self.host, self.port)
_logger.info(f"Tunnel {self.host}:{self.port} closed")

def start(self) -> None:
"""Start the client and the event loop"""
Expand Down
12 changes: 9 additions & 3 deletions src/socket_proxy/tunnel_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from logging.handlers import QueueHandler
from typing import List

from .base import LOG_FORMAT, InternetType
from .base import LOG_FORMAT, InternetType, IPvXAddress
from .tunnel_client import TunnelClient
from .utils import format_transfer

Expand Down Expand Up @@ -33,6 +33,11 @@ def get_dimension(self) -> None:
"""Get the dimensions of the current window"""
self.height, self.width = self.scr.getmaxyx()

# pylint: disable=W0613,R0201
def fmt_port(self, ip_type: InternetType, ip: IPvXAddress, port: int) -> str:
"""Format an address"""
return f"{ip}:{port}" if ip else str(port)

def _draw(self) -> None:
"""Draw all GUI elements"""
self.scr.clear()
Expand Down Expand Up @@ -115,7 +120,8 @@ def _draw_log(self) -> None:
win.refresh()
return win

def _draw_lines(self, win, lines: List[str]) -> None: # pylint: disable=R0201
# disable: pylint=R0201
def _draw_lines(self, win: curses.window, lines: List[str]) -> None:
"""Draw multiple lines in a window with some border"""
h, w = [k - 2 for k in win.getmaxyx()]
for y, line in enumerate(lines[:h]):
Expand All @@ -127,7 +133,7 @@ async def _handle(self) -> bool:
self._draw()
return await super()._handle()

def _gui(self, scr) -> None:
def _gui(self, scr: curses.window) -> None:
"""Configure the main screen"""
self.scr = scr
curses.noecho()
Expand Down
18 changes: 9 additions & 9 deletions src/socket_proxy/tunnel_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def idle(self) -> None:
for ip, ban in list(self.connections.items()):
if ban.first < dt:
self.connections.pop(ip)
_logger.info("Connection number of %s resetted", ip)
_logger.info(f"Connection number of {ip} resetted")

async def _client_accept(
self,
Expand All @@ -76,7 +76,7 @@ async def _client_accept(
writer.close()
await writer.wait_closed()

_logger.info("Connection from %s blocked", ip)
_logger.info(f"Connection from {ip} blocked")
await self.event.send(msg="client_blocked", tunnel=self.uuid, ip=str(ip))
return

Expand All @@ -86,7 +86,7 @@ async def _client_accept(
client = Connection(reader, writer, self.protocol, utils.generate_token())
self.add(client)

_logger.info("Client %s connected on %s:%s", client.uuid, host, port)
_logger.info(f"Client {client.uuid} connected on {host}:{port}")
await self.event.send(
msg="client_connect",
tunnel=self.uuid,
Expand Down Expand Up @@ -149,7 +149,7 @@ async def _client_loop(self) -> None:

# Initialize the tunnel by sending the appropiate data
out = " ".join(sorted(f"{host}:{port}" for host, port in self.addr))
self.info("Listen on %s", out)
self.info(f"Listen on {out}")

addr = [(base.InternetType.from_ip(ip), ip, port) for ip, port in self.addr]
pkg = package.InitPackage(self.token, addr, self.domain)
Expand All @@ -168,10 +168,10 @@ async def _handle(self) -> bool:
self.error(f"Disabled protocol {self.protocol.name}")
return False

self.info("Using protocol: %s", self.protocol.name)
self.info(f"Using protocol: {self.protocol.name}")

if self.protocol != base.ProtocolType.TCP:
self.info("Reachable with domain: %s", self.domain)
self.info(f"Reachable with domain: {self.domain}")
pkg = package.InitPackage(self.token, [], self.domain)
await self.tunnel.tun_write(pkg)
elif not await self._open_server():
Expand Down Expand Up @@ -199,7 +199,7 @@ async def _handle(self) -> bool:
if isinstance(pkg, package.ClientDataPackage):
# Check for valid tokens
if pkg.token not in self:
self.error("Invalid client token: %s", pkg.token)
self.error(f"Invalid client token: {pkg.token}")
return False

conn = self[pkg.token]
Expand All @@ -208,7 +208,7 @@ async def _handle(self) -> bool:

# Invalid package means to close the connection
if pkg is not None:
self.error("Invalid package: %s", pkg)
self.error(f"Invalid package: {pkg}")
return await super()._handle()

return await super()._handle()
Expand All @@ -227,7 +227,7 @@ async def loop(self) -> None:
"""Main loop of the proxy tunnel"""
ssl_obj = self.tunnel.writer.get_extra_info("ssl_object")
extra = f" [{ssl_obj.version()}]" if ssl_obj else ""
self.info("Connected %s:%s%s", self.host, self.port, extra)
self.info(f"Connected {self.host}:{self.port}{extra}")

try:
await self._serve()
Expand Down
10 changes: 5 additions & 5 deletions src/socket_proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,14 @@ def generate_ssl_context(
ctx.set_ciphers(ciphers)

# Output debugging
_logger.info("CA usage: %s", bool(ca))
_logger.info("Certificate: %s", bool(cert))
_logger.info("Hostname verification: %s", bool(check_hostname))
_logger.info(f"CA usage: {bool(ca)}")
_logger.info(f"Certificate: {bool(cert)}")
_logger.info(f"Hostname verification: {bool(check_hostname)}")
# pylint: disable=no-member
_logger.info("Minimal TLS Version: %s", ctx.minimum_version.name)
_logger.info(f"Minimal TLS Version: {ctx.minimum_version.name}")

ciphers = sorted(c["name"] for c in ctx.get_ciphers())
_logger.info("Ciphers: %s", ", ".join(ciphers))
_logger.info(f"Ciphers: {', '.join(ciphers)}")

return ctx

Expand Down

0 comments on commit 5f0b1bd

Please sign in to comment.