Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: speed up implementation #14

Merged
merged 12 commits into from
Jul 16, 2023
31 changes: 14 additions & 17 deletions src/chacha20poly1305_reuseable/__init__.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ import cython
cdef object _ENCRYPT
cdef object _DECRYPT

cdef object lib
cdef object ffi

cdef object InvalidTag
cdef object openssl_assert
cdef object openssl_failure
cdef object NULL

cdef object EVP_CIPHER_CTX_ctrl
Expand All @@ -26,26 +29,20 @@ cdef object ffi_new
cdef object ffi_from_buffer
cdef object ffi_buffer

cdef object MAX_SIZE
cdef object KEY_LEN
cdef object NONCE_LEN
cdef cython.uint NONCE_LEN_UINT
cdef object TAG_LENGTH
cdef object CIPHER_NAME

cdef _check_params(
object nonce_len,
cython.uint nonce_len,
object nonce,
object data,
object associated_data
)

cdef _create_ctx()


cdef _set_cipher(object ctx, object cipher_name, object operation)

cdef _set_key_len(object ctx, object key_len)

cdef _set_key(object ctx, object key, object operation)

cdef _set_decrypt_tag(object ctx, object tag)

cdef _set_nonce_len(object ctx, object nonce_len)

cdef _set_nonce(object ctx, object nonce, object operation)

cdef _aead_setup_with_fixed_nonce_len(object cipher_name, object key, object nonce_len, object operation)
Expand All @@ -62,15 +59,15 @@ cdef _encrypt_with_fixed_nonce_len(
object tag_length,
)

cdef openssl_assert(object ok)

cdef _encrypt_data(
object ctx,
object data,
object associated_data,
object tag_length
)

cdef _tag_from_data(object data, object tag_length)

