Skip to content

Commit

Permalink
Use pytest assertions (#9585)
Browse files Browse the repository at this point in the history
* run unittest2pytest

The command used here was `unittest2pytest -nw acme/tests certbot*/tests`.

* fix with pytest.raises

* add parens to fix refactoring

* <= not <
  • Loading branch information
bmw committed Feb 16, 2023
1 parent fedb0b5 commit a3c9371
Show file tree
Hide file tree
Showing 79 changed files with 4,060 additions and 4,233 deletions.
159 changes: 76 additions & 83 deletions acme/tests/challenges_test.py

Large diffs are not rendered by default.

245 changes: 121 additions & 124 deletions acme/tests/client_test.py

Large diffs are not rendered by default.

123 changes: 62 additions & 61 deletions acme/tests/crypto_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,20 @@ def _start_server(self):

def test_probe_ok(self):
self._start_server()
self.assertEqual(self.cert, self._probe(b'foo'))
assert self.cert == self._probe(b'foo')

def test_probe_not_recognized_name(self):
self._start_server()
self.assertRaises(errors.Error, self._probe, b'bar')
with pytest.raises(errors.Error):
self._probe(b'bar')

def test_probe_connection_error(self):
self.server.server_close()
original_timeout = socket.getdefaulttimeout()
try:
socket.setdefaulttimeout(1)
self.assertRaises(errors.Error, self._probe, b'bar')
with pytest.raises(errors.Error):
self._probe(b'bar')
finally:
socket.setdefaulttimeout(original_timeout)

Expand All @@ -77,10 +79,10 @@ class SSLSocketTest(unittest.TestCase):

def test_ssl_socket_invalid_arguments(self):
from acme.crypto_util import SSLSocket
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
_ = SSLSocket(None, {'sni': ('key', 'cert')},
cert_selection=lambda _: None)
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
_ = SSLSocket(None)


Expand All @@ -97,15 +99,15 @@ def _call_cert(self, name):
return self._call(test_util.load_cert, name)

def test_cert_one_san_no_common(self):
self.assertEqual(self._call_cert('cert-nocn.der'),
['no-common-name.badssl.com'])
assert self._call_cert('cert-nocn.der') == \
['no-common-name.badssl.com']

def test_cert_no_sans_yes_common(self):
self.assertEqual(self._call_cert('cert.pem'), ['example.com'])
assert self._call_cert('cert.pem') == ['example.com']

def test_cert_two_sans_yes_common(self):
self.assertEqual(self._call_cert('cert-san.pem'),
['example.com', 'www.example.com'])
assert self._call_cert('cert-san.pem') == \
['example.com', 'www.example.com']


class PyOpenSSLCertOrReqSANTest(unittest.TestCase):
Expand Down Expand Up @@ -133,47 +135,47 @@ def _call_csr(self, name):
return self._call(test_util.load_csr, name)

def test_cert_no_sans(self):
self.assertEqual(self._call_cert('cert.pem'), [])
assert self._call_cert('cert.pem') == []

def test_cert_two_sans(self):
self.assertEqual(self._call_cert('cert-san.pem'),
['example.com', 'www.example.com'])
assert self._call_cert('cert-san.pem') == \
['example.com', 'www.example.com']

def test_cert_hundred_sans(self):
self.assertEqual(self._call_cert('cert-100sans.pem'),
['example{0}.com'.format(i) for i in range(1, 101)])
assert self._call_cert('cert-100sans.pem') == \
['example{0}.com'.format(i) for i in range(1, 101)]

def test_cert_idn_sans(self):
self.assertEqual(self._call_cert('cert-idnsans.pem'),
self._get_idn_names())
assert self._call_cert('cert-idnsans.pem') == \
self._get_idn_names()

def test_csr_no_sans(self):
self.assertEqual(self._call_csr('csr-nosans.pem'), [])
assert self._call_csr('csr-nosans.pem') == []

def test_csr_one_san(self):
self.assertEqual(self._call_csr('csr.pem'), ['example.com'])
assert self._call_csr('csr.pem') == ['example.com']

def test_csr_two_sans(self):
self.assertEqual(self._call_csr('csr-san.pem'),
['example.com', 'www.example.com'])
assert self._call_csr('csr-san.pem') == \
['example.com', 'www.example.com']

def test_csr_six_sans(self):
self.assertEqual(self._call_csr('csr-6sans.pem'),
assert self._call_csr('csr-6sans.pem') == \
['example.com', 'example.org', 'example.net',
'example.info', 'subdomain.example.com',
'other.subdomain.example.com'])
'other.subdomain.example.com']

