From 0f322c02576989442e5eca44e26c3dca2914b943 Mon Sep 17 00:00:00 2001 From: Ian Good Date: Thu, 20 May 2021 22:12:20 -0400 Subject: [PATCH] Add DNSBL result to extension TLV --- proxyprotocol/__init__.py | 4 +- proxyprotocol/detect.py | 6 +- proxyprotocol/dnsbl.py | 146 +++++++++++++++++++++++++++++++ proxyprotocol/noop.py | 3 +- proxyprotocol/server/echo.py | 4 + proxyprotocol/server/main.py | 10 ++- proxyprotocol/server/protocol.py | 40 +++++---- proxyprotocol/sock.py | 19 +++- proxyprotocol/tlv.py | 39 +++++++-- proxyprotocol/v1.py | 3 +- proxyprotocol/v2.py | 23 ++--- setup.py | 2 +- test/test_dnsbl.py | 95 ++++++++++++++++++++ test/test_sock.py | 13 +++ test/test_tlv.py | 10 ++- test/test_v2.py | 11 ++- 16 files changed, 378 insertions(+), 50 deletions(-) create mode 100644 proxyprotocol/dnsbl.py create mode 100644 test/test_dnsbl.py diff --git a/proxyprotocol/__init__.py b/proxyprotocol/__init__.py index b92c6ab..5740ecd 100644 --- a/proxyprotocol/__init__.py +++ b/proxyprotocol/__init__.py @@ -147,7 +147,8 @@ def build(self, source: Address, dest: Address, *, family: AddressFamily, protocol: Optional[SocketKind] = None, ssl: Union[None, SSLObject, SSLSocket] = None, unique_id: Optional[bytes] = None, - proxied: bool = True) -> bytes: + proxied: bool = True, + dnsbl: Optional[str] = None) -> bytes: """Builds a PROXY protocol v1 header that may be sent at the beginning of an outbound, client-side connection to indicate the original information about the connection. @@ -160,6 +161,7 @@ def build(self, source: Address, dest: Address, *, family: AddressFamily, ssl: The original socket SSL information. unique_id: The original connection unique identifier. proxied: True if the connection should be considered proxied. + dnsbl: The DNSBL lookup result, if any. Raises: :exc:`KeyError`: This PROXY protocol header format does not support diff --git a/proxyprotocol/detect.py b/proxyprotocol/detect.py index 262f8c8..866688e 100644 --- a/proxyprotocol/detect.py +++ b/proxyprotocol/detect.py @@ -54,12 +54,14 @@ def build(self, source: Address, dest: Address, *, family: AddressFamily, protocol: Optional[SocketKind] = None, ssl: Union[None, SSLSocket, SSLObject] = None, unique_id: Optional[bytes] = None, - proxied: bool = True) -> bytes: + proxied: bool = True, + dnsbl: Optional[str] = None) -> bytes: for version in self.versions: try: return version.build(source, dest, family=family, protocol=protocol, ssl=ssl, - unique_id=unique_id, proxied=proxied) + unique_id=unique_id, proxied=proxied, + dnsbl=dnsbl) except (KeyError, ValueError): pass else: diff --git a/proxyprotocol/dnsbl.py b/proxyprotocol/dnsbl.py new file mode 100644 index 0000000..08a4a05 --- /dev/null +++ b/proxyprotocol/dnsbl.py @@ -0,0 +1,146 @@ + +from __future__ import annotations + +import asyncio +from abc import abstractmethod, ABCMeta +from asyncio import AbstractEventLoop, TimeoutError +from ipaddress import IPv4Address, IPv4Network +from socket import AF_INET, SOCK_STREAM +from typing import Optional, Sequence +from typing_extensions import Final + +from .sock import SocketInfo + +__all__ = ['Dnsbl', 'NoopDnsbl', 'BasicDnsbl', 'SpamhausDnsbl'] + + +class Dnsbl(metaclass=ABCMeta): + """Manages the optional lookup of the connecting IP address against a + trusted `DNSBL + `_. + + """ + + __slots__: Sequence[str] = [] + + @abstractmethod + async def lookup(self, sock_info: SocketInfo, *, + loop: Optional[AbstractEventLoop] = None) \ + -> Optional[str]: + """Looks up the connecting IP address and returns the DNSBL hostname + and the lookup result. Any timeout or misconfiguration is treated as an + empty result. + + Args: + sock_info: The connection socket info. + + """ + ... + + @classmethod + def load(cls, host: Optional[str], *, + timeout: Optional[float] = None) -> Dnsbl: + """Given a DNSBL hostname, returns a :class:`Dnsbl` implementation that + best suits the given *host*. + + Args: + host: The DNSBL hostname, if any. + timeout: The time to wait for a response, in seconds, or None for + indefinite. + + Raises: + ValueError: The *host* is invalid for this :class:`Dnsbl`. + + """ + if host is None: + return NoopDnsbl() + elif host.endswith('.spamhaus.org'): + return SpamhausDnsbl(host, timeout=timeout) + else: + return BasicDnsbl(host, timeout=timeout) + + +class NoopDnsbl(Dnsbl): + """Disables DNSBL lookup altogether, :meth:`.lookup` always returns + ``None``. + + """ + + __slots__: Sequence[str] = [] + + async def lookup(self, sock_info: SocketInfo, *, + loop: Optional[AbstractEventLoop] = None) -> None: + return None + + +class BasicDnsbl(Dnsbl): + """A basic :class:`Dnsbl` implementation that simply returns the DNSBL + hostname if the DNS lookup returns any IP addresses. + + """ + + __slots__ = ['host', 'timeout'] + + def __init__(self, host: str, timeout: Optional[float]) -> None: + super().__init__() + self.host: Final = host + self.timeout: Final = timeout + + def map_results(self, addresses: Sequence[IPv4Address]) -> Optional[str]: + """Given a list of IP address results from a DNSBL lookup, return a + single string categorizing the results or ``None`` to discard them. + + Args: + addresses: The list of IP address results. + + """ + if addresses: + result = self.host + assert result is not None + return result + else: + return None + + async def lookup(self, sock_info: SocketInfo, *, + loop: Optional[AbstractEventLoop] = None) \ + -> Optional[str]: + host = self.host + peername_ip = sock_info.peername_ip + if not isinstance(peername_ip, IPv4Address): + return self.map_results([]) + loop = loop or asyncio.get_running_loop() + lookup = '.'.join(peername_ip.reverse_pointer.split('.')[0:4] + [host]) + try: + addrinfo = await asyncio.wait_for( + loop.getaddrinfo(lookup, 0, family=AF_INET, type=SOCK_STREAM), + self.timeout) + except (OSError, TimeoutError): + pass + else: + if addrinfo: + addresses = [IPv4Address(res[4][0]) for res in addrinfo] + return self.map_results(addresses) + return self.map_results([]) + + +class SpamhausDnsbl(BasicDnsbl): + """A :class:`Dnsbl` designed for querying `Spamhaus + `_ DNSBLs. + + """ + + __slots__: Sequence[str] = [] + + _mapping = [(IPv4Network('127.0.0.2/32'), 'sbl.spamhaus.org'), + (IPv4Network('127.0.0.3/32'), 'css.spamhaus.org'), + (IPv4Network('127.0.0.4/30'), 'xbl.spamhaus.org'), + (IPv4Network('127.0.0.10/31'), 'pbl.spamhaus.org')] + + def map_results(self, addresses: Sequence[IPv4Address]) -> Optional[str]: + if not addresses: + return None + result = addresses[0] + for network, host in self._mapping: + if result in network: + return host + return self.host diff --git a/proxyprotocol/noop.py b/proxyprotocol/noop.py index 94fb177..c21d1fb 100644 --- a/proxyprotocol/noop.py +++ b/proxyprotocol/noop.py @@ -31,5 +31,6 @@ def build(self, source: Address, dest: Address, *, family: AddressFamily, protocol: Optional[SocketKind] = None, ssl: Union[None, SSLSocket, SSLObject] = None, unique_id: Optional[bytes] = None, - proxied: bool = True) -> bytes: + proxied: bool = True, + dnsbl: Optional[str] = None) -> bytes: return b'' diff --git a/proxyprotocol/server/echo.py b/proxyprotocol/server/echo.py index bb69fb1..599a34a 100644 --- a/proxyprotocol/server/echo.py +++ b/proxyprotocol/server/echo.py @@ -61,6 +61,10 @@ async def run_conn(pp: ProxyProtocol, reader: StreamReader, sock_info = SocketInfo(writer, result) _log.info('[%s] Connection received: %s', sock_info.unique_id.hex(), sock_info) + if sock_info.dnsbl is not None: + _log.error('[%s] Connection rejected: %s', + sock_info.unique_id.hex(), sock_info.dnsbl) + return try: while True: line = await reader.readline() diff --git a/proxyprotocol/server/main.py b/proxyprotocol/server/main.py index 6acd0da..ef934e4 100644 --- a/proxyprotocol/server/main.py +++ b/proxyprotocol/server/main.py @@ -9,13 +9,15 @@ import logging import signal import sys -from argparse import Namespace, ArgumentParser, ArgumentDefaultsHelpFormatter +from argparse import Namespace, ArgumentParser, \ + ArgumentDefaultsHelpFormatter, SUPPRESS from asyncio import CancelledError from contextlib import AsyncExitStack from functools import partial from . import Address from .protocol import DownstreamProtocol, UpstreamProtocol +from ..dnsbl import Dnsbl __all__ = ['main'] @@ -31,6 +33,9 @@ def main() -> int: help='size of the read buffer') parser.add_argument('-q', '--quiet', action='store_true', help='show only upstream connection errors') + parser.add_argument('--dnsbl', metavar='HOST', default=None, + help='the DNSBL lookup hostname') + parser.add_argument('--dnsbl-timeout', type=float, help=SUPPRESS) args = parser.parse_args() if not args.services: @@ -48,8 +53,9 @@ async def run(args: Namespace) -> int: services = [(Address(source, server=True), Address(dest)) for (source, dest) in args.services] buf_len: int = args.buf_len + dnsbl = Dnsbl.load(args.dnsbl, timeout=args.dnsbl_timeout) new_server = partial(DownstreamProtocol, UpstreamProtocol, - loop, buf_len) + loop, buf_len, dnsbl) servers = [ await loop.create_server(partial(new_server, dest), source.host, source.port or 0, diff --git a/proxyprotocol/server/protocol.py b/proxyprotocol/server/protocol.py index 79c8b2b..3512fbb 100644 --- a/proxyprotocol/server/protocol.py +++ b/proxyprotocol/server/protocol.py @@ -9,13 +9,14 @@ from collections import deque from functools import partial from socket import AddressFamily, SocketKind -from typing import Any, Type, Optional, Tuple, Deque +from typing import Any, Type, Optional, Tuple, Deque, Set from typing_extensions import Final from uuid import uuid4 +from . import Address from .. import ProxyProtocol +from ..dnsbl import Dnsbl from ..sock import SocketInfo -from . import Address __all__ = ['DownstreamProtocol', 'UpstreamProtocol'] @@ -93,32 +94,31 @@ def proxy_data(self, data: memoryview) -> None: class DownstreamProtocol(_Base): def __init__(self, upstream_protocol: Type[UpstreamProtocol], - loop: AbstractEventLoop, buf_len: int, + loop: AbstractEventLoop, buf_len: int, dnsbl: Dnsbl, upstream: Address) -> None: super().__init__(buf_len) self.loop: Final = loop + self.dnsbl: Final = dnsbl self.upstream: Final = upstream self.id: Final = uuid4().bytes self._waiting: Deque[memoryview] = deque() - self._connect: Optional[Task[Any]] = None + self._tasks: Set[Task[Any]] = set() self._upstream: Optional[UpstreamProtocol] = None self._upstream_factory = partial(upstream_protocol, self, buf_len, upstream.pp) def close(self) -> None: super().close() - if self._connect is not None: - self._connect.cancel() - self._connect = None + for task in self._tasks: + task.cancel() if self._upstream is not None: upstream = self._upstream self._upstream = None upstream.close() - def _set_client(self, connect: Task[_Connect]) -> None: - self._connect = None + def _set_client(self, connect_task: Task[_Connect]) -> None: try: - _, upstream = connect.result() + _, upstream = connect_task.result() except CancelledError: pass # Connection was never established except OSError: @@ -138,11 +138,16 @@ def connection_made(self, transport: BaseTransport) -> None: _log.info('[%s] Downstream connection received: %s', self.id.hex(), self.sock_info) loop = self.loop - self._connect = connect = loop.create_task( - loop.create_connection(self._upstream_factory, + dnsbl_task = loop.create_task(self.dnsbl.lookup(self.sock_info)) + self._tasks.add(dnsbl_task) + dnsbl_task.add_done_callback(self._tasks.discard) + connect_task = loop.create_task( + loop.create_connection(partial(self._upstream_factory, dnsbl_task), self.upstream.host, self.upstream.port or 0, ssl=self.upstream.ssl)) - connect.add_done_callback(self._set_client) + self._tasks.add(connect_task) + connect_task.add_done_callback(self._tasks.discard) + connect_task.add_done_callback(self._set_client) def connection_lost(self, exc: Optional[Exception]) -> None: super().connection_lost(exc) @@ -159,10 +164,11 @@ def proxy_data(self, data: memoryview) -> None: class UpstreamProtocol(_Base): def __init__(self, downstream: DownstreamProtocol, buf_len: int, - pp: ProxyProtocol) -> None: + pp: ProxyProtocol, dnsbl_task: Task[Optional[str]]) -> None: super().__init__(buf_len) self.pp: Final = pp self.downstream: Final = downstream + self.dnsbl_task: Final = dnsbl_task def close(self) -> None: super().close() @@ -176,13 +182,17 @@ def build_pp_header(self) -> bytes: protocol: Optional[SocketKind] = SocketKind(sock.proto) except ValueError: protocol = None + dnsbl = self.dnsbl_task.result() return self.pp.build(sock.getpeername(), sock.getsockname(), family=AddressFamily(sock.family), protocol=protocol, unique_id=self.downstream.id, - ssl=ssl_object) + ssl=ssl_object, dnsbl=dnsbl) def connection_made(self, transport: BaseTransport) -> None: super().connection_made(transport) + self.dnsbl_task.add_done_callback(self._write_header) + + def _write_header(self, task: Task[Any]) -> None: header = self.build_pp_header() self.write(memoryview(header)) diff --git a/proxyprotocol/sock.py b/proxyprotocol/sock.py index c6e1eb5..154c1b5 100644 --- a/proxyprotocol/sock.py +++ b/proxyprotocol/sock.py @@ -24,18 +24,20 @@ class SocketInfo: result: The PROXY protocol result. unique_id: A unique ID to associate with the connection, unless overridden by the PROXY protocol result. + dnsbl: The DNSBL lookup result, if any. """ - __slots__ = ['transport', 'pp_result', '_unique_id'] + __slots__ = ['transport', 'pp_result', '_unique_id', '_dnsbl'] def __init__(self, transport: TransportProtocol, result: Optional[ProxyProtocolResult] = None, *, - unique_id: bytes = b'') -> None: + unique_id: bytes = b'', dnsbl: Optional[str] = None) -> None: super().__init__() self.transport: Final = transport self.pp_result: Final = result or ProxyProtocolResultLocal() self._unique_id = unique_id + self._dnsbl = dnsbl @property def socket(self) -> socket.socket: @@ -266,6 +268,19 @@ def from_localhost(self) -> bool: return False return ip.is_loopback + @property + def dnsbl(self) -> Optional[str]: + """The DNSBL lookup result of the connecting IP address, if any. + + This value is contextual to the DNSBL in use, but generally any value + here other than ``None`` indicates the IP address should be blocked. + + """ + if self.pp_result.proxied: + return self.pp_result.tlv.ext.dnsbl + else: + return self._dnsbl + def __str__(self) -> str: proxied = ' proxied=True' if self.pp_result.proxied else '' return '' \ diff --git a/proxyprotocol/tlv.py b/proxyprotocol/tlv.py index 439f6b7..94a7d8d 100644 --- a/proxyprotocol/tlv.py +++ b/proxyprotocol/tlv.py @@ -40,6 +40,7 @@ class Type(IntEnum): PP2_SUBTYPE_EXT_COMPRESSION = 0x01 PP2_SUBTYPE_EXT_SECRET_BITS = 0x02 PP2_SUBTYPE_EXT_PEERCERT = 0x03 + PP2_SUBTYPE_EXT_DNSBL = 0x04 class SSLClient(IntFlag): @@ -65,10 +66,11 @@ class TLV(Mapping[int, bytes], Hashable): _fmt = Struct('!BH') def __init__(self, data: bytes = b'', - init: Mapping[int, bytes] = {}) -> None: + init: Optional[Mapping[int, bytes]] = None) -> None: super().__init__() self._tlv = self._unpack(data) - self._tlv.update(init) + if init is not None: + self._tlv.update(init) self._frozen = self._freeze() def _freeze(self) -> Hashable: @@ -137,7 +139,8 @@ class ProxyProtocolTLV(TLV): _crc32c_fmt = Struct('!L') - def __init__(self, data: bytes = b'', init: Mapping[int, bytes] = {}, *, + def __init__(self, data: bytes = b'', + init: Optional[Mapping[int, bytes]] = None, *, alpn: Optional[bytes] = None, authority: Optional[str] = None, crc32c: Optional[int] = None, @@ -145,7 +148,7 @@ def __init__(self, data: bytes = b'', init: Mapping[int, bytes] = {}, *, ssl: Optional[ProxyProtocolSSLTLV] = None, netns: Optional[str] = None, ext: Optional[ProxyProtocolExtTLV] = None) -> None: - results = dict(init) + results = dict(init or {}) if alpn is not None: results[Type.PP2_TYPE_ALPN] = alpn if authority is not None: @@ -233,7 +236,8 @@ class ProxyProtocolSSLTLV(TLV): _prefix_fmt = Struct('!BL') - def __init__(self, data: bytes = b'', init: Mapping[int, bytes] = {}, *, + def __init__(self, data: bytes = b'', + init: Optional[Mapping[int, bytes]] = None, *, has_ssl: Optional[bool] = None, has_cert_conn: Optional[bool] = None, has_cert_sess: Optional[bool] = None, @@ -245,7 +249,7 @@ def __init__(self, data: bytes = b'', init: Mapping[int, bytes] = {}, *, key_alg: Optional[str] = None) -> None: self._client = 0 self._verify = 1 - results = dict(init) + results = dict(init or {}) if version is not None: results[Type.PP2_SUBTYPE_SSL_VERSION] = version.encode('ascii') if cn is not None: @@ -386,11 +390,13 @@ class ProxyProtocolExtTLV(TLV): _secret_bits_fmt = Struct('!H') - def __init__(self, data: bytes = b'', init: Mapping[int, bytes] = {}, *, + def __init__(self, data: bytes = b'', + init: Optional[Mapping[int, bytes]] = None, *, compression: Optional[str] = None, secret_bits: Optional[int] = None, - peercert: Optional[PeerCert] = None) -> None: - results = dict(init) + peercert: Optional[PeerCert] = None, + dnsbl: Optional[str] = None) -> None: + results = dict(init or {}) if compression is not None: val = compression.encode('ascii') results[Type.PP2_SUBTYPE_EXT_COMPRESSION] = val @@ -400,6 +406,9 @@ def __init__(self, data: bytes = b'', init: Mapping[int, bytes] = {}, *, if peercert is not None: val = zlib.compress(json.dumps(peercert).encode('ascii')) results[Type.PP2_SUBTYPE_EXT_PEERCERT] = val + if dnsbl is not None: + val = dnsbl.encode('utf-8') + results[Type.PP2_SUBTYPE_EXT_DNSBL] = val super().__init__(data, results) def _unpack(self, data: bytes) -> Dict[int, bytes]: @@ -449,3 +458,15 @@ def peercert(self) -> Optional[PeerCert]: ret: PeerCert = json.loads(decompressed) return ret return None + + @property + def dnsbl(self) -> Optional[str]: + """The ``PP2_SUBTYPE_EXT_DNSBL`` value. This is the hostname or other + identifier that reports a status or reputation of the connecting IP + address. + + """ + val = self.get(Type.PP2_SUBTYPE_EXT_DNSBL) + if val is not None: + return str(val, 'utf-8') + return None diff --git a/proxyprotocol/v1.py b/proxyprotocol/v1.py index fd6cddf..2a22753 100644 --- a/proxyprotocol/v1.py +++ b/proxyprotocol/v1.py @@ -68,7 +68,8 @@ def build(self, source: Address, dest: Address, *, family: AddressFamily, protocol: Optional[SocketKind] = None, ssl: Union[None, SSLSocket, SSLObject] = None, unique_id: Optional[bytes] = None, - proxied: bool = True) -> bytes: + proxied: bool = True, + dnsbl: Optional[str] = None) -> bytes: if not proxied: raise ValueError('proxied must be True in v1') family_b = self._build_family(family) diff --git a/proxyprotocol/v2.py b/proxyprotocol/v2.py index 2311078..0a419dc 100644 --- a/proxyprotocol/v2.py +++ b/proxyprotocol/v2.py @@ -140,9 +140,10 @@ def build(self, source: Address, dest: Address, *, family: AddressFamily, protocol: Optional[SocketKind] = None, ssl: Union[None, SSLSocket, SSLObject] = None, unique_id: Optional[bytes] = None, - proxied: bool = True) -> bytes: + proxied: bool = True, + dnsbl: Optional[str] = None) -> bytes: addresses = self.build_addresses(source, dest, family=family) - tlv = self.build_tlv(ssl, unique_id) + tlv = self.build_tlv(ssl, unique_id, dnsbl) data_len = len(addresses) + len(tlv) header = self.build_header(data_len, family=family, protocol=protocol, proxied=proxied) @@ -205,27 +206,29 @@ def build_addresses(self, source: Address, dest: Address, *, return b'' def build_tlv(self, ssl: Union[None, SSLSocket, SSLObject], - unique_id: Optional[bytes]) -> bytes: + unique_id: Optional[bytes], dnsbl: Optional[str]) -> bytes: """Builds the TLV data written after the PROXY protocol v2 address data. Args: ssl: The SSL information for the connection. unique_id: The unique ID of the connection. + dnsbl: The DNSBL lookup result, if any. """ + ssl_tlv: Optional[ProxyProtocolSSLTLV] = None + ext_tlv: Optional[ProxyProtocolExtTLV] = None + if dnsbl is not None: + ext_tlv = ProxyProtocolExtTLV(init=ext_tlv, dnsbl=dnsbl) if ssl is not None: cipher, version, secret_bits = ssl.cipher() or (None, None, None) peercert: Optional[PeerCert] = ssl.getpeercert() - ssl_tlv: Optional[ProxyProtocolSSLTLV] = ProxyProtocolSSLTLV( + ssl_tlv = ProxyProtocolSSLTLV( has_ssl=True, verify=True, has_cert_conn=(peercert is not None), cipher=cipher, version=version) - ext_tlv: Optional[ProxyProtocolExtTLV] = ProxyProtocolExtTLV( - compression=ssl.compression(), - secret_bits=secret_bits, peercert=peercert) - else: - ssl_tlv = None - ext_tlv = None + ext_tlv = ProxyProtocolExtTLV( + init=ext_tlv, compression=ssl.compression(), + secret_bits=secret_bits, peercert=peercert, dnsbl=dnsbl) tlv = ProxyProtocolTLV(unique_id=unique_id, ssl=ssl_tlv, ext=ext_tlv) return bytes(tlv) diff --git a/setup.py b/setup.py index 0c6828f..37067c3 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ license = f.read() setup(name='proxy-protocol', - version='0.6.1', + version='0.7.0', author='Ian Good', author_email='ian@icgood.net', description='PROXY protocol library with asyncio server implementation', diff --git a/test/test_dnsbl.py b/test/test_dnsbl.py new file mode 100644 index 0000000..bd990e8 --- /dev/null +++ b/test/test_dnsbl.py @@ -0,0 +1,95 @@ + +from asyncio import AbstractEventLoop +from ipaddress import IPv4Address, IPv6Address +from unittest.mock import MagicMock + +try: + from unittest import IsolatedAsyncioTestCase + from unittest.mock import AsyncMock +except ImportError as exc: # Python < 3.8 + from unittest import SkipTest + raise SkipTest('Missing unittest asyncio imports') from exc + +from proxyprotocol.dnsbl import Dnsbl, NoopDnsbl, BasicDnsbl, SpamhausDnsbl +from proxyprotocol.sock import SocketInfo + + +class TestDnsbl(IsolatedAsyncioTestCase): + + def test_load(self) -> None: + dnsbl = Dnsbl.load('test.spamhaus.org', timeout=1.3) + self.assertIsInstance(dnsbl, SpamhausDnsbl) + dnsbl = Dnsbl.load('test.example.com', timeout=1.3) + self.assertIsInstance(dnsbl, BasicDnsbl) + dnsbl = Dnsbl.load(None, timeout=1.3) + self.assertIsInstance(dnsbl, NoopDnsbl) + + async def test_noop_lookup(self) -> None: + dnsbl = NoopDnsbl() + sock_info = MagicMock(SocketInfo) + result = await dnsbl.lookup(sock_info) + self.assertIsNone(result) + + async def test_basic_lookup_ipv6(self) -> None: + dnsbl = BasicDnsbl('test.example.com', None) + sock_info = MagicMock(SocketInfo) + sock_info.peername_ip = IPv6Address('::1') + result = await dnsbl.lookup(sock_info) + self.assertIsNone(result) + + async def test_basic_lookup_oserror(self) -> None: + dnsbl = BasicDnsbl('test.example.com', None) + sock_info = MagicMock(SocketInfo) + sock_info.peername_ip = IPv4Address('1.2.3.4') + loop = MagicMock(AbstractEventLoop) + loop.getaddrinfo = AsyncMock(side_effect=OSError) + result = await dnsbl.lookup(sock_info, loop=loop) + self.assertIsNone(result) + + async def test_basic_lookup_empty(self) -> None: + dnsbl = BasicDnsbl('test.example.com', None) + sock_info = MagicMock(SocketInfo) + sock_info.peername_ip = IPv4Address('1.2.3.4') + loop = MagicMock(AbstractEventLoop) + loop.getaddrinfo = AsyncMock(return_value=[]) + result = await dnsbl.lookup(sock_info, loop=loop) + self.assertIsNone(result) + + async def test_basic_lookup(self) -> None: + dnsbl = BasicDnsbl('test.example.com', None) + sock_info = MagicMock(SocketInfo) + sock_info.peername_ip = IPv4Address('1.2.3.4') + loop = MagicMock(AbstractEventLoop) + loop.getaddrinfo = AsyncMock(return_value=[ + (None, None, None, None, ('0.0.0.0', 0))]) + result = await dnsbl.lookup(sock_info, loop=loop) + self.assertEqual('test.example.com', result) + + async def test_spamhaus_lookup_empty(self) -> None: + dnsbl = SpamhausDnsbl('test.spamhaus.org', None) + sock_info = MagicMock(SocketInfo) + sock_info.peername_ip = IPv4Address('1.2.3.4') + loop = MagicMock(AbstractEventLoop) + loop.getaddrinfo = AsyncMock(return_value=[]) + result = await dnsbl.lookup(sock_info, loop=loop) + self.assertIsNone(result) + + async def test_spamhaus_lookup(self) -> None: + dnsbl = SpamhausDnsbl('test.spamhaus.org', None) + sock_info = MagicMock(SocketInfo) + sock_info.peername_ip = IPv4Address('1.2.3.4') + loop = MagicMock(AbstractEventLoop) + loop.getaddrinfo = AsyncMock(return_value=[ + (None, None, None, None, ('127.0.0.4', 0))]) + result = await dnsbl.lookup(sock_info, loop=loop) + self.assertEqual('xbl.spamhaus.org', result) + + async def test_spamhaus_lookup_unmapped(self) -> None: + dnsbl = SpamhausDnsbl('test.spamhaus.org', None) + sock_info = MagicMock(SocketInfo) + sock_info.peername_ip = IPv4Address('1.2.3.4') + loop = MagicMock(AbstractEventLoop) + loop.getaddrinfo = AsyncMock(return_value=[ + (None, None, None, None, ('127.0.0.100', 0))]) + result = await dnsbl.lookup(sock_info, loop=loop) + self.assertEqual('test.spamhaus.org', result) diff --git a/test/test_sock.py b/test/test_sock.py index 547e1c6..38ad633 100644 --- a/test/test_sock.py +++ b/test/test_sock.py @@ -167,6 +167,19 @@ def test_peercert_override(self) -> None: info = SocketInfo(self.transport, result) self.assertEqual({'subject': 'test'}, info.peercert) + def test_dnsbl_socket(self) -> None: + result = ProxyProtocolResultLocal() + info = SocketInfo(self.transport, result, dnsbl='abc') + self.assertEqual('abc', info.dnsbl) + + def test_dnsbl_override(self) -> None: + tlv = ProxyProtocolTLV(ext=ProxyProtocolExtTLV(dnsbl='test_dnsbl')) + result = ProxyProtocolResultIPv6((IPv6Address('::1'), 10), + (IPv6Address('::FFFF:1.2.3.4'), 20), + tlv=tlv) + info = SocketInfo(self.transport, result, dnsbl='abc') + self.assertEqual('test_dnsbl', info.dnsbl) + def test_unique_id_socket(self) -> None: result = ProxyProtocolResultLocal() info = SocketInfo(self.transport, result, unique_id=b'abc') diff --git a/test/test_tlv.py b/test/test_tlv.py index f95a9d6..953de6a 100644 --- a/test/test_tlv.py +++ b/test/test_tlv.py @@ -19,7 +19,8 @@ ext_data = ProxyProtocolExtTLV.MAGIC_PREFIX + \ pack('!BH', Type.PP2_SUBTYPE_EXT_COMPRESSION, 16) + b'test_compression' + \ pack('!BHH', Type.PP2_SUBTYPE_EXT_SECRET_BITS, 2, 2048) + \ - pack('!BH', Type.PP2_SUBTYPE_EXT_PEERCERT, len(peercert)) + peercert + pack('!BH', Type.PP2_SUBTYPE_EXT_PEERCERT, len(peercert)) + peercert + \ + pack('!BH', Type.PP2_SUBTYPE_EXT_DNSBL, 10) + b'test_dnsbl' tlv_data = \ pack('!BH', Type.PP2_TYPE_ALPN, 5) + b'test1' + \ pack('!BH', Type.PP2_TYPE_AUTHORITY, 7) + b'test\xe2\x91\xa1' + \ @@ -116,6 +117,10 @@ def test_ext_peercert(self) -> None: self.assertIsNone(self.empty.ext.peercert) self.assertEqual({'test': 'peercert'}, self.tlv.ext.peercert) + def test_ext_dnsbl(self) -> None: + self.assertIsNone(self.empty.ext.dnsbl) + self.assertEqual('test_dnsbl', self.tlv.ext.dnsbl) + def test_iter(self) -> None: self.assertEqual({Type.PP2_TYPE_ALPN, Type.PP2_TYPE_AUTHORITY, Type.PP2_TYPE_CRC32C, Type.PP2_TYPE_NOOP, @@ -157,7 +162,8 @@ def test_kwargs(self) -> None: key_alg='test_key_alg') ext_tlv = ProxyProtocolExtTLV(compression='test_compression', secret_bits=2048, - peercert={'test': 'peercert'}) + peercert={'test': 'peercert'}, + dnsbl='test_dnsbl') custom_type = Type.PP2_TYPE_MIN_CUSTOM + 2 unique_id = b'\x00\x00\x00\x12W\xbb\x1d3\x00\x00\x00\x009\xe9\xdbv' init_tlv = TLV(init={custom_type: memoryview(b'test4')}) diff --git a/test/test_v2.py b/test/test_v2.py index 872c865..83206af 100644 --- a/test/test_v2.py +++ b/test/test_v2.py @@ -181,10 +181,13 @@ def test_build_tlv(self) -> None: ssl_object.cipher.return_value = ('cipher_name', 'ssl_version', 123) ssl_object.getpeercert.return_value = None header = pp.build(None, None, family=socket.AF_UNSPEC, - ssl=ssl_object, unique_id=b'connection_id') - self.assertEqual(b'\r\n\r\n\x00\r\nQUIT\n!\x00\x00W' - b'\x04\x00 \x88\x1by\xc1\xce\x96\x85\xb0\x01\x00\x10' - b'compression_name\x02\x00\x02\x00{\x05\x00\r' + ssl=ssl_object, unique_id=b'connection_id', + dnsbl='dnsbl_result') + print(repr(header)) + self.assertEqual(b'\r\n\r\n\x00\r\nQUIT\n!\x00\x00f' + b'\x04\x00/\x88\x1by\xc1\xce\x96\x85\xb0\x01\x00\x10' + b'compression_name\x02\x00\x02\x00{\x04\x00\x0c' + b'dnsbl_result\x05\x00\r' b'connection_id \x00!\x01\x00\x00\x00\x01!\x00\x0b' b'ssl_version#\x00\x0bcipher_name', header)