Skip to content

Commit

Permalink
Update to pass the flake8 hook
Browse files Browse the repository at this point in the history
Make the necessary changes to pass the `flake8` pre-commit hook. This
is almost exclusively to satisfy the flake8-docstring plugin.
  • Loading branch information
mcdonnnj committed Jan 23, 2023
1 parent a8fb546 commit fa93631
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 22 deletions.
4 changes: 4 additions & 0 deletions src/trustymail/cli.py
@@ -1,4 +1,5 @@
"""trustymail: A tool for scanning DNS mail records for evaluating security.
Usage:
trustymail (INPUT ...) [options]
trustymail (INPUT ...) [--output=OUTFILE] [--timeout=TIMEOUT] [--smtp-timeout=TIMEOUT] [--smtp-localhost=HOSTNAME] [--smtp-ports=PORTS] [--no-smtp-cache] [--mx] [--starttls] [--spf] [--dmarc] [--debug] [--json] [--dns=HOSTNAMES] [--psl-filename=FILENAME] [--psl-read-only]
Expand Down Expand Up @@ -65,6 +66,7 @@


def main():
"""Perform a trustymail scan using the provided options."""
args = docopt.docopt(__doc__, version=__version__)

# Monkey patching trustymail to make it cache the PSL where we want
Expand Down Expand Up @@ -163,6 +165,7 @@ def main():


def write(content, out_file):
"""Write the provided content to a file after ensuring all intermediate directories exist."""
parent = os.path.dirname(out_file)
if parent != "":
mkdir_p(parent)
Expand All @@ -175,6 +178,7 @@ def write(content, out_file):
# mkdir -p in python, from:
# http://stackoverflow.com/questions/600268/mkdir-p-functionality-in-python
def mkdir_p(path):
"""Make a directory and all intermediate directories in its path."""
try:
os.makedirs(path)
except OSError as exc: # Python >2.5
Expand Down
39 changes: 27 additions & 12 deletions src/trustymail/domain.py
@@ -1,3 +1,5 @@
"""Provide a data model for domains and some utility functions."""

# Standard Python Libraries
from collections import OrderedDict
from datetime import datetime, timedelta
Expand All @@ -12,7 +14,7 @@

def get_psl():
"""
Gets the Public Suffix List - either new, or cached in the CWD for 24 hours
Get the Public Suffix List - either new, or cached in the CWD for 24 hours.
Returns
-------
Expand Down Expand Up @@ -42,14 +44,14 @@ def download_psl():


def get_public_suffix(domain):
"""Returns the public suffix of a given domain"""
"""Return the public suffix of a given domain."""
public_list = get_psl()

return public_list.get_public_suffix(domain)


def format_list(record_list):
"""Format a list into a string to increase readability in CSV"""
"""Format a list into a string to increase readability in CSV."""
# record_list should only be a list, not an integer, None, or
# anything else. Thus this if clause handles only empty
# lists. This makes a "null" appear in the JSON output for
Expand All @@ -61,7 +63,9 @@ def format_list(record_list):


class Domain:
base_domains: Dict[str, Domain] = {}
"""Store information about a domain."""

base_domains: Dict[str, "Domain"] = {}

