diff --git a/src/trustymail/cli.py b/src/trustymail/cli.py index 5331147..71847f4 100644 --- a/src/trustymail/cli.py +++ b/src/trustymail/cli.py @@ -1,4 +1,5 @@ """trustymail: A tool for scanning DNS mail records for evaluating security. + Usage: trustymail (INPUT ...) [options] trustymail (INPUT ...) [--output=OUTFILE] [--timeout=TIMEOUT] [--smtp-timeout=TIMEOUT] [--smtp-localhost=HOSTNAME] [--smtp-ports=PORTS] [--no-smtp-cache] [--mx] [--starttls] [--spf] [--dmarc] [--debug] [--json] [--dns=HOSTNAMES] [--psl-filename=FILENAME] [--psl-read-only] @@ -65,6 +66,7 @@ def main(): + """Perform a trustymail scan using the provided options.""" args = docopt.docopt(__doc__, version=__version__) # Monkey patching trustymail to make it cache the PSL where we want @@ -163,6 +165,7 @@ def main(): def write(content, out_file): + """Write the provided content to a file after ensuring all intermediate directories exist.""" parent = os.path.dirname(out_file) if parent != "": mkdir_p(parent) @@ -175,6 +178,7 @@ def write(content, out_file): # mkdir -p in python, from: # http://stackoverflow.com/questions/600268/mkdir-p-functionality-in-python def mkdir_p(path): + """Make a directory and all intermediate directories in its path.""" try: os.makedirs(path) except OSError as exc: # Python >2.5 diff --git a/src/trustymail/domain.py b/src/trustymail/domain.py index 8d797e0..a8067ca 100644 --- a/src/trustymail/domain.py +++ b/src/trustymail/domain.py @@ -1,3 +1,5 @@ +"""Provide a data model for domains and some utility functions.""" + # Standard Python Libraries from collections import OrderedDict from datetime import datetime, timedelta @@ -12,7 +14,7 @@ def get_psl(): """ - Gets the Public Suffix List - either new, or cached in the CWD for 24 hours + Get the Public Suffix List - either new, or cached in the CWD for 24 hours. Returns ------- @@ -42,14 +44,14 @@ def download_psl(): def get_public_suffix(domain): - """Returns the public suffix of a given domain""" + """Return the public suffix of a given domain.""" public_list = get_psl() return public_list.get_public_suffix(domain) def format_list(record_list): - """Format a list into a string to increase readability in CSV""" + """Format a list into a string to increase readability in CSV.""" # record_list should only be a list, not an integer, None, or # anything else. Thus this if clause handles only empty # lists. This makes a "null" appear in the JSON output for @@ -61,7 +63,9 @@ def format_list(record_list): class Domain: - base_domains: Dict[str, Domain] = {} + """Store information about a domain.""" + + base_domains: Dict[str, "Domain"] = {} def __init__( self, @@ -73,6 +77,7 @@ def __init__( smtp_cache, dns_hostnames, ): + """Retrieve information about a given domain name.""" self.domain_name = domain_name.lower() self.base_domain_name = get_public_suffix(self.domain_name) @@ -137,15 +142,13 @@ def __init__( self.ports_tested = set() def has_mail(self): + """Check if there are any mail servers associated with this domain.""" if self.mail_servers is not None: return len(self.mail_servers) > 0 return None def has_supports_smtp(self): - """ - Returns True if any of the mail servers associated with this - domain are listening and support SMTP. - """ + """Check if any of the mail servers associated with this domain are listening and support SMTP.""" result = None if len(self.starttls_results) > 0: result = ( @@ -160,10 +163,7 @@ def has_supports_smtp(self): return result def has_starttls(self): - """ - Returns True if any of the mail servers associated with this - domain are listening and support STARTTLS. - """ + """Check if any of the mail servers associated with this domain are listening and support STARTTLS.""" result = None if len(self.starttls_results) > 0: result = ( @@ -178,16 +178,19 @@ def has_starttls(self): return result def has_spf(self): + """Check if this domain has any Sender Policy Framework records.""" if self.spf is not None: return len(self.spf) > 0 return None def has_dmarc(self): + """Check if this domain has a Domain-based Message Authentication, Reporting, and Conformance record.""" if self.dmarc is not None: return len(self.dmarc) > 0 return None def add_mx_record(self, record): + """Add a mail server record for this domain.""" if self.mx_records is None: self.mx_records = [] self.mx_records.append(record) @@ -198,30 +201,35 @@ def add_mx_record(self, record): self.mail_servers.append(record.exchange.to_text().rstrip(".").lower()) def parent_has_dmarc(self): + """Check if a domain or its parent has a Domain-based Message Authentication, Reporting, and Conformance record.""" ans = self.has_dmarc() if self.base_domain: ans = self.base_domain.has_dmarc() return ans def parent_dmarc_dnssec(self): + """Get this domain or its parent's DMARC DNSSEC information.""" ans = self.dmarc_dnssec if self.base_domain: ans = self.base_domain.dmarc_dnssec return ans def parent_valid_dmarc(self): + """Check if this domain or its parent have a valid DMARC record.""" ans = self.valid_dmarc if self.base_domain: return self.base_domain.valid_dmarc return ans def parent_dmarc_results(self): + """Get this domain or its parent's DMARC information.""" ans = format_list(self.dmarc) if self.base_domain: ans = format_list(self.base_domain.dmarc) return ans def get_dmarc_policy(self): + """Get this domain or its parent's DMARC policy.""" ans = self.dmarc_policy # If the policy was never set, or isn't in the list of valid # policies, check the parents. @@ -239,6 +247,7 @@ def get_dmarc_policy(self): return ans def get_dmarc_subdomain_policy(self): + """Get this domain or its parent's DMARC subdomain policy.""" ans = self.dmarc_subdomain_policy # If the policy was never set, or isn't in the list of valid # policies, check the parents. @@ -250,6 +259,7 @@ def get_dmarc_subdomain_policy(self): return ans def get_dmarc_pct(self): + """Get this domain or its parent's DMARC percentage information.""" ans = self.dmarc_pct if not ans and self.base_domain: # Check the parents @@ -257,6 +267,7 @@ def get_dmarc_pct(self): return ans def get_dmarc_has_aggregate_uri(self): + """Get this domain or its parent's DMARC aggregate URI.""" ans = self.dmarc_has_aggregate_uri # If there are no aggregate URIs then check the parents. if not ans and self.base_domain: @@ -264,6 +275,7 @@ def get_dmarc_has_aggregate_uri(self): return ans def get_dmarc_has_forensic_uri(self): + """Check if this domain or its parent have a DMARC forensic URI.""" ans = self.dmarc_has_forensic_uri # If there are no forensic URIs then check the parents. if not ans and self.base_domain: @@ -271,6 +283,7 @@ def get_dmarc_has_forensic_uri(self): return ans def get_dmarc_aggregate_uris(self): + """Get this domain or its parent's DMARC aggregate URIs.""" ans = self.dmarc_aggregate_uris # If there are no aggregate URIs then check the parents. if not ans and self.base_domain: @@ -278,6 +291,7 @@ def get_dmarc_aggregate_uris(self): return ans def get_dmarc_forensic_uris(self): + """Get this domain or its parent's DMARC forensic URIs.""" ans = self.dmarc_forensic_uris # If there are no forensic URIs then check the parents. if not ans and self.base_domain: @@ -285,6 +299,7 @@ def get_dmarc_forensic_uris(self): return ans def generate_results(self): + """Generate the results for this domain.""" if len(self.starttls_results.keys()) == 0: domain_supports_smtp = None domain_supports_starttls = None diff --git a/src/trustymail/trustymail.py b/src/trustymail/trustymail.py index 6f07697..e1506c7 100644 --- a/src/trustymail/trustymail.py +++ b/src/trustymail/trustymail.py @@ -1,3 +1,5 @@ +"""Functions to check a domain's configuration for trustworthy mail.""" + # Standard Python Libraries from collections import OrderedDict import csv @@ -27,6 +29,7 @@ def domain_list_from_url(url): + """Get a list of domains from a provided URL.""" if not url: return [] @@ -38,6 +41,7 @@ def domain_list_from_url(url): def domain_list_from_csv(csv_file): + """Get a list of domains from a provided CSV file.""" domain_list = list(csv.reader(csv_file, delimiter=",")) # Check the headers for the word domain - use that column. @@ -61,7 +65,9 @@ def domain_list_from_csv(csv_file): def check_dnssec(domain, domain_name, record_type): - """Checks whether the domain has a record of type that is protected + """Test to see if a DNSSEC record is valid and correct. + + Checks a domain for DNSSEC whether the domain has a record of type that is protected by DNSSEC or NXDOMAIN or NoAnswer that is protected by DNSSEC. TODO: Probably does not follow redirects (CNAMEs). Should work on @@ -82,6 +88,7 @@ def check_dnssec(domain, domain_name, record_type): def mx_scan(resolver, domain): + """Scan a domain to see if it has any mail servers.""" try: if domain.mx_records is None: domain.mx_records = [] @@ -426,9 +433,10 @@ def get_spf_record_text(resolver, domain_name, domain, follow_redirect=False): def spf_scan(resolver, domain): - """Scan a domain to see if it supports SPF. If the domain has an SPF - record, verify that it properly handles mail sent from an IP known - not to be listed in an MX record for ANY domain. + """Scan a domain to see if it supports SPF. + + If the domain has an SPF record, verify that it properly handles mail sent from + an IP known not to be listed in an MX record for ANY domain. Parameters ---------- @@ -460,7 +468,7 @@ def spf_scan(resolver, domain): def parse_dmarc_report_uri(uri): """ - Parses a DMARC Reporting (i.e. ``rua``/``ruf)`` URI + Parse a DMARC Reporting (i.e. ``rua``/``ruf)`` URI. Notes ----- @@ -492,6 +500,7 @@ def parse_dmarc_report_uri(uri): def dmarc_scan(resolver, domain): + """Scan a domain to see if it supports DMARC.""" # dmarc records are kept in TXT records for _dmarc.domain_name. try: if domain.dmarc is None: @@ -773,6 +782,7 @@ def dmarc_scan(resolver, domain): def find_host_from_ip(resolver, ip_addr): + """Find the host name for a given IP address.""" # Use TCP, since we care about the content and correctness of the records # more than whether their records fit in a single UDP packet. hostname, _ = resolver.query(dns.reversename.from_address(ip_addr), "PTR", tcp=True) @@ -789,6 +799,7 @@ def scan( scan_types, dns_hostnames, ): + """Parse a domain's DNS information for mail related records.""" # # Configure the dnspython library # @@ -878,9 +889,10 @@ def scan( def handle_error(prefix, domain, error, syntax_error=False): - """Handle an error by logging via the Python logging library and - recording it in the debug_info or syntax_error members of the - trustymail.Domain object. + """Handle the provided error by logging a message and storing it in the Domain object. + + Logging is performed via the Python logging library and recording it in the + debug_info or syntax_error members of the trustymail.Domain object. Since the "Debug Info" and "Syntax Error" fields in the CSV output of trustymail come directly from the debug_info and syntax_error @@ -946,11 +958,12 @@ def handle_error(prefix, domain, error, syntax_error=False): def handle_syntax_error(prefix, domain, error): - """Convenience method for handle_error""" + """Handle a syntax error by passing it to handle_error().""" handle_error(prefix, domain, error, syntax_error=True) def generate_csv(domains, file_name): + """Generate a CSV file with the given domain information.""" with open(file_name, "w", encoding="utf-8", newline="\n") as output_file: writer = csv.DictWriter( output_file, fieldnames=domains[0].generate_results().keys() @@ -965,6 +978,7 @@ def generate_csv(domains, file_name): def generate_json(domains): + """Generate a JSON string with the given domain information.""" output = [] for domain in domains: output.append(domain.generate_results()) @@ -974,6 +988,7 @@ def generate_json(domains): # Taken from pshtt to keep formatting similar def format_datetime(obj): + """Format the provided datetime information.""" if isinstance(obj, datetime.date): return obj.isoformat() elif isinstance(obj, str): @@ -983,7 +998,7 @@ def format_datetime(obj): def remove_quotes(txt_record): - """Remove double quotes and contatenate strings in a DNS TXT record + """Remove double quotes and contatenate strings in a DNS TXT record. A DNS TXT record can contain multiple double-quoted strings, and in that case the client has to remove the quotes and concatenate the diff --git a/tests/test_trustymail.py b/tests/test_trustymail.py index c0c2c99..90f3794 100644 --- a/tests/test_trustymail.py +++ b/tests/test_trustymail.py @@ -1,7 +1,11 @@ +"""Tests for the trustymail module.""" # Standard Python Libraries import unittest class TestLiveliness(unittest.TestCase): + """Test the liveliness of a domain.""" + def test_domain_list_parsing(self): + """Test that a domain list is correctly parsed.""" pass