diff --git a/CHANGES.rst b/CHANGES.rst index 76755f5..a27ded8 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,7 @@ Changes Unreleased ---------- +- Introduce serializer/deserializer for payload. `#67 `__ - Sync official test vectors. `#64 `__ Version 1.1.0 diff --git a/pyseto/pyseto.py b/pyseto/pyseto.py index b104e03..19ecf63 100644 --- a/pyseto/pyseto.py +++ b/pyseto/pyseto.py @@ -1,4 +1,5 @@ -from typing import List, Union +import json +from typing import Any, List, Optional, Union from .key_interface import KeyInterface from .token import Token @@ -7,17 +8,19 @@ def encode( key: KeyInterface, - payload: Union[bytes, str], + payload: Union[bytes, str, dict], footer: Union[bytes, str] = b"", implicit_assertion: Union[bytes, str] = b"", nonce: bytes = b"", + serializer: Any = json, ) -> bytes: + """ Encodes a message to a PASETO token with a key for encryption or signing. Args: key (KeyInterface): A key for encryption or signing. - payload (Union[bytes, str]): A message to be encrypted or signed. + payload (Union[bytes, str, dict]): A message to be encrypted or signed. footer (Union[bytes, str]): A footer. implicit_assertion (Union[bytes, str]): An implicit assertion. It is only used in ``v3`` or ``v4``. @@ -25,6 +28,9 @@ def encode( generated with ``secrets.token_bytes()`` internally. If you don't want ot use ``secrets.token_bytes()``, you can specify it via this parameter explicitly. + serializer (Any): A serializer which is used when the type of + ``payload`` is ``object``. It must have a ``dumps()`` function to + serialize the payload. Typically, you can use ``json`` or ``cbor2``. Returns: bytes: A PASETO token. Raise: @@ -32,9 +38,27 @@ def encode( EncryptError: Failed to encrypt the message. SignError: Failed to sign the message. """ - if not isinstance(payload, (bytes, str)): - raise ValueError("payload should be bytes or str.") - bp = payload if isinstance(payload, bytes) else payload.encode("utf-8") + if not isinstance(payload, (bytes, str, dict)): + raise ValueError("payload should be bytes, str or dict.") + + bp: bytes + if isinstance(payload, dict): + if not serializer: + raise ValueError("serializer should be specified for the payload object.") + try: + if not callable(serializer.dumps): + raise ValueError("serializer should have dumps().") + except AttributeError: + raise ValueError("serializer should have dumps().") + except Exception: + raise + try: + bp = serializer.dumps(payload).encode("utf-8") + except Exception as err: + raise ValueError("Failed to serialize the payload.") from err + else: + bp = payload if isinstance(payload, bytes) else payload.encode("utf-8") + bf = footer if isinstance(footer, bytes) else footer.encode("utf-8") bi = ( implicit_assertion @@ -55,7 +79,9 @@ def decode( keys: Union[KeyInterface, List[KeyInterface]], token: Union[bytes, str], implicit_assertion: Union[bytes, str] = b"", + deserializer: Optional[Any] = None, ) -> Token: + """ Decodes a PASETO token with a key for decryption and/or verifying. @@ -64,6 +90,10 @@ def decode( token (Union[bytes, str]): A PASETO token to be decrypted or verified. implicit_assertion (Union[bytes, str]): An implicit assertion. It is only used in ``v3`` or ``v4``. + deserializer (Optional[Any]): A deserializer which is used when you want to + deserialize a ``payload`` attribute in the response object. It must have a + ``loads()`` function to deserialize the payload. Typically, you can use + ``json`` or ``cbor2``. Returns: Token: A parsed PASETO token object. Raise: @@ -71,6 +101,15 @@ def decode( DecryptError: Failed to decrypt the message. VerifyError: Failed to verify the message. """ + if deserializer: + try: + if not callable(deserializer.loads): + raise ValueError("deserializer should have loads().") + except AttributeError: + raise ValueError("deserializer should have loads().") + except Exception: + raise + keys = keys if isinstance(keys, list) else [keys] bi = ( implicit_assertion @@ -88,6 +127,11 @@ def decode( t.payload = k.decrypt(t.payload, t.footer, bi) return t t.payload = k.verify(t.payload, t.footer, bi) + try: + if deserializer: + t.payload = deserializer.loads(t.payload) + except Exception as err: + raise ValueError("Failed to deserialize the payload.") from err return t except Exception as err: failed = err diff --git a/tests/test_pyseto.py b/tests/test_pyseto.py index d063058..5186eaa 100644 --- a/tests/test_pyseto.py +++ b/tests/test_pyseto.py @@ -6,6 +6,26 @@ from .utils import load_key +class InvalidSerializer: + def __init__(self): + self.dumps = "not a function." + + +class InvalidSerializer2: + def dumps(self, *args): + raise NotImplementedError("Not implemented") + + +class InvalidDeserializer: + def __init__(self): + self.loads = "not a function." + + +class InvalidDeserializer2: + def loads(self, *args): + raise NotImplementedError("Not implemented") + + class TestPyseto: """ Tests for pyseto.encode and decode. @@ -43,6 +63,59 @@ def test_encode_with_public_key(self, version, key, msg): pytest.fail("pyseto.encode() should fail.") assert msg in str(err.value) + @pytest.mark.parametrize( + "serializer, msg", + [ + ( + None, + "serializer should be specified for the payload object.", + ), + ( + {}, + "serializer should be specified for the payload object.", + ), + ( + [], + "serializer should be specified for the payload object.", + ), + ( + "", + "serializer should be specified for the payload object.", + ), + ( + b"", + "serializer should be specified for the payload object.", + ), + ( + {"key": "value"}, + "serializer should have dumps().", + ), + ( + InvalidSerializer(), + "serializer should have dumps().", + ), + ( + InvalidSerializer2(), + "Failed to serialize the payload.", + ), + ], + ) + def test_encode_object_payload_with_invalid_serializer(self, serializer, msg): + private_key_pem = b"-----BEGIN PRIVATE KEY-----\nMC4CAQAwBQYDK2VwBCIEILTL+0PfTOIQcn2VPkpxMwf6Gbt9n4UEFDjZ4RuUKjd0\n-----END PRIVATE KEY-----" + + private_key = Key.new(version=4, purpose="public", key=private_key_pem) + with pytest.raises(ValueError) as err: + pyseto.encode( + private_key, + { + "data": "this is a signed message", + "exp": "2022-01-01T00:00:00+00:00", + }, + serializer=serializer, + ) + pytest.fail("pyseto.encode() should fail.") + assert msg in str(err.value) + @pytest.mark.parametrize( "version, key, msg", [ @@ -76,6 +149,37 @@ def test_decode_with_another_version_key(self, version, public_key): pytest.fail("pyseto.decode() should fail.") assert "key is not found for verifying the token." in str(err.value) + @pytest.mark.parametrize( + "deserializer, msg", + [ + ( + {"key": "value"}, + "deserializer should have loads().", + ), + ( + InvalidDeserializer(), + "deserializer should have loads().", + ), + ( + InvalidDeserializer2(), + "Failed to deserialize the payload.", + ), + ], + ) + def test_decode_object_payload_with_invalid_deserializer(self, deserializer, msg): + private_key_pem = b"-----BEGIN PRIVATE KEY-----\nMC4CAQAwBQYDK2VwBCIEILTL+0PfTOIQcn2VPkpxMwf6Gbt9n4UEFDjZ4RuUKjd0\n-----END PRIVATE KEY-----" + public_key_pem = b"-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAHrnbu7wEfAP9cGBOAHHwmH4Wsot1ciXBHwBBXQ4gsaI=\n-----END PUBLIC KEY-----" + private_key = Key.new(version=4, purpose="public", key=private_key_pem) + token = pyseto.encode( + private_key, + {"data": "this is a signed message", "exp": "2022-01-01T00:00:00+00:00"}, + ) + public_key = Key.new(version=4, purpose="public", key=public_key_pem) + with pytest.raises(ValueError) as err: + pyseto.decode(public_key, token, deserializer=deserializer) + pytest.fail("pyseto.decode() should fail.") + assert msg in str(err.value) + def test_decode_with_empty_list_of_keys(self): sk = Key.new(4, "public", load_key("keys/private_key_ed25519.pem")) token = pyseto.encode(sk, "Hello world!") diff --git a/tests/test_sample.py b/tests/test_sample.py index 6658664..7ab70f6 100644 --- a/tests/test_sample.py +++ b/tests/test_sample.py @@ -1,3 +1,4 @@ +import json from secrets import token_bytes import pyseto @@ -83,6 +84,26 @@ def test_sample_v4_public(self): == b'{"data": "this is a signed message", "exp": "2022-01-01T00:00:00+00:00"}' ) + def test_sample_v4_public_with_serializer(self): + + private_key_pem = b"-----BEGIN PRIVATE KEY-----\nMC4CAQAwBQYDK2VwBCIEILTL+0PfTOIQcn2VPkpxMwf6Gbt9n4UEFDjZ4RuUKjd0\n-----END PRIVATE KEY-----" + public_key_pem = b"-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAHrnbu7wEfAP9cGBOAHHwmH4Wsot1ciXBHwBBXQ4gsaI=\n-----END PUBLIC KEY-----" + + private_key = Key.new(version=4, purpose="public", key=private_key_pem) + token = pyseto.encode( + private_key, + {"data": "this is a signed message", "exp": "2022-01-01T00:00:00+00:00"}, + ) + public_key = Key.new(version=4, purpose="public", key=public_key_pem) + decoded = pyseto.decode(public_key, token, deserializer=json) + + assert ( + token + == b"v4.public.eyJkYXRhIjogInRoaXMgaXMgYSBzaWduZWQgbWVzc2FnZSIsICJleHAiOiAiMjAyMi0wMS0wMVQwMDowMDowMCswMDowMCJ9l1YiKei2FESvHBSGPkn70eFO1hv3tXH0jph1IfZyEfgm3t1DjkYqD5r4aHWZm1eZs_3_bZ9pBQlZGp0DPSdzDg" + ) + assert decoded.payload["data"] == "this is a signed message" + assert decoded.payload["exp"] == "2022-01-01T00:00:00+00:00" + def test_sample_paserk(self): symmetric_key = Key.new(version=4, purpose="local", key=b"our-secret") diff --git a/tests/test_with_test_vectors.py b/tests/test_with_test_vectors.py index d4a2a35..a1d71d7 100644 --- a/tests/test_with_test_vectors.py +++ b/tests/test_with_test_vectors.py @@ -93,7 +93,7 @@ def test_with_test_vectors(self, v): with pytest.raises(ValueError) as err: pyseto.encode(k, payload, footer, implicit_assertion, nonce=nonce) pytest.fail("encode should fail.") - assert "payload should be bytes or str." in str(err.value) + assert "payload should be bytes, str or dict." in str(err.value) return secret_key_pem = v["secret-key"] if version == 1 else v["secret-key-pem"] @@ -103,7 +103,7 @@ def test_with_test_vectors(self, v): with pytest.raises(ValueError) as err: pyseto.encode(sk, payload, footer, implicit_assertion) pytest.fail("encode should fail.") - assert "payload should be bytes or str." in str(err.value) + assert "payload should be bytes, str or dict." in str(err.value) return payload = payload.encode("utf-8")