cdef _decrypt_with_fixed_nonce_len(
object ctx,
object nonce,
Expand Down
191 changes: 92 additions & 99 deletions src/chacha20poly1305_reuseable/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,36 @@

import os
import typing
from functools import partial
from typing import Optional, Union

from cryptography import exceptions
from cryptography.exceptions import InvalidTag
from cryptography.hazmat.backends.openssl.backend import backend
from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305

openssl_assert = backend.openssl_assert
EVP_CIPHER_CTX_ctrl = backend._lib.EVP_CIPHER_CTX_ctrl
EVP_CTRL_AEAD_SET_TAG = backend._lib.EVP_CTRL_AEAD_SET_TAG
EVP_CTRL_AEAD_SET_IVLEN = backend._lib.EVP_CTRL_AEAD_SET_IVLEN
EVP_CipherInit_ex = backend._lib.EVP_CipherInit_ex
EVP_CIPHER_CTX_new = backend._lib.EVP_CIPHER_CTX_new
EVP_CIPHER_CTX_free = backend._lib.EVP_CIPHER_CTX_free
EVP_get_cipherbyname = backend._lib.EVP_get_cipherbyname
EVP_CIPHER_CTX_set_key_length = backend._lib.EVP_CIPHER_CTX_set_key_length
EVP_CipherUpdate = backend._lib.EVP_CipherUpdate
EVP_CipherFinal_ex = backend._lib.EVP_CipherFinal_ex
EVP_CTRL_AEAD_GET_TAG = backend._lib.EVP_CTRL_AEAD_GET_TAG

ffi_from_buffer = backend._ffi.from_buffer
ffi_gc = backend._ffi.gc
ffi_new = backend._ffi.new
ffi_buffer = backend._ffi.buffer

NULL = backend._ffi.NULL
openssl_failure = partial(backend.openssl_assert, False)
lib = backend._lib
ffi = backend._ffi

EVP_CIPHER_CTX_ctrl = lib.EVP_CIPHER_CTX_ctrl
EVP_CTRL_AEAD_SET_TAG = lib.EVP_CTRL_AEAD_SET_TAG
EVP_CTRL_AEAD_SET_IVLEN = lib.EVP_CTRL_AEAD_SET_IVLEN
EVP_CipherInit_ex = lib.EVP_CipherInit_ex
EVP_CIPHER_CTX_new = lib.EVP_CIPHER_CTX_new
EVP_CIPHER_CTX_free = lib.EVP_CIPHER_CTX_free
EVP_get_cipherbyname = lib.EVP_get_cipherbyname
EVP_CIPHER_CTX_set_key_length = lib.EVP_CIPHER_CTX_set_key_length
EVP_CipherUpdate = lib.EVP_CipherUpdate
EVP_CipherFinal_ex = lib.EVP_CipherFinal_ex
EVP_CTRL_AEAD_GET_TAG = lib.EVP_CTRL_AEAD_GET_TAG

ffi_from_buffer = ffi.from_buffer
ffi_gc = ffi.gc
ffi_new = ffi.new
ffi_buffer = ffi.buffer

NULL = ffi.NULL

_ENCRYPT = 1
_DECRYPT = 0
Expand All @@ -56,6 +60,14 @@
raise ValueError("Nonce must be 12 bytes")


MAX_SIZE = 2**32
KEY_LEN = 32
NONCE_LEN = 12
NONCE_LEN_UINT = NONCE_LEN
TAG_LENGTH = 16
CIPHER_NAME = b"chacha20-poly1305"


class ChaCha20Poly1305Reusable(ChaCha20Poly1305):
"""A reuseable version of ChaCha20Poly1305.

Expand All @@ -66,11 +78,6 @@
The primary use case for this code is HAP streams.
"""

_MAX_SIZE = 2**32
_KEY_LEN = 32
_NONCE_LEN = 12
_TAG_LENGTH = 16

def __init__(self, key: Union[_bytes, bytearray]) -> None:
if not backend.aead_cipher_supported(self):
raise exceptions.UnsupportedAlgorithm(
Expand All @@ -81,46 +88,47 @@
if not isinstance(key, (bytes, bytearray)):
raise TypeError("key must be bytes or bytearay")

if len(key) != self._KEY_LEN:
if len(key) != KEY_LEN:
raise ValueError("ChaCha20Poly1305Reusable key must be 32 bytes.")

self._cipher_name = b"chacha20-poly1305"
self._key = key
self._decrypt_ctx: Optional[object] = None
self._encrypt_ctx: Optional[object] = None

@classmethod
def generate_key(cls) -> _bytes:
return os.urandom(ChaCha20Poly1305Reusable._KEY_LEN)
return os.urandom(KEY_LEN)

def encrypt(
self,
nonce: Union[_bytes, bytearray],
data: _bytes,
associated_data: typing.Optional[bytes],
) -> bytes:
if not self._encrypt_ctx:
self._encrypt_ctx = _aead_setup_with_fixed_nonce_len(
self._cipher_name,
encrypt_ctx = self._encrypt_ctx
if not encrypt_ctx:
encrypt_ctx = _aead_setup_with_fixed_nonce_len(
CIPHER_NAME,
self._key,
self._NONCE_LEN,
NONCE_LEN,
_ENCRYPT,
)
self._encrypt_ctx = encrypt_ctx

if associated_data is None:
associated_data = b""

if len(data) > self._MAX_SIZE or len(associated_data) > self._MAX_SIZE:
if len(data) > MAX_SIZE or len(associated_data) > MAX_SIZE:
# This is OverflowError to match what cffi would raise
raise OverflowError("Data or associated data too long. Max 2**32 bytes")

_check_params(self._NONCE_LEN, nonce, data, associated_data)
_check_params(NONCE_LEN_UINT, nonce, data, associated_data)
return _encrypt_with_fixed_nonce_len(
self._encrypt_ctx,
encrypt_ctx,
nonce,
data,
associated_data,
self._TAG_LENGTH,
TAG_LENGTH,
)

def decrypt(
Expand All @@ -129,34 +137,49 @@
data: _bytes,
associated_data: typing.Optional[_bytes],
) -> bytes:
if not self._decrypt_ctx:
self._decrypt_ctx = _aead_setup_with_fixed_nonce_len(
self._cipher_name,
decrypt_ctx = self._decrypt_ctx
if not decrypt_ctx:
decrypt_ctx = _aead_setup_with_fixed_nonce_len(
CIPHER_NAME,
self._key,
self._NONCE_LEN,
NONCE_LEN,
_DECRYPT,
)
self._decrypt_ctx = decrypt_ctx

if associated_data is None:
associated_data = b""

_check_params(self._NONCE_LEN, nonce, data, associated_data)
_check_params(NONCE_LEN_UINT, nonce, data, associated_data)
return _decrypt_with_fixed_nonce_len(
self._decrypt_ctx,
decrypt_ctx,
nonce,
data,
associated_data,
self._TAG_LENGTH,
TAG_LENGTH,
)


def _create_ctx() -> object:
ctx = EVP_CIPHER_CTX_new()
ctx = ffi_gc(ctx, EVP_CIPHER_CTX_free)
return ctx
def _set_nonce(ctx: object, nonce: Union[_bytes, bytearray], operation: int) -> None:
nonce_ptr = ffi_from_buffer(nonce)
res = EVP_CipherInit_ex(
ctx,
NULL,
NULL,
NULL,
nonce_ptr,
int(operation == _ENCRYPT),
)
openssl_assert(res != 0)


def _set_cipher(ctx: object, cipher_name: _bytes, operation: int) -> None:
def _aead_setup_with_fixed_nonce_len(
cipher_name: _bytes, key: Union[_bytes, bytearray], nonce_len: int, operation: int
) -> object:
# create the ctx
ctx = EVP_CIPHER_CTX_new()
ctx = ffi_gc(ctx, EVP_CIPHER_CTX_free)
# set the cipher
evp_cipher = EVP_get_cipherbyname(cipher_name)
openssl_assert(evp_cipher != NULL)
res = EVP_CipherInit_ex(
Expand All @@ -168,62 +191,27 @@
int(operation == _ENCRYPT),
)
openssl_assert(res != 0)


def _set_key_len(ctx: object, key_len: int) -> None:
res = EVP_CIPHER_CTX_set_key_length(ctx, key_len)
# Set the key length
res = EVP_CIPHER_CTX_set_key_length(ctx, len(key))
openssl_assert(res != 0)


def _set_key(ctx: object, key: _bytes, operation: int) -> None:
key_ptr = ffi_from_buffer(key)
# Set the key
res = EVP_CipherInit_ex(
ctx,
NULL,
NULL,
key_ptr,
ffi_from_buffer(key),
NULL,
int(operation == _ENCRYPT),
)
openssl_assert(res != 0)


def _set_decrypt_tag(ctx: object, tag: _bytes) -> None:
res = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_TAG, len(tag), tag)
openssl_assert(res != 0)


def _set_nonce_len(ctx: object, nonce_len: int) -> None:
# set nonce length
res = EVP_CIPHER_CTX_ctrl(
ctx,
EVP_CTRL_AEAD_SET_IVLEN,
nonce_len,
NULL,
)
openssl_assert(res != 0)


def _set_nonce(ctx: object, nonce: Union[_bytes, bytearray], operation: int) -> None:
nonce_ptr = ffi_from_buffer(nonce)
res = EVP_CipherInit_ex(
ctx,
NULL,
NULL,
NULL,
nonce_ptr,
int(operation == _ENCRYPT),
)
openssl_assert(res != 0)


def _aead_setup_with_fixed_nonce_len(
cipher_name: _bytes, key: Union[_bytes, bytearray], nonce_len: int, operation: int
) -> object:
ctx = _create_ctx()
_set_cipher(ctx, cipher_name, operation)
_set_key_len(ctx, len(key))
_set_key(ctx, key, operation)
_set_nonce_len(ctx, nonce_len)
return ctx


Expand All @@ -235,8 +223,9 @@

def _process_data(ctx: object, data: _bytes) -> _bytes:
outlen = ffi_new("int *")
buf = ffi_new("unsigned char[]", len(data))
res = EVP_CipherUpdate(ctx, buf, outlen, data, len(data))
data_len = len(data)
buf = ffi_new("unsigned char[]", data_len)
res = EVP_CipherUpdate(ctx, buf, outlen, data, data_len)
openssl_assert(res != 0)
return ffi_buffer(buf, outlen[0])[:]

Expand Down Expand Up @@ -265,27 +254,25 @@
res = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_GET_TAG, tag_length, tag_buf)
openssl_assert(res != 0)
tag = ffi_buffer(tag_buf)[:]

return processed_data + tag


def _tag_from_data(data: _bytes, tag_length: int) -> _bytes:
if len(data) < tag_length:
raise InvalidTag
return data[-tag_length:]


def _decrypt_with_fixed_nonce_len(
ctx: object,
nonce: Union[_bytes, bytearray],
data: _bytes,
associated_data: _bytes,
tag_length: int,
) -> bytes:
tag = _tag_from_data(data, tag_length)
data = data[:-tag_length]
if len(data) < tag_length:
raise InvalidTag
negative_tag_length = -tag_length
tag = data[negative_tag_length:]
data = data[:negative_tag_length]
_set_nonce(ctx, nonce, _DECRYPT)
_set_decrypt_tag(ctx, tag)
# set the decrypted tag
res = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_TAG, tag_length, tag)
openssl_assert(res != 0)
return _decrypt_data(ctx, data, associated_data)


Expand All @@ -299,3 +286,9 @@
raise InvalidTag

return processed_data


def openssl_assert(ok: bool) -> None:
"""Raise an exception if OpenSSL returns an error."""
if not ok:
openssl_failure()

Check warning on line 294 in src/chacha20poly1305_reuseable/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/chacha20poly1305_reuseable/__init__.py#L294

Added line #L294 was not covered by tests