Skip to content

Commit

Permalink
refactor: use Annotated to handle struct packing
Browse files Browse the repository at this point in the history
  • Loading branch information
leoslf committed Jun 9, 2024
1 parent 8cbf208 commit 3508f7a
Show file tree
Hide file tree
Showing 5 changed files with 625 additions and 211 deletions.
209 changes: 137 additions & 72 deletions socks_router/models.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,62 @@
from __future__ import annotations
from typing import Literal, Optional
from subprocess import Popen

from typing import Any, Annotated, Final, Literal, Optional, Type, Self, Protocol, runtime_checkable, overload, assert_never
from abc import abstractmethod
from collections.abc import Callable, Mapping, MutableMapping
from collections.abc import Mapping, MutableMapping
from enum import IntEnum, StrEnum, auto
from dataclasses import dataclass, field
from ipaddress import IPv4Address, IPv6Address
from subprocess import Popen

import struct
import threading
import ipaddress

SOCKS_VERSION: Literal[5] = 5

type PackingSequence = str | tuple[str, str]

type RecursiveMapping[K, V] = Mapping[K, V | RecursiveMapping[K, V]]

PACKABLE_DEFERRED_FORMAT: Final[str] = "&"
PACKABLE_VARIABLE_LENGTH_DECLARATION_FORMAT: Final[str] = "%*"


@runtime_checkable
class Packable(Protocol):
@classmethod
def __pack_format__(cls) -> str: ...


@runtime_checkable
class SupportsUnbytes(Protocol):
@classmethod
@abstractmethod
def __unbytes__(cls, input: bytes) -> Self: ...


# class BytePackable(metaclass=ABCMeta):
# @classmethod
# def __pack_format__(cls) -> str:
# return "!B"

# class ByteEnum(IntEnum):
# def __new__(cls, value: int) -> Self:
# if not 0 <= value < 256:
# raise ValueError("value must be 0 <= value < 256")
# return super().__new__(cls, value)
#
# @classmethod
# def __pack_format__(cls) -> str:
# return "!B"


@dataclass(frozen=True)
class SocketAddress:
address: str
port: Optional[int] = None
address: Any = field()
port: Annotated[Optional[int], "!H"] = None

def __str__(self):
if self.port is None:
return self.address
return f"{self.address}"
return f"{self.address}:{self.port}"

@property
Expand All @@ -29,44 +65,51 @@ def pattern(self) -> str:

@property
def sockaddr(self) -> tuple[str, int]:
return self.address, self.port or 0
return f"{self.address}", self.port or 0

@property
@abstractmethod
def type(self) -> Socks5AddressType: ...

@property
def packed_type(self) -> bytes:
return struct.pack("!B", self.type)
@dataclass(frozen=True)
class IPv4(SocketAddress):
address: IPv4.IPv4Address

@property
@abstractmethod
def packed_address(self) -> bytes: ...
class IPv4Address(ipaddress.IPv4Address):
@classmethod
def __pack_format__(cls) -> str:
return "!4B"

@property
def packed_port(self) -> bytes:
return struct.pack("!H", self.port or 0)
def __bytes__(self) -> bytes:
return self.packed

@classmethod
def __unbytes__(cls, input: bytes) -> Self:
return cls(input)

def __bytes__(self) -> bytes:
return self.packed_type + self.packed_address + self.packed_port
def __init__(self, address: str | IPv4.IPv4Address, *argv, **kwargs):
if isinstance(address, str):
address = IPv4.IPv4Address(address)
super().__init__(address, *argv, **kwargs)


@dataclass(frozen=True)
class IPv4(SocketAddress):
@property
def type(self) -> Socks5AddressType:
return Socks5AddressType.IPv4
class IPv6(SocketAddress):
address: IPv6.IPv6Address

@property
def packed_address(self) -> bytes:
return IPv4Address(self.address).packed
class IPv6Address(ipaddress.IPv6Address):
@classmethod
def __pack_format__(cls) -> str:
return "!6B"

def __bytes__(self) -> bytes:
return self.packed

@dataclass(frozen=True)
class IPv6(SocketAddress):
@property
def type(self) -> Socks5AddressType:
return Socks5AddressType.IPv6
@classmethod
def __unbytes__(cls, input: bytes) -> Self:
return cls(input)

def __init__(self, address: str | IPv6.IPv6Address, *argv, **kwargs):
if isinstance(address, str):
address = IPv6.IPv6Address(address)
super().__init__(address, *argv, **kwargs)

def __str__(self):
if self.port is None:
Expand All @@ -77,22 +120,10 @@ def __str__(self):
def pattern(self):
return f"[{self.address}]:{self.port or '*'}"

@property
def packed_address(self) -> bytes:
return IPv6Address(self.address).packed


@dataclass(frozen=True)
class Host(SocketAddress):
@property
def type(self) -> Socks5AddressType:
return Socks5AddressType.DOMAINNAME

@property
def packed_address(self) -> bytes:
encoded = self.address.encode("utf-8")
assert len(encoded) < 255, "can only carry less than 255 bytes for host"
return struct.pack("!B", len(encoded)) + encoded
address: Annotated[str, "!B%*s"]


