Skip to content

Commit

Permalink
Merge pull request #15 from ojarva/internal-refactoring
Browse files Browse the repository at this point in the history
Reformatting, commenting and internal refactoring
  • Loading branch information
Olli Jarva committed Jul 31, 2016
2 parents 9c7544a + 14cdbc9 commit 0ad5d7e
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 72 deletions.
8 changes: 6 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,16 @@ Usage:
print(ssh.hash_sha256()) # SHA256:xk3IEJIdIoR9MmSRXTP98rjDdZocmXJje/28ohMQEwM
print(ssh.hash_sha512()) # SHA512:1C3lNBhjpDVQe39hnyy+xvlZYU3IPwzqK1rVneGavy6O3/ebjEQSFvmeWoyMTplIanmUK1hmr9nA8Skmj516HA
print(ssh.comment) # ojar@ojar-laptop
print(ssh.options) # None (optional options at the beginning of public key. You may want to check for these if you're validating user-submitted keys.)
print(ssh.options_raw) # None (string of optional options at the beginning of public key)
print(ssh.options) # None (options as a dictionary, parsed and validated)

Options
-------

- strict_mode: if set to True, disallows keys OpenSSH's ssh-keygen refuses to create. For instance, this includes DSA keys where length != 1024 bits and RSA keys shorter than 1024-bit. If set to False, tries to allow all keys OpenSSH accepts, including highly insecure 1-bit DSA keys.
Set options in constructor as a keywords (i.e., `SSHKey(None, strict_mode=False)`)

- strict_mode: defaults to True. Disallows keys OpenSSH's ssh-keygen refuses to create. For instance, this includes DSA keys where length != 1024 bits and RSA keys shorter than 1024-bit. If set to False, tries to allow all keys OpenSSH accepts, including highly insecure 1-bit DSA keys.
- skip_option_parsing: if set to True, options string is not parsed (ssh.options_raw is populated, but ssh.options is not).

Exceptions
----------
Expand Down
127 changes: 73 additions & 54 deletions sshpubkeys/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from Crypto.PublicKey import RSA, DSA

from .exceptions import *
from .exceptions import * # pylint:disable=wildcard-import

__all__ = ["SSHKey"]

Expand Down Expand Up @@ -89,10 +89,11 @@ class SSHKey(object): # pylint:disable=too-many-instance-attributes

INT_LEN = 4

FIELDS = ["rsa", "dsa", "ecdsa", "bits", "comment", "options", "options_raw", "key_type"]

def __init__(self, keydata=None, **kwargs):
self.keydata = keydata
self.current_position = 0
self.decoded_key = None
self._decoded_key = None
self.rsa = None
self.dsa = None
self.ecdsa = None
Expand All @@ -109,12 +110,17 @@ def __init__(self, keydata=None, **kwargs):
except (InvalidKeyException, NotImplementedError):
pass

def reset(self):
""" Reset all data fields """
for field in self.FIELDS:
setattr(self, field, None)

def hash(self):
""" Calculate md5 fingerprint.
Deprecated, use .hash_md5() instead.
"""
warnings.warn("hash() is deprecated. Use hash_md5() or hash_sha256() instead.")
warnings.warn("hash() is deprecated. Use hash_md5(), hash_sha256() or hash_sha512() instead.")
return self.hash_md5().replace(b"MD5:", b"")