def test_csr_hundred_sans(self):
self.assertEqual(self._call_csr('csr-100sans.pem'),
['example{0}.com'.format(i) for i in range(1, 101)])
assert self._call_csr('csr-100sans.pem') == \
['example{0}.com'.format(i) for i in range(1, 101)]

def test_csr_idn_sans(self):
self.assertEqual(self._call_csr('csr-idnsans.pem'),
self._get_idn_names())
assert self._call_csr('csr-idnsans.pem') == \
self._get_idn_names()

def test_critical_san(self):
self.assertEqual(self._call_cert('critical-san.pem'),
['chicago-cubs.venafi.example', 'cubs.venafi.example'])
assert self._call_cert('critical-san.pem') == \
['chicago-cubs.venafi.example', 'cubs.venafi.example']


class PyOpenSSLCertOrReqSANIPTest(unittest.TestCase):
Expand All @@ -192,30 +194,30 @@ def _call_csr(self, name):
return self._call(test_util.load_csr, name)

def test_cert_no_sans(self):
self.assertEqual(self._call_cert('cert.pem'), [])
assert self._call_cert('cert.pem') == []

def test_csr_no_sans(self):
self.assertEqual(self._call_csr('csr-nosans.pem'), [])
assert self._call_csr('csr-nosans.pem') == []

def test_cert_domain_sans(self):
self.assertEqual(self._call_cert('cert-san.pem'), [])
assert self._call_cert('cert-san.pem') == []

def test_csr_domain_sans(self):
self.assertEqual(self._call_csr('csr-san.pem'), [])
assert self._call_csr('csr-san.pem') == []

def test_cert_ip_two_sans(self):
self.assertEqual(self._call_cert('cert-ipsans.pem'), ['192.0.2.145', '203.0.113.1'])
assert self._call_cert('cert-ipsans.pem') == ['192.0.2.145', '203.0.113.1']

def test_csr_ip_two_sans(self):
self.assertEqual(self._call_csr('csr-ipsans.pem'), ['192.0.2.145', '203.0.113.1'])
assert self._call_csr('csr-ipsans.pem') == ['192.0.2.145', '203.0.113.1']

def test_csr_ipv6_sans(self):
self.assertEqual(self._call_csr('csr-ipv6sans.pem'),
['0:0:0:0:0:0:0:1', 'A3BE:32F3:206E:C75D:956:CEE:9858:5EC5'])
assert self._call_csr('csr-ipv6sans.pem') == \
['0:0:0:0:0:0:0:1', 'A3BE:32F3:206E:C75D:956:CEE:9858:5EC5']

def test_cert_ipv6_sans(self):
self.assertEqual(self._call_cert('cert-ipv6sans.pem'),
['0:0:0:0:0:0:0:1', 'A3BE:32F3:206E:C75D:956:CEE:9858:5EC5'])
assert self._call_cert('cert-ipv6sans.pem') == \
['0:0:0:0:0:0:0:1', 'A3BE:32F3:206E:C75D:956:CEE:9858:5EC5']


class GenSsCertTest(unittest.TestCase):
Expand All @@ -234,12 +236,12 @@ def test_sn_collisions(self):
cert = gen_ss_cert(self.key, ['dummy'], force_san=True,
ips=[ipaddress.ip_address("10.10.10.10")])
self.serial_num.append(cert.get_serial_number())
self.assertGreaterEqual(len(set(self.serial_num)), self.cert_count)
assert len(set(self.serial_num)) >= self.cert_count


def test_no_name(self):
from acme.crypto_util import gen_ss_cert
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
gen_ss_cert(self.key, ips=[ipaddress.ip_address("1.1.1.1")])
gen_ss_cert(self.key)

Expand All @@ -257,41 +259,39 @@ def _call_with_key(cls, *args, **kwargs):

