Skip to content

Commit

Permalink
New code structure - draft
Browse files Browse the repository at this point in the history
  • Loading branch information
FiloSottile committed Oct 30, 2013
1 parent 4a289ba commit 1d63105
Showing 1 changed file with 227 additions and 79 deletions.
306 changes: 227 additions & 79 deletions triplesec/triplesec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
import Crypto
import scrypt
import struct
import hmac
import hashlib
import six
from six.moves import zip
from collections import namedtuple

from Crypto import Random
rndfile = Random.new()
Expand All @@ -33,21 +36,103 @@ class TripleSecFailedAssertion(TripleSecError):
def _constant_time_compare(a, b):
if len(a) != len(b): return False
result = 0
for x, y in zip(a, b):
result |= six.byte2int(x) ^ six.byte2int(y)
for x, y in zip(six.iterbytes(a), six.iterbytes(b)):
result |= x ^ y
return (result == 0)


### DATA STRUCTURES
Cipher = namedtuple('Cipher', ['name', 'implementation', 'overhead_size', 'key_size'])
MAC = namedtuple('MAC', ['name', 'implementation', 'key_size', 'output_size'])
KDF = namedtuple('KDF', ['name', 'implementation', 'parameters'])
Scrypt_params = namedtuple('Scrypt_params', ['N', 'r', 'p'])
Constants = namedtuple('Constants', ['header', 'salt_size', 'MACs', 'ciphers', 'KDF'])


### CIPHERS AND HMAC IMPLEMENTATIONS
class AES:
key_size = 32
block_size = 16

@classmethod
def encrypt(cls, data, key):
# TODO stubbed
if len(key) != cls.key_size:
raise TripleSecFailedAssertion(u"Wrong AES key size")
iv = rndfile.read(cls.block_size)
return iv + data

@classmethod
def decrypt(cls, data, key):
# TODO stubbed
if len(key) != cls.key_size:
raise TripleSecFailedAssertion(u"Wrong AES key size")
iv = data[:cls.block_size]
return data[cls.block_size:]

class Twofish:
key_size = 32
block_size = 16

@classmethod
def encrypt(cls, data, key):
# TODO stubbed
if len(key) != cls.key_size:
raise TripleSecFailedAssertion(u"Wrong Twofish key size")
iv = rndfile.read(cls.block_size)
return iv + data

@classmethod
def decrypt(cls, data, key):
# TODO stubbed
if len(key) != cls.key_size:
raise TripleSecFailedAssertion(u"Wrong Twofish key size")
iv = data[:cls.block_size]
return data[cls.block_size:]

class XSalsa20:
key_size = 32
iv_size = 24

@classmethod
def encrypt(cls, data, key):
if len(key) != cls.key_size:
raise TripleSecFailedAssertion(u"Wrong XSalsa20 key size")
iv = rndfile.read(cls.iv_size)
return iv + data

@classmethod
def decrypt(cls, data, key):
# TODO stubbed
if len(key) != cls.key_size:
raise TripleSecFailedAssertion(u"Wrong XSalsa20 key size")
iv = data[:cls.iv_size]
return data[cls.iv_size:]

def HMAC_SHA512(data, key):
return hmac.new(key, data, hashlib.sha512).digest()

def HMAC_SHA3(data, key):
# TODO stubbed
return b'\x00' * 64

def Scrypt(key, salt, length, parameters):
try:
return scrypt.hash(key, salt, parameters.N, parameters.r, parameters.p, length)
except scrypt.error:
raise TripleSecError(u"scrypt error")


### MAIN CLASS
class TripleSec():
LATEST_VERSION = 3
MAGIC_BYTES = binascii.unhexlify(b'1c94d7de')

_versions_implementations = {}
_versions = {}

@staticmethod
def _check_key_type(key):
if key is not None and not isinstance(key, six.binary_types):
if key is not None and not isinstance(key, six.binary_type):
raise TripleSecError(u"The key needs needs to be a binary string (str() in Python 2 and bytes() in Python 3)")

@staticmethod
Expand All @@ -64,21 +149,76 @@ def __init__(self, key=None):
self._check_key_type(key)
self.key = key

def encrypt(self, data, key=None):
def _key_stretching(self, key, salt, version, extra_bytes=0):
total_keys_size = sum(x.key_size for x in version.MACs + version.ciphers) + extra_bytes
key_material = version.KDF.implementation(key, salt, total_keys_size, version.KDF.parameters)