type Address = IPv4 | IPv6 | Host
Expand All @@ -108,6 +139,10 @@ class Socks5Method(IntEnum):
# RESERVED_FOR_PRIVATE_METHODS = frozenset(range(0x80, 0xFF)) # 0x80..0xFE
NO_ACCEPTABLE_METHODS = 0xFF

@classmethod
def __pack_format__(cls) -> str:
return "!B"


class Socks5Command(IntEnum):
"""SEE: https://datatracker.ietf.org/doc/html/rfc1928#section-4"""
Expand All @@ -116,6 +151,10 @@ class Socks5Command(IntEnum):
BIND = 0x02
UDP_ASSOCIATE = 0x03

@classmethod
def __pack_format__(cls) -> str:
return "!B"


class Socks5AddressType(IntEnum):
"""SEE: https://datatracker.ietf.org/doc/html/rfc1928#section-4"""
Expand All @@ -124,6 +163,10 @@ class Socks5AddressType(IntEnum):
DOMAINNAME = 0x03
IPv6 = 0x04

@classmethod
def __pack_format__(cls) -> str:
return "!B"


@dataclass
class Socks5MethodSelectionRequest:
Expand All @@ -135,12 +178,8 @@ class Socks5MethodSelectionRequest:
| 1 byte | 1 byte | [method_count] bytes |
"""

version: int
methods: list[int]

def __bytes__(self) -> bytes:
assert len(self.methods) < 256
return struct.pack("!BB", self.version, len(self.methods)) + bytes(self.methods)
version: Annotated[int, "!B"]
methods: Annotated[list[int], "!B%*B"]


@dataclass
Expand All @@ -152,27 +191,46 @@ class Socks5MethodSelectionResponse:
| 1 byte | 1 byte |
"""

version: int
version: Annotated[int, "!B"]
method: Socks5Method

def __bytes__(self) -> bytes:
return struct.pack("!BB", self.version, self.method)

@dataclass(frozen=True)
class Socks5Address:
@overload
def address_type(cls, type: Literal[Socks5AddressType.IPv4]) -> Type[IPv4]: ...
@overload
def address_type(cls, type: Literal[Socks5AddressType.DOMAINNAME]) -> Type[Host]: ...
@overload
def address_type(cls, type: Literal[Socks5AddressType.IPv6]) -> Type[IPv6]: ...

def address_type(cls, type: Socks5AddressType) -> Type[IPv4] | Type[Host] | Type[IPv6]:
match type:
case Socks5AddressType.IPv4:
return IPv4
case Socks5AddressType.DOMAINNAME:
return Host
case Socks5AddressType.IPv6:
return IPv6
case _ as unreachable:
assert_never(unreachable)

type: Socks5AddressType
sockaddr: Annotated[Address, "&", "type", "address_type"]


@dataclass(frozen=True)
class Socks5Request:
"""SEE: https://datatracker.ietf.org/doc/html/rfc1928#section-4"""

version: int
version: Annotated[int, "!B"]
command: Socks5Command
reserved: Literal[0x00]
address_type: Socks5AddressType
destination_address: str
destination_port: int
reserved: Annotated[int, "!B"]
destination: Socks5Address

@property
def destination(self):
return Socks5Addresses[self.address_type](self.destination_address, self.destination_port)
# @property
# def destination(self):
# return Socks5Addresses[self.address_type](self.destination_address, self.destination_port)


class Socks5ReplyType(IntEnum):
Expand All @@ -189,6 +247,10 @@ class Socks5ReplyType(IntEnum):
ADDRESS_TYPE_NOT_SUPPORTED = 0x08
# UNASSIGNED = frozenset(range(0x09, 0x100)) # 0x09..0xFF

@classmethod
def __pack_format__(cls) -> str:
return "!B"


@dataclass
class Socks5Reply:
Expand All @@ -199,13 +261,10 @@ class Socks5Reply:
| 1 byte | 1 byte | 0x00 | 1 byte | 4-255 bytes | 2 bytes |
"""

version: Literal[5]
version: Annotated[int, "!B"]
reply: Socks5ReplyType
reserved: Literal[0] = 0x00
address: Address = IPv4("0.0.0.0", 0)

def __bytes__(self) -> bytes:
return struct.pack("!BBB", self.version, self.reply, self.reserved) + bytes(self.address)
reserved: Annotated[int, "!B"] = 0x00
destination: Socks5Address = Socks5Address(Socks5AddressType.IPv4, IPv4("0.0.0.0", 0))


class Socks5State(StrEnum):
Expand Down Expand Up @@ -243,12 +302,18 @@ def __str__(self):

type RoutingTable = Mapping[UpstreamAddress, RoutingEntry]

Socks5Addresses: Mapping[Socks5AddressType, Callable[[str, Optional[int]], Address]] = {
Socks5Addresses: Mapping[Socks5AddressType, type[Address]] = {
Socks5AddressType.IPv4: IPv4,
Socks5AddressType.IPv6: IPv6,
Socks5AddressType.DOMAINNAME: Host,
}

Socks5AddressTypes = {
IPv4: Socks5AddressType.IPv4,
IPv6: Socks5AddressType.IPv6,
Host: Socks5AddressType.DOMAINNAME,
}


@dataclass
class SSHUpstream:
Expand Down
Loading

0 comments on commit 3508f7a

Please sign in to comment.