def hash_md5(self):
Expand All @@ -124,38 +130,38 @@ def hash_md5(self):
For specification, see RFC4716, section 4.
"""
fp_plain = hashlib.md5(self.decoded_key).hexdigest()
fp_plain = hashlib.md5(self._decoded_key).hexdigest()
return "MD5:" + ':'.join(a + b for a, b in zip(fp_plain[::2], fp_plain[1::2]))

def hash_sha256(self):
""" Calculate sha256 fingerprint. """
fp_plain = hashlib.sha256(self.decoded_key).digest()
fp_plain = hashlib.sha256(self._decoded_key).digest()
return (b"SHA256:" + base64.b64encode(fp_plain).replace(b"=", b"")).decode("utf-8")

def hash_sha512(self):
""" Calculates sha512 fingerprint. """
fp_plain = hashlib.sha512(self.decoded_key).digest()
fp_plain = hashlib.sha512(self._decoded_key).digest()
return (b"SHA512:" + base64.b64encode(fp_plain).replace(b"=", b"")).decode("utf-8")

def _unpack_by_int(self):
""" Returns next data field. """
def _unpack_by_int(self, data, current_position):
""" Returns a tuple with (location of next data field, contents of requested data field). """
# Unpack length of data field
try:
requested_data_length = struct.unpack('>I', self.decoded_key[self.current_position:self.current_position + self.INT_LEN])[0]
requested_data_length = struct.unpack('>I', data[current_position:current_position + self.INT_LEN])[0]
except struct.error:
raise MalformedDataException("Unable to unpack %s bytes from the data" % self.INT_LEN)

# Move pointer to the beginning of the data field
self.current_position += self.INT_LEN
remaining_data_length = len(self.decoded_key[self.current_position:])
current_position += self.INT_LEN
remaining_data_length = len(data[current_position:])

if remaining_data_length < requested_data_length:
raise MalformedDataException("Requested %s bytes, but only %s bytes available." % (requested_data_length, remaining_data_length))

next_data = self.decoded_key[self.current_position:self.current_position + requested_data_length]
next_data = data[current_position:current_position + requested_data_length]
# Move pointer to the end of the data field
self.current_position += requested_data_length
return next_data
current_position += requested_data_length
return current_position, next_data

@classmethod
def _parse_long(cls, data):
Expand All @@ -175,12 +181,12 @@ def _split_key(self, data):
# Terribly inefficient way to remove options, but hey, it works.
if not data.startswith("ssh-") and not data.startswith("ecdsa-"):
quote_open = False
for i in range(len(data)):
if data[i] == '"': # only double quotes are allowed, no need to care about single quotes
for i, character in enumerate(data):
if character == '"': # only double quotes are allowed, no need to care about single quotes
quote_open = not quote_open
if quote_open:
continue
if data[i] == " ":
if character == " ":
# Data begins after the first space
options_raw = data[:i]
data = data[i + 1:]
Expand All @@ -193,19 +199,20 @@ def _split_key(self, data):
if len(key_parts) == 3:
self.comment = key_parts[2]
key_parts = key_parts[0:2]
if options_raw and not self.skip_option_parsing:
if options_raw:
# Populate and parse options field.
self.options_raw = options_raw
self.options = self.parse_options(self.options_raw)
if not self.skip_option_parsing:
self.options = self.parse_options(self.options_raw)
else:
# Set empty defaults for fields
self.options_raw = None
self.options = {}
return key_parts

@classmethod
def _decode_key(cls, pubkey_content):
# Decode base64 coded part.
def decode_key(cls, pubkey_content):
""" Decode base64 coded part of the key. """
try:
decoded_key = base64.b64decode(pubkey_content.encode("ascii"))
except (TypeError, binascii.Error):
Expand All @@ -217,39 +224,40 @@ def _bits_in_number(cls, number):
return len(format(number, "b"))

def parse_options(self, options):

""" Parses ssh options string """
quote_open = False
parsed_options = {}

def parse_add_single_option(opt):
""" Parses and validates a single option, and adds it to parsed_options field. """
if "=" in opt:
opt_name, opt_value = opt.split("=", 1)
opt_value = opt_value.replace('"', '')
else:
opt_name = opt
opt_value = True
if " " in opt_name or not self.OPTION_NAME_RE.match(opt_name):
raise InvalidOptionNameException
raise InvalidOptionNameException("%s is not valid option name." % opt_name)
if self.strict_mode:
for valid_opt_name, value_required in self.OPTIONS_SPEC:
if opt_name.lower() == valid_opt_name:
if value_required and opt_value is True:
raise MissingMandatoryOptionValueException
raise MissingMandatoryOptionValueException("%s is missing mandatory value." % opt_name)
break
else:
raise UnknownOptionNameException
raise UnknownOptionNameException("%s is unrecognized option name." % opt_name)
if opt_name not in parsed_options:
parsed_options[opt_name] = []
parsed_options[opt_name].append(opt_value)

start_of_current_opt = 0
i = 1 # Need to be set for empty options strings
for i in range(len(options)):
if options[i] == '"': # only double quotes are allowed, no need to care about single quotes
for i, character in enumerate(options):
if character == '"': # only double quotes are allowed, no need to care about single quotes
quote_open = not quote_open
if quote_open:
continue
if options[i] == ",":
if character == ",":
opt = options[start_of_current_opt:i]
parse_add_single_option(opt)
start_of_current_opt = i + 1
Expand All @@ -258,13 +266,13 @@ def parse_add_single_option(opt):
opt = options[start_of_current_opt:]
parse_add_single_option(opt)
if quote_open:
raise InvalidOptionsException
raise InvalidOptionsException("Unbalanced quotes.")
return parsed_options

def _process_ssh_rsa(self):
def _process_ssh_rsa(self, data):
""" Parses ssh-rsa public keys """
raw_e = self._unpack_by_int()
raw_n = self._unpack_by_int()
current_position, raw_e = self._unpack_by_int(data, 0)
current_position, raw_n = self._unpack_by_int(data, current_position)

unpacked_e = self._parse_long(raw_e)
unpacked_n = self._parse_long(raw_n)
Expand All @@ -282,12 +290,15 @@ def _process_ssh_rsa(self):
raise TooShortKeyException("%s key data can not be shorter than %s bits (was %s)" % (self.key_type, min_length, self.bits))
if self.bits > max_length:
raise TooLongKeyException("%s key data can not be longer than %s bits (was %s)" % (self.key_type, max_length, self.bits))
return current_position

def _process_ssh_dss(self):
def _process_ssh_dss(self, data):
""" Parses ssh-dsa public keys """
data_fields = {}
current_position = 0
for item in ("p", "q", "g", "y"):
data_fields[item] = self._parse_long(self._unpack_by_int())
current_position, value = self._unpack_by_int(data, current_position)
data_fields[item] = self._parse_long(value)

self.dsa = DSA.construct((data_fields["y"], data_fields["g"], data_fields["p"], data_fields["q"]))
self.bits = self.dsa.size() + 1
Expand All @@ -305,32 +316,34 @@ def _process_ssh_dss(self):
raise TooShortKeyException("%s key can not be shorter than %s bits (was %s)" % (self.key_type, min_length, self.bits))
if self.bits > max_length:
raise TooLongKeyException("%s key data can not be longer than %s bits (was %s)" % (self.key_type, max_length, self.bits))
return current_position

def _process_ecdsa_sha(self):
def _process_ecdsa_sha(self, data):
""" Parses ecdsa-sha public keys """
curve_information = self._unpack_by_int()
current_position, curve_information = self._unpack_by_int(data, 0)
if curve_information not in self.ECDSA_CURVE_DATA:
raise NotImplementedError("Invalid curve type: %s" % curve_information)
curve, hash_algorithm = self.ECDSA_CURVE_DATA[curve_information]

data = self._unpack_by_int()
current_position, key_data = self._unpack_by_int(data, current_position)
try:
# data starts with \x04, which should be discarded.
ecdsa_key = ecdsa.VerifyingKey.from_string(data[1:], curve, hash_algorithm)
ecdsa_key = ecdsa.VerifyingKey.from_string(key_data[1:], curve, hash_algorithm)
except AssertionError:
raise InvalidKeyException("Invalid ecdsa key")
self.bits = int(curve_information.replace(b"nistp", b"")) # TODO: this is rather ugly way to extract bit length
self.bits = int(curve_information.replace(b"nistp", b""))
self.ecdsa = ecdsa_key
return current_position

def _process_ed25516(self):
def _process_ed25516(self, data):
""" Parses ed25516 keys.
There is no (apparent) way to validate ed25519 keys. This only
checks data length (256 bits), but does not try to validate
the key in any way.
"""

verifying_key = self._unpack_by_int()
current_position, verifying_key = self._unpack_by_int(data, 0)
verifying_key_length = len(verifying_key) * 8
verifying_key = self._parse_long(verifying_key)

Expand All @@ -340,16 +353,17 @@ def _process_ed25516(self):
self.bits = verifying_key_length
if self.bits != 256:
raise InvalidKeyLengthException("ed25519 keys must be 256 bits (was %s bits)" % self.bits)
return current_position

def _process_key(self):
def _process_key(self, data):
if self.key_type == b"ssh-rsa":
self._process_ssh_rsa()
return self._process_ssh_rsa(data)
elif self.key_type == b"ssh-dss":
self._process_ssh_dss()
return self._process_ssh_dss(data)
elif self.key_type.strip().startswith(b"ecdsa-sha"):
self._process_ecdsa_sha()
return self._process_ecdsa_sha(data)
elif self.key_type == b"ssh-ed25519":
self._process_ed25516()
return self._process_ed25516(data)
else:
raise NotImplementedError("Invalid key type: %s" % self.key_type)

Expand All @@ -363,13 +377,17 @@ def parse(self, keydata=None):
For rsa keys, see field "rsa" for raw public key data.
For dsa keys, see field "dsa".
For ecdsa keys, see field "ecdsa". """
self.current_position = 0
if keydata is None:
if self.keydata is None:
raise ValueError("Key data must be supplied either in constructor or to parse()")
keydata = self.keydata
else:
self.reset()
self.keydata = keydata

