Skip to content

Commit

Permalink
Added unit tests around nameserver selection
Browse files Browse the repository at this point in the history
  • Loading branch information
mschwager committed Sep 1, 2016
1 parent 30b7fcc commit f201e1f
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 8 deletions.
32 changes: 24 additions & 8 deletions fierce.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,17 +182,33 @@ def find_nearby(resolver, ips, filter_func=None):
pprint.pprint({k: v[0].to_text() for k, v in reversed_ips.items() if v})


def update_resolver_nameservers(resolver, nameservers, nameserver_filename):
"""
Update a resolver's nameservers. The following priority is taken:
1. Nameservers list provided as an argument
2. A filename containing a list of nameservers
3. The original nameservers associated with the resolver
"""
if nameservers:
resolver.nameservers = nameservers
elif nameserver_filename:
nameservers = [ns.strip() for ns in open(nameserver_filename).readlines()]
resolver.nameservers = nameservers
else:
# Use original nameservers
pass

return resolver


def fierce(**kwargs):
resolver = dns.resolver.Resolver()

nameservers = None
if kwargs.get('dns_servers'):
nameservers = kwargs['dns_servers']
elif kwargs.get('dns_file'):
nameservers = [ns.strip() for ns in open(kwargs["dns_file"]).readlines()]

if nameservers:
resolver.nameservers = nameservers
resolver = update_resolver_nameservers(
resolver,
kwargs['dns_servers'],
kwargs['dns_file']
)

if kwargs.get("range"):
internal_range = ipaddress.IPv4Network(kwargs.get("range"))
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
nose2
cov-core
coveralls
pyfakefs
116 changes: 116 additions & 0 deletions tests/test_filesystem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#!/usr/bin/env python

import os
import unittest

import dns.resolver

from pyfakefs import fake_filesystem_unittest

import fierce


CONTENTS = """nameserver1
nameserver2
nameserver3
"""


class TestFilesystem(fake_filesystem_unittest.TestCase):

def setUp(self):
self.setUpPyfakefs()

def tearDown(self):
# It is no longer necessary to add self.tearDownPyfakefs()
pass

def test_update_resolver_nameservers_empty_no_file(self):
nameserver_filename = None
nameservers = []

resolver = dns.resolver.Resolver()

expected = resolver.nameservers

result = fierce.update_resolver_nameservers(
resolver,
nameservers,
nameserver_filename
)

self.assertEqual(expected, result.nameservers)

def test_update_resolver_nameservers_single_nameserver_no_file(self):
nameserver_filename = None
nameservers = ['192.168.1.1']

resolver = dns.resolver.Resolver()

result = fierce.update_resolver_nameservers(
resolver,
nameservers,
nameserver_filename
)

expected = nameservers
self.assertEqual(expected, result.nameservers)

def test_update_resolver_nameservers_multiple_nameservers_no_file(self):
nameserver_filename = None
nameservers = ['192.168.1.1', '192.168.1.2']

resolver = dns.resolver.Resolver()

result = fierce.update_resolver_nameservers(
resolver,
nameservers,
nameserver_filename
)

expected = nameservers
self.assertEqual(expected, result.nameservers)

def test_update_resolver_nameservers_no_nameserver_use_file(self):
nameserver_filename = os.path.join("directory", "nameservers")
nameservers = []

self.fs.CreateFile(
nameserver_filename,
contents=CONTENTS
)

resolver = dns.resolver.Resolver()

result = fierce.update_resolver_nameservers(
resolver,
nameservers,
nameserver_filename
)

expected = CONTENTS.split()
self.assertEqual(expected, result.nameservers)

def test_update_resolver_nameservers_prefer_nameservers_over_file(self):
nameserver_filename = os.path.join("directory", "nameservers")
nameservers = ['192.168.1.1', '192.168.1.2']

self.fs.CreateFile(
nameserver_filename,
contents=CONTENTS
)

resolver = dns.resolver.Resolver()

result = fierce.update_resolver_nameservers(
resolver,
nameservers,
nameserver_filename
)

expected = nameservers
self.assertEqual(expected, result.nameservers)


if __name__ == "__main__":
unittest.main()

0 comments on commit f201e1f

Please sign in to comment.