i = 0
mac_keys = []
for m in version.MACs:
mac_keys.append(key_material[i:i + m.key_size])
i += m.key_size
cipher_keys = []
for c in version.ciphers:
cipher_keys.append(key_material[i:i + c.key_size])
i += c.key_size
extra = key_material[i:]

return mac_keys, cipher_keys, extra

def _calc_overhead(self, version):
tot = 0
tot += sum(map(len, version.header))
tot += version.salt_size
tot += sum(m.output_size for m in version.MACs)
tot += sum(c.overhead_size for c in version.ciphers)
return tot

def encrypt(self, data, key=None, v=None, extra_bytes=0):
self._check_data_type(data)
self._check_key_type(key)
if key is None and self.key is None:
raise TripleSecError(u"You didn't initialize TripleSec with a key, so you need to specify one")

implementation = self._versions_implementations[self.LATEST_VERSION]()
result = implementation._encrypt(data, key)
if not v: v = self.LATEST_VERSION
version = self._versions[v]
result, extra = self._encrypt(data, key, version, extra_bytes)

self._check_output_type(result)
self._check_output_type(extra)
return result, extra

def _encrypt(self, data, key, version, extra_bytes):
salt = rndfile.read(version.salt_size)
mac_keys, cipher_keys, extra = self._key_stretching(key, salt, version, extra_bytes)

encrypted_material = self._encrypt_data(data, cipher_keys, version)

header = b''.join(version.header)

authenticated_data = header + salt + encrypted_material
macs = self._generate_macs(authenticated_data, mac_keys, version)

result = header + salt + b''.join(macs) + encrypted_material

if len(result) != self._calc_overhead(version) + len(data):
raise TripleSecFailedAssertion(u"Wrong encrypt output length")
return result, extra

def _generate_macs(self, authenticated_data, mac_keys, version):
result = []
for n, m in enumerate(version.MACs):
mac = m.implementation(authenticated_data, mac_keys[n])
result.append(mac)
return result

def _encrypt(self, data, key):
"""This should be defined in versions implementation subclasses"""
pass
def _encrypt_data(self, data, cipher_keys, version):
ciphers_num = len(version.ciphers)
for n, c in enumerate(version.ciphers):
# the keys order is from the outermost to the innermost
key = cipher_keys[ciphers_num - 1 - n]
data = c.implementation.encrypt(data, key)
return data

def decrypt(self, data, key=None):
self._check_data_type(data)
Expand All @@ -89,98 +229,106 @@ def decrypt(self, data, key=None):
if len(data) < 8 or data[:4] != self.MAGIC_BYTES:
raise TripleSecError(u"This does not look like a TripleSec ciphertext")

version = struct.unpack("<I", data[4:8])[0]
if version not in self._versions_implementations:
header_version = struct.unpack("<I", data[4:8])[0]
if header_version not in self._versions:
raise TripleSecError(u"Unimplemented version")

implementation = self._versions_implementations[version]()
result = implementation._decrypt(data, key)
version = self._versions[header_version]
result = self._decrypt(data, key, version)

self._check_output_type(result)
return result

def _decrypt(self, data, key):
"""This should be defined in versions implementation subclasses"""
pass


### VERSIONS IMPLEMENTATIONS
class TripleSec_v3():
VERSION = 3

@staticmethod
def _key_stretching(key, salt):
try:
return scrypt.hash(key, salt, N=1 << 13, r=8, p=1)
except scrypt.error:
raise TripleSecError(u"scrypt error")
def _decrypt(self, data, key, version):
if len(data) < self._calc_overhead(version):
raise TripleSecDecryptionError(u"Input does not look like a TripleSec ciphertext")

def _salsa20_encrypt(data, key):
pass
def _salsa20_decrypt(data, key):
pass
header, salt, macs, encrypted_material = \
self._split_ciphertext(data, version)

def _twofish_encrypt(data, key):
pass
def _twofish_decrypt(data, key):
pass
mac_keys, cipher_keys, _ = self._key_stretching(key, salt, version)

def _aes_encrypt(data, key):
pass
def _aes_decrypt(data, key):
pass
authenticated_data = header + salt + encrypted_material
if not self._check_macs(authenticated_data, macs, mac_keys, version):
raise TripleSecDecryptionError(u"Failed authentication of the data")

def _hmac_sha256(data, key):
pass
result = self._decrypt_data(encrypted_material, cipher_keys, version)

