From 57ea966731f7091c3a0971e30a0b52ed2fd5ccf2 Mon Sep 17 00:00:00 2001 From: "Ajitomi, Daisuke" Date: Sun, 30 May 2021 09:37:34 +0900 Subject: [PATCH] Merge RecipientsBuilder into Recipients. --- CHANGES.rst | 6 ++++- cwt/cose.py | 7 +++--- cwt/recipients.py | 40 +++++++++++++++++++++++++++++++--- cwt/recipients_builder.py | 46 --------------------------------------- tests/test_recipient.py | 36 +++++++++++------------------- 5 files changed, 58 insertions(+), 77 deletions(-) delete mode 100644 cwt/recipients_builder.py diff --git a/CHANGES.rst b/CHANGES.rst index a6aabdd..1161e96 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,9 +4,13 @@ Changes Unreleased ---------- -- Rename RecipientBuilder to COSERecipient. `#99 `__ +- Merge RecipientsBuilder into Recipients. `#103 `__ +- Rename Key to COSEKeyInterface. `#102 `__ +- Rename RecipientBuilder to Recipient. `#101 `__ +- Make Key private. `#100 `__ - Merge ClaimsBuilder into Claims. `#98 `__ - Rename KeyBuilder to COSEKey. `#97 `__ +- Rename COSEKey to Key. `#97 `__ - Add support for external AAD. `#94 `__ - Make unwrap_key return COSEKey. `#93 `__ - Fix default HMAC key size. `#91 `__ diff --git a/cwt/cose.py b/cwt/cose.py index 5d2164f..e72d8c2 100644 --- a/cwt/cose.py +++ b/cwt/cose.py @@ -6,7 +6,7 @@ from .const import COSE_ALGORITHMS_RECIPIENT from .cose_key_interface import COSEKeyInterface from .recipient_interface import RecipientInterface -from .recipients_builder import RecipientsBuilder +from .recipients import Recipients class COSE(CBORProcessor): @@ -26,7 +26,6 @@ def __init__(self, options: Optional[Dict[str, Any]] = None): of COSE. At this time, ``kid_auto_inclusion`` (default value: ``True``) and ``alg_auto_inclusion`` (default value: ``True``) are supported. """ - self._recipients_builder = RecipientsBuilder() self._kid_auto_inclusion = True self._alg_auto_inclusion = True if not options: @@ -342,7 +341,7 @@ def decode( if not isinstance(unprotected, dict): raise ValueError("unprotected header should be dict.") nonce = unprotected.get(5, None) - recipients = self._recipients_builder.from_list(data.value[3]) + recipients = Recipients.from_list(data.value[3]) enc_key = ( recipients.derive_key(keys=keys, alg_hint=alg_hint) if key is not None @@ -379,7 +378,7 @@ def decode( if isinstance(protected, dict) and 1 in protected else 0 ) - recipients = self._recipients_builder.from_list(data.value[4]) + recipients = Recipients.from_list(data.value[4]) mac_auth_key = recipients.derive_key(keys=keys, alg_hint=alg_hint) mac_auth_key.verify(to_be_maced, data.value[3]) return data.value[2] diff --git a/cwt/recipients.py b/cwt/recipients.py index 0882c6f..cd815af 100644 --- a/cwt/recipients.py +++ b/cwt/recipients.py @@ -1,13 +1,15 @@ -from typing import List, Optional +from typing import Any, List, Optional + +import cbor2 -from .cbor_processor import CBORProcessor from .const import COSE_ALGORITHMS_KEY_WRAP from .cose_key_interface import COSEKeyInterface +from .recipient import Recipient from .recipient_interface import RecipientInterface from .utils import base64url_decode, to_cis -class Recipients(CBORProcessor): +class Recipients: """ A Set of COSE Recipients. """ @@ -16,6 +18,38 @@ def __init__(self, recipients: List[RecipientInterface]): self._recipients = recipients return + @classmethod + def from_list(cls, recipients: List[Any]): + """ + Create Recipients from a CBOR-like list. + """ + res: List[RecipientInterface] = [] + for r in recipients: + res.append(cls._create_recipient(r)) + return cls(res) + + @classmethod + def _create_recipient(cls, recipient: List[Any]) -> RecipientInterface: + if not isinstance(recipient, list) or ( + len(recipient) != 3 and len(recipient) != 4 + ): + raise ValueError("Invalid recipient format.") + if not isinstance(recipient[0], bytes): + raise ValueError("protected header should be bytes.") + protected = {} if not recipient[0] else cbor2.loads(recipient[0]) + if not isinstance(recipient[1], dict): + raise ValueError("unprotected header should be dict.") + if not isinstance(recipient[2], bytes): + raise ValueError("ciphertext should be bytes.") + if len(recipient) == 3: + return Recipient.from_dict(protected, recipient[1], recipient[2]) + if not isinstance(recipient[3], list): + raise ValueError("recipients should be list.") + recipients: List[RecipientInterface] = [] + for r in recipient[3]: + recipients.append(cls._create_recipient(r)) + return Recipient.from_dict(protected, recipient[1], recipient[2], recipients) + def derive_key( self, keys: Optional[List[COSEKeyInterface]] = None, diff --git a/cwt/recipients_builder.py b/cwt/recipients_builder.py deleted file mode 100644 index 64354c2..0000000 --- a/cwt/recipients_builder.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Any, Dict, List, Optional - -from .cbor_processor import CBORProcessor -from .recipient import Recipient -from .recipient_interface import RecipientInterface -from .recipients import Recipients - - -class RecipientsBuilder(CBORProcessor): - """ - A Recipients Builder. - """ - - def __init__(self, options: Optional[Dict[str, Any]] = None): - self._options = options - return - - def from_list(self, recipients: List[Any]) -> Recipients: - """ - Create Recipient from a CBOR-like list. - """ - res: List[RecipientInterface] = [] - for r in recipients: - res.append(self._create_recipient(r)) - return Recipients(res) - - def _create_recipient(self, recipient: List[Any]) -> RecipientInterface: - if not isinstance(recipient, list) or ( - len(recipient) != 3 and len(recipient) != 4 - ): - raise ValueError("Invalid recipient format.") - if not isinstance(recipient[0], bytes): - raise ValueError("protected header should be bytes.") - protected = {} if not recipient[0] else self._loads(recipient[0]) - if not isinstance(recipient[1], dict): - raise ValueError("unprotected header should be dict.") - if not isinstance(recipient[2], bytes): - raise ValueError("ciphertext should be bytes.") - if len(recipient) == 3: - return Recipient.from_dict(protected, recipient[1], recipient[2]) - if not isinstance(recipient[3], list): - raise ValueError("recipients should be list.") - recipients: List[RecipientInterface] = [] - for r in recipient[3]: - recipients.append(self._create_recipient(r)) - return Recipient.from_dict(protected, recipient[1], recipient[2], recipients) diff --git a/tests/test_recipient.py b/tests/test_recipient.py index a0ca2d3..3284398 100644 --- a/tests/test_recipient.py +++ b/tests/test_recipient.py @@ -12,7 +12,6 @@ from cwt import COSEKey, Recipient from cwt.recipient_interface import RecipientInterface from cwt.recipients import Recipients -from cwt.recipients_builder import RecipientsBuilder @pytest.fixture(scope="session", autouse=True) @@ -404,34 +403,25 @@ def test_recipients_derive_key_with_different_kid(self): pytest.fail("derive_key() should fail.") assert "Failed to derive a key." in str(err.value) - -class TestRecipientsBuilder: - """ - Tests for RecipientsBuilder. - """ - - def test_recipients_builder_constructor(self): - rb = RecipientsBuilder() - assert isinstance(rb, RecipientsBuilder) - - def test_recipients_builder_from_list(self): - rb = RecipientsBuilder() + def test_recipients_from_list(self): try: - rb.from_list([[cbor2.dumps({1: -10}), {-20: b"aabbccddeefff"}, b""]]) + Recipients.from_list( + [[cbor2.dumps({1: -10}), {-20: b"aabbccddeefff"}, b""]] + ) except Exception: pytest.fail("from_list() should not fail.") - def test_recipients_builder_from_list_with_empty_recipients(self): - rb = RecipientsBuilder() + def test_recipients_from_list_with_empty_recipients(self): try: - rb.from_list([[cbor2.dumps({1: -10}), {-20: b"aabbccddeefff"}, b"", []]]) + Recipients.from_list( + [[cbor2.dumps({1: -10}), {-20: b"aabbccddeefff"}, b"", []]] + ) except Exception: pytest.fail("from_list() should not fail.") - def test_recipients_builder_from_list_with_recipients(self): - rb = RecipientsBuilder() + def test_recipients_from_list_with_recipients(self): try: - rb.from_list( + Recipients.from_list( [ [ cbor2.dumps({1: -10}), @@ -455,6 +445,7 @@ def test_recipients_builder_from_list_with_recipients(self): ([["", {}, b""]], "protected header should be bytes."), ([[{}, {}, b""]], "protected header should be bytes."), ([[[], {}, b""]], "protected header should be bytes."), + ([[[], {}, b""]], "protected header should be bytes."), ([[123, {}, b""]], "protected header should be bytes."), ([[b"", [], b""]], "unprotected header should be dict."), ([[b"", "", b""]], "unprotected header should be dict."), @@ -470,9 +461,8 @@ def test_recipients_builder_from_list_with_recipients(self): ([[b"", {}, b"", 123]], "recipients should be list."), ], ) - def test_recipients_builder_from_list_with_invalid_args(self, invalid, msg): - rb = RecipientsBuilder() + def test_recipients_from_list_with_invalid_args(self, invalid, msg): with pytest.raises(ValueError) as err: - rb.from_list(invalid) + Recipients.from_list(invalid) pytest.fail("derive_key() should fail.") assert msg in str(err.value)