def test_make_csr(self):
csr_pem = self._call_with_key(["a.example", "b.example"])
self.assertIn(b'--BEGIN CERTIFICATE REQUEST--', csr_pem)
self.assertIn(b'--END CERTIFICATE REQUEST--', csr_pem)
assert b'--BEGIN CERTIFICATE REQUEST--' in csr_pem
assert b'--END CERTIFICATE REQUEST--' in csr_pem
csr = OpenSSL.crypto.load_certificate_request(
OpenSSL.crypto.FILETYPE_PEM, csr_pem)
# In pyopenssl 0.13 (used with TOXENV=py27-oldest), csr objects don't
# have a get_extensions() method, so we skip this test if the method
# isn't available.
if hasattr(csr, 'get_extensions'):
self.assertEqual(len(csr.get_extensions()), 1)
self.assertEqual(csr.get_extensions()[0].get_data(),
assert len(csr.get_extensions()) == 1
assert csr.get_extensions()[0].get_data() == \
OpenSSL.crypto.X509Extension(
b'subjectAltName',
critical=False,
value=b'DNS:a.example, DNS:b.example',
).get_data(),
)
).get_data()

def test_make_csr_ip(self):
csr_pem = self._call_with_key(["a.example"], False, [ipaddress.ip_address('127.0.0.1'), ipaddress.ip_address('::1')])
self.assertIn(b'--BEGIN CERTIFICATE REQUEST--' , csr_pem)
self.assertIn(b'--END CERTIFICATE REQUEST--' , csr_pem)
assert b'--BEGIN CERTIFICATE REQUEST--' in csr_pem
assert b'--END CERTIFICATE REQUEST--' in csr_pem
csr = OpenSSL.crypto.load_certificate_request(
OpenSSL.crypto.FILETYPE_PEM, csr_pem)
# In pyopenssl 0.13 (used with TOXENV=py27-oldest), csr objects don't
# have a get_extensions() method, so we skip this test if the method
# isn't available.
if hasattr(csr, 'get_extensions'):
self.assertEqual(len(csr.get_extensions()), 1)
self.assertEqual(csr.get_extensions()[0].get_data(),
assert len(csr.get_extensions()) == 1
assert csr.get_extensions()[0].get_data() == \
OpenSSL.crypto.X509Extension(
b'subjectAltName',
critical=False,
value=b'DNS:a.example, IP:127.0.0.1, IP:::1',
).get_data(),
)
).get_data()
# for IP san it's actually need to be octet-string,
# but somewhere downstream thankfully handle it for us

Expand All @@ -304,25 +304,26 @@ def test_make_csr_must_staple(self):
# have a get_extensions() method, so we skip this test if the method
# isn't available.
if hasattr(csr, 'get_extensions'):
self.assertEqual(len(csr.get_extensions()), 2)
assert len(csr.get_extensions()) == 2
# NOTE: Ideally we would filter by the TLS Feature OID, but
# OpenSSL.crypto.X509Extension doesn't give us the extension's raw OID,
# and the shortname field is just "UNDEF"
must_staple_exts = [e for e in csr.get_extensions()
if e.get_data() == b"0\x03\x02\x01\x05"]
self.assertEqual(len(must_staple_exts), 1,
"Expected exactly one Must Staple extension")
assert len(must_staple_exts) == 1, \
"Expected exactly one Must Staple extension"

def test_make_csr_without_hostname(self):
self.assertRaises(ValueError, self._call_with_key)
with pytest.raises(ValueError):
self._call_with_key()

def test_make_csr_correct_version(self):
csr_pem = self._call_with_key(["a.example"])
csr = OpenSSL.crypto.load_certificate_request(
OpenSSL.crypto.FILETYPE_PEM, csr_pem)

self.assertEqual(csr.get_version(), 0,
"Expected CSR version to be v1 (encoded as 0), per RFC 2986, section 4")
assert csr.get_version() == 0, \
"Expected CSR version to be v1 (encoded as 0), per RFC 2986, section 4"


class DumpPyopensslChainTest(unittest.TestCase):
Expand All @@ -340,7 +341,7 @@ def test_dump_pyopenssl_chain(self):
length = sum(
len(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert))
for cert in loaded)
self.assertEqual(len(self._call(loaded)), length)
assert len(self._call(loaded)) == length

def test_dump_pyopenssl_chain_wrapped(self):
names = ['cert.pem', 'cert-san.pem', 'cert-idnsans.pem']
Expand All @@ -349,7 +350,7 @@ def test_dump_pyopenssl_chain_wrapped(self):
wrapped = [wrap_func(cert) for cert in loaded]
dump_func = OpenSSL.crypto.dump_certificate
length = sum(len(dump_func(OpenSSL.crypto.FILETYPE_PEM, cert)) for cert in loaded)
self.assertEqual(len(self._call(wrapped)), length)
assert len(self._call(wrapped)) == length


if __name__ == '__main__':
Expand Down
14 changes: 7 additions & 7 deletions acme/tests/errors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def setUp(self):
self.error = BadNonce(nonce="xxx", error="error")