def _hmac_sha3(data, key):
pass
if len(result) != len(data) - self._calc_overhead(version):
raise TripleSecFailedAssertion(u"Wrong decrypt output length")
return result

def _encrypt(self, data, key):
salt = rndfile.read(16)
stretched_key = self._key_stretching(key, salt)
def _split_ciphertext(self, data, version):
i = 0

first_step = self._salsa20_encrypt(data, stretched_key[0])
second_step = self._twofish_encrypt(first_step, stretched_key[1])
encrypted_material = self._aes_encrypt(second_step, stretched_key[2])
header_size = sum(map(len, version.header))
header = data[i:i + header_size]
i += header_size

header = TripleSec.MAGIC_BYTES + struct.pack("<I", self.VERSION)
salt = data[i:i + version.salt_size]
i += version.salt_size

hmac_sha2 = self._hmac_sha256(header + salt + encrypted_material, stretched_key[3])
hmac_sha3 = self._hmac_sha3(header + salt + encrypted_material, stretched_key[4])
macs = []
for m in version.MACs:
macs.append(data[i:i + m.output_size])
i += m.output_size

result = header + salt + hmac_sha2 + hmac_sha3 + encrypted_material
encrypted_material = data[i:]

if len(result) != 208 + len(data):
raise TripleSecFailedAssertion(u"Wrong encrypt output length")
return result
return header, salt, macs, encrypted_material

def _decrypt(self, data, key):
if len(data) < 208:
raise TripleSecDecryptionError(u"Input does not look like a TripleSec ciphertext")
def _check_macs(self, authenticated_data, macs, mac_keys, version):
expected_macs = self._generate_macs(authenticated_data, mac_keys, version)

header, salt, hmac_sha2, hmac_sha3, encrypted_material = \
data[:8], data[8:24], data[24:88], data[88:152], data[152:]
result = True

stretched_key = self._key_stretching(key, salt)
for expected, actual in zip(expected_macs, macs):
result = _constant_time_compare(expected, actual) and result

generated_hmac_sha2 = self._hmac_sha256(header + salt + encrypted_material, stretched_key[3])
generated_hmac_sha3 = self._hmac_sha3(header + salt + encrypted_material, stretched_key[4])

if not _constant_time_compare(generated_hmac_sha2, hmac_sha2) or \
not _constant_time_compare(generated_hmac_sha3, hmac_sha3):
raise TripleSecDecryptionError(u"Failed authentication of the data")

second_step = self._aes_decrypt(encrypted_material, stretched_key[2])
first_step = self._twofish_decrypt(second_step, stretched_key[1])
result = self._salsa20_decrypt(first_step, stretched_key[0])

if len(result) != len(data) - 208:
raise TripleSecFailedAssertion(u"Wrong decrypt output length")
return result


TripleSec._versions_implementations[TripleSec_v3.VERSION] = TripleSec_v3
def _decrypt_data(self, encrypted_material, cipher_keys, version):
ciphers_num = len(version.ciphers)
data = encrypted_material
for n, c in enumerate(reversed(version.ciphers)):
# the keys order is from the outermost to the innermost
key = cipher_keys[ciphers_num - 1 - n]
data = c.implementation.decrypt(data, key)
return data


### VERSIONS DEFINITIONS
TripleSec._versions[3] = Constants(
header = [ TripleSec.MAGIC_BYTES, struct.pack("<I", 3) ],
salt_size = 16,

KDF = KDF(name = 'scrypt',
implementation = Scrypt,
parameters = Scrypt_params(N = 2**13,
r = 8,
p = 1)),

MACs = [ MAC(name = 'HMAC-SHA-512',
implementation = HMAC_SHA512,
key_size = 48,
output_size = 64),
MAC(name = 'HMAC-SHA3',
implementation = HMAC_SHA3,
key_size = 48,
output_size = 64) ],

ciphers = [ Cipher(name = 'XSalsa20',
implementation = XSalsa20,
overhead_size = XSalsa20.iv_size,
key_size = XSalsa20.key_size),
Cipher(name = 'Twofish-CTR',
implementation = Twofish,
overhead_size = Twofish.block_size,
key_size = Twofish.key_size),
Cipher(name = 'AES-256-CTR',
implementation = AES,
overhead_size = AES.block_size,
key_size = AES.key_size) ])


# Expose encrypt() and decrypt() shortcuts
Expand Down

0 comments on commit 1d63105

Please sign in to comment.