Skip to content
This repository has been archived by the owner on Aug 3, 2021. It is now read-only.

Implement RFC 7871 EDNS Client Subnet (ECS) #88

Merged
merged 1 commit into from May 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion dohproxy/httpproxy.py
Expand Up @@ -71,12 +71,15 @@ def set_upstream_resolver(self, upstream_resolver, upstream_port):
self.upstream_resolver = upstream_resolver
self.upstream_port = upstream_port

def set_ecs(self, ecs):
self.ecs = ecs

async def resolve(self, request, dnsq):
self.time_stamp = time.time()
clientip = request.remote
dnsclient = DNSClient(self.upstream_resolver, self.upstream_port,
logger=self.logger)
dnsr = await dnsclient.query(dnsq, clientip)
dnsr = await dnsclient.query(dnsq, clientip, ecs=self.ecs)

if dnsr is None:
return self.on_answer(request, dnsq=dnsq)
Expand Down Expand Up @@ -128,6 +131,7 @@ def get_app(args):
logger = utils.configure_logger('doh-httpproxy', args.level)
app = DOHApplication(logger=logger, debug=args.debug)
app.set_upstream_resolver(args.upstream_resolver, args.upstream_port)
app.set_ecs(args.ecs)
app.router.add_get(args.uri, doh1handler)
app.router.add_post(args.uri, doh1handler)

Expand Down
8 changes: 5 additions & 3 deletions dohproxy/proxy.py
Expand Up @@ -41,14 +41,15 @@ def parse_args():

class H2Protocol(asyncio.Protocol):
def __init__(self, upstream_resolver=None, upstream_port=None,
uri=None, logger=None, debug=False):
uri=None, logger=None, debug=False, ecs=False):
config = H2Configuration(client_side=False, header_encoding='utf-8')
self.conn = H2Connection(config=config)
self.logger = logger
if logger is None:
self.logger = utils.configure_logger('doh-proxy', 'DEBUG')
self.transport = None
self.debug = debug
self.ecs = ecs
self.stream_data = {}
self.upstream_resolver = upstream_resolver
self.upstream_port = upstream_port
Expand Down Expand Up @@ -195,7 +196,7 @@ async def resolve(self, dnsq, stream_id):
clientip = utils.get_client_ip(self.transport)
dnsclient = DNSClient(self.upstream_resolver, self.upstream_port,
logger=self.logger)
dnsr = await dnsclient.query(dnsq, clientip)
dnsr = await dnsclient.query(dnsq, clientip, ecs=self.ecs)

if dnsr is None:
self.on_answer(stream_id, dnsq=dnsq)
Expand Down Expand Up @@ -283,7 +284,8 @@ def main():
upstream_port=args.upstream_port,
uri=args.uri,
logger=logger,
debug=args.debug),
debug=args.debug,
ecs=args.ecs),
host=addr,
port=args.port,
ssl=ssl_ctx)
Expand Down
21 changes: 18 additions & 3 deletions dohproxy/server_protocol.py
Expand Up @@ -7,6 +7,7 @@
# LICENSE file in the root directory of this source tree.
#
import asyncio
import dns.edns
import dns.entropy
import dns.message
import struct
Expand Down Expand Up @@ -42,10 +43,24 @@ def __init__(self, upstream_resolver, upstream_port, logger=None):
self.logger = logger
self.transport = None

async def query(self, dnsq, clientip, timeout=DEFAULT_TIMEOUT):
dnsr = await self.query_udp(dnsq, clientip, timeout=timeout)
async def query(self, dnsq, clientip, timeout=DEFAULT_TIMEOUT,
ecs=False):
# (Potentially) modified copy of dnsq
dnsq_mod = dns.message.from_wire(dnsq.to_wire())
chantra marked this conversation as resolved.
Show resolved Hide resolved
we_set_ecs = False
if ecs:
we_set_ecs = utils.set_dns_ecs(dnsq_mod, clientip)

dnsr = await self.query_udp(dnsq_mod, clientip, timeout=timeout)
if dnsr is None or (dnsr.flags & dns.flags.TC):
dnsr = await self.query_tcp(dnsq, clientip, timeout=timeout)
dnsr = await self.query_tcp(dnsq_mod, clientip, timeout=timeout)

if dnsr is not None and we_set_ecs:
for option in dnsr.options:
if isinstance(option, dns.edns.ECSOption):
dnsr.options.remove(option)
dnsr.edns = dnsq.edns

return dnsr

async def query_udp(self, dnsq, clientip, timeout=DEFAULT_TIMEOUT):
Expand Down
40 changes: 40 additions & 0 deletions dohproxy/utils.py
Expand Up @@ -10,9 +10,11 @@
import asyncio
import binascii
import base64
import dns.edns
import dns.exception
import dns.message
import dns.rcode
import ipaddress
import logging
import ssl
import struct
Expand Down Expand Up @@ -348,6 +350,11 @@ def proxy_parser_base(*, port: int,
action='version',
version='%(prog)s {}'.format(__version__),
)
parser.add_argument(
'--ecs',
action='store_true',
help='Enable EDNS Client Subnet (ECS)'
)
return parser


Expand Down Expand Up @@ -412,3 +419,36 @@ def handle_dns_tcp_data(data, cb):
return data
msglen = struct.unpack('!H', data[0:2])[0]
return data


def set_dns_ecs(dnsq, ip):
"""Sets RFC 7871 EDNS Client Subnet (ECS) option in a DNS packet.
An existing ECS option will not be overwritten if present.
:param dnsq: DNS packet.
:param ip: IP address. String or ipaddress object.
:return: Whether ECS was set (bool)
"""
options = []
for option in dnsq.options:
if isinstance(option, dns.edns.ECSOption):
return False
options.append(option)

if not isinstance(
ip,
(ipaddress.IPv4Address, ipaddress.IPv6Address)
):
ip = ipaddress.ip_address(ip)
ip_supernet_bits = 56 if ip.version == 6 else 24
ip_supernet = ipaddress.ip_network(ip).supernet(
new_prefix=ip_supernet_bits,
)

options.append(dns.edns.ECSOption(
address=ip_supernet.network_address.compressed,
srclen=ip_supernet_bits,
))
dnsq.edns = 0 # 0 == True
dnsq.options = options

return True
16 changes: 16 additions & 0 deletions test/test_utils.py
Expand Up @@ -529,3 +529,19 @@ def test_complete_multiple(self):
self.assertEqual(res, b'')
self.assertIsInstance(self._cb_data[0], dns.message.Message)
self.assertIsInstance(self._cb_data[1], dns.message.Message)


class TestDNSECS(unittest.TestCase):
def test_set_dns_ecs_ipv4(self):
dnsq = dns.message.Message()
utils.set_dns_ecs(dnsq, '10.0.0.242')
self.assertEqual(dnsq.edns, 0)
self.assertEqual(dnsq.options[0].address, '10.0.0.0')
self.assertEqual(dnsq.options[0].srclen, 24)

def test_set_dns_ecs_ipv6(self):
dnsq = dns.message.Message()
utils.set_dns_ecs(dnsq, '2000::aa')
self.assertEqual(dnsq.edns, 0)
self.assertEqual(dnsq.options[0].address, '2000::')
self.assertEqual(dnsq.options[0].srclen, 56)
rfinnie marked this conversation as resolved.
Show resolved Hide resolved