From 3877af6619f690e64fa750f83e12c7a744336a25 Mon Sep 17 00:00:00 2001 From: Joona Hoikkala Date: Thu, 21 Jun 2018 17:27:19 +0300 Subject: [PATCH] Gradually increasing HSTS max-age (#5912) This PR adds the functionality to enhance Apache configuration to include HTTP Strict Transport Security header with a low initial max-age value. The max-age value will get increased on every (scheduled) run of certbot renew regardless of the certificate actually getting renewed, if the last increase took place longer than ten hours ago. The increase steps are visible in constants.AUTOHSTS_STEPS. Upon the first actual renewal after reaching the maximum increase step, the max-age value will be made "permanent" and will get value of one year. To achieve accurate VirtualHost discovery on subsequent runs, a comment with unique id string will be added to each enhanced VirtualHost. * AutoHSTS code rebased on master * Fixes to match the changes in master * Make linter happy with metaclass registration * Address small review comments * Use new enhancement interfaces * New style enhancement changes * Do not allow --hsts and --auto-hsts simultaneuously * MyPy annotation fixes and added test * Change oldest requrements to point to local certbot core version * Enable new style enhancements for run and install verbs * Test refactor * New test class for main.install tests * Move a test to a correct test class --- certbot-apache/certbot_apache/apache_util.py | 6 + certbot-apache/certbot_apache/configurator.py | 312 +++++++++++++++++- certbot-apache/certbot_apache/constants.py | 13 + certbot-apache/certbot_apache/parser.py | 32 ++ .../certbot_apache/tests/autohsts_test.py | 181 ++++++++++ .../certbot_apache/tests/configurator_test.py | 15 + .../certbot_apache/tests/parser_test.py | 7 + certbot-apache/local-oldest-requirements.txt | 2 +- certbot-apache/setup.py | 2 +- certbot/cli.py | 8 + certbot/constants.py | 1 + certbot/main.py | 44 ++- certbot/plugins/enhancements.py | 159 +++++++++ certbot/plugins/enhancements_test.py | 65 ++++ certbot/tests/main_test.py | 82 ++++- certbot/tests/renewupdater_test.py | 87 +++-- certbot/updater.py | 46 +++ 17 files changed, 1027 insertions(+), 35 deletions(-) create mode 100644 certbot-apache/certbot_apache/tests/autohsts_test.py create mode 100644 certbot/plugins/enhancements.py create mode 100644 certbot/plugins/enhancements_test.py diff --git a/certbot-apache/certbot_apache/apache_util.py b/certbot-apache/certbot_apache/apache_util.py index f03c9da8793..62342004fc8 100644 --- a/certbot-apache/certbot_apache/apache_util.py +++ b/certbot-apache/certbot_apache/apache_util.py @@ -1,4 +1,5 @@ """ Utility functions for certbot-apache plugin """ +import binascii import os from certbot import util @@ -98,3 +99,8 @@ def parse_define_file(filepath, varname): var_parts = v[2:].partition("=") return_vars[var_parts[0]] = var_parts[2] return return_vars + + +def unique_id(): + """ Returns an unique id to be used as a VirtualHost identifier""" + return binascii.hexlify(os.urandom(16)).decode("utf-8") diff --git a/certbot-apache/certbot_apache/configurator.py b/certbot-apache/certbot_apache/configurator.py index 861fe4458d7..ab83a5332d9 100644 --- a/certbot-apache/certbot_apache/configurator.py +++ b/certbot-apache/certbot_apache/configurator.py @@ -13,7 +13,7 @@ import zope.interface from acme import challenges -from acme.magic_typing import DefaultDict, Dict, List, Set # pylint: disable=unused-import, no-name-in-module +from acme.magic_typing import Any, DefaultDict, Dict, List, Set, Union # pylint: disable=unused-import, no-name-in-module from certbot import errors from certbot import interfaces @@ -22,6 +22,7 @@ from certbot.achallenges import KeyAuthorizationAnnotatedChallenge # pylint: disable=unused-import from certbot.plugins import common from certbot.plugins.util import path_surgery +from certbot.plugins.enhancements import AutoHSTSEnhancement from certbot_apache import apache_util from certbot_apache import augeas_configurator @@ -160,6 +161,8 @@ def __init__(self, *args, **kwargs): self._wildcard_vhosts = dict() # type: Dict[str, List[obj.VirtualHost]] # Maps enhancements to vhosts we've enabled the enhancement for self._enhanced_vhosts = defaultdict(set) # type: DefaultDict[str, Set[obj.VirtualHost]] + # Temporary state for AutoHSTS enhancement + self._autohsts = {} # type: Dict[str, Dict[str, Union[int, float]]] # These will be set in the prepare function self.parser = None @@ -1472,6 +1475,67 @@ def _add_name_vhost_if_necessary(self, vhost): if need_to_save: self.save() + def find_vhost_by_id(self, id_str): + """ + Searches through VirtualHosts and tries to match the id in a comment + + :param str id_str: Id string for matching + + :returns: The matched VirtualHost or None + :rtype: :class:`~certbot_apache.obj.VirtualHost` or None + + :raises .errors.PluginError: If no VirtualHost is found + """ + + for vh in self.vhosts: + if self._find_vhost_id(vh) == id_str: + return vh + msg = "No VirtualHost with ID {} was found.".format(id_str) + logger.warning(msg) + raise errors.PluginError(msg) + + def _find_vhost_id(self, vhost): + """Tries to find the unique ID from the VirtualHost comments. This is + used for keeping track of VirtualHost directive over time. + + :param vhost: Virtual host to add the id + :type vhost: :class:`~certbot_apache.obj.VirtualHost` + + :returns: The unique ID or None + :rtype: str or None + """ + + # Strip the {} off from the format string + search_comment = constants.MANAGED_COMMENT_ID.format("") + + id_comment = self.parser.find_comments(search_comment, vhost.path) + if id_comment: + # Use the first value, multiple ones shouldn't exist + comment = self.parser.get_arg(id_comment[0]) + return comment.split(" ")[-1] + return None + + def add_vhost_id(self, vhost): + """Adds an unique ID to the VirtualHost as a comment for mapping back + to it on later invocations, as the config file order might have changed. + If ID already exists, returns that instead. + + :param vhost: Virtual host to add or find the id + :type vhost: :class:`~certbot_apache.obj.VirtualHost` + + :returns: The unique ID for vhost + :rtype: str or None + """ + + vh_id = self._find_vhost_id(vhost) + if vh_id: + return vh_id + + id_string = apache_util.unique_id() + comment = constants.MANAGED_COMMENT_ID.format(id_string) + self.parser.add_comment(vhost.path, comment) + return id_string + def _escape(self, fp): fp = fp.replace(",", "\\,") fp = fp.replace("[", "\\[") @@ -1531,6 +1595,78 @@ def enhance(self, domain, enhancement, options=None): logger.warning("Failed %s for %s", enhancement, domain) raise + def _autohsts_increase(self, vhost, id_str, nextstep): + """Increase the AutoHSTS max-age value + + :param vhost: Virtual host object to modify + :type vhost: :class:`~certbot_apache.obj.VirtualHost` + + :param str id_str: The unique ID string of VirtualHost + + :param int nextstep: Next AutoHSTS max-age value index + + """ + nextstep_value = constants.AUTOHSTS_STEPS[nextstep] + self._autohsts_write(vhost, nextstep_value) + self._autohsts[id_str] = {"laststep": nextstep, "timestamp": time.time()} + + def _autohsts_write(self, vhost, nextstep_value): + """ + Write the new HSTS max-age value to the VirtualHost file + """ + + hsts_dirpath = None + header_path = self.parser.find_dir("Header", None, vhost.path) + if header_path: + pat = '(?:[ "]|^)(strict-transport-security)(?:[ "]|$)' + for match in header_path: + if re.search(pat, self.aug.get(match).lower()): + hsts_dirpath = match + if not hsts_dirpath: + err_msg = ("Certbot was unable to find the existing HSTS header " + "from the VirtualHost at path {0}.").format(vhost.filep) + raise errors.PluginError(err_msg) + + # Prepare the HSTS header value + hsts_maxage = "\"max-age={0}\"".format(nextstep_value) + + # Update the header + # Our match statement was for string strict-transport-security, but + # we need to update the value instead. The next index is for the value + hsts_dirpath = hsts_dirpath.replace("arg[3]", "arg[4]") + self.aug.set(hsts_dirpath, hsts_maxage) + note_msg = ("Increasing HSTS max-age value to {0} for VirtualHost " + "in {1}\n".format(nextstep_value, vhost.filep)) + logger.debug(note_msg) + self.save_notes += note_msg + self.save(note_msg) + + def _autohsts_fetch_state(self): + """ + Populates the AutoHSTS state from the pluginstorage + """ + try: + self._autohsts = self.storage.fetch("autohsts") + except KeyError: + self._autohsts = dict() + + def _autohsts_save_state(self): + """ + Saves the state of AutoHSTS object to pluginstorage + """ + self.storage.put("autohsts", self._autohsts) + self.storage.save() + + def _autohsts_vhost_in_lineage(self, vhost, lineage): + """ + Searches AutoHSTS managed VirtualHosts that belong to the lineage. + Matches the private key path. + """ + + return bool( + self.parser.find_dir("SSLCertificateKeyFile", + lineage.key_path, vhost.path)) + def _enable_ocsp_stapling(self, ssl_vhost, unused_options): """Enables OCSP Stapling @@ -2158,3 +2294,177 @@ def install_ssl_options_conf(self, options_ssl, options_ssl_digest): # to be modified. return common.install_version_controlled_file(options_ssl, options_ssl_digest, self.constant("MOD_SSL_CONF_SRC"), constants.ALL_SSL_OPTIONS_HASHES) + + def enable_autohsts(self, _unused_lineage, domains): + """ + Enable the AutoHSTS enhancement for defined domains + + :param _unused_lineage: Certificate lineage object, unused + :type _unused_lineage: certbot.storage.RenewableCert + + :param domains: List of domains in certificate to enhance + :type domains: str + """ + + self._autohsts_fetch_state() + _enhanced_vhosts = [] + for d in domains: + matched_vhosts = self.choose_vhosts(d, create_if_no_ssl=False) + # We should be handling only SSL vhosts for AutoHSTS + vhosts = [vhost for vhost in matched_vhosts if vhost.ssl] + + if not vhosts: + msg_tmpl = ("Certbot was not able to find SSL VirtualHost for a " + "domain {0} for enabling AutoHSTS enhancement.") + msg = msg_tmpl.format(d) + logger.warning(msg) + raise errors.PluginError(msg) + for vh in vhosts: + try: + self._enable_autohsts_domain(vh) + _enhanced_vhosts.append(vh) + except errors.PluginEnhancementAlreadyPresent: + if vh in _enhanced_vhosts: + continue + msg = ("VirtualHost for domain {0} in file {1} has a " + + "String-Transport-Security header present, exiting.") + raise errors.PluginEnhancementAlreadyPresent( + msg.format(d, vh.filep)) + if _enhanced_vhosts: + note_msg = "Enabling AutoHSTS" + self.save(note_msg) + logger.info(note_msg) + self.restart() + + # Save the current state to pluginstorage + self._autohsts_save_state() + + def _enable_autohsts_domain(self, ssl_vhost): + """Do the initial AutoHSTS deployment to a vhost + + :param ssl_vhost: The VirtualHost object to deploy the AutoHSTS + :type ssl_vhost: :class:`~certbot_apache.obj.VirtualHost` or None + + :raises errors.PluginEnhancementAlreadyPresent: When already enhanced + + """ + # This raises the exception + self._verify_no_matching_http_header(ssl_vhost, + "Strict-Transport-Security") + + if "headers_module" not in self.parser.modules: + self.enable_mod("headers") + # Prepare the HSTS header value + hsts_header = constants.HEADER_ARGS["Strict-Transport-Security"][:-1] + initial_maxage = constants.AUTOHSTS_STEPS[0] + hsts_header.append("\"max-age={0}\"".format(initial_maxage)) + + # Add ID to the VirtualHost for mapping back to it later + uniq_id = self.add_vhost_id(ssl_vhost) + self.save_notes += "Adding unique ID {0} to VirtualHost in {1}\n".format( + uniq_id, ssl_vhost.filep) + # Add the actual HSTS header + self.parser.add_dir(ssl_vhost.path, "Header", hsts_header) + note_msg = ("Adding gradually increasing HSTS header with initial value " + "of {0} to VirtualHost in {1}\n".format( + initial_maxage, ssl_vhost.filep)) + self.save_notes += note_msg + + # Save the current state to pluginstorage + self._autohsts[uniq_id] = {"laststep": 0, "timestamp": time.time()} + + def update_autohsts(self, _unused_domain): + """ + Increase the AutoHSTS values of VirtualHosts that the user has enabled + this enhancement for. + + :param _unused_domain: Not currently used + :type _unused_domain: Not Available + + """ + self._autohsts_fetch_state() + if not self._autohsts: + # No AutoHSTS enabled for any domain + return + curtime = time.time() + save_and_restart = False + for id_str, config in list(self._autohsts.items()): + if config["timestamp"] + constants.AUTOHSTS_FREQ > curtime: + # Skip if last increase was < AUTOHSTS_FREQ ago + continue + nextstep = config["laststep"] + 1 + if nextstep < len(constants.AUTOHSTS_STEPS): + # Have not reached the max value yet + try: + vhost = self.find_vhost_by_id(id_str) + except errors.PluginError: + msg = ("Could not find VirtualHost with ID {0}, disabling " + "AutoHSTS for this VirtualHost").format(id_str) + logger.warning(msg) + # Remove the orphaned AutoHSTS entry from pluginstorage + self._autohsts.pop(id_str) + continue + self._autohsts_increase(vhost, id_str, nextstep) + msg = ("Increasing HSTS max-age value for VirtualHost with id " + "{0}").format(id_str) + self.save_notes += msg + save_and_restart = True + + if save_and_restart: + self.save("Increased HSTS max-age values") + self.restart() + + self._autohsts_save_state() + + def deploy_autohsts(self, lineage): + """ + Checks if autohsts vhost has reached maximum auto-increased value + and changes the HSTS max-age to a high value. + + :param lineage: Certificate lineage object + :type lineage: certbot.storage.RenewableCert + """ + self._autohsts_fetch_state() + if not self._autohsts: + # No autohsts enabled for any vhost + return + + vhosts = [] + affected_ids = [] + # Copy, as we are removing from the dict inside the loop + for id_str, config in list(self._autohsts.items()): + if config["laststep"]+1 >= len(constants.AUTOHSTS_STEPS): + # max value reached, try to make permanent + try: + vhost = self.find_vhost_by_id(id_str) + except errors.PluginError: + msg = ("VirtualHost with id {} was not found, unable to " + "make HSTS max-age permanent.").format(id_str) + logger.warning(msg) + self._autohsts.pop(id_str) + continue + if self._autohsts_vhost_in_lineage(vhost, lineage): + vhosts.append(vhost) + affected_ids.append(id_str) + + save_and_restart = False + for vhost in vhosts: + self._autohsts_write(vhost, constants.AUTOHSTS_PERMANENT) + msg = ("Strict-Transport-Security max-age value for " + "VirtualHost in {0} was made permanent.").format(vhost.filep) + logger.debug(msg) + self.save_notes += msg+"\n" + save_and_restart = True + + if save_and_restart: + self.save("Made HSTS max-age permanent") + self.restart() + + for id_str in affected_ids: + self._autohsts.pop(id_str) + + # Update AutoHSTS storage (We potentially removed vhosts from managed) + self._autohsts_save_state() + + +AutoHSTSEnhancement.register(ApacheConfigurator) # pylint: disable=no-member diff --git a/certbot-apache/certbot_apache/constants.py b/certbot-apache/certbot_apache/constants.py index fd6a9eb11fe..23a7b7afd03 100644 --- a/certbot-apache/certbot_apache/constants.py +++ b/certbot-apache/certbot_apache/constants.py @@ -48,3 +48,16 @@ HEADER_ARGS = {"Strict-Transport-Security": HSTS_ARGS, "Upgrade-Insecure-Requests": UIR_ARGS} + +AUTOHSTS_STEPS = [60, 300, 900, 3600, 21600, 43200, 86400] +"""AutoHSTS increase steps: 1min, 5min, 15min, 1h, 6h, 12h, 24h""" + +AUTOHSTS_PERMANENT = 31536000 +"""Value for the last max-age of HSTS""" + +AUTOHSTS_FREQ = 172800 +"""Minimum time since last increase to perform a new one: 48h""" + +MANAGED_COMMENT = "DO NOT REMOVE - Managed by Certbot" +MANAGED_COMMENT_ID = MANAGED_COMMENT+", VirtualHost id: {0}" +"""Managed by Certbot comments and the VirtualHost identification template""" diff --git a/certbot-apache/certbot_apache/parser.py b/certbot-apache/certbot_apache/parser.py index 43878eda227..02337f8d451 100644 --- a/certbot-apache/certbot_apache/parser.py +++ b/certbot-apache/certbot_apache/parser.py @@ -16,6 +16,7 @@ class ApacheParser(object): + # pylint: disable=too-many-public-methods """Class handles the fine details of parsing the Apache Configuration. .. todo:: Make parsing general... remove sites-available etc... @@ -350,6 +351,37 @@ def add_dir_beginning(self, aug_conf_path, dirname, args): else: self.aug.set(first_dir + "/arg", args) + def add_comment(self, aug_conf_path, comment): + """Adds the comment to the augeas path + + :param str aug_conf_path: Augeas configuration path to add directive + :param str comment: Comment content + + """ + self.aug.set(aug_conf_path + "/#comment[last() + 1]", comment) + + def find_comments(self, arg, start=None): + """Finds a comment with specified content from the provided DOM path + + :param str arg: Comment content to search + :param str start: Beginning Augeas path to begin looking + + :returns: List of augeas paths containing the comment content + :rtype: list + + """ + if not start: + start = get_aug_path(self.root) + + comments = self.aug.match("%s//*[label() = '#comment']" % start) + + results = [] + for comment in comments: + c_content = self.aug.get(comment) + if c_content and arg in c_content: + results.append(comment) + return results + def find_dir(self, directive, arg=None, start=None, exclude=True): """Finds directive in the configuration. diff --git a/certbot-apache/certbot_apache/tests/autohsts_test.py b/certbot-apache/certbot_apache/tests/autohsts_test.py new file mode 100644 index 00000000000..86d985079f5 --- /dev/null +++ b/certbot-apache/certbot_apache/tests/autohsts_test.py @@ -0,0 +1,181 @@ +# pylint: disable=too-many-public-methods,too-many-lines +"""Test for certbot_apache.configurator AutoHSTS functionality""" +import re +import unittest +import mock +# six is used in mock.patch() +import six # pylint: disable=unused-import + +from certbot import errors +from certbot_apache import constants +from certbot_apache.tests import util + + +class AutoHSTSTest(util.ApacheTest): + """Tests for AutoHSTS feature""" + # pylint: disable=protected-access + + def setUp(self): # pylint: disable=arguments-differ + super(AutoHSTSTest, self).setUp() + + self.config = util.get_apache_configurator( + self.config_path, self.vhost_path, self.config_dir, self.work_dir) + self.config.parser.modules.add("headers_module") + self.config.parser.modules.add("mod_headers.c") + self.config.parser.modules.add("ssl_module") + self.config.parser.modules.add("mod_ssl.c") + + self.vh_truth = util.get_vh_truth( + self.temp_dir, "debian_apache_2_4/multiple_vhosts") + + def get_autohsts_value(self, vh_path): + """ Get value from Strict-Transport-Security header """ + header_path = self.config.parser.find_dir("Header", None, vh_path) + if header_path: + pat = '(?:[ "]|^)(strict-transport-security)(?:[ "]|$)' + for head in header_path: + if re.search(pat, self.config.parser.aug.get(head).lower()): + return self.config.parser.aug.get(head.replace("arg[3]", + "arg[4]")) + + @mock.patch("certbot_apache.configurator.ApacheConfigurator.restart") + @mock.patch("certbot_apache.configurator.ApacheConfigurator.enable_mod") + def test_autohsts_enable_headers_mod(self, mock_enable, _restart): + self.config.parser.modules.discard("headers_module") + self.config.parser.modules.discard("mod_header.c") + self.config.enable_autohsts(mock.MagicMock(), ["ocspvhost.com"]) + self.assertTrue(mock_enable.called) + + @mock.patch("certbot_apache.configurator.ApacheConfigurator.restart") + def test_autohsts_deploy_already_exists(self, _restart): + self.config.enable_autohsts(mock.MagicMock(), ["ocspvhost.com"]) + self.assertRaises(errors.PluginEnhancementAlreadyPresent, + self.config.enable_autohsts, + mock.MagicMock(), ["ocspvhost.com"]) + + @mock.patch("certbot_apache.constants.AUTOHSTS_FREQ", 0) + @mock.patch("certbot_apache.configurator.ApacheConfigurator.restart") + def test_autohsts_increase(self, _mock_restart): + maxage = "\"max-age={0}\"" + initial_val = maxage.format(constants.AUTOHSTS_STEPS[0]) + inc_val = maxage.format(constants.AUTOHSTS_STEPS[1]) + + self.config.enable_autohsts(mock.MagicMock(), ["ocspvhost.com"]) + # Verify initial value + self.assertEquals(self.get_autohsts_value(self.vh_truth[7].path), + initial_val) + # Increase + self.config.update_autohsts(mock.MagicMock()) + # Verify increased value + self.assertEquals(self.get_autohsts_value(self.vh_truth[7].path), + inc_val) + + @mock.patch("certbot_apache.configurator.ApacheConfigurator.restart") + @mock.patch("certbot_apache.configurator.ApacheConfigurator._autohsts_increase") + def test_autohsts_increase_noop(self, mock_increase, _restart): + maxage = "\"max-age={0}\"" + initial_val = maxage.format(constants.AUTOHSTS_STEPS[0]) + self.config.enable_autohsts(mock.MagicMock(), ["ocspvhost.com"]) + # Verify initial value + self.assertEquals(self.get_autohsts_value(self.vh_truth[7].path), + initial_val) + + self.config.update_autohsts(mock.MagicMock()) + # Freq not patched, so value shouldn't increase + self.assertFalse(mock_increase.called) + + + @mock.patch("certbot_apache.configurator.ApacheConfigurator.restart") + @mock.patch("certbot_apache.constants.AUTOHSTS_FREQ", 0) + def test_autohsts_increase_no_header(self, _restart): + self.config.enable_autohsts(mock.MagicMock(), ["ocspvhost.com"]) + # Remove the header + dir_locs = self.config.parser.find_dir("Header", None, + self.vh_truth[7].path) + dir_loc = "/".join(dir_locs[0].split("/")[:-1]) + self.config.parser.aug.remove(dir_loc) + self.assertRaises(errors.PluginError, + self.config.update_autohsts, + mock.MagicMock()) + + @mock.patch("certbot_apache.constants.AUTOHSTS_FREQ", 0) + @mock.patch("certbot_apache.configurator.ApacheConfigurator.restart") + def test_autohsts_increase_and_make_permanent(self, _mock_restart): + maxage = "\"max-age={0}\"" + max_val = maxage.format(constants.AUTOHSTS_PERMANENT) + mock_lineage = mock.MagicMock() + mock_lineage.key_path = "/etc/apache2/ssl/key-certbot_15.pem" + self.config.enable_autohsts(mock.MagicMock(), ["ocspvhost.com"]) + for i in range(len(constants.AUTOHSTS_STEPS)-1): + # Ensure that value is not made permanent prematurely + self.config.deploy_autohsts(mock_lineage) + self.assertNotEquals(self.get_autohsts_value(self.vh_truth[7].path), + max_val) + self.config.update_autohsts(mock.MagicMock()) + # Value should match pre-permanent increment step + cur_val = maxage.format(constants.AUTOHSTS_STEPS[i+1]) + self.assertEquals(self.get_autohsts_value(self.vh_truth[7].path), + cur_val) + # Make permanent + self.config.deploy_autohsts(mock_lineage) + self.assertEquals(self.get_autohsts_value(self.vh_truth[7].path), + max_val) + + def test_autohsts_update_noop(self): + with mock.patch("time.time") as mock_time: + # Time mock is used to make sure that the execution does not + # continue when no autohsts entries exist in pluginstorage + self.config.update_autohsts(mock.MagicMock()) + self.assertFalse(mock_time.called) + + def test_autohsts_make_permanent_noop(self): + self.config.storage.put = mock.MagicMock() + self.config.deploy_autohsts(mock.MagicMock()) + # Make sure that the execution does not continue when no entries in store + self.assertFalse(self.config.storage.put.called) + + @mock.patch("certbot_apache.display_ops.select_vhost") + def test_autohsts_no_ssl_vhost(self, mock_select): + mock_select.return_value = self.vh_truth[0] + with mock.patch("certbot_apache.configurator.logger.warning") as mock_log: + self.assertRaises(errors.PluginError, + self.config.enable_autohsts, + mock.MagicMock(), "invalid.example.com") + self.assertTrue( + "Certbot was not able to find SSL" in mock_log.call_args[0][0]) + + @mock.patch("certbot_apache.configurator.ApacheConfigurator.restart") + @mock.patch("certbot_apache.configurator.ApacheConfigurator.add_vhost_id") + def test_autohsts_dont_enhance_twice(self, mock_id, _restart): + mock_id.return_value = "1234567" + self.config.enable_autohsts(mock.MagicMock(), + ["ocspvhost.com", "ocspvhost.com"]) + self.assertEquals(mock_id.call_count, 1) + + def test_autohsts_remove_orphaned(self): + # pylint: disable=protected-access + self.config._autohsts_fetch_state() + self.config._autohsts["orphan_id"] = {"laststep": 0, "timestamp": 0} + + self.config._autohsts_save_state() + self.config.update_autohsts(mock.MagicMock()) + self.assertFalse("orphan_id" in self.config._autohsts) + # Make sure it's removed from the pluginstorage file as well + self.config._autohsts = None + self.config._autohsts_fetch_state() + self.assertFalse(self.config._autohsts) + + def test_autohsts_make_permanent_vhost_not_found(self): + # pylint: disable=protected-access + self.config._autohsts_fetch_state() + self.config._autohsts["orphan_id"] = {"laststep": 999, "timestamp": 0} + self.config._autohsts_save_state() + with mock.patch("certbot_apache.configurator.logger.warning") as mock_log: + self.config.deploy_autohsts(mock.MagicMock()) + self.assertTrue(mock_log.called) + self.assertTrue( + "VirtualHost with id orphan_id was not" in mock_log.call_args[0][0]) + + +if __name__ == "__main__": + unittest.main() # pragma: no cover diff --git a/certbot-apache/certbot_apache/tests/configurator_test.py b/certbot-apache/certbot_apache/tests/configurator_test.py index 23c1ee82bd8..35026263449 100644 --- a/certbot-apache/certbot_apache/tests/configurator_test.py +++ b/certbot-apache/certbot_apache/tests/configurator_test.py @@ -1487,6 +1487,21 @@ def test_enhance_wildcard_no_install(self, mock_choose): "Upgrade-Insecure-Requests") self.assertTrue(mock_choose.called) + def test_add_vhost_id(self): + for vh in [self.vh_truth[0], self.vh_truth[1], self.vh_truth[2]]: + vh_id = self.config.add_vhost_id(vh) + self.assertEqual(vh, self.config.find_vhost_by_id(vh_id)) + + def test_find_vhost_by_id_404(self): + self.assertRaises(errors.PluginError, + self.config.find_vhost_by_id, + "nonexistent") + + def test_add_vhost_id_already_exists(self): + first_id = self.config.add_vhost_id(self.vh_truth[0]) + second_id = self.config.add_vhost_id(self.vh_truth[0]) + self.assertEqual(first_id, second_id) + class AugeasVhostsTest(util.ApacheTest): """Test vhosts with illegal names dependent on augeas version.""" diff --git a/certbot-apache/certbot_apache/tests/parser_test.py b/certbot-apache/certbot_apache/tests/parser_test.py index 4496781c9e6..f95f1b34680 100644 --- a/certbot-apache/certbot_apache/tests/parser_test.py +++ b/certbot-apache/certbot_apache/tests/parser_test.py @@ -299,6 +299,13 @@ def test_update_runtime_vars_bad_exit(self, mock_popen): errors.MisconfigurationError, self.parser.update_runtime_variables) + def test_add_comment(self): + from certbot_apache.parser import get_aug_path + self.parser.add_comment(get_aug_path(self.parser.loc["name"]), "123456") + comm = self.parser.find_comments("123456") + self.assertEquals(len(comm), 1) + self.assertTrue(self.parser.loc["name"] in comm[0]) + class ParserInitTest(util.ApacheTest): def setUp(self): # pylint: disable=arguments-differ diff --git a/certbot-apache/local-oldest-requirements.txt b/certbot-apache/local-oldest-requirements.txt index 4e4aadbd8ae..e70ac0c7f2d 100644 --- a/certbot-apache/local-oldest-requirements.txt +++ b/certbot-apache/local-oldest-requirements.txt @@ -1,2 +1,2 @@ acme[dev]==0.25.0 -certbot[dev]==0.21.1 +-e .[dev] diff --git a/certbot-apache/setup.py b/certbot-apache/setup.py index 7a0dc43bfc8..8d15e79713d 100644 --- a/certbot-apache/setup.py +++ b/certbot-apache/setup.py @@ -8,7 +8,7 @@ # acme/certbot version. install_requires = [ 'acme>0.24.0', - 'certbot>=0.21.1', + 'certbot>0.25.1', 'mock', 'python-augeas', 'setuptools', diff --git a/certbot/cli.py b/certbot/cli.py index 24e0bddac1d..81e926e2f4f 100644 --- a/certbot/cli.py +++ b/certbot/cli.py @@ -32,6 +32,7 @@ from certbot.display import util as display_util from certbot.plugins import disco as plugins_disco +import certbot.plugins.enhancements as enhancements import certbot.plugins.selection as plugin_selection logger = logging.getLogger(__name__) @@ -627,6 +628,10 @@ def parse_args(self): raise errors.Error("Using --allow-subset-of-names with a" " wildcard domain is not supported.") + if parsed_args.hsts and parsed_args.auto_hsts: + raise errors.Error( + "Parameters --hsts and --auto-hsts cannot be used simultaneously.") + possible_deprecation_warning(parsed_args) return parsed_args @@ -1228,6 +1233,9 @@ def prepare_and_parse_args(plugins, args, detect_defaults=False): # pylint: dis helpful.add_deprecated_argument("--agree-dev-preview", 0) helpful.add_deprecated_argument("--dialog", 0) + # Populate the command line parameters for new style enhancements + enhancements.populate_cli(helpful.add) + _create_subparsers(helpful) _paths_parser(helpful) # _plugins_parsing should be the last thing to act upon the main diff --git a/certbot/constants.py b/certbot/constants.py index 93bc269af63..7047424e5b7 100644 --- a/certbot/constants.py +++ b/certbot/constants.py @@ -58,6 +58,7 @@ rsa_key_size=2048, must_staple=False, redirect=None, + auto_hsts=False, hsts=None, uir=None, staple=None, diff --git a/certbot/main.py b/certbot/main.py index 089eab0c3dd..a457919d41c 100644 --- a/certbot/main.py +++ b/certbot/main.py @@ -36,7 +36,7 @@ from certbot.display import util as display_util, ops as display_ops from certbot.plugins import disco as plugins_disco from certbot.plugins import selection as plug_sel - +from certbot.plugins import enhancements USER_CANCELLED = ("User chose to cancel the operation and may " "reinvoke the client.") @@ -750,8 +750,8 @@ def _install_cert(config, le_client, domains, lineage=None): :param le_client: Client object :type le_client: client.Client - :param plugins: List of domains - :type plugins: `list` of `str` + :param domains: List of domains + :type domains: `list` of `str` :param lineage: Certificate lineage object. Defaults to `None` :type lineage: storage.RenewableCert @@ -789,11 +789,19 @@ def install(config, plugins): except errors.PluginSelectionError as e: return str(e) + if not enhancements.are_supported(config, installer): + raise errors.NotSupportedError("One ore more of the requested enhancements " + "are not supported by the selected installer") # If cert-path is defined, populate missing (ie. not overridden) values. # Unfortunately this can't be done in argument parser, as certificate # manager needs the access to renewal directory paths if config.certname: config = _populate_from_certname(config) + elif enhancements.are_requested(config): + # Preflight config check + raise errors.ConfigurationError("One or more of the requested enhancements " + "require --cert-name to be provided") + if config.key_path and config.cert_path: _check_certificate_and_key(config) domains, _ = _find_domains_or_certname(config, installer) @@ -804,6 +812,11 @@ def install(config, plugins): "If your certificate is managed by Certbot, please use --cert-name " "to define which certificate you would like to install.") + if enhancements.are_requested(config): + # In the case where we don't have certname, we have errored out already + lineage = cert_manager.lineage_for_certname(config, config.certname) + enhancements.enable(lineage, domains, installer, config) + def _populate_from_certname(config): """Helper function for install to populate missing config values from lineage defined by --cert-name.""" @@ -881,7 +894,8 @@ def enhance(config, plugins): """ supported_enhancements = ["hsts", "redirect", "uir", "staple"] # Check that at least one enhancement was requested on command line - if not any([getattr(config, enh) for enh in supported_enhancements]): + oldstyle_enh = any([getattr(config, enh) for enh in supported_enhancements]) + if not enhancements.are_requested(config) and not oldstyle_enh: msg = ("Please specify one or more enhancement types to configure. To list " "the available enhancement types, run:\n\n%s --help enhance\n") logger.warning(msg, sys.argv[0]) @@ -892,6 +906,10 @@ def enhance(config, plugins): except errors.PluginSelectionError as e: return str(e) + if not enhancements.are_supported(config, installer): + raise errors.NotSupportedError("One ore more of the requested enhancements " + "are not supported by the selected installer") + certname_question = ("Which certificate would you like to use to enhance " "your configuration?") config.certname = cert_manager.get_certnames( @@ -907,11 +925,15 @@ def enhance(config, plugins): if not domains: raise errors.Error("User cancelled the domain selection. No domains " "defined, exiting.") + + lineage = cert_manager.lineage_for_certname(config, config.certname) if not config.chain_path: - lineage = cert_manager.lineage_for_certname(config, config.certname) config.chain_path = lineage.chain_path - le_client = _init_le_client(config, authenticator=None, installer=installer) - le_client.enhance_config(domains, config.chain_path, ask_redirect=False) + if oldstyle_enh: + le_client = _init_le_client(config, authenticator=None, installer=installer) + le_client.enhance_config(domains, config.chain_path, ask_redirect=False) + if enhancements.are_requested(config): + enhancements.enable(lineage, domains, installer, config) def rollback(config, plugins): @@ -1073,6 +1095,11 @@ def run(config, plugins): # pylint: disable=too-many-branches,too-many-locals except errors.PluginSelectionError as e: return str(e) + # Preflight check for enhancement support by the selected installer + if not enhancements.are_supported(config, installer): + raise errors.NotSupportedError("One ore more of the requested enhancements " + "are not supported by the selected installer") + # TODO: Handle errors from _init_le_client? le_client = _init_le_client(config, authenticator, installer) @@ -1091,6 +1118,9 @@ def run(config, plugins): # pylint: disable=too-many-branches,too-many-locals _install_cert(config, le_client, domains, new_lineage) + if enhancements.are_requested(config) and new_lineage: + enhancements.enable(new_lineage, domains, installer, config) + if lineage is None or not should_get_cert: display_ops.success_installation(domains) else: diff --git a/certbot/plugins/enhancements.py b/certbot/plugins/enhancements.py new file mode 100644 index 00000000000..506abe433b8 --- /dev/null +++ b/certbot/plugins/enhancements.py @@ -0,0 +1,159 @@ +"""New interface style Certbot enhancements""" +import abc +import six + +from certbot import constants + +from acme.magic_typing import Dict, List, Any # pylint: disable=unused-import, no-name-in-module + +def enabled_enhancements(config): + """ + Generator to yield the enabled new style enhancements. + + :param config: Configuration. + :type config: :class:`certbot.interfaces.IConfig` + """ + for enh in _INDEX: + if getattr(config, enh["cli_dest"]): + yield enh + +def are_requested(config): + """ + Checks if one or more of the requested enhancements are those of the new + enhancement interfaces. + + :param config: Configuration. + :type config: :class:`certbot.interfaces.IConfig` + """ + return any(enabled_enhancements(config)) + +def are_supported(config, installer): + """ + Checks that all of the requested enhancements are supported by the + installer. + + :param config: Configuration. + :type config: :class:`certbot.interfaces.IConfig` + + :param installer: Installer object + :type installer: interfaces.IInstaller + + :returns: If all the requested enhancements are supported by the installer + :rtype: bool + """ + for enh in enabled_enhancements(config): + if not isinstance(installer, enh["class"]): + return False + return True + +def enable(lineage, domains, installer, config): + """ + Run enable method for each requested enhancement that is supported. + + :param lineage: Certificate lineage object + :type lineage: certbot.storage.RenewableCert + + :param domains: List of domains in certificate to enhance + :type domains: str + + :param installer: Installer object + :type installer: interfaces.IInstaller + + :param config: Configuration. + :type config: :class:`certbot.interfaces.IConfig` + """ + for enh in enabled_enhancements(config): + getattr(installer, enh["enable_function"])(lineage, domains) + +def populate_cli(add): + """ + Populates the command line flags for certbot.cli.HelpfulParser + + :param add: Add function of certbot.cli.HelpfulParser + :type add: func + """ + for enh in _INDEX: + add(enh["cli_groups"], enh["cli_flag"], action=enh["cli_action"], + dest=enh["cli_dest"], default=enh["cli_flag_default"], + help=enh["cli_help"]) + + +@six.add_metaclass(abc.ABCMeta) +class AutoHSTSEnhancement(object): + """ + Enhancement interface that installer plugins can implement in order to + provide functionality that configures the software to have a + 'Strict-Transport-Security' with initially low max-age value that will + increase over time. + + The plugins implementing new style enhancements are responsible of handling + the saving of configuration checkpoints as well as calling possible restarts + of managed software themselves. + + Methods: + enable_autohsts is called when the header is initially installed using a + low max-age value. + + update_autohsts is called every time when Certbot is run using 'renew' + verb. The max-age value should be increased over time using this method. + + deploy_autohsts is called for every lineage that has had its certificate + renewed. A long HSTS max-age value should be set here, as we should be + confident that the user is able to automatically renew their certificates. + + + """ + + @abc.abstractmethod + def update_autohsts(self, lineage, *args, **kwargs): + """ + Gets called for each lineage every time Certbot is run with 'renew' verb. + Implementation of this method should increase the max-age value. + + :param lineage: Certificate lineage object + :type lineage: certbot.storage.RenewableCert + """ + + @abc.abstractmethod + def deploy_autohsts(self, lineage, *args, **kwargs): + """ + Gets called for a lineage when its certificate is successfully renewed. + Long max-age value should be set in implementation of this method. + + :param lineage: Certificate lineage object + :type lineage: certbot.storage.RenewableCert + """ + + @abc.abstractmethod + def enable_autohsts(self, lineage, domains, *args, **kwargs): + """ + Enables the AutoHSTS enhancement, installing + Strict-Transport-Security header with a low initial value to be increased + over the subsequent runs of Certbot renew. + + :param lineage: Certificate lineage object + :type lineage: certbot.storage.RenewableCert + + :param domains: List of domains in certificate to enhance + :type domains: str + """ + +# This is used to configure internal new style enhancements in Certbot. These +# enhancement interfaces need to be defined in this file. Please do not modify +# this list from plugin code. +_INDEX = [ + { + "name": "AutoHSTS", + "cli_help": "Gradually increasing max-age value for HTTP Strict Transport "+ + "Security security header", + "cli_flag": "--auto-hsts", + "cli_flag_default": constants.CLI_DEFAULTS["auto_hsts"], + "cli_groups": ["security", "enhance"], + "cli_dest": "auto_hsts", + "cli_action": "store_true", + "class": AutoHSTSEnhancement, + "updater_function": "update_autohsts", + "deployer_function": "deploy_autohsts", + "enable_function": "enable_autohsts" + } +] # type: List[Dict[str, Any]] diff --git a/certbot/plugins/enhancements_test.py b/certbot/plugins/enhancements_test.py new file mode 100644 index 00000000000..b69dc98363b --- /dev/null +++ b/certbot/plugins/enhancements_test.py @@ -0,0 +1,65 @@ +"""Tests for new style enhancements""" +import unittest +import mock + +from certbot.plugins import enhancements +from certbot.plugins import null + +import certbot.tests.util as test_util + + +class EnhancementTest(test_util.ConfigTestCase): + """Tests for new style enhancements in certbot.plugins.enhancements""" + + def setUp(self): + super(EnhancementTest, self).setUp() + self.mockinstaller = mock.MagicMock(spec=enhancements.AutoHSTSEnhancement) + + + @test_util.patch_get_utility() + def test_enhancement_enabled_enhancements(self, _): + FAKEINDEX = [ + { + "name": "autohsts", + "cli_dest": "auto_hsts", + }, + { + "name": "somethingelse", + "cli_dest": "something", + } + ] + with mock.patch("certbot.plugins.enhancements._INDEX", FAKEINDEX): + self.config.auto_hsts = True + self.config.something = True + enabled = list(enhancements.enabled_enhancements(self.config)) + self.assertEqual(len(enabled), 2) + self.assertTrue([i for i in enabled if i["name"] == "autohsts"]) + self.assertTrue([i for i in enabled if i["name"] == "somethingelse"]) + + def test_are_requested(self): + self.assertEquals( + len([i for i in enhancements.enabled_enhancements(self.config)]), 0) + self.assertFalse(enhancements.are_requested(self.config)) + self.config.auto_hsts = True + self.assertEquals( + len([i for i in enhancements.enabled_enhancements(self.config)]), 1) + self.assertTrue(enhancements.are_requested(self.config)) + + def test_are_supported(self): + self.config.auto_hsts = True + unsupported = null.Installer(self.config, "null") + self.assertTrue(enhancements.are_supported(self.config, self.mockinstaller)) + self.assertFalse(enhancements.are_supported(self.config, unsupported)) + + def test_enable(self): + self.config.auto_hsts = True + domains = ["example.com", "www.example.com"] + lineage = "lineage" + enhancements.enable(lineage, domains, self.mockinstaller, self.config) + self.assertTrue(self.mockinstaller.enable_autohsts.called) + self.assertEquals(self.mockinstaller.enable_autohsts.call_args[0], + (lineage, domains)) + + +if __name__ == '__main__': + unittest.main() # pragma: no cover diff --git a/certbot/tests/main_test.py b/certbot/tests/main_test.py index 4b251c42118..24d059e53b7 100644 --- a/certbot/tests/main_test.py +++ b/certbot/tests/main_test.py @@ -29,7 +29,9 @@ from certbot import util from certbot.plugins import disco +from certbot.plugins import enhancements from certbot.plugins import manual +from certbot.plugins import null import certbot.tests.util as test_util @@ -52,10 +54,11 @@ def test_handle_identical_cert_request_pending(self): self.assertEqual(ret, ("reinstall", mock_lineage)) -class RunTest(unittest.TestCase): +class RunTest(test_util.ConfigTestCase): """Tests for certbot.main.run.""" def setUp(self): + super(RunTest, self).setUp() self.domain = 'example.org' self.patches = [ mock.patch('certbot.main._get_and_save_cert'), @@ -105,6 +108,15 @@ def test_renewal_success(self): self._call() self.mock_success_renewal.assert_called_once_with([self.domain]) + @mock.patch('certbot.main.plug_sel.choose_configurator_plugins') + def test_run_enhancement_not_supported(self, mock_choose): + mock_choose.return_value = (null.Installer(self.config, "null"), None) + plugins = disco.PluginsRegistry.find_all() + self.config.auto_hsts = True + self.assertRaises(errors.NotSupportedError, + main.run, + self.config, plugins) + class CertonlyTest(unittest.TestCase): """Tests for certbot.main.certonly.""" @@ -1573,12 +1585,14 @@ def test_it(self, mock_util): strict=self.config.strict_permissions) -class EnhanceTest(unittest.TestCase): +class EnhanceTest(test_util.ConfigTestCase): """Tests for certbot.main.enhance.""" def setUp(self): + super(EnhanceTest, self).setUp() self.get_utility_patch = test_util.patch_get_utility() self.mock_get_utility = self.get_utility_patch.start() + self.mockinstaller = mock.MagicMock(spec=enhancements.AutoHSTSEnhancement) def tearDown(self): self.get_utility_patch.stop() @@ -1670,7 +1684,7 @@ def test_user_abort_domains(self, _rec, mock_choose): def test_no_enhancements_defined(self): self.assertRaises(errors.MisconfigurationError, - self._call, ['enhance']) + self._call, ['enhance', '-a', 'null']) @mock.patch('certbot.main.plug_sel.choose_configurator_plugins') @mock.patch('certbot.main.display_ops.choose_values') @@ -1682,5 +1696,67 @@ def test_plugin_selection_error(self, _rec, mock_choose, mock_pick): mock_client = self._call(['enhance', '--hsts']) self.assertFalse(mock_client.enhance_config.called) + @mock.patch('certbot.cert_manager.lineage_for_certname') + @mock.patch('certbot.main.display_ops.choose_values') + @mock.patch('certbot.main.plug_sel.pick_installer') + @mock.patch('certbot.main.plug_sel.record_chosen_plugins') + @test_util.patch_get_utility() + def test_enhancement_enable(self, _, _rec, mock_inst, mock_choose, mock_lineage): + mock_inst.return_value = self.mockinstaller + mock_choose.return_value = ["example.com", "another.tld"] + mock_lineage.return_value = mock.MagicMock(chain_path="/tmp/nonexistent") + self._call(['enhance', '--auto-hsts']) + self.assertTrue(self.mockinstaller.enable_autohsts.called) + self.assertEquals(self.mockinstaller.enable_autohsts.call_args[0][1], + ["example.com", "another.tld"]) + + @mock.patch('certbot.cert_manager.lineage_for_certname') + @mock.patch('certbot.main.display_ops.choose_values') + @mock.patch('certbot.main.plug_sel.pick_installer') + @mock.patch('certbot.main.plug_sel.record_chosen_plugins') + @test_util.patch_get_utility() + def test_enhancement_enable_not_supported(self, _, _rec, mock_inst, mock_choose, mock_lineage): + mock_inst.return_value = null.Installer(self.config, "null") + mock_choose.return_value = ["example.com", "another.tld"] + mock_lineage.return_value = mock.MagicMock(chain_path="/tmp/nonexistent") + self.assertRaises( + errors.NotSupportedError, + self._call, ['enhance', '--auto-hsts']) + + def test_enhancement_enable_conflict(self): + self.assertRaises( + errors.Error, + self._call, ['enhance', '--auto-hsts', '--hsts']) + + +class InstallTest(test_util.ConfigTestCase): + """Tests for certbot.main.install.""" + + def setUp(self): + super(InstallTest, self).setUp() + self.mockinstaller = mock.MagicMock(spec=enhancements.AutoHSTSEnhancement) + + @mock.patch('certbot.main.plug_sel.record_chosen_plugins') + @mock.patch('certbot.main.plug_sel.pick_installer') + def test_install_enhancement_not_supported(self, mock_inst, _rec): + mock_inst.return_value = null.Installer(self.config, "null") + plugins = disco.PluginsRegistry.find_all() + self.config.auto_hsts = True + self.assertRaises(errors.NotSupportedError, + main.install, + self.config, plugins) + + @mock.patch('certbot.main.plug_sel.record_chosen_plugins') + @mock.patch('certbot.main.plug_sel.pick_installer') + def test_install_enhancement_no_certname(self, mock_inst, _rec): + mock_inst.return_value = self.mockinstaller + plugins = disco.PluginsRegistry.find_all() + self.config.auto_hsts = True + self.config.certname = None + self.assertRaises(errors.ConfigurationError, + main.install, + self.config, plugins) + + if __name__ == '__main__': unittest.main() # pragma: no cover diff --git a/certbot/tests/renewupdater_test.py b/certbot/tests/renewupdater_test.py index bd1cd891e39..c1b97843e23 100644 --- a/certbot/tests/renewupdater_test.py +++ b/certbot/tests/renewupdater_test.py @@ -6,6 +6,8 @@ from certbot import main from certbot import updater +from certbot.plugins import enhancements + import certbot.tests.util as test_util @@ -14,25 +16,10 @@ class RenewUpdaterTest(test_util.ConfigTestCase): def setUp(self): super(RenewUpdaterTest, self).setUp() - class MockInstallerGenericUpdater(interfaces.GenericUpdater): - """Mock class that implements GenericUpdater""" - def __init__(self, *args, **kwargs): - # pylint: disable=unused-argument - self.restart = mock.MagicMock() - self.callcounter = mock.MagicMock() - def generic_updates(self, lineage, *args, **kwargs): - self.callcounter(*args, **kwargs) - - class MockInstallerRenewDeployer(interfaces.RenewDeployer): - """Mock class that implements RenewDeployer""" - def __init__(self, *args, **kwargs): - # pylint: disable=unused-argument - self.callcounter = mock.MagicMock() - def renew_deploy(self, lineage, *args, **kwargs): - self.callcounter(*args, **kwargs) - - self.generic_updater = MockInstallerGenericUpdater() - self.renew_deployer = MockInstallerRenewDeployer() + self.generic_updater = mock.MagicMock(spec=interfaces.GenericUpdater) + self.generic_updater.restart = mock.MagicMock() + self.renew_deployer = mock.MagicMock(spec=interfaces.RenewDeployer) + self.mockinstaller = mock.MagicMock(spec=enhancements.AutoHSTSEnhancement) @mock.patch('certbot.main._get_and_save_cert') @mock.patch('certbot.plugins.selection.choose_configurator_plugins') @@ -48,16 +35,16 @@ def test_server_updates(self, _, mock_select, mock_getsave): self.assertTrue(mock_generic_updater.restart.called) mock_generic_updater.restart.reset_mock() - mock_generic_updater.callcounter.reset_mock() + mock_generic_updater.generic_updates.reset_mock() updater.run_generic_updaters(self.config, mock.MagicMock(), None) - self.assertEqual(mock_generic_updater.callcounter.call_count, 1) + self.assertEqual(mock_generic_updater.generic_updates.call_count, 1) self.assertFalse(mock_generic_updater.restart.called) def test_renew_deployer(self): lineage = mock.MagicMock() mock_deployer = self.renew_deployer updater.run_renewal_deployer(self.config, lineage, mock_deployer) - self.assertTrue(mock_deployer.callcounter.called_with(lineage)) + self.assertTrue(mock_deployer.renew_deploy.called_with(lineage)) @mock.patch("certbot.updater.logger.debug") def test_updater_skip_dry_run(self, mock_log): @@ -75,6 +62,62 @@ def test_deployer_skip_dry_run(self, mock_log): self.assertEquals(mock_log.call_args[0][0], "Skipping renewal deployer in dry-run mode.") + @mock.patch('certbot.plugins.selection.choose_configurator_plugins') + def test_enhancement_updates(self, mock_select): + mock_select.return_value = (self.mockinstaller, None) + updater.run_generic_updaters(self.config, mock.MagicMock(), None) + self.assertTrue(self.mockinstaller.update_autohsts.called) + self.assertEqual(self.mockinstaller.update_autohsts.call_count, 1) + + def test_enhancement_deployer(self): + updater.run_renewal_deployer(self.config, mock.MagicMock(), + self.mockinstaller) + self.assertTrue(self.mockinstaller.deploy_autohsts.called) + + @mock.patch('certbot.plugins.selection.choose_configurator_plugins') + def test_enhancement_updates_not_called(self, mock_select): + self.config.disable_renew_updates = True + mock_select.return_value = (self.mockinstaller, None) + updater.run_generic_updaters(self.config, mock.MagicMock(), None) + self.assertFalse(self.mockinstaller.update_autohsts.called) + + def test_enhancement_deployer_not_called(self): + self.config.disable_renew_updates = True + updater.run_renewal_deployer(self.config, mock.MagicMock(), + self.mockinstaller) + self.assertFalse(self.mockinstaller.deploy_autohsts.called) + + @mock.patch('certbot.plugins.selection.choose_configurator_plugins') + def test_enhancement_no_updater(self, mock_select): + FAKEINDEX = [ + { + "name": "Test", + "class": enhancements.AutoHSTSEnhancement, + "updater_function": None, + "deployer_function": "deploy_autohsts", + "enable_function": "enable_autohsts" + } + ] + mock_select.return_value = (self.mockinstaller, None) + with mock.patch("certbot.plugins.enhancements._INDEX", FAKEINDEX): + updater.run_generic_updaters(self.config, mock.MagicMock(), None) + self.assertFalse(self.mockinstaller.update_autohsts.called) + + def test_enhancement_no_deployer(self): + FAKEINDEX = [ + { + "name": "Test", + "class": enhancements.AutoHSTSEnhancement, + "updater_function": "deploy_autohsts", + "deployer_function": None, + "enable_function": "enable_autohsts" + } + ] + with mock.patch("certbot.plugins.enhancements._INDEX", FAKEINDEX): + updater.run_renewal_deployer(self.config, mock.MagicMock(), + self.mockinstaller) + self.assertFalse(self.mockinstaller.deploy_autohsts.called) + if __name__ == '__main__': unittest.main() # pragma: no cover diff --git a/certbot/updater.py b/certbot/updater.py index 112cf06efeb..fb7c52f7798 100644 --- a/certbot/updater.py +++ b/certbot/updater.py @@ -5,6 +5,7 @@ from certbot import interfaces from certbot.plugins import selection as plug_sel +import certbot.plugins.enhancements as enhancements logger = logging.getLogger(__name__) @@ -33,6 +34,7 @@ def run_generic_updaters(config, lineage, plugins): logger.warning("Could not choose appropriate plugin for updaters: %s", e) return _run_updaters(lineage, installer, config) + _run_enhancement_updaters(lineage, installer, config) def run_renewal_deployer(config, lineage, installer): """Helper function to run deployer interface method if supported by the used @@ -57,6 +59,7 @@ def run_renewal_deployer(config, lineage, installer): if not config.disable_renew_updates and isinstance(installer, interfaces.RenewDeployer): installer.renew_deploy(lineage) + _run_enhancement_deployers(lineage, installer, config) def _run_updaters(lineage, installer, config): """Helper function to run the updater interface methods if supported by the @@ -74,3 +77,46 @@ def _run_updaters(lineage, installer, config): if not config.disable_renew_updates: if isinstance(installer, interfaces.GenericUpdater): installer.generic_updates(lineage) + +def _run_enhancement_updaters(lineage, installer, config): + """Iterates through known enhancement interfaces. If the installer implements + an enhancement interface and the enhance interface has an updater method, the + updater method gets run. + + :param lineage: Certificate lineage object + :type lineage: storage.RenewableCert + + :param installer: Installer object + :type installer: interfaces.IInstaller + + :param config: Configuration object + :type config: interfaces.IConfig + """ + + if config.disable_renew_updates: + return + for enh in enhancements._INDEX: # pylint: disable=protected-access + if isinstance(installer, enh["class"]) and enh["updater_function"]: + getattr(installer, enh["updater_function"])(lineage) + + +def _run_enhancement_deployers(lineage, installer, config): + """Iterates through known enhancement interfaces. If the installer implements + an enhancement interface and the enhance interface has an deployer method, the + deployer method gets run. + + :param lineage: Certificate lineage object + :type lineage: storage.RenewableCert + + :param installer: Installer object + :type installer: interfaces.IInstaller + + :param config: Configuration object + :type config: interfaces.IConfig + """ + + if config.disable_renew_updates: + return + for enh in enhancements._INDEX: # pylint: disable=protected-access + if isinstance(installer, enh["class"]) and enh["deployer_function"]: + getattr(installer, enh["deployer_function"])(lineage)