diff --git a/src/chacha20poly1305_reuseable/__init__.pxd b/src/chacha20poly1305_reuseable/__init__.pxd index 3f67c0a..d1c978a 100644 --- a/src/chacha20poly1305_reuseable/__init__.pxd +++ b/src/chacha20poly1305_reuseable/__init__.pxd @@ -5,8 +5,11 @@ import cython cdef object _ENCRYPT cdef object _DECRYPT +cdef object lib +cdef object ffi + cdef object InvalidTag -cdef object openssl_assert +cdef object openssl_failure cdef object NULL cdef object EVP_CIPHER_CTX_ctrl @@ -26,26 +29,20 @@ cdef object ffi_new cdef object ffi_from_buffer cdef object ffi_buffer +cdef object MAX_SIZE +cdef object KEY_LEN +cdef object NONCE_LEN +cdef cython.uint NONCE_LEN_UINT +cdef object TAG_LENGTH +cdef object CIPHER_NAME + cdef _check_params( - object nonce_len, + cython.uint nonce_len, object nonce, object data, object associated_data ) -cdef _create_ctx() - - -cdef _set_cipher(object ctx, object cipher_name, object operation) - -cdef _set_key_len(object ctx, object key_len) - -cdef _set_key(object ctx, object key, object operation) - -cdef _set_decrypt_tag(object ctx, object tag) - -cdef _set_nonce_len(object ctx, object nonce_len) - cdef _set_nonce(object ctx, object nonce, object operation) cdef _aead_setup_with_fixed_nonce_len(object cipher_name, object key, object nonce_len, object operation) @@ -62,6 +59,8 @@ cdef _encrypt_with_fixed_nonce_len( object tag_length, ) +cdef openssl_assert(object ok) + cdef _encrypt_data( object ctx, object data, @@ -69,8 +68,6 @@ cdef _encrypt_data( object tag_length ) -cdef _tag_from_data(object data, object tag_length) - cdef _decrypt_with_fixed_nonce_len( object ctx, object nonce, diff --git a/src/chacha20poly1305_reuseable/__init__.py b/src/chacha20poly1305_reuseable/__init__.py index 01f5596..fbbf274 100644 --- a/src/chacha20poly1305_reuseable/__init__.py +++ b/src/chacha20poly1305_reuseable/__init__.py @@ -7,6 +7,7 @@ import os import typing +from functools import partial from typing import Optional, Union from cryptography import exceptions @@ -14,25 +15,28 @@ from cryptography.hazmat.backends.openssl.backend import backend from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 -openssl_assert = backend.openssl_assert -EVP_CIPHER_CTX_ctrl = backend._lib.EVP_CIPHER_CTX_ctrl -EVP_CTRL_AEAD_SET_TAG = backend._lib.EVP_CTRL_AEAD_SET_TAG -EVP_CTRL_AEAD_SET_IVLEN = backend._lib.EVP_CTRL_AEAD_SET_IVLEN -EVP_CipherInit_ex = backend._lib.EVP_CipherInit_ex -EVP_CIPHER_CTX_new = backend._lib.EVP_CIPHER_CTX_new -EVP_CIPHER_CTX_free = backend._lib.EVP_CIPHER_CTX_free -EVP_get_cipherbyname = backend._lib.EVP_get_cipherbyname -EVP_CIPHER_CTX_set_key_length = backend._lib.EVP_CIPHER_CTX_set_key_length -EVP_CipherUpdate = backend._lib.EVP_CipherUpdate -EVP_CipherFinal_ex = backend._lib.EVP_CipherFinal_ex -EVP_CTRL_AEAD_GET_TAG = backend._lib.EVP_CTRL_AEAD_GET_TAG - -ffi_from_buffer = backend._ffi.from_buffer -ffi_gc = backend._ffi.gc -ffi_new = backend._ffi.new -ffi_buffer = backend._ffi.buffer - -NULL = backend._ffi.NULL +openssl_failure = partial(backend.openssl_assert, False) +lib = backend._lib +ffi = backend._ffi + +EVP_CIPHER_CTX_ctrl = lib.EVP_CIPHER_CTX_ctrl +EVP_CTRL_AEAD_SET_TAG = lib.EVP_CTRL_AEAD_SET_TAG +EVP_CTRL_AEAD_SET_IVLEN = lib.EVP_CTRL_AEAD_SET_IVLEN +EVP_CipherInit_ex = lib.EVP_CipherInit_ex +EVP_CIPHER_CTX_new = lib.EVP_CIPHER_CTX_new +EVP_CIPHER_CTX_free = lib.EVP_CIPHER_CTX_free +EVP_get_cipherbyname = lib.EVP_get_cipherbyname +EVP_CIPHER_CTX_set_key_length = lib.EVP_CIPHER_CTX_set_key_length +EVP_CipherUpdate = lib.EVP_CipherUpdate +EVP_CipherFinal_ex = lib.EVP_CipherFinal_ex +EVP_CTRL_AEAD_GET_TAG = lib.EVP_CTRL_AEAD_GET_TAG + +ffi_from_buffer = ffi.from_buffer +ffi_gc = ffi.gc +ffi_new = ffi.new +ffi_buffer = ffi.buffer + +NULL = ffi.NULL _ENCRYPT = 1 _DECRYPT = 0 @@ -56,6 +60,14 @@ def _check_params( raise ValueError("Nonce must be 12 bytes") +MAX_SIZE = 2**32 +KEY_LEN = 32 +NONCE_LEN = 12 +NONCE_LEN_UINT = NONCE_LEN +TAG_LENGTH = 16 +CIPHER_NAME = b"chacha20-poly1305" + + class ChaCha20Poly1305Reusable(ChaCha20Poly1305): """A reuseable version of ChaCha20Poly1305. @@ -66,11 +78,6 @@ class ChaCha20Poly1305Reusable(ChaCha20Poly1305): The primary use case for this code is HAP streams. """ - _MAX_SIZE = 2**32 - _KEY_LEN = 32 - _NONCE_LEN = 12 - _TAG_LENGTH = 16 - def __init__(self, key: Union[_bytes, bytearray]) -> None: if not backend.aead_cipher_supported(self): raise exceptions.UnsupportedAlgorithm( @@ -81,17 +88,16 @@ def __init__(self, key: Union[_bytes, bytearray]) -> None: if not isinstance(key, (bytes, bytearray)): raise TypeError("key must be bytes or bytearay") - if len(key) != self._KEY_LEN: + if len(key) != KEY_LEN: raise ValueError("ChaCha20Poly1305Reusable key must be 32 bytes.") - self._cipher_name = b"chacha20-poly1305" self._key = key self._decrypt_ctx: Optional[object] = None self._encrypt_ctx: Optional[object] = None @classmethod def generate_key(cls) -> _bytes: - return os.urandom(ChaCha20Poly1305Reusable._KEY_LEN) + return os.urandom(KEY_LEN) def encrypt( self, @@ -99,28 +105,30 @@ def encrypt( data: _bytes, associated_data: typing.Optional[bytes], ) -> bytes: - if not self._encrypt_ctx: - self._encrypt_ctx = _aead_setup_with_fixed_nonce_len( - self._cipher_name, + encrypt_ctx = self._encrypt_ctx + if not encrypt_ctx: + encrypt_ctx = _aead_setup_with_fixed_nonce_len( + CIPHER_NAME, self._key, - self._NONCE_LEN, + NONCE_LEN, _ENCRYPT, ) + self._encrypt_ctx = encrypt_ctx if associated_data is None: associated_data = b"" - if len(data) > self._MAX_SIZE or len(associated_data) > self._MAX_SIZE: + if len(data) > MAX_SIZE or len(associated_data) > MAX_SIZE: # This is OverflowError to match what cffi would raise raise OverflowError("Data or associated data too long. Max 2**32 bytes") - _check_params(self._NONCE_LEN, nonce, data, associated_data) + _check_params(NONCE_LEN_UINT, nonce, data, associated_data) return _encrypt_with_fixed_nonce_len( - self._encrypt_ctx, + encrypt_ctx, nonce, data, associated_data, - self._TAG_LENGTH, + TAG_LENGTH, ) def decrypt( @@ -129,34 +137,49 @@ def decrypt( data: _bytes, associated_data: typing.Optional[_bytes], ) -> bytes: - if not self._decrypt_ctx: - self._decrypt_ctx = _aead_setup_with_fixed_nonce_len( - self._cipher_name, + decrypt_ctx = self._decrypt_ctx + if not decrypt_ctx: + decrypt_ctx = _aead_setup_with_fixed_nonce_len( + CIPHER_NAME, self._key, - self._NONCE_LEN, + NONCE_LEN, _DECRYPT, ) + self._decrypt_ctx = decrypt_ctx if associated_data is None: associated_data = b"" - _check_params(self._NONCE_LEN, nonce, data, associated_data) + _check_params(NONCE_LEN_UINT, nonce, data, associated_data) return _decrypt_with_fixed_nonce_len( - self._decrypt_ctx, + decrypt_ctx, nonce, data, associated_data, - self._TAG_LENGTH, + TAG_LENGTH, ) -def _create_ctx() -> object: - ctx = EVP_CIPHER_CTX_new() - ctx = ffi_gc(ctx, EVP_CIPHER_CTX_free) - return ctx +def _set_nonce(ctx: object, nonce: Union[_bytes, bytearray], operation: int) -> None: + nonce_ptr = ffi_from_buffer(nonce) + res = EVP_CipherInit_ex( + ctx, + NULL, + NULL, + NULL, + nonce_ptr, + int(operation == _ENCRYPT), + ) + openssl_assert(res != 0) -def _set_cipher(ctx: object, cipher_name: _bytes, operation: int) -> None: +def _aead_setup_with_fixed_nonce_len( + cipher_name: _bytes, key: Union[_bytes, bytearray], nonce_len: int, operation: int +) -> object: + # create the ctx + ctx = EVP_CIPHER_CTX_new() + ctx = ffi_gc(ctx, EVP_CIPHER_CTX_free) + # set the cipher evp_cipher = EVP_get_cipherbyname(cipher_name) openssl_assert(evp_cipher != NULL) res = EVP_CipherInit_ex( @@ -168,32 +191,20 @@ def _set_cipher(ctx: object, cipher_name: _bytes, operation: int) -> None: int(operation == _ENCRYPT), ) openssl_assert(res != 0) - - -def _set_key_len(ctx: object, key_len: int) -> None: - res = EVP_CIPHER_CTX_set_key_length(ctx, key_len) + # Set the key length + res = EVP_CIPHER_CTX_set_key_length(ctx, len(key)) openssl_assert(res != 0) - - -def _set_key(ctx: object, key: _bytes, operation: int) -> None: - key_ptr = ffi_from_buffer(key) + # Set the key res = EVP_CipherInit_ex( ctx, NULL, NULL, - key_ptr, + ffi_from_buffer(key), NULL, int(operation == _ENCRYPT), ) openssl_assert(res != 0) - - -def _set_decrypt_tag(ctx: object, tag: _bytes) -> None: - res = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_TAG, len(tag), tag) - openssl_assert(res != 0) - - -def _set_nonce_len(ctx: object, nonce_len: int) -> None: + # set nonce length res = EVP_CIPHER_CTX_ctrl( ctx, EVP_CTRL_AEAD_SET_IVLEN, @@ -201,29 +212,6 @@ def _set_nonce_len(ctx: object, nonce_len: int) -> None: NULL, ) openssl_assert(res != 0) - - -def _set_nonce(ctx: object, nonce: Union[_bytes, bytearray], operation: int) -> None: - nonce_ptr = ffi_from_buffer(nonce) - res = EVP_CipherInit_ex( - ctx, - NULL, - NULL, - NULL, - nonce_ptr, - int(operation == _ENCRYPT), - ) - openssl_assert(res != 0) - - -def _aead_setup_with_fixed_nonce_len( - cipher_name: _bytes, key: Union[_bytes, bytearray], nonce_len: int, operation: int -) -> object: - ctx = _create_ctx() - _set_cipher(ctx, cipher_name, operation) - _set_key_len(ctx, len(key)) - _set_key(ctx, key, operation) - _set_nonce_len(ctx, nonce_len) return ctx @@ -235,8 +223,9 @@ def _process_aad(ctx: object, associated_data: _bytes) -> None: def _process_data(ctx: object, data: _bytes) -> _bytes: outlen = ffi_new("int *") - buf = ffi_new("unsigned char[]", len(data)) - res = EVP_CipherUpdate(ctx, buf, outlen, data, len(data)) + data_len = len(data) + buf = ffi_new("unsigned char[]", data_len) + res = EVP_CipherUpdate(ctx, buf, outlen, data, data_len) openssl_assert(res != 0) return ffi_buffer(buf, outlen[0])[:] @@ -265,16 +254,9 @@ def _encrypt_data( res = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_GET_TAG, tag_length, tag_buf) openssl_assert(res != 0) tag = ffi_buffer(tag_buf)[:] - return processed_data + tag -def _tag_from_data(data: _bytes, tag_length: int) -> _bytes: - if len(data) < tag_length: - raise InvalidTag - return data[-tag_length:] - - def _decrypt_with_fixed_nonce_len( ctx: object, nonce: Union[_bytes, bytearray], @@ -282,10 +264,15 @@ def _decrypt_with_fixed_nonce_len( associated_data: _bytes, tag_length: int, ) -> bytes: - tag = _tag_from_data(data, tag_length) - data = data[:-tag_length] + if len(data) < tag_length: + raise InvalidTag + negative_tag_length = -tag_length + tag = data[negative_tag_length:] + data = data[:negative_tag_length] _set_nonce(ctx, nonce, _DECRYPT) - _set_decrypt_tag(ctx, tag) + # set the decrypted tag + res = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_TAG, tag_length, tag) + openssl_assert(res != 0) return _decrypt_data(ctx, data, associated_data) @@ -299,3 +286,9 @@ def _decrypt_data(ctx: object, data: _bytes, associated_data: _bytes) -> _bytes: raise InvalidTag return processed_data + + +def openssl_assert(ok: bool) -> None: + """Raise an exception if OpenSSL returns an error.""" + if not ok: + openssl_failure()