Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Add "from_pem_data" and "to_pem_data" methods to JWT Key objects.

  • Loading branch information...
commit 963391f1060e44cbc1e677d091c9d63df0b28960 1 parent 4279246
@rfk rfk authored
View
19 browserid/crypto/fallback.py
@@ -20,10 +20,25 @@
class Key(object):
"""Generic base class for Key objects."""
+ @classmethod
+ def from_pem_data(cls, data=None, filename=None):
+ """Alternative constructor for loading from PEM format data."""
+ msg = "PEM data loading is not implemented for pure-python crypto."
+ msg += " Please install M2Crypto to access this functionality."
+ raise NotImplementedError(msg)
+
+ def to_pem_data(self):
+ """Save the public key data to a PEM format string."""
+ msg = "PEM data saving is not implemented for pure-python crypto."
+ msg += " Please install M2Crypto to access this functionality."
+ raise NotImplementedError(msg)
+
def verify(self, signed_data, signature):
+ """Verify the given signature."""
raise NotImplementedError
def sign(self, data):
+ """Sign the given data."""
raise NotImplementedError
@@ -35,7 +50,7 @@ def sign(self, data):
}
-class RSKey(object):
+class RSKey(Key):
"""Generic base class for RSA key objects.
Concrete subclasses should provide the SIZE, HASHNAME and HASHMOD
@@ -78,7 +93,7 @@ def _get_digest(self, data):
return padded_digest
-class DSKey(object):
+class DSKey(Key):
"""Generic base class for DSA key objects.
Concrete subclasses should provide the BITLENGTH and HASHMOD attributes.
View
118 browserid/crypto/m2.py
@@ -14,8 +14,11 @@
"""
import struct
-from binascii import unhexlify
+from binascii import hexlify, unhexlify
+from M2Crypto import BIO
+
+from browserid.crypto._m2_monkeypatch import m2
from browserid.crypto._m2_monkeypatch import DSA as _DSA
from browserid.crypto._m2_monkeypatch import RSA as _RSA
@@ -23,10 +26,37 @@
class Key(object):
"""Generic base class for Key objects."""
+ KEY_MODULE = None
+
+ @classmethod
+ def from_pem_data(cls, data=None, filename=None):
+ """Alternative constructor for loading from PEM format data."""
+ self = cls.__new__(cls)
+ if data is not None:
+ bio = BIO.MemoryBuffer(str(data))
+ elif filename is not None:
+ bio = BIO.openfile(filename)
+ else:
+ msg = "Please specify either 'data' or 'filename' argument."
+ raise ValueError(msg)
+ self.keyobj = self.KEY_MODULE.load_pub_key_bio(bio)
+ return self
+
+ def to_pem_data(self):
+ """Save the public key data to a PEM format string."""
+ b = BIO.MemoryBuffer()
+ try:
+ self.keyobj.save_pub_key_bio(b)
+ return b.getvalue()
+ finally:
+ b.close()
+
def verify(self, signed_data, signature):
+ """Verify the given signature."""
raise NotImplementedError # pragma: nocover
def sign(self, data):
+ """Sign the given data."""
raise NotImplementedError # pragma: nocover
@@ -34,37 +64,34 @@ def sign(self, data):
# RSA keys, implemented using the RSA support in M2Crypto.
#
-class RSKey(object):
+class RSKey(Key):
+ KEY_MODULE = _RSA
SIZE = None
+ HASHNAME = None
HASHMOD = None
- def __init__(self, data=None, obj=None):
- if data is None and obj is None:
- raise ValueError('You should specify either data or obj')
- if obj is not None:
- self.rsa = obj
+ def __init__(self, data):
+ _check_keys(data, ('e', 'n'))
+ e = int2mpint(int(data["e"]))
+ n = int2mpint(int(data["n"]))
+ try:
+ d = int2mpint(int(data["d"]))
+ except KeyError:
+ self.keyobj = _RSA.new_pub_key((e, n))
else:
- _check_keys(data, ('e', 'n'))
- e = int2mpint(int(data["e"]))
- n = int2mpint(int(data["n"]))
- try:
- d = int2mpint(int(data["d"]))
- except KeyError:
- self.rsa = _RSA.new_pub_key((e, n))
- else:
- self.rsa = _RSA.new_key((e, n, d))
+ self.keyobj = _RSA.new_key((e, n, d))
def verify(self, signed_data, signature):
digest = self.HASHMOD(signed_data).digest()
try:
- return self.rsa.verify(digest, signature, self.HASHNAME)
+ return self.keyobj.verify(digest, signature, self.HASHNAME)
except _RSA.RSAError:
return False
def sign(self, data):
digest = self.HASHMOD(data).digest()
- return self.rsa.sign(digest, self.HASHNAME)
+ return self.keyobj.sign(digest, self.HASHNAME)
#
@@ -72,33 +99,38 @@ def sign(self, data):
# some formatting tweaks to match what the browserid node-js server does.
#
-class DSKey(object):
+class DSKey(Key):
+ KEY_MODULE = _DSA
BITLENGTH = None
HASHMOD = None
- def __init__(self, data=None, obj=None):
- if data is None and obj is None:
- raise ValueError('You should specify either data or obj')
- if obj:
- self.dsa = obj
+ def __init__(self, data):
+ _check_keys(data, ('p', 'q', 'g', 'y'))
+ self.p = p = long(data["p"], 16)
+ self.q = q = long(data["q"], 16)
+ self.g = g = long(data["g"], 16)
+ self.y = y = long(data["y"], 16)
+ if "x" not in data:
+ self.x = None
+ self.keyobj = _DSA.load_pub_key_params(int2mpint(p), int2mpint(q),
+ int2mpint(g), int2mpint(y))
else:
- _check_keys(data, ('p', 'q', 'g', 'y'))
-
- self.p = p = long(data["p"], 16)
- self.q = q = long(data["q"], 16)
- self.g = g = long(data["g"], 16)
- self.y = y = long(data["y"], 16)
- if "x" not in data:
- self.x = None
- self.dsa = _DSA.load_pub_key_params(int2mpint(p), int2mpint(q),
- int2mpint(g), int2mpint(y))
- else:
- self.x = x = long(data["x"], 16)
- self.dsa = _DSA.load_key_params(int2mpint(p), int2mpint(q),
- int2mpint(g), int2mpint(y),
+ self.x = x = long(data["x"], 16)
+ self.keyobj = _DSA.load_key_params(int2mpint(p), int2mpint(q),
+ int2mpint(g), int2mpint(y),
int2mpint(x))
+ @classmethod
+ def from_pem_data(cls, data=None, filename=None):
+ self = super(DSKey, cls).from_pem_data(data, filename)
+ self.p = mpint2int(m2.dsa_get_p(self.keyobj.dsa))
+ self.q = mpint2int(m2.dsa_get_q(self.keyobj.dsa))
+ self.g = mpint2int(m2.dsa_get_g(self.keyobj.dsa))
+ self.y = None
+ self.x = None
+ return self
+
def verify(self, signed_data, signature):
# Restore any leading zero bytes that might have been stripped.
signature = signature.encode("hex")
@@ -115,13 +147,13 @@ def verify(self, signed_data, signature):
return False
# Now we can check the digest.
digest = self.HASHMOD(signed_data).digest()
- return self.dsa.verify(digest, int2mpint(r), int2mpint(s))
+ return self.keyobj.verify(digest, int2mpint(r), int2mpint(s))
def sign(self, data):
if not self.x:
raise ValueError("private key not present")
digest = self.HASHMOD(data).digest()
- r, s = self.dsa.sign(digest)
+ r, s = self.keyobj.sign(digest)
# We need precisely "bytelength" bytes from each integer.
# M2Crypto might give us more or less, so snip and pad appropriately.
bytelength = self.BITLENGTH / 8
@@ -151,6 +183,12 @@ def int2mpint(x):
return struct.pack(">I", len(bytes) + 1) + "\x00" + bytes
+def mpint2int(data):
+ """Convert a string in OpenSSL's MPINT format to a Python long integer."""
+ hexbytes = hexlify(data[4:])
+ return int(hexbytes, 16)
+
+
def _check_keys(data, keys):
"""Verify that the given data dict contains the specified keys."""
for key in keys:
View
132 browserid/tests/test_jwt.py
@@ -2,11 +2,57 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this file,
# You can obtain one at http://mozilla.org/MPL/2.0/.
+import tempfile
+
from browserid.tests.support import get_keypair, unittest
from browserid.utils import encode_json_bytes, encode_bytes
from browserid import jwt
+def _long(value):
+ return long(value.replace(" ", "").replace("\n", "").strip())
+
+
+def _hex(value):
+ return hex(long(value.replace(" ", "").replace("\n", "").strip()))
+
+
+# This is a dummy RSA key I generated via PyCrypto.
+# M2Crypto doesn't seem to let me get at the values of e, n and d.
+RSA_KEY_DATA = {
+ "e": 65537L,
+ "n": _long("""110897663942528265066856163966583557538666146275146
+ 569193074111045116764854772535689458732714049671807506
+ 396649306730328647317126800964431366624486416551078177
+ 528195103050868728550429561392842977259407335332582178
+ 624191611001106449477645116630750398871838788574825885
+ 770446686329706009000279629721965986677219L"""),
+ "d": _long("""295278123166626215026113502482091502365034141401240
+ 159363282304307076544046230487782634982660202141239450
+ 481640966544735782181647417005558287318200095948234745
+ 214183393770321992676297531378428617531522265932631860
+ 693144704788708252936752025413728425562033678747736289
+ 64114133156747686886305629893015763517873L"""),
+}
+
+
+# This is a dummy DSA key I generated via PyCrypto.
+# M2Crypto doesn't seem to let me get at the values of x and y.
+DSA_KEY_DATA = {
+ "p": _hex("""6703904104057623261995085583676902361410672713749348
+ 7374515589871295072792250899011720632358392764362903244
+ 12395020783955234715731001076129344181463063193L"""),
+ "q": hex(1006478751418673383937866166434285354892250535133L),
+ "g": _hex("""1801778249650423365253284139284406405780267098493217
+ 0320675876307450879812560049234773036938891018778074993
+ 01874343843218156663689824126183823813389886834L"""),
+ "y": _hex("""4148629652526876030475847300836791685289385792662680
+ 5886292874741635965095055693693232436255359496594291250
+ 77637642734034732001089176915352691113947372211L"""),
+ "x": hex(487025797851506801093339352420308364866214860934L),
+}
+
+
class TestJWT(unittest.TestCase):
def test_error_jwt_with_no_algorithm(self):
@@ -30,26 +76,7 @@ def test_loading_unknown_algorithms(self):
self.assertRaises(ValueError, jwt.load_key, "DS64", {})
def test_rsa_verification(self):
- # This is a dummy RSA key I generated via PyCrypto.
- # M2Crypto doesn't seem to let me get at the values of e, n and d.
- # I've line wrapped it for readability.
- def _long(value):
- return long(value.replace(" ", "").replace("\n", "").strip())
- data = {
- "e": 65537L,
- "n": _long("""110897663942528265066856163966583557538666146275146
- 569193074111045116764854772535689458732714049671807506
- 396649306730328647317126800964431366624486416551078177
- 528195103050868728550429561392842977259407335332582178
- 624191611001106449477645116630750398871838788574825885
- 770446686329706009000279629721965986677219L"""),
- "d": _long("""295278123166626215026113502482091502365034141401240
- 159363282304307076544046230487782634982660202141239450
- 481640966544735782181647417005558287318200095948234745
- 214183393770321992676297531378428617531522265932631860
- 693144704788708252936752025413728425562033678747736289
- 64114133156747686886305629893015763517873L"""),
- }
+ data = RSA_KEY_DATA.copy()
key = jwt.RS64Key(data)
data.pop("d")
pubkey = jwt.RS64Key(data)
@@ -58,24 +85,7 @@ def _long(value):
self.assertFalse(pubkey.verify("HELLO", key.sign("hello")))
def test_dsa_verification(self):
- # This is a dummy DSA key I generated via PyCrypto.
- # M2Crypto doesn't seem to let me get at the values of x and y.
- # I've line wrapped it for readability.
- def _hex(value):
- return hex(long(value.replace(" ", "").replace("\n", "").strip()))
- data = {
- "p": _hex("""6703904104057623261995085583676902361410672713749348
- 7374515589871295072792250899011720632358392764362903244
- 12395020783955234715731001076129344181463063193L"""),
- "q": hex(1006478751418673383937866166434285354892250535133L),
- "g": _hex("""1801778249650423365253284139284406405780267098493217
- 0320675876307450879812560049234773036938891018778074993
- 01874343843218156663689824126183823813389886834L"""),
- "y": _hex("""4148629652526876030475847300836791685289385792662680
- 5886292874741635965095055693693232436255359496594291250
- 77637642734034732001089176915352691113947372211L"""),
- "x": hex(487025797851506801093339352420308364866214860934L),
- }
+ data = DSA_KEY_DATA.copy()
key = jwt.DS128Key(data)
data.pop("x")
pubkey = jwt.DS128Key(data)
@@ -90,3 +100,49 @@ def _hex(value):
self.assertFalse(pubkey.verify("HELLO", ("\xFF" * 20) + "\x01" * 20))
# - "s" value too large
self.assertFalse(pubkey.verify("HELLO", "\x01" + ("\xFF" * 20)))
+
+ def test_loading_rsa_from_pem_data(self):
+ data = RSA_KEY_DATA.copy()
+ key = jwt.RS64Key(data)
+ try:
+ data = key.to_pem_data()
+ except NotImplementedError:
+ raise unittest.SkipTest
+ pubkey = jwt.RS64Key.from_pem_data(data)
+ self.assertTrue(pubkey.verify("hello", key.sign("hello")))
+
+ def test_loading_rsa_from_pem_data_filename(self):
+ data = RSA_KEY_DATA.copy()
+ key = jwt.RS64Key(data)
+ try:
+ data = key.to_pem_data()
+ except NotImplementedError:
+ raise unittest.SkipTest
+ with tempfile.NamedTemporaryFile() as f:
+ f.write(data)
+ f.flush()
+ pubkey = jwt.RS64Key.from_pem_data(filename=f.name)
+ self.assertTrue(pubkey.verify("hello", key.sign("hello")))
+
+ def test_loading_dsa_from_pem_data(self):
+ data = DSA_KEY_DATA.copy()
+ key = jwt.DS128Key(data)
+ try:
+ data = key.to_pem_data()
+ except NotImplementedError:
+ raise unittest.SkipTest
+ pubkey = jwt.DS128Key.from_pem_data(data)
+ self.assertTrue(pubkey.verify("hello", key.sign("hello")))
+
+ def test_loading_dsa_from_pem_data_filename(self):
+ data = DSA_KEY_DATA.copy()
+ key = jwt.DS128Key(data)
+ try:
+ data = key.to_pem_data()
+ except NotImplementedError:
+ raise unittest.SkipTest
+ with tempfile.NamedTemporaryFile() as f:
+ f.write(data)
+ f.flush()
+ pubkey = jwt.DS128Key.from_pem_data(filename=f.name)
+ self.assertTrue(pubkey.verify("hello", key.sign("hello")))
View
2  setup.py
@@ -10,7 +10,7 @@
with open(os.path.join(here, 'CHANGES.txt')) as f:
CHANGES = f.read()
-requires = ['requests']
+requires = ['requests', 'mock']
tests_require = requires + ['mock']
Please sign in to comment.
Something went wrong with that request. Please try again.