diff --git a/triplesec/crypto.py b/triplesec/crypto.py index 08b1385..e758958 100644 --- a/triplesec/crypto.py +++ b/triplesec/crypto.py @@ -27,67 +27,87 @@ sha3_512 ) -class AES: +def validate_key_size(key, key_size, algorithm): + if len(key) != key_size: + raise TripleSecFailedAssertion(u"Wrong {algo} key size" + .format(algo=algorithm)) + +class BlockCipher(object): + + @classmethod + def generate_counter(cls, block_size, iv): + ctr = Counter.new(block_size * 8, + initial_value=int(binascii.hexlify(iv), 16)) + return ctr + + @classmethod + def generate_encrypt_iv_counter(cls, block_size): + iv = rndfile.read(block_size) + ctr = cls.generate_counter(block_size, iv) + + return iv, ctr + + @classmethod + def generate_decrypt_counter(cls, data, block_size): + iv = data[:block_size] + ctr = cls.generate_counter(block_size, iv) + + return ctr + + +class AES(object): key_size = 32 block_size = 16 @classmethod def encrypt(cls, data, key): - if len(key) != cls.key_size: - raise TripleSecFailedAssertion(u"Wrong AES key size") - - iv = rndfile.read(cls.block_size) - ctr = Counter.new(cls.block_size*8, initial_value=int(binascii.hexlify(iv), 16)) + validate_key_size(key, cls.key_size, "AES") + iv, ctr = BlockCipher.generate_encrypt_iv_counter(cls.block_size) ciphertext = Crypto_AES.new(key, Crypto_AES.MODE_CTR, - counter=ctr).encrypt(data) + counter=ctr).encrypt(data) return iv + ciphertext @classmethod def decrypt(cls, data, key): - if len(key) != cls.key_size: - raise TripleSecFailedAssertion(u"Wrong AES key size") + validate_key_size(key, cls.key_size, "AES") - iv = data[:cls.block_size] - ctr = Counter.new(cls.block_size*8, initial_value=int(binascii.hexlify(iv), 16)) + ctr = BlockCipher.generate_decrypt_counter(data, cls.block_size) return Crypto_AES.new(key, Crypto_AES.MODE_CTR, - counter=ctr).decrypt(data[cls.block_size:]) + counter=ctr).decrypt(data[cls.block_size:]) -class Twofish: +class Twofish(object): key_size = 32 block_size = 16 @classmethod - def _gen_keystream(cls, length, T, ctr): + def _gen_keystream(cls, length, tfish, ctr): req_blocks = length // cls.block_size + 1 keystream = b'' for _ in range(req_blocks): - keystream += T.encrypt(ctr()) + keystream += tfish.encrypt(ctr()) return keystream[:length] @classmethod def encrypt(cls, data, key): - if len(key) != cls.key_size: - raise TripleSecFailedAssertion(u"Wrong Twofish key size") + validate_key_size(key, cls.key_size, "Twofish") - iv = rndfile.read(cls.block_size) - ctr = Counter.new(cls.block_size*8, initial_value=int(binascii.hexlify(iv), 16)) + iv, ctr = BlockCipher.generate_encrypt_iv_counter(cls.block_size) + tfish = twofish.Twofish(key) + ciphertext = strxor(data, cls._gen_keystream(len(data), tfish, ctr)) - T = twofish.Twofish(key) - ciphertext = strxor(data, cls._gen_keystream(len(data), T, ctr)) return iv + ciphertext @classmethod def decrypt(cls, data, key): - if len(key) != cls.key_size: - raise TripleSecFailedAssertion(u"Wrong Twofish key size") + validate_key_size(key, cls.key_size, "Twofish") - iv = data[:cls.block_size] - ctr = Counter.new(cls.block_size*8, initial_value=int(binascii.hexlify(iv), 16)) + ctr = BlockCipher.generate_decrypt_counter(data, cls.block_size) + tfish = twofish.Twofish(key) - T = twofish.Twofish(key) - return strxor(data[cls.block_size:], cls._gen_keystream(len(data[cls.block_size:]), T, ctr)) + return strxor(data[cls.block_size:], + cls._gen_keystream(len(data[cls.block_size:]), tfish, ctr)) class XSalsa20: key_size = 32 @@ -95,18 +115,16 @@ class XSalsa20: @classmethod def encrypt(cls, data, key): - if len(key) != cls.key_size: - raise TripleSecFailedAssertion(u"Wrong XSalsa20 key size") + validate_key_size(key, cls.key_size, "XSalsa20") iv = rndfile.read(cls.iv_size) - ciphertext = salsa20.XSalsa20_xor(data, iv, key) + return iv + ciphertext @classmethod def decrypt(cls, data, key): - if len(key) != cls.key_size: - raise TripleSecFailedAssertion(u"Wrong XSalsa20 key size") + validate_key_size(key, cls.key_size, "XSalsa20") iv = data[:cls.iv_size] diff --git a/triplesec/utils.py b/triplesec/utils.py index c214ac6..dca13fa 100644 --- a/triplesec/utils.py +++ b/triplesec/utils.py @@ -52,24 +52,30 @@ class TripleSecFailedAssertion(TripleSecError): ### UTILITIES def _constant_time_compare(a, b): - if len(a) != len(b): return False + if len(a) != len(b): + return False result = 0 for x, y in zip(six.iterbytes(a), six.iterbytes(b)): result |= x ^ y - return (result == 0) + return result == 0 -class new_sha3_512: +class new_sha3_512(object): block_size = 72 digest_size = 64 + def __init__(self, string=b''): self._obj = hashlib.sha3_512() self._obj.update(string) + def digest(self): return self._obj.digest() + def hexdigest(self): return self._obj.hexdigest() + def update(self, string): return self._obj.update(string) + def copy(self): copy = new_sha3_512() copy._obj = self._obj.copy() @@ -116,6 +122,6 @@ def win32_utf8_argv(): else: start = 0 return [argv[i].encode('utf-8') for i in - xrange(start, argc.value)] + range(start, argc.value)] except Exception: pass