def __init__(
self,
Expand All @@ -73,6 +77,7 @@ def __init__(
smtp_cache,
dns_hostnames,
):
"""Retrieve information about a given domain name."""
self.domain_name = domain_name.lower()

self.base_domain_name = get_public_suffix(self.domain_name)
Expand Down Expand Up @@ -137,15 +142,13 @@ def __init__(
self.ports_tested = set()

def has_mail(self):
"""Check if there are any mail servers associated with this domain."""
if self.mail_servers is not None:
return len(self.mail_servers) > 0
return None

def has_supports_smtp(self):
"""
Returns True if any of the mail servers associated with this
domain are listening and support SMTP.
"""
"""Check if any of the mail servers associated with this domain are listening and support SMTP."""
result = None
if len(self.starttls_results) > 0:
result = (
Expand All @@ -160,10 +163,7 @@ def has_supports_smtp(self):
return result

def has_starttls(self):
"""
Returns True if any of the mail servers associated with this
domain are listening and support STARTTLS.
"""
"""Check if any of the mail servers associated with this domain are listening and support STARTTLS."""
result = None
if len(self.starttls_results) > 0:
result = (
Expand All @@ -178,16 +178,19 @@ def has_starttls(self):
return result

def has_spf(self):
"""Check if this domain has any Sender Policy Framework records."""
if self.spf is not None:
return len(self.spf) > 0
return None

def has_dmarc(self):
"""Check if this domain has a Domain-based Message Authentication, Reporting, and Conformance record."""
if self.dmarc is not None:
return len(self.dmarc) > 0
return None

def add_mx_record(self, record):
"""Add a mail server record for this domain."""
if self.mx_records is None:
self.mx_records = []
self.mx_records.append(record)
Expand All @@ -198,30 +201,35 @@ def add_mx_record(self, record):
self.mail_servers.append(record.exchange.to_text().rstrip(".").lower())

def parent_has_dmarc(self):
"""Check if a domain or its parent has a Domain-based Message Authentication, Reporting, and Conformance record."""
ans = self.has_dmarc()
if self.base_domain:
ans = self.base_domain.has_dmarc()
return ans

def parent_dmarc_dnssec(self):
"""Get this domain or its parent's DMARC DNSSEC information."""
ans = self.dmarc_dnssec
if self.base_domain:
ans = self.base_domain.dmarc_dnssec
return ans

def parent_valid_dmarc(self):
"""Check if this domain or its parent have a valid DMARC record."""
ans = self.valid_dmarc
if self.base_domain:
return self.base_domain.valid_dmarc
return ans

def parent_dmarc_results(self):
"""Get this domain or its parent's DMARC information."""
ans = format_list(self.dmarc)
if self.base_domain:
ans = format_list(self.base_domain.dmarc)
return ans

def get_dmarc_policy(self):
"""Get this domain or its parent's DMARC policy."""
ans = self.dmarc_policy
# If the policy was never set, or isn't in the list of valid
# policies, check the parents.
Expand All @@ -239,6 +247,7 @@ def get_dmarc_policy(self):
return ans

def get_dmarc_subdomain_policy(self):
"""Get this domain or its parent's DMARC subdomain policy."""
ans = self.dmarc_subdomain_policy
# If the policy was never set, or isn't in the list of valid
# policies, check the parents.
Expand All @@ -250,41 +259,47 @@ def get_dmarc_subdomain_policy(self):
return ans

def get_dmarc_pct(self):
"""Get this domain or its parent's DMARC percentage information."""
ans = self.dmarc_pct
if not ans and self.base_domain:
# Check the parents
ans = self.base_domain.get_dmarc_pct()
return ans

def get_dmarc_has_aggregate_uri(self):
"""Get this domain or its parent's DMARC aggregate URI."""
ans = self.dmarc_has_aggregate_uri
# If there are no aggregate URIs then check the parents.
if not ans and self.base_domain:
ans = self.base_domain.get_dmarc_has_aggregate_uri()
return ans

def get_dmarc_has_forensic_uri(self):
"""Check if this domain or its parent have a DMARC forensic URI."""
ans = self.dmarc_has_forensic_uri
# If there are no forensic URIs then check the parents.
if not ans and self.base_domain:
ans = self.base_domain.get_dmarc_has_forensic_uri()
return ans

def get_dmarc_aggregate_uris(self):
"""Get this domain or its parent's DMARC aggregate URIs."""
ans = self.dmarc_aggregate_uris
# If there are no aggregate URIs then check the parents.
if not ans and self.base_domain:
ans = self.base_domain.get_dmarc_aggregate_uris()
return ans

def get_dmarc_forensic_uris(self):
"""Get this domain or its parent's DMARC forensic URIs."""
ans = self.dmarc_forensic_uris
# If there are no forensic URIs then check the parents.
if not ans and self.base_domain:
ans = self.base_domain.get_dmarc_forensic_uris()
return ans

def generate_results(self):
"""Generate the results for this domain."""
if len(self.starttls_results.keys()) == 0:
domain_supports_smtp = None
domain_supports_starttls = None
Expand Down
35 changes: 25 additions & 10 deletions src/trustymail/trustymail.py
@@ -1,3 +1,5 @@
"""Functions to check a domain's configuration for trustworthy mail."""

# Standard Python Libraries
from collections import OrderedDict
import csv
Expand Down Expand Up @@ -27,6 +29,7 @@


def domain_list_from_url(url):
"""Get a list of domains from a provided URL."""
if not url:
return []

Expand All @@ -38,6 +41,7 @@ def domain_list_from_url(url):


def domain_list_from_csv(csv_file):
"""Get a list of domains from a provided CSV file."""
domain_list = list(csv.reader(csv_file, delimiter=","))

# Check the headers for the word domain - use that column.
Expand All @@ -61,7 +65,9 @@ def domain_list_from_csv(csv_file):


def check_dnssec(domain, domain_name, record_type):
"""Checks whether the domain has a record of type that is protected
"""Test to see if a DNSSEC record is valid and correct.
Checks a domain for DNSSEC whether the domain has a record of type that is protected
by DNSSEC or NXDOMAIN or NoAnswer that is protected by DNSSEC.
TODO: Probably does not follow redirects (CNAMEs). Should work on
Expand All @@ -82,6 +88,7 @@ def check_dnssec(domain, domain_name, record_type):


def mx_scan(resolver, domain):
"""Scan a domain to see if it has any mail servers."""
try:
if domain.mx_records is None:
domain.mx_records = []
Expand Down Expand Up @@ -426,9 +433,10 @@ def get_spf_record_text(resolver, domain_name, domain, follow_redirect=False):


def spf_scan(resolver, domain):
"""Scan a domain to see if it supports SPF. If the domain has an SPF
record, verify that it properly handles mail sent from an IP known
not to be listed in an MX record for ANY domain.
"""Scan a domain to see if it supports SPF.
If the domain has an SPF record, verify that it properly handles mail sent from
an IP known not to be listed in an MX record for ANY domain.
Parameters
----------
Expand Down Expand Up @@ -460,7 +468,7 @@ def spf_scan(resolver, domain):

def parse_dmarc_report_uri(uri):
"""
Parses a DMARC Reporting (i.e. ``rua``/``ruf)`` URI
Parse a DMARC Reporting (i.e. ``rua``/``ruf)`` URI.
Notes
-----
Expand Down Expand Up @@ -492,6 +500,7 @@ def parse_dmarc_report_uri(uri):


def dmarc_scan(resolver, domain):
"""Scan a domain to see if it supports DMARC."""
# dmarc records are kept in TXT records for _dmarc.domain_name.
try:
if domain.dmarc is None:
Expand Down Expand Up @@ -773,6 +782,7 @@ def dmarc_scan(resolver, domain):


def find_host_from_ip(resolver, ip_addr):
"""Find the host name for a given IP address."""
# Use TCP, since we care about the content and correctness of the records
# more than whether their records fit in a single UDP packet.
hostname, _ = resolver.query(dns.reversename.from_address(ip_addr), "PTR", tcp=True)
Expand All @@ -789,6 +799,7 @@ def scan(
scan_types,
dns_hostnames,
):
"""Parse a domain's DNS information for mail related records."""
#
# Configure the dnspython library
#
Expand Down Expand Up @@ -878,9 +889,10 @@ def scan(


def handle_error(prefix, domain, error, syntax_error=False):
"""Handle an error by logging via the Python logging library and
recording it in the debug_info or syntax_error members of the
trustymail.Domain object.
"""Handle the provided error by logging a message and storing it in the Domain object.
Logging is performed via the Python logging library and recording it in the
debug_info or syntax_error members of the trustymail.Domain object.
Since the "Debug Info" and "Syntax Error" fields in the CSV output
of trustymail come directly from the debug_info and syntax_error
Expand Down Expand Up @@ -946,11 +958,12 @@ def handle_error(prefix, domain, error, syntax_error=False):


def handle_syntax_error(prefix, domain, error):
"""Convenience method for handle_error"""
"""Handle a syntax error by passing it to handle_error()."""
handle_error(prefix, domain, error, syntax_error=True)


def generate_csv(domains, file_name):
"""Generate a CSV file with the given domain information."""
with open(file_name, "w", encoding="utf-8", newline="\n") as output_file:
writer = csv.DictWriter(
output_file, fieldnames=domains[0].generate_results().keys()
Expand All @@ -965,6 +978,7 @@ def generate_csv(domains, file_name):


def generate_json(domains):
"""Generate a JSON string with the given domain information."""
output = []
for domain in domains:
output.append(domain.generate_results())
Expand All @@ -974,6 +988,7 @@ def generate_json(domains):

# Taken from pshtt to keep formatting similar
def format_datetime(obj):
"""Format the provided datetime information."""
if isinstance(obj, datetime.date):
return obj.isoformat()
elif isinstance(obj, str):
Expand All @@ -983,7 +998,7 @@ def format_datetime(obj):


def remove_quotes(txt_record):
"""Remove double quotes and contatenate strings in a DNS TXT record
"""Remove double quotes and contatenate strings in a DNS TXT record.
A DNS TXT record can contain multiple double-quoted strings, and
in that case the client has to remove the quotes and concatenate the
Expand Down
4 changes: 4 additions & 0 deletions tests/test_trustymail.py
@@ -1,7 +1,11 @@
"""Tests for the trustymail module."""
# Standard Python Libraries
import unittest


class TestLiveliness(unittest.TestCase):
"""Test the liveliness of a domain."""

def test_domain_list_parsing(self):
"""Test that a domain list is correctly parsed."""
pass

0 comments on commit fa93631

Please sign in to comment.