diff --git a/CHANGES.rst b/CHANGES.rst index 8a2472d..74c85b4 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,7 @@ Changes Unreleased ---------- +- Add support for key derivation without kid. `#120 `__ - Add support for ECDH-SS direct HKDF. `#119 `__ - Add support for ECDH-ES direct HKDF. `#118 `__ diff --git a/cwt/recipient_algs/ecdh_direct_hkdf.py b/cwt/recipient_algs/ecdh_direct_hkdf.py index 1b2720a..e286f80 100644 --- a/cwt/recipient_algs/ecdh_direct_hkdf.py +++ b/cwt/recipient_algs/ecdh_direct_hkdf.py @@ -95,14 +95,14 @@ def derive_key( info=self._dumps(context), ) key = hkdf.derive(shared_key) - self._unprotected[4] = self._kid if self._kid else public_key.kid - derived = COSEKey.from_symmetric_key( - key, alg=context[0], kid=self._unprotected[4] - ) + kid = self._kid if self._kid else public_key.kid + if kid: + self._unprotected[4] = kid + derived = COSEKey.from_symmetric_key(key, alg=context[0], kid=kid) if self._alg in [-25, -26]: # ECDH-ES self._unprotected[-1] = EC2Key.to_cose_key(priv_key.public_key()) - else: # in [-27, -28] - # ECDH-SS + else: + # ECDH-SS (alg=-27 or -28) self._unprotected[-2] = EC2Key.to_cose_key(priv_key.public_key()) return derived diff --git a/cwt/recipient_interface.py b/cwt/recipient_interface.py index b626c77..7a822e1 100644 --- a/cwt/recipient_interface.py +++ b/cwt/recipient_interface.py @@ -42,8 +42,6 @@ def __init__( if not isinstance(unprotected[4], bytes): raise ValueError("unprotected[4](kid) should be bytes.") params[2] = unprotected[4] - else: - params[2] = b"" # alg if 1 in protected: diff --git a/cwt/recipients.py b/cwt/recipients.py index fb493fd..88bd6d4 100644 --- a/cwt/recipients.py +++ b/cwt/recipients.py @@ -79,10 +79,10 @@ def _extract_key_from_cose_keys( continue if r.alg == -6: # direct return k - elif r.alg in COSE_ALGORITHMS_KEY_WRAP.values(): + if r.alg in COSE_ALGORITHMS_KEY_WRAP.values(): r.set_key(k.key) return r.unwrap_key(alg) - elif r.alg in COSE_ALGORITHMS_CKDM_KEY_AGREEMENT_DIRECT.values(): + if r.alg in COSE_ALGORITHMS_CKDM_KEY_AGREEMENT_DIRECT.values(): if not context: raise ValueError("context should be set.") return r.derive_key(context, private_key=k) diff --git a/tests/test_ecdh_direct_hkdf.py b/tests/test_ecdh_direct_hkdf.py index 1e05dcc..bf656f9 100644 --- a/tests/test_ecdh_direct_hkdf.py +++ b/tests/test_ecdh_direct_hkdf.py @@ -83,16 +83,18 @@ def test_ecdh_direct_hkdf_derive_key_with_raw_context(self): encoded, private_key, context={"alg": "A128GCM"} ) - # def test_ecdh_direct_hkdf_derive_key_without_kid(self): - # rec = Recipient.from_json({"alg": "ECDH-ES+HKDF-256"}) - # with open(key_path("private_key_es256.pem")) as key_file: - # private_key = COSEKey.from_pem(key_file.read()) - # with open(key_path("public_key_es256.pem")) as key_file: - # public_key = COSEKey.from_pem(key_file.read()) - # enc_key = rec.derive_key({"alg": "A128GCM"}, public_key=public_key) - # ctx = COSE.new(alg_auto_inclusion=True) - # encoded = ctx.encode_and_encrypt(b"Hello world!", enc_key, recipients=[rec]) - # assert b"Hello world!" == ctx.decode(encoded, private_key, context={"alg": "A128GCM"}) + def test_ecdh_direct_hkdf_derive_key_without_kid(self): + rec = Recipient.from_json({"alg": "ECDH-ES+HKDF-256"}) + with open(key_path("private_key_es256.pem")) as key_file: + private_key = COSEKey.from_pem(key_file.read()) + with open(key_path("public_key_es256.pem")) as key_file: + public_key = COSEKey.from_pem(key_file.read()) + enc_key = rec.derive_key({"alg": "A128GCM"}, public_key=public_key) + ctx = COSE.new(alg_auto_inclusion=True) + encoded = ctx.encode_and_encrypt(b"Hello world!", enc_key, recipients=[rec]) + assert b"Hello world!" == ctx.decode( + encoded, private_key, context={"alg": "A128GCM"} + ) def test_ecdh_direct_hkdf_derive_key_with_invalid_private_key(self): rec = Recipient.from_json({"alg": "ECDH-ES+HKDF-256"}) diff --git a/tests/test_recipient.py b/tests/test_recipient.py index 008147c..35dc40e 100644 --- a/tests/test_recipient.py +++ b/tests/test_recipient.py @@ -57,7 +57,7 @@ def test_recipient_constructor(self): assert r.unprotected == {} assert r.ciphertext == b"" assert isinstance(r.recipients, list) - assert r.kid == b"" + assert r.kid is None assert r.alg == 0 assert len(r.recipients) == 0 with pytest.raises(NotImplementedError):