diff --git a/CHANGES.rst b/CHANGES.rst index 778f752..0dca007 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,9 @@ Changes Unreleased ---------- +- Add support for AES-GCM (A128GCM, A192GCM and A256GCM). `#41 `__ +- Make key optional for KeyBuilder.from_symmetric_key. `#41 `__ + Version 0.3.0 ------------- diff --git a/cwt/key_builder.py b/cwt/key_builder.py index e74c321..2181138 100644 --- a/cwt/key_builder.py +++ b/cwt/key_builder.py @@ -32,7 +32,7 @@ from .cose_key import COSEKey from .key_types.ec2 import EC2Key from .key_types.okp import OKPKey -from .key_types.symmetric import AESCCMKey, HMACKey +from .key_types.symmetric import AESCCMKey, AESGCMKey, HMACKey class KeyBuilder: @@ -73,7 +73,7 @@ def __init__(self, options: Optional[Dict[str, Any]] = None): def from_symmetric_key( self, - key: Union[bytes, str], + key: Union[bytes, str] = b"", alg: Union[int, str] = "HMAC 256/256", kid: Union[bytes, str] = b"", ) -> COSEKey: @@ -105,6 +105,8 @@ def from_symmetric_key( kid = kid.encode("utf-8") if kid: cose_key[2] = kid + if alg_id in [1, 2, 3]: + return AESGCMKey(cose_key) if alg_id in [4, 5, 6, 7]: return HMACKey(cose_key) if alg_id in [10, 11, 12, 13, 30, 31, 32, 33]: diff --git a/cwt/key_types/symmetric.py b/cwt/key_types/symmetric.py index df33bf4..ca03364 100644 --- a/cwt/key_types/symmetric.py +++ b/cwt/key_types/symmetric.py @@ -1,12 +1,15 @@ import hashlib import hmac +from secrets import token_bytes from typing import Any, Dict, Optional -from cryptography.hazmat.primitives.ciphers.aead import AESCCM +from cryptography.hazmat.primitives.ciphers.aead import AESCCM, AESGCM from ..cose_key import COSEKey from ..exceptions import DecodeError, EncodeError, VerifyError +_CWT_DEFAULT_HMAC_KEY_SIZE = 32 # bytes + class SymmetricKey(COSEKey): """""" @@ -22,11 +25,10 @@ def __init__(self, cose_key: Dict[int, Any]): raise ValueError("kty(1) should be Symmetric(4).") # Validate k. - if -1 not in cose_key: - raise ValueError("k(-1) not found.") - if -1 in cose_key and not isinstance(cose_key[-1], bytes): - raise ValueError("k(-1) should be bytes(bstr).") - self._key = cose_key[-1] + if -1 in cose_key: + if not isinstance(cose_key[-1], bytes): + raise ValueError("k(-1) should be bytes(bstr).") + self._key = cose_key[-1] if 3 not in cose_key: raise ValueError("alg(3) not found.") @@ -42,6 +44,8 @@ def __init__(self, cose_key: Dict[int, Any]): self._hash_alg = None self._trunc = 0 + if not self._key: + self._key = token_bytes(_CWT_DEFAULT_HMAC_KEY_SIZE) # Validate alg. if self._alg == 4: # HMAC 256/64 @@ -85,6 +89,8 @@ def __init__(self, cose_key: Dict[int, Any]): # Validate alg. if self._alg == 10: # AES-CCM-16-64-128 + if not self._key: + self._key = AESCCM.generate_key(bit_length=128) if len(self._key) != 16: raise ValueError( "The length of AES-CCM-16-64-128 key should be 16 bytes." @@ -92,6 +98,8 @@ def __init__(self, cose_key: Dict[int, Any]): self._cipher = AESCCM(self._key, tag_length=8) self._nonce_len = 13 elif self._alg == 11: # AES-CCM-16-64-256 + if not self._key: + self._key = AESCCM.generate_key(bit_length=256) if len(self._key) != 32: raise ValueError( "The length of AES-CCM-16-64-256 key should be 32 bytes." @@ -99,6 +107,8 @@ def __init__(self, cose_key: Dict[int, Any]): self._cipher = AESCCM(self._key, tag_length=8) self._nonce_len = 13 elif self._alg == 12: # AES-CCM-64-64-128 + if not self._key: + self._key = AESCCM.generate_key(bit_length=128) if len(self._key) != 16: raise ValueError( "The length of AES-CCM-64-64-128 key should be 16 bytes." @@ -106,6 +116,8 @@ def __init__(self, cose_key: Dict[int, Any]): self._cipher = AESCCM(self._key, tag_length=8) self._nonce_len = 7 elif self._alg == 13: # AES-CCM-64-64-256 + if not self._key: + self._key = AESCCM.generate_key(bit_length=256) if len(self._key) != 32: raise ValueError( "The length of AES-CCM-64-64-256 key should be 32 bytes." @@ -113,6 +125,8 @@ def __init__(self, cose_key: Dict[int, Any]): self._cipher = AESCCM(self._key, tag_length=8) self._nonce_len = 7 elif self._alg == 30: # AES-CCM-16-128-128 + if not self._key: + self._key = AESCCM.generate_key(bit_length=128) if len(self._key) != 16: raise ValueError( "The length of AES-CCM-16-128-128 key should be 16 bytes." @@ -120,6 +134,8 @@ def __init__(self, cose_key: Dict[int, Any]): self._cipher = AESCCM(self._key) self._nonce_len = 13 elif self._alg == 31: # AES-CCM-16-128-256 + if not self._key: + self._key = AESCCM.generate_key(bit_length=256) if len(self._key) != 32: raise ValueError( "The length of AES-CCM-16-128-256 key should be 32 bytes." @@ -127,6 +143,8 @@ def __init__(self, cose_key: Dict[int, Any]): self._cipher = AESCCM(self._key) self._nonce_len = 13 elif self._alg == 32: # AES-CCM-64-128-128 + if not self._key: + self._key = AESCCM.generate_key(bit_length=128) if len(self._key) != 16: raise ValueError( "The length of AES-CCM-64-128-128 key should be 16 bytes." @@ -134,6 +152,8 @@ def __init__(self, cose_key: Dict[int, Any]): self._cipher = AESCCM(self._key) self._nonce_len = 7 elif self._alg == 33: # AES-CCM-64-128-256 + if not self._key: + self._key = AESCCM.generate_key(bit_length=256) if len(self._key) != 32: raise ValueError( "The length of AES-CCM-64-128-256 key should be 32 bytes." @@ -164,3 +184,48 @@ def decrypt(self, msg: bytes, nonce: bytes, aad: Optional[bytes] = None) -> byte return self._cipher.decrypt(nonce, msg, aad) except Exception as err: raise DecodeError("Failed to decrypt.") from err + + +class AESGCMKey(SymmetricKey): + """""" + + def __init__(self, cose_key: Dict[int, Any]): + """""" + super().__init__(cose_key) + + self._cipher: AESGCM + + # Validate alg. + if self._alg == 1: # A128GCM + if not self._key: + self._key = AESGCM.generate_key(bit_length=128) + if len(self._key) != 16: + raise ValueError("The length of A128GCM key should be 16 bytes.") + elif self._alg == 2: # A192GCM + if not self._key: + self._key = AESGCM.generate_key(bit_length=192) + if len(self._key) != 24: + raise ValueError("The length of A192GCM key should be 24 bytes.") + elif self._alg == 3: # A256GCM + if not self._key: + self._key = AESGCM.generate_key(bit_length=256) + if len(self._key) != 32: + raise ValueError("The length of A256GCM key should be 32 bytes.") + else: + raise ValueError(f"Unsupported or unknown alg(3) for AES GCM: {self._alg}.") + self._cipher = AESGCM(self._key) + return + + def encrypt(self, msg: bytes, nonce: bytes, aad: Optional[bytes] = None) -> bytes: + """""" + try: + return self._cipher.encrypt(nonce, msg, aad) + except Exception as err: + raise EncodeError("Failed to encrypt.") from err + + def decrypt(self, msg: bytes, nonce: bytes, aad: Optional[bytes] = None) -> bytes: + """""" + try: + return self._cipher.decrypt(nonce, msg, aad) + except Exception as err: + raise DecodeError("Failed to decrypt.") from err diff --git a/docs/algorithms.rst b/docs/algorithms.rst index 109810f..d307113 100644 --- a/docs/algorithms.rst +++ b/docs/algorithms.rst @@ -79,11 +79,11 @@ COSE Algorithms +------------------------+--------+-------+-----------------------------------------------------+ | ... | +------------------------+--------+-------+-----------------------------------------------------+ -| A128GCM | | 1 | AES-GCM mode w/ 128-bit key, 128-bit tag | +| A128GCM | ✅ | 1 | AES-GCM mode w/ 128-bit key, 128-bit tag | +------------------------+--------+-------+-----------------------------------------------------+ -| A192GCM | | 2 | AES-GCM mode w/ 192-bit key, 128-bit tag | +| A192GCM | ✅ | 2 | AES-GCM mode w/ 192-bit key, 128-bit tag | +------------------------+--------+-------+-----------------------------------------------------+ -| A256GCM | | 3 | AES-GCM mode w/ 256-bit key, 128-bit tag | +| A256GCM | ✅ | 3 | AES-GCM mode w/ 256-bit key, 128-bit tag | +------------------------+--------+-------+-----------------------------------------------------+ | HMAC 256/64 | ✅ | 4 | HMAC w/ SHA-256 truncated to 64 bits | +------------------------+--------+-------+-----------------------------------------------------+ diff --git a/tests/test_cwt.py b/tests/test_cwt.py index 05fe3e7..bcf5257 100644 --- a/tests/test_cwt.py +++ b/tests/test_cwt.py @@ -247,6 +247,27 @@ def test_cwt_encode_and_encrypt_with_valid_alg_aes_ccm(self, ctx, alg, nonce, ke assert 2 in decoded and decoded[2] == "someone" assert 7 in decoded and decoded[7] == b"123" + @pytest.mark.parametrize( + "alg, key", + [ + ("A128GCM", token_bytes(16)), + ("A192GCM", token_bytes(24)), + ("A256GCM", token_bytes(32)), + ], + ) + def test_cwt_encode_and_encrypt_with_valid_alg_aes_gcm(self, ctx, alg, key): + """""" + enc_key = cose_key.from_symmetric_key(key, alg=alg) + token = ctx.encode_and_encrypt( + {1: "https://as.example", 2: "someone", 7: b"123"}, + enc_key, + nonce=token_bytes(12), + ) + decoded = ctx.decode(token, enc_key) + assert 1 in decoded and decoded[1] == "https://as.example" + assert 2 in decoded and decoded[2] == "someone" + assert 7 in decoded and decoded[7] == b"123" + def test_cwt_encode_and_encrypt_with_tagged(self, ctx): """""" key = token_bytes(16) diff --git a/tests/test_key_builder.py b/tests/test_key_builder.py index 2b78c72..76c2506 100644 --- a/tests/test_key_builder.py +++ b/tests/test_key_builder.py @@ -45,12 +45,39 @@ def test_key_builder_from_symmetric_key_hmac(self, ctx, alg): @pytest.mark.parametrize( "alg", - ["xxx", 3, 8, 9, 34], + [ + "HMAC 256/64", + "HMAC 256/256", + "HMAC 384/384", + "HMAC 512/512", + "A128GCM", + "A192GCM", + "A256GCM", + "AES-CCM-16-64-128", + "AES-CCM-16-64-256", + "AES-CCM-64-64-128", + "AES-CCM-64-64-256", + "AES-CCM-16-128-128", + "AES-CCM-16-128-256", + "AES-CCM-64-128-128", + "AES-CCM-64-128-256", + ], + ) + def test_key_builder_from_symmetric_key_without_key(self, ctx, alg): + try: + k = ctx.from_symmetric_key(alg=alg) + assert k.kty == 4 + except Exception: + pytest.fail("from_symmetric_key should not fail.") + + @pytest.mark.parametrize( + "alg", + ["xxx", 0, 8, 9, 34], ) def test_key_builder_from_symmetric_key_with_invalid_alg(self, ctx, alg): with pytest.raises(ValueError) as err: - res = ctx.from_symmetric_key("mysecretpassword", alg=alg) - pytest.fail("from_symmetric_key should fail: res=%s" % vars(res)) + ctx.from_symmetric_key("mysecretpassword", alg=alg) + pytest.fail("from_symmetric_key should fail.") assert f"Unsupported or unknown alg({alg})." in str(err.value) @pytest.mark.parametrize( diff --git a/tests/test_symmetric.py b/tests/test_symmetric.py index 1836694..fea6607 100644 --- a/tests/test_symmetric.py +++ b/tests/test_symmetric.py @@ -5,8 +5,8 @@ import pytest -from cwt.exceptions import DecodeError, VerifyError -from cwt.key_types.symmetric import AESCCMKey, HMACKey, SymmetricKey +from cwt.exceptions import DecodeError, EncodeError, VerifyError +from cwt.key_types.symmetric import AESCCMKey, AESGCMKey, HMACKey, SymmetricKey class TestSymmetricKey: @@ -48,10 +48,6 @@ def test_symmetric_key_constructor_with_hmac_256_256(self): {1: []}, "kty(1) should be int or str(tstr).", ), - ( - {1: 4}, - "k(-1) not found.", - ), ( {1: 4, -1: 123}, "k(-1) should be bytes(bstr).", @@ -103,6 +99,25 @@ def test_hmac_key_constructor_with_hmac_256_256(self): except Exception: pytest.fail("sign/verify should not fail.") + def test_hmac_key_constructor_with_hmac_256_256_without_key(self): + """""" + key = HMACKey( + { + 1: 4, + 3: 5, # HMAC 256/256 + } + ) + assert key.kty == 4 + assert key.kid is None + assert key.alg == 5 + assert key.key_ops is None + assert key.base_iv is None + try: + sig = key.sign(b"Hello world!") + key.verify(b"Hello world!", sig) + except Exception: + pytest.fail("sign/verify should not fail.") + @pytest.mark.parametrize( "invalid, msg", [ @@ -122,10 +137,6 @@ def test_hmac_key_constructor_with_hmac_256_256(self): {1: []}, "kty(1) should be int or str(tstr).", ), - ( - {1: 4}, - "k(-1) not found.", - ), ( {1: 4, -1: 123}, "k(-1) should be bytes(bstr).", @@ -205,6 +216,56 @@ def test_aesccm_key_constructor_with_aes_ccm_16_64_128(self): except Exception: pytest.fail("sign/verify should not fail.") + @pytest.mark.parametrize( + "key_args, nonce", + [ + ( + {1: 4, 3: 10}, + token_bytes(13), + ), + ( + {1: 4, 3: 11}, + token_bytes(13), + ), + ( + {1: 4, 3: 12}, + token_bytes(7), + ), + ( + {1: 4, 3: 13}, + token_bytes(7), + ), + ( + {1: 4, 3: 30}, + token_bytes(13), + ), + ( + {1: 4, 3: 31}, + token_bytes(13), + ), + ( + {1: 4, 3: 32}, + token_bytes(7), + ), + ( + {1: 4, 3: 33}, + token_bytes(7), + ), + ], + ) + def test_aesccm_key_constructor_with_aes_ccm_without_key(self, key_args, nonce): + """""" + key = AESCCMKey(key_args) + assert key.kty == 4 + assert key.kid is None + assert key.key_ops is None + assert key.base_iv is None + try: + encrypted = key.encrypt(b"Hello world!", nonce=nonce) + assert key.decrypt(encrypted, nonce) == b"Hello world!" + except Exception: + pytest.fail("sign/verify should not fail.") + @pytest.mark.parametrize( "invalid, msg", [ @@ -224,10 +285,6 @@ def test_aesccm_key_constructor_with_aes_ccm_16_64_128(self): {1: []}, "kty(1) should be int or str(tstr).", ), - ( - {1: 4}, - "k(-1) not found.", - ), ( {1: 4, -1: 123}, "k(-1) should be bytes(bstr).", @@ -293,6 +350,19 @@ def test_aesccm_key_constructor_with_invalid_args(self, invalid, msg): pytest.fail("AESCCMKey should fail.") assert msg in str(err.value) + def test_aesgcm_key_encrypt_without_msg(self): + key = AESCCMKey( + { + 1: 4, + -1: token_bytes(16), + 3: 10, # AES-CCM-16-64-128 + } + ) + nonce = token_bytes(13) + with pytest.raises(EncodeError) as err: + key.encrypt(None, nonce=nonce) + assert "Failed to encrypt." in str(err.value) + def test_aesccm_key_decrypt_with_invalid_nonce(self): """""" key = AESCCMKey( @@ -332,3 +402,113 @@ def test_aesccm_key_decrypt_with_invalid_length_nonce(self): with pytest.raises(ValueError) as err: key.decrypt(encrypted, nonce=token_bytes(7)) assert "The length of nonce should be 13 bytes." in str(err.value) + + +class TestAESGCMKey: + """ + Tests for AESGCMKey. + """ + + def test_aesgcm_key_constructor_with_aes_gcm_a128gcm(self): + """""" + key = AESGCMKey( + { + 1: 4, + -1: token_bytes(16), + 3: 1, # A128GCM + } + ) + assert key.kty == 4 + assert key.kid is None + assert key.alg == 1 + assert key.key_ops is None + assert key.base_iv is None + nonce = token_bytes(12) + try: + encrypted = key.encrypt(b"Hello world!", nonce=nonce) + assert key.decrypt(encrypted, nonce) == b"Hello world!" + except Exception: + pytest.fail("sign/verify should not fail.") + + @pytest.mark.parametrize( + "key_args", + [ + {1: 4, 3: 1}, + {1: 4, 3: 2}, + {1: 4, 3: 3}, + ], + ) + def test_aesgcm_key_constructor_with_aes_ccm_without_key(self, key_args): + """""" + key = AESGCMKey(key_args) + assert key.kty == 4 + assert key.kid is None + assert key.key_ops is None + assert key.base_iv is None + nonce = token_bytes(12) + try: + encrypted = key.encrypt(b"Hello world!", nonce=nonce) + assert key.decrypt(encrypted, nonce) == b"Hello world!" + except Exception: + pytest.fail("sign/verify should not fail.") + + @pytest.mark.parametrize( + "invalid, msg", + [ + ( + {1: 4, -1: b"mysecret", 3: 4}, + "Unsupported or unknown alg(3) for AES GCM: 4", + ), + ( + {1: 4, -1: b"mysecret", 3: 1}, + "The length of A128GCM key should be 16 bytes.", + ), + ( + {1: 4, -1: b"mysecret", 3: 2}, + "The length of A192GCM key should be 24 bytes.", + ), + ( + {1: 4, -1: b"mysecret", 3: 3}, + "The length of A256GCM key should be 32 bytes.", + ), + ], + ) + def test_aesgcm_key_constructor_with_invalid_args(self, invalid, msg): + """""" + with pytest.raises(ValueError) as err: + AESGCMKey(invalid) + pytest.fail("AESGCMKey should fail.") + assert msg in str(err.value) + + def test_aesgcm_key_encrypt_with_empty_nonce(self): + """""" + key = AESGCMKey( + { + 1: 4, + -1: token_bytes(16), + 3: 1, # A128GCM + } + ) + with pytest.raises(EncodeError) as err: + key.encrypt(b"Hello world!", nonce=b"") + assert "Failed to encrypt." in str(err.value) + + def test_aesgcm_key_decrypt_with_invalid_nonce(self): + """""" + key = AESGCMKey( + { + 1: 4, + -1: token_bytes(16), + 3: 1, # A128GCM + } + ) + assert key.kty == 4 + assert key.kid is None + assert key.alg == 1 + assert key.key_ops is None + assert key.base_iv is None + nonce = token_bytes(12) + encrypted = key.encrypt(b"Hello world!", nonce=nonce) + with pytest.raises(DecodeError) as err: + key.decrypt(encrypted, nonce=token_bytes(13)) + assert "Failed to decrypt." in str(err.value)