Skip to content

Commit

Permalink
Merge pull request #25 from icgood/dnsbl
Browse files Browse the repository at this point in the history
Add DNSBL result to extension TLV
  • Loading branch information
icgood committed May 21, 2021
2 parents 7a3a7ba + 0f322c0 commit 68e6663
Show file tree
Hide file tree
Showing 16 changed files with 378 additions and 50 deletions.
4 changes: 3 additions & 1 deletion proxyprotocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def build(self, source: Address, dest: Address, *, family: AddressFamily,
protocol: Optional[SocketKind] = None,
ssl: Union[None, SSLObject, SSLSocket] = None,
unique_id: Optional[bytes] = None,
proxied: bool = True) -> bytes:
proxied: bool = True,
dnsbl: Optional[str] = None) -> bytes:
"""Builds a PROXY protocol v1 header that may be sent at the beginning
of an outbound, client-side connection to indicate the original
information about the connection.
Expand All @@ -160,6 +161,7 @@ def build(self, source: Address, dest: Address, *, family: AddressFamily,
ssl: The original socket SSL information.
unique_id: The original connection unique identifier.
proxied: True if the connection should be considered proxied.
dnsbl: The DNSBL lookup result, if any.
Raises:
:exc:`KeyError`: This PROXY protocol header format does not support
Expand Down
6 changes: 4 additions & 2 deletions proxyprotocol/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@ def build(self, source: Address, dest: Address, *, family: AddressFamily,
protocol: Optional[SocketKind] = None,
ssl: Union[None, SSLSocket, SSLObject] = None,
unique_id: Optional[bytes] = None,
proxied: bool = True) -> bytes:
proxied: bool = True,
dnsbl: Optional[str] = None) -> bytes:
for version in self.versions:
try:
return version.build(source, dest, family=family,
protocol=protocol, ssl=ssl,
unique_id=unique_id, proxied=proxied)
unique_id=unique_id, proxied=proxied,
dnsbl=dnsbl)
except (KeyError, ValueError):
pass
else:
Expand Down
146 changes: 146 additions & 0 deletions proxyprotocol/dnsbl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@

from __future__ import annotations

import asyncio
from abc import abstractmethod, ABCMeta
from asyncio import AbstractEventLoop, TimeoutError
from ipaddress import IPv4Address, IPv4Network
from socket import AF_INET, SOCK_STREAM
from typing import Optional, Sequence
from typing_extensions import Final

from .sock import SocketInfo

__all__ = ['Dnsbl', 'NoopDnsbl', 'BasicDnsbl', 'SpamhausDnsbl']


class Dnsbl(metaclass=ABCMeta):
"""Manages the optional lookup of the connecting IP address against a
trusted `DNSBL
<https://en.wikipedia.org/wiki/Domain_Name_System-based_blackhole_list>`_.
"""

__slots__: Sequence[str] = []

@abstractmethod
async def lookup(self, sock_info: SocketInfo, *,
loop: Optional[AbstractEventLoop] = None) \
-> Optional[str]:
"""Looks up the connecting IP address and returns the DNSBL hostname
and the lookup result. Any timeout or misconfiguration is treated as an
empty result.
Args:
sock_info: The connection socket info.
"""
...

@classmethod
def load(cls, host: Optional[str], *,
timeout: Optional[float] = None) -> Dnsbl:
"""Given a DNSBL hostname, returns a :class:`Dnsbl` implementation that
best suits the given *host*.
Args:
host: The DNSBL hostname, if any.
timeout: The time to wait for a response, in seconds, or None for
indefinite.
Raises:
ValueError: The *host* is invalid for this :class:`Dnsbl`.
"""
if host is None:
return NoopDnsbl()
elif host.endswith('.spamhaus.org'):
return SpamhausDnsbl(host, timeout=timeout)
else:
return BasicDnsbl(host, timeout=timeout)


class NoopDnsbl(Dnsbl):
"""Disables DNSBL lookup altogether, :meth:`.lookup` always returns
``None``.
"""

__slots__: Sequence[str] = []

async def lookup(self, sock_info: SocketInfo, *,
loop: Optional[AbstractEventLoop] = None) -> None:
return None


class BasicDnsbl(Dnsbl):
"""A basic :class:`Dnsbl` implementation that simply returns the DNSBL
hostname if the DNS lookup returns any IP addresses.
"""

__slots__ = ['host', 'timeout']

def __init__(self, host: str, timeout: Optional[float]) -> None:
super().__init__()
self.host: Final = host
self.timeout: Final = timeout

def map_results(self, addresses: Sequence[IPv4Address]) -> Optional[str]:
"""Given a list of IP address results from a DNSBL lookup, return a
single string categorizing the results or ``None`` to discard them.
Args:
addresses: The list of IP address results.
"""
if addresses:
result = self.host
assert result is not None
return result
else:
return None

async def lookup(self, sock_info: SocketInfo, *,
loop: Optional[AbstractEventLoop] = None) \
-> Optional[str]:
host = self.host
peername_ip = sock_info.peername_ip
if not isinstance(peername_ip, IPv4Address):
return self.map_results([])
loop = loop or asyncio.get_running_loop()
lookup = '.'.join(peername_ip.reverse_pointer.split('.')[0:4] + [host])
try:
addrinfo = await asyncio.wait_for(
loop.getaddrinfo(lookup, 0, family=AF_INET, type=SOCK_STREAM),
self.timeout)
except (OSError, TimeoutError):
pass
else:
if addrinfo:
addresses = [IPv4Address(res[4][0]) for res in addrinfo]
return self.map_results(addresses)
return self.map_results([])


class SpamhausDnsbl(BasicDnsbl):
"""A :class:`Dnsbl` designed for querying `Spamhaus
<https://www.spamhaus.org/>`_ DNSBLs.
"""

__slots__: Sequence[str] = []

_mapping = [(IPv4Network('127.0.0.2/32'), 'sbl.spamhaus.org'),
(IPv4Network('127.0.0.3/32'), 'css.spamhaus.org'),
(IPv4Network('127.0.0.4/30'), 'xbl.spamhaus.org'),
(IPv4Network('127.0.0.10/31'), 'pbl.spamhaus.org')]

def map_results(self, addresses: Sequence[IPv4Address]) -> Optional[str]:
if not addresses:
return None
result = addresses[0]
for network, host in self._mapping:
if result in network:
return host
return self.host
3 changes: 2 additions & 1 deletion proxyprotocol/noop.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ def build(self, source: Address, dest: Address, *, family: AddressFamily,
protocol: Optional[SocketKind] = None,
ssl: Union[None, SSLSocket, SSLObject] = None,
unique_id: Optional[bytes] = None,
proxied: bool = True) -> bytes:
proxied: bool = True,
dnsbl: Optional[str] = None) -> bytes:
return b''
4 changes: 4 additions & 0 deletions proxyprotocol/server/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ async def run_conn(pp: ProxyProtocol, reader: StreamReader,
sock_info = SocketInfo(writer, result)
_log.info('[%s] Connection received: %s',
sock_info.unique_id.hex(), sock_info)
if sock_info.dnsbl is not None:
_log.error('[%s] Connection rejected: %s',
sock_info.unique_id.hex(), sock_info.dnsbl)
return
try:
while True:
line = await reader.readline()
Expand Down
10 changes: 8 additions & 2 deletions proxyprotocol/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
import logging
import signal
import sys
from argparse import Namespace, ArgumentParser, ArgumentDefaultsHelpFormatter
from argparse import Namespace, ArgumentParser, \
ArgumentDefaultsHelpFormatter, SUPPRESS
from asyncio import CancelledError
from contextlib import AsyncExitStack
from functools import partial

from . import Address
from .protocol import DownstreamProtocol, UpstreamProtocol
from ..dnsbl import Dnsbl

__all__ = ['main']

Expand All @@ -31,6 +33,9 @@ def main() -> int:
help='size of the read buffer')
parser.add_argument('-q', '--quiet', action='store_true',
help='show only upstream connection errors')
parser.add_argument('--dnsbl', metavar='HOST', default=None,
help='the DNSBL lookup hostname')
parser.add_argument('--dnsbl-timeout', type=float, help=SUPPRESS)
args = parser.parse_args()

if not args.services:
Expand All @@ -48,8 +53,9 @@ async def run(args: Namespace) -> int:
services = [(Address(source, server=True), Address(dest))
for (source, dest) in args.services]
buf_len: int = args.buf_len
dnsbl = Dnsbl.load(args.dnsbl, timeout=args.dnsbl_timeout)
new_server = partial(DownstreamProtocol, UpstreamProtocol,
loop, buf_len)
loop, buf_len, dnsbl)
servers = [
await loop.create_server(partial(new_server, dest),
source.host, source.port or 0,
Expand Down
40 changes: 25 additions & 15 deletions proxyprotocol/server/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from collections import deque
from functools import partial
from socket import AddressFamily, SocketKind
from typing import Any, Type, Optional, Tuple, Deque
from typing import Any, Type, Optional, Tuple, Deque, Set
from typing_extensions import Final
from uuid import uuid4

from . import Address
from .. import ProxyProtocol
from ..dnsbl import Dnsbl
from ..sock import SocketInfo
from . import Address

__all__ = ['DownstreamProtocol', 'UpstreamProtocol']

Expand Down Expand Up @@ -93,32 +94,31 @@ def proxy_data(self, data: memoryview) -> None:
class DownstreamProtocol(_Base):

def __init__(self, upstream_protocol: Type[UpstreamProtocol],
loop: AbstractEventLoop, buf_len: int,
loop: AbstractEventLoop, buf_len: int, dnsbl: Dnsbl,
upstream: Address) -> None:
super().__init__(buf_len)
self.loop: Final = loop
self.dnsbl: Final = dnsbl
self.upstream: Final = upstream
self.id: Final = uuid4().bytes
self._waiting: Deque[memoryview] = deque()
self._connect: Optional[Task[Any]] = None
self._tasks: Set[Task[Any]] = set()
self._upstream: Optional[UpstreamProtocol] = None
self._upstream_factory = partial(upstream_protocol, self, buf_len,
upstream.pp)

def close(self) -> None:
super().close()
if self._connect is not None:
self._connect.cancel()
self._connect = None
for task in self._tasks:
task.cancel()
if self._upstream is not None:
upstream = self._upstream
self._upstream = None
upstream.close()

def _set_client(self, connect: Task[_Connect]) -> None:
self._connect = None
def _set_client(self, connect_task: Task[_Connect]) -> None:
try:
_, upstream = connect.result()
_, upstream = connect_task.result()
except CancelledError:
pass # Connection was never established
except OSError:
Expand All @@ -138,11 +138,16 @@ def connection_made(self, transport: BaseTransport) -> None:
_log.info('[%s] Downstream connection received: %s',
self.id.hex(), self.sock_info)
loop = self.loop
self._connect = connect = loop.create_task(
loop.create_connection(self._upstream_factory,
dnsbl_task = loop.create_task(self.dnsbl.lookup(self.sock_info))
self._tasks.add(dnsbl_task)
dnsbl_task.add_done_callback(self._tasks.discard)
connect_task = loop.create_task(
loop.create_connection(partial(self._upstream_factory, dnsbl_task),
self.upstream.host, self.upstream.port or 0,
ssl=self.upstream.ssl))
connect.add_done_callback(self._set_client)
self._tasks.add(connect_task)
connect_task.add_done_callback(self._tasks.discard)
connect_task.add_done_callback(self._set_client)

def connection_lost(self, exc: Optional[Exception]) -> None:
super().connection_lost(exc)
Expand All @@ -159,10 +164,11 @@ def proxy_data(self, data: memoryview) -> None:
class UpstreamProtocol(_Base):

def __init__(self, downstream: DownstreamProtocol, buf_len: int,
pp: ProxyProtocol) -> None:
pp: ProxyProtocol, dnsbl_task: Task[Optional[str]]) -> None:
super().__init__(buf_len)
self.pp: Final = pp
self.downstream: Final = downstream
self.dnsbl_task: Final = dnsbl_task

def close(self) -> None:
super().close()
Expand All @@ -176,13 +182,17 @@ def build_pp_header(self) -> bytes:
protocol: Optional[SocketKind] = SocketKind(sock.proto)
except ValueError:
protocol = None
dnsbl = self.dnsbl_task.result()
return self.pp.build(sock.getpeername(), sock.getsockname(),
family=AddressFamily(sock.family),
protocol=protocol, unique_id=self.downstream.id,
ssl=ssl_object)
ssl=ssl_object, dnsbl=dnsbl)

def connection_made(self, transport: BaseTransport) -> None:
super().connection_made(transport)
self.dnsbl_task.add_done_callback(self._write_header)

def _write_header(self, task: Task[Any]) -> None:
header = self.build_pp_header()
self.write(memoryview(header))

Expand Down
19 changes: 17 additions & 2 deletions proxyprotocol/sock.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,20 @@ class SocketInfo:
result: The PROXY protocol result.
unique_id: A unique ID to associate with the connection, unless
overridden by the PROXY protocol result.
dnsbl: The DNSBL lookup result, if any.
"""

