Skip to content

Commit

Permalink
test fixes for openshift_certificates_expiry
Browse files Browse the repository at this point in the history
- create pytest fixtures for building certs at runtime
- update tests to use the fixtures
- add tests for load_and_handle_cert
- fix py2/py3 encode/decode issues raised by tests
- add get_extension_count method to fakeOpenSSLCertificate
- avoid using a temp file for passing ssl certificate to openssl
  subprocess
- other test tweaks:
  - exclude conftest.py and tests from coverage report
  - reduce the fail_under to 26%, since the tests being included were
    inflating our coverage
  • Loading branch information
detiber committed Mar 2, 2017
1 parent 5a91f31 commit 293f185
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 352 deletions.
7 changes: 4 additions & 3 deletions .coveragerc
Expand Up @@ -8,11 +8,12 @@ omit =
# TODO(rhcarvalho): this is used to ignore test files from coverage report.
# We can make this less generic when we stick with a single test pattern in
# the repo.
test_*.py
*_tests.py
*/conftest.py
*/test_*.py
*/*_tests.py

[report]
fail_under = 28
fail_under = 26

[html]
directory = cover
116 changes: 42 additions & 74 deletions roles/openshift_certificate_expiry/library/openshift_cert_expiry.py
Expand Up @@ -8,29 +8,27 @@
import io
import os
import subprocess
import sys
import tempfile
import yaml

# File pointers from io.open require unicode inputs when using their
# `write` method
import six
from six.moves import configparser

import yaml
from ansible.module_utils.basic import AnsibleModule

try:
# You can comment this import out and include a 'pass' in this
# block if you're manually testing this module on a NON-ATOMIC
# HOST (or any host that just doesn't have PyOpenSSL
# available). That will force the `load_and_handle_cert` function
# to use the Fake OpenSSL classes.
import OpenSSL.crypto
HAS_OPENSSL = True
except ImportError:
# Some platforms (such as RHEL Atomic) may not have the Python
# OpenSSL library installed. In this case we will use a manual
# work-around to parse each certificate.
#
# Check for 'OpenSSL.crypto' in `sys.modules` later.
pass
HAS_OPENSSL = False

DOCUMENTATION = '''
---
Expand Down Expand Up @@ -158,6 +156,10 @@ def get_extension(self, i):
'subjectAltName'"""
return self.extensions[i]

def get_extension_count(self):
""" get_extension_count """
return len(self.extensions)

def get_notAfter(self):
"""Returns a date stamp as a string in the form
'20180922170439Z'. strptime the result with format param:
Expand Down Expand Up @@ -268,30 +270,23 @@ def load_and_handle_cert(cert_string, now, base64decode=False, ans_module=None):
# around a missing library on the target host.
#
# pylint: disable=redefined-variable-type
if 'OpenSSL.crypto' in sys.modules:
if HAS_OPENSSL:
# No work-around required
cert_loaded = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_PEM, _cert_string)
else:
# Missing library, work-around required. We need to write the
# cert out to disk temporarily so we can run the 'openssl'
# Missing library, work-around required. Run the 'openssl'
# command on it to decode it
_, path = tempfile.mkstemp()
with io.open(path, 'w') as fp:
fp.write(six.u(_cert_string))
fp.flush()

cmd = 'openssl x509 -in {} -text'.format(path)
cmd = 'openssl x509 -text'
try:
openssl_decoded = subprocess.Popen(cmd.split(),
stdout=subprocess.PIPE)
openssl_proc = subprocess.Popen(cmd.split(),
stdout=subprocess.PIPE,
stdin=subprocess.PIPE)
except OSError:
ans_module.fail_json(msg="Error: The 'OpenSSL' python library and CLI command were not found on the target host. Unable to parse any certificates. This host will not be included in generated reports.")
else:
openssl_decoded = openssl_decoded.communicate()[0]
openssl_decoded = openssl_proc.communicate(_cert_string.encode('utf-8'))[0].decode('utf-8')
cert_loaded = FakeOpenSSLCertificate(openssl_decoded)
finally:
os.remove(path)

######################################################################
# Read all possible names from the cert
Expand All @@ -301,34 +296,12 @@ def load_and_handle_cert(cert_string, now, base64decode=False, ans_module=None):

# To read SANs from a cert we must read the subjectAltName
# extension from the X509 Object. What makes this more difficult
# is that pyOpenSSL does not give extensions as a list, nor does
# it provide a count of all loaded extensions.
#
# Rather, extensions are REQUESTED by index. We must iterate over
# all extensions until we find the one called 'subjectAltName'. If
# we don't find that extension we'll eventually request an
# extension at an index where no extension exists (IndexError is
# raised). When that happens we know that the cert has no SANs so
# we break out of the loop.
i = 0
checked_all_extensions = False
while not checked_all_extensions:
try:
# Read the extension at index 'i'
ext = cert_loaded.get_extension(i)
except IndexError:
# We tried to read an extension but it isn't there, that
# means we ran out of extensions to check. Abort
san = None
checked_all_extensions = True
else:
# We were able to load the extension at index 'i'
if ext.get_short_name() == 'subjectAltName':
san = ext
checked_all_extensions = True
else:
# Try reading the next extension
i += 1
# is that pyOpenSSL does not give extensions as an iterable
san = None
for i in range(cert_loaded.get_extension_count()):
ext = cert_loaded.get_extension(i)
if ext.get_short_name() == 'subjectAltName':
san = ext

if san is not None:
# The X509Extension object for subjectAltName prints as a
Expand All @@ -341,9 +314,13 @@ def load_and_handle_cert(cert_string, now, base64decode=False, ans_module=None):
######################################################################

# Grab the expiration date
not_after = cert_loaded.get_notAfter()
# example get_notAfter() => 20180922170439Z
if isinstance(not_after, bytes):
not_after = not_after.decode('utf-8')

cert_expiry_date = datetime.datetime.strptime(
cert_loaded.get_notAfter(),
# example get_notAfter() => 20180922170439Z
not_after,
'%Y%m%d%H%M%SZ')

time_remaining = cert_expiry_date - now
Expand Down Expand Up @@ -455,13 +432,11 @@ def main():
)

# Basic scaffolding for OpenShift specific certs
openshift_base_config_path = module.params['config_base']
openshift_master_config_path = os.path.normpath(
os.path.join(openshift_base_config_path, "master/master-config.yaml")
)
openshift_node_config_path = os.path.normpath(
os.path.join(openshift_base_config_path, "node/node-config.yaml")
)
openshift_base_config_path = os.path.realpath(module.params['config_base'])
openshift_master_config_path = os.path.join(openshift_base_config_path,
"master", "master-config.yaml")
openshift_node_config_path = os.path.join(openshift_base_config_path,
"node", "node-config.yaml")
openshift_cert_check_paths = [
openshift_master_config_path,
openshift_node_config_path,
Expand All @@ -476,9 +451,7 @@ def main():
kubeconfig_paths = []
for m_kube_config in master_kube_configs:
kubeconfig_paths.append(
os.path.normpath(
os.path.join(openshift_base_config_path, "master/%s.kubeconfig" % m_kube_config)
)
os.path.join(openshift_base_config_path, "master", m_kube_config + ".kubeconfig")
)

# Validate some paths we have the ability to do ahead of time
Expand Down Expand Up @@ -527,7 +500,7 @@ def main():
######################################################################
for os_cert in filter_paths(openshift_cert_check_paths):
# Open up that config file and locate the cert and CA
with open(os_cert, 'r') as fp:
with io.open(os_cert, 'r', encoding='utf-8') as fp:
cert_meta = {}
cfg = yaml.load(fp)
# cert files are specified in parsed `fp` as relative to the path
Expand All @@ -542,7 +515,7 @@ def main():
# Load the certificate and the CA, parse their expiration dates into
# datetime objects so we can manipulate them later
for _, v in cert_meta.items():
with open(v, 'r') as fp:
with io.open(v, 'r', encoding='utf-8') as fp:
cert = fp.read()
(cert_subject,
cert_expiry_date,
Expand Down Expand Up @@ -575,7 +548,7 @@ def main():
try:
# Try to read the standard 'node-config.yaml' file to check if
# this host is a node.
with open(openshift_node_config_path, 'r') as fp:
with io.open(openshift_node_config_path, 'r', encoding='utf-8') as fp:
cfg = yaml.load(fp)

# OK, the config file exists, therefore this is a
Expand All @@ -588,7 +561,7 @@ def main():
cfg_path = os.path.dirname(fp.name)
node_kubeconfig = os.path.join(cfg_path, node_masterKubeConfig)

with open(node_kubeconfig, 'r') as fp:
with io.open(node_kubeconfig, 'r', encoding='utf8') as fp:
# Read in the nodes kubeconfig file and grab the good stuff
cfg = yaml.load(fp)

Expand All @@ -613,7 +586,7 @@ def main():
pass

for kube in filter_paths(kubeconfig_paths):
with open(kube, 'r') as fp:
with io.open(kube, 'r', encoding='utf-8') as fp:
# TODO: Maybe consider catching exceptions here?
cfg = yaml.load(fp)

Expand Down Expand Up @@ -656,7 +629,7 @@ def main():
etcd_certs = []
etcd_cert_params.append('dne')
try:
with open('/etc/etcd/etcd.conf', 'r') as fp:
with io.open('/etc/etcd/etcd.conf', 'r', encoding='utf-8') as fp:
etcd_config = configparser.ConfigParser()
# Reason: This check is disabled because the issue was introduced
# during a period where the pylint checks weren't enabled for this file
Expand All @@ -675,7 +648,7 @@ def main():
pass

for etcd_cert in filter_paths(etcd_certs_to_check):
with open(etcd_cert, 'r') as fp:
with io.open(etcd_cert, 'r', encoding='utf-8') as fp:
c = fp.read()
(cert_subject,
cert_expiry_date,
Expand All @@ -697,7 +670,7 @@ def main():
# Now the embedded etcd
######################################################################
try:
with open('/etc/origin/master/master-config.yaml', 'r') as fp:
with io.open('/etc/origin/master/master-config.yaml', 'r', encoding='utf-8') as fp:
cfg = yaml.load(fp)
except IOError:
# Not present
Expand Down Expand Up @@ -864,10 +837,5 @@ def cert_key(item):
)


######################################################################
# It's just the way we do things in Ansible. So disable this warning
#
# pylint: disable=wrong-import-position,import-error
from ansible.module_utils.basic import AnsibleModule # noqa: E402
if __name__ == '__main__':
main()
116 changes: 116 additions & 0 deletions roles/openshift_certificate_expiry/test/conftest.py
@@ -0,0 +1,116 @@
# pylint: disable=missing-docstring,invalid-name,redefined-outer-name
import pytest
from OpenSSL import crypto

# Parameter list for valid_cert fixture
VALID_CERTIFICATE_PARAMS = [
{
'short_name': 'client',
'cn': 'client.example.com',
'serial': 4,
'uses': b'clientAuth',
'dns': [],
'ip': [],
},
{
'short_name': 'server',
'cn': 'server.example.com',
'serial': 5,
'uses': b'serverAuth',
'dns': ['kubernetes', 'openshift'],
'ip': ['10.0.0.1', '192.168.0.1']
},
{
'short_name': 'combined',
'cn': 'combined.example.com',
'serial': 6,
'uses': b'clientAuth, serverAuth',
'dns': ['etcd'],
'ip': ['10.0.0.2', '192.168.0.2']
}
]

# Extract the short_name from VALID_CERTIFICATE_PARAMS to provide
# friendly naming for the valid_cert fixture
VALID_CERTIFICATE_IDS = [param['short_name'] for param in VALID_CERTIFICATE_PARAMS]


@pytest.fixture(scope='session')
def ca(tmpdir_factory):
ca_dir = tmpdir_factory.mktemp('ca')

key = crypto.PKey()
key.generate_key(crypto.TYPE_RSA, 2048)

cert = crypto.X509()
cert.set_version(3)
cert.set_serial_number(1)
cert.get_subject().commonName = 'test-signer'
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(24 * 60 * 60)
cert.set_issuer(cert.get_subject())
cert.set_pubkey(key)
cert.add_extensions([
crypto.X509Extension(b'basicConstraints', True, b'CA:TRUE, pathlen:0'),
crypto.X509Extension(b'keyUsage', True,
b'digitalSignature, keyEncipherment, keyCertSign, cRLSign'),
crypto.X509Extension(b'subjectKeyIdentifier', False, b'hash', subject=cert)
])
cert.add_extensions([
crypto.X509Extension(b'authorityKeyIdentifier', False, b'keyid:always', issuer=cert)
])
cert.sign(key, 'sha256')

return {
'dir': ca_dir,
'key': key,
'cert': cert,
}


@pytest.fixture(scope='session',
ids=VALID_CERTIFICATE_IDS,
params=VALID_CERTIFICATE_PARAMS)
def valid_cert(request, ca):
common_name = request.param['cn']

key = crypto.PKey()
key.generate_key(crypto.TYPE_RSA, 2048)

cert = crypto.X509()
cert.set_serial_number(request.param['serial'])
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(24 * 60 * 60)
cert.set_issuer(ca['cert'].get_subject())
cert.set_pubkey(key)
cert.set_version(3)
cert.get_subject().commonName = common_name
cert.add_extensions([
crypto.X509Extension(b'basicConstraints', True, b'CA:FALSE'),
crypto.X509Extension(b'keyUsage', True, b'digitalSignature, keyEncipherment'),
crypto.X509Extension(b'extendedKeyUsage', False, request.param['uses']),
])

if request.param['dns'] or request.param['ip']:
san_list = ['DNS:{}'.format(common_name)]
san_list.extend(['DNS:{}'.format(x) for x in request.param['dns']])
san_list.extend(['IP:{}'.format(x) for x in request.param['ip']])

cert.add_extensions([
crypto.X509Extension(b'subjectAltName', False, ', '.join(san_list).encode('utf8'))
])
cert.sign(ca['key'], 'sha256')

cert_contents = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
cert_file = ca['dir'].join('{}.crt'.format(common_name))
cert_file.write_binary(cert_contents)

return {
'common_name': common_name,
'serial': request.param['serial'],
'dns': request.param['dns'],
'ip': request.param['ip'],
'uses': request.param['uses'],
'cert_file': cert_file,
'cert': cert
}

0 comments on commit 293f185

Please sign in to comment.