Skip to content

Commit

Permalink
Firewall / Aliases - improve resolve performance by implementing asyn…
Browse files Browse the repository at this point in the history
…c dns lookups. ref #5117

This will need a new version of py-dnspython (py-dnspython2 in ports) for dns.asyncresolver support. Some additional log messages have been added to gain more insights into the resolving process via the general log.
Intermediate results aren't saved to disk anymore, which also simplifies the resolve() function in the Alias class. An address parser can queue hostname lookups for later retrieval (see _parse_address()) so we can batch process the list of hostnames to be collected.
  • Loading branch information
AdSchellevis committed Aug 19, 2021
1 parent 2872298 commit 76b8ae4
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 28 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ CORE_DEPENDS?= ca_root_nss \
php${CORE_PHP}-zlib \
pkg \
py${CORE_PYTHON}-Jinja2 \
py${CORE_PYTHON}-dnspython \
py${CORE_PYTHON}-dnspython2 \
py${CORE_PYTHON}-netaddr \
py${CORE_PYTHON}-requests \
py${CORE_PYTHON}-sqlite3 \
Expand Down
83 changes: 83 additions & 0 deletions src/opnsense/scripts/filter/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
"""
import asyncio
import dns.resolver
import ipaddress
import itertools
import syslog
import time
from dns.rdatatype import RdataType
from dns.asyncresolver import Resolver


def net_wildcard_iterator(network: str):
Expand Down Expand Up @@ -62,3 +68,80 @@ def net_wildcard_iterator(network: str):
yield ipaddress.IPv6Network((this_ip, wildcard.max_prefixlen - mask_length), strict=False)
else:
yield ipaddress.IPv4Network((this_ip, wildcard.max_prefixlen - mask_length), strict=False)


class AsyncDNSResolver:
""" Asynchronous DNS resolver, collect addresses for hostnames collected in request queue.
simple example usecase collecting addresses associated with two domains:
asyncresolver = AsyncDNSResolver()
asyncresolver.add('www.example.com')
asyncresolver.add('mail.example.com')
asyncresolver.collect()
print(asyncresolver.addresses())
"""
batch_size = 100
report_size = 10000

def __init__(self, origin="<unknown>"):
self._request_queue = list()
self._requested = set()
self._response = set()
self._origin = origin
self._domains_queued = 0

def add(self, hostname):
self._request_queue.append(hostname)

async def request_ittr(self, loop):
dnsResolver = Resolver()
dnsResolver.timeout = 2
collected_errors = set()
while len(self._request_queue) > 0:
tasks = []
while len(tasks) < self.batch_size and len(self._request_queue) > 0:
hostname = self._request_queue.pop()
if hostname not in self._requested:
self._domains_queued += 1
# make sure we only request a host once
for record_type in ['A', 'AAAA']:
tasks.append(dnsResolver.resolve(hostname, record_type))
self._requested.add(hostname)
if len(tasks) > 0:
responses = await asyncio.gather(*tasks, return_exceptions=True)
for response in responses:
if type(response) is dns.resolver.Answer:
for item in response.response.answer:
if type(item) is dns.rrset.RRset:
for addr in item.items:
if addr.rdtype is RdataType.CNAME:
# query cname (recursion)
self._request_queue.append(addr.target)
else:
self._response.add(addr.address)
elif type(response) in [
dns.resolver.NoAnswer,
dns.resolver.NXDOMAIN,
dns.exception.Timeout,
dns.resolver.NoNameservers,
dns.name.EmptyLabel]:
if str(response) not in collected_errors:
syslog.syslog(syslog.LOG_ERR, '%s [for %s]' % (response, self._origin))
collected_errors.add(str(response))
if self._domains_queued % self.report_size == 0:
syslog.syslog(syslog.LOG_NOTICE, 'requested %d hostnames for %s' % (self._domains_queued, self._origin))

def collect(self):
if len(self._request_queue) > 0:
start_time = time.time()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
asyncio.run(self.request_ittr(loop))
loop.close()
syslog.syslog(syslog.LOG_NOTICE, 'resolving %d hostnames (%d addresses) for %s took %.2f seconds' % (
self._domains_queued, len(self._response), self._origin, time.time() - start_time
))
return self

def addresses(self):
return self._response
39 changes: 12 additions & 27 deletions src/opnsense/scripts/filter/lib/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import syslog
from hashlib import md5
from . import geoip
from . import net_wildcard_iterator
from . import net_wildcard_iterator, AsyncDNSResolver
from .arpcache import ArpCache

class Alias(object):
Expand All @@ -49,8 +49,6 @@ def __init__(self, elem, known_aliases=[], ttl=-1, ssl_no_verify=False, timeout=
:return: None
"""
self._known_aliases = known_aliases
self._dnsResolver = dns.resolver.Resolver()
self._dnsResolver.timeout = 2
self._is_changed = None
self._has_expired = None
self._ttl = ttl
Expand Down Expand Up @@ -85,6 +83,7 @@ def __init__(self, elem, known_aliases=[], ttl=-1, ssl_no_verify=False, timeout=
self._filename_alias_hash = '/var/db/aliastables/%s.md5.txt' % self._name
# the generated alias contents, without dependencies
self._filename_alias_content = '/var/db/aliastables/%s.self.txt' % self._name
self._dnsResolver = AsyncDNSResolver(self._name)

def _parse_address(self, address):
""" parse addresses and hostnames, yield only valid addresses and networks
Expand Down Expand Up @@ -125,19 +124,8 @@ def _parse_address(self, address):
except (ipaddress.AddressValueError, ValueError):
pass

# try to resolve provided address
could_resolve = False
for record_type in ['A', 'AAAA']:
try:
for rdata in self._dnsResolver.query(address, record_type):
yield str(rdata)
could_resolve = True
except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN, dns.exception.Timeout, dns.resolver.NoNameservers, dns.name.EmptyLabel):
pass

if not could_resolve:
# log when none could be found
syslog.syslog(syslog.LOG_ERR, 'unable to resolve %s for alias %s' % (address, self._name))
# try to resolve provided address (queue for retrieval)
self._dnsResolver.add(address)

def _fetch_url(self, url):
""" return unparsed (raw) alias entries without dependencies
Expand Down Expand Up @@ -244,18 +232,15 @@ def resolve(self, force=False):
else:
undo_content = ""
try:
address_parser = self.get_parser()
for item in self.items():
if address_parser:
for address in address_parser(item):
self._resolve_content.add(address)
# resolve hostnames (async) if there are any in the collected set
self._resolve_content = self._resolve_content.union(self._dnsResolver.collect().addresses())
with open(self._filename_alias_content, 'w') as f_out:
for item in self.items():
address_parser = self.get_parser()
if address_parser:
for address in address_parser(item):
if address not in self._resolve_content:
# flush new alias content (without dependencies) to disk, so progress can easliy
# be followed, large lists of domain names can take quite some resolve time.
f_out.write('%s\n' % address)
f_out.flush()
# preserve addresses
self._resolve_content.add(address)
f_out.write('\n'.join(self._resolve_content))
except IOError:
# parse issue, keep data as-is, flush previous content to disk
with open(self._filename_alias_content, 'w') as f_out:
Expand Down

0 comments on commit 76b8ae4

Please sign in to comment.