__slots__ = ['transport', 'pp_result', '_unique_id']
__slots__ = ['transport', 'pp_result', '_unique_id', '_dnsbl']

def __init__(self, transport: TransportProtocol,
result: Optional[ProxyProtocolResult] = None, *,
unique_id: bytes = b'') -> None:
unique_id: bytes = b'', dnsbl: Optional[str] = None) -> None:
super().__init__()
self.transport: Final = transport
self.pp_result: Final = result or ProxyProtocolResultLocal()
self._unique_id = unique_id
self._dnsbl = dnsbl

@property
def socket(self) -> socket.socket:
Expand Down Expand Up @@ -266,6 +268,19 @@ def from_localhost(self) -> bool:
return False
return ip.is_loopback

@property
def dnsbl(self) -> Optional[str]:
"""The DNSBL lookup result of the connecting IP address, if any.
This value is contextual to the DNSBL in use, but generally any value
here other than ``None`` indicates the IP address should be blocked.
"""
if self.pp_result.proxied:
return self.pp_result.tlv.ext.dnsbl
else:
return self._dnsbl

def __str__(self) -> str:
proxied = ' proxied=True' if self.pp_result.proxied else ''
return '<SocketInfo peername=%r sockname=%r%s>' \
Expand Down
Loading

0 comments on commit 68e6663

Please sign in to comment.