def test_str(self):
self.assertEqual("Invalid nonce ('xxx'): error", str(self.error))
assert "Invalid nonce ('xxx'): error" == str(self.error)


class MissingNonceTest(unittest.TestCase):
Expand All @@ -27,8 +27,8 @@ def setUp(self):
self.error = MissingNonce(self.response)

def test_str(self):
self.assertIn("FOO", str(self.error))
self.assertIn("{}", str(self.error))
assert "FOO" in str(self.error)
assert "{}" in str(self.error)


class PollErrorTest(unittest.TestCase):
Expand All @@ -43,12 +43,12 @@ def setUp(self):
mock.sentinel.AR: mock.sentinel.AR2})

def test_timeout(self):
self.assertTrue(self.timeout.timeout)
self.assertFalse(self.invalid.timeout)
assert self.timeout.timeout
assert not self.invalid.timeout

def test_repr(self):
self.assertEqual('PollError(exhausted=%s, updated={sentinel.AR: '
'sentinel.AR2})' % repr(set()), repr(self.invalid))
assert 'PollError(exhausted=%s, updated={sentinel.AR: ' \
'sentinel.AR2})' % repr(set()) == repr(self.invalid)


if __name__ == "__main__":
Expand Down
23 changes: 11 additions & 12 deletions acme/tests/fields_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@ def setUp(self):
self.field = fixed('name', 'x')

def test_decode(self):
self.assertEqual('x', self.field.decode('x'))
assert 'x' == self.field.decode('x')

def test_decode_bad(self):
self.assertRaises(jose.DeserializationError, self.field.decode, 'y')
with pytest.raises(jose.DeserializationError):
self.field.decode('y')

def test_encode(self):
self.assertEqual('x', self.field.encode('x'))
assert 'x' == self.field.encode('x')

def test_encode_override(self):
self.assertEqual('y', self.field.encode('y'))
assert 'y' == self.field.encode('y')


class RFC3339FieldTest(unittest.TestCase):
Expand All @@ -38,23 +39,21 @@ def setUp(self):

def test_default_encoder(self):
from acme.fields import RFC3339Field
self.assertEqual(
self.encoded, RFC3339Field.default_encoder(self.decoded))
assert self.encoded == RFC3339Field.default_encoder(self.decoded)

def test_default_encoder_naive_fails(self):
from acme.fields import RFC3339Field
self.assertRaises(
ValueError, RFC3339Field.default_encoder, datetime.datetime.now())
with pytest.raises(ValueError):
RFC3339Field.default_encoder(datetime.datetime.now())

def test_default_decoder(self):
from acme.fields import RFC3339Field
self.assertEqual(
self.decoded, RFC3339Field.default_decoder(self.encoded))
assert self.decoded == RFC3339Field.default_decoder(self.encoded)

def test_default_decoder_raises_deserialization_error(self):
from acme.fields import RFC3339Field
self.assertRaises(
jose.DeserializationError, RFC3339Field.default_decoder, '')
with pytest.raises(jose.DeserializationError):
RFC3339Field.default_decoder('')


if __name__ == '__main__':
Expand Down
8 changes: 4 additions & 4 deletions acme/tests/jose_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def _test_it(self, submodule, attribute):
acme_jose_mod = importlib.import_module(acme_jose_path)
josepy_mod = importlib.import_module(josepy_path)

self.assertIs(acme_jose_mod, josepy_mod)
self.assertIs(getattr(acme_jose_mod, attribute), getattr(josepy_mod, attribute))
assert acme_jose_mod is josepy_mod
assert getattr(acme_jose_mod, attribute) is getattr(josepy_mod, attribute)

# We use the imports below with eval, but pylint doesn't
# understand that.
Expand All @@ -29,8 +29,8 @@ def _test_it(self, submodule, attribute):
import acme # pylint: disable=unused-import
acme_jose_mod = eval(acme_jose_path) # pylint: disable=eval-used
josepy_mod = eval(josepy_path) # pylint: disable=eval-used
self.assertIs(acme_jose_mod, josepy_mod)
self.assertIs(getattr(acme_jose_mod, attribute), getattr(josepy_mod, attribute))
assert acme_jose_mod is josepy_mod
assert getattr(acme_jose_mod, attribute) is getattr(josepy_mod, attribute)

def test_top_level(self):
self._test_it('', 'RS512')
Expand Down

0 comments on commit a3c9371

Please sign in to comment.