if keydata.startswith("---- BEGIN SSH2 PUBLIC KEY ----"):
# SSH2 key format
key_type = None
key_type = None # There is no redundant key-type field - skip comparing plain-text and encoded data.
pubkey_content = ""
for line in keydata.split("\n"):
if ":" in line: # key-value lines
Expand All @@ -382,16 +400,17 @@ def parse(self, keydata=None):
key_type = key_parts[0]
pubkey_content = key_parts[1]

self.decoded_key = self._decode_key(pubkey_content)
self._decoded_key = self.decode_key(pubkey_content)

# Check key type
unpacked_key_type = self._unpack_by_int()
current_position, unpacked_key_type = self._unpack_by_int(self._decoded_key, 0)
if key_type is not None and key_type != unpacked_key_type.decode():
raise InvalidTypeException("Keytype mismatch: %s != %s" % (key_type, unpacked_key_type))

self.key_type = unpacked_key_type

self._process_key()
key_data_length = self._process_key(self._decoded_key[current_position:])
current_position = current_position + key_data_length

if self.current_position != len(self.decoded_key):
raise MalformedDataException("Leftover data: %s bytes" % (len(self.decoded_key) - self.current_position))
if current_position != len(self._decoded_key):
raise MalformedDataException("Leftover data: %s bytes" % (len(self._decoded_key) - current_position))

0 comments on commit 0ad5d7e

Please sign in to comment.