Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle recursively serializing a dataclasses as a dictionary. #547

Merged
merged 19 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Change log
- [#547](https://github.com/mobilityhouse/ocpp/pull/547) Feat: Handle recursively serializing a dataclasses as a dictionary
- [#557](https://github.com/mobilityhouse/ocpp/issues/557) OCPP 2.0.1 Wrong data type in CostUpdated total_cost
- [#564](https://github.com/mobilityhouse/ocpp/issues/564) Add support For Python 3.11 and 3.12
- [#583](https://github.com/mobilityhouse/ocpp/issues/583) OCPP v1.6/v2.0.1 deprecate dataclasses from calls and call results with the suffix 'Payload'
Expand Down
73 changes: 69 additions & 4 deletions ocpp/charge_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import re
import time
import uuid
from dataclasses import asdict
from typing import Dict, List, Union
from dataclasses import Field, asdict, is_dataclass
from typing import Any, Dict, List, Union, get_args, get_origin

from ocpp.exceptions import NotImplementedError, NotSupportedError, OCPPError
from ocpp.messages import Call, MessageType, unpack, validate_payload
Expand Down Expand Up @@ -71,6 +71,71 @@ def snake_to_camel_case(data):
return data


def _is_dataclass_instance(input: Any) -> bool:
"""Verify if given `input` is a dataclass."""
return is_dataclass(input) and not isinstance(input, type)


def _is_optional_field(field: Field) -> bool:
"""Verify if given `field` allows `None` as value.

The fields `schema` and `host` on the following class would return `False`.
While the fields `post` and `query` return `True`.

@dataclass
class URL:
schema: str,
host: str,
post: Optional[str],
query: Union[None, str]

"""
return get_origin(field.type) is Union and type(None) in get_args(field.type)


def serialize_as_dict(dataclass):
"""Serialize the given `dataclass` as a `dict` recursively.

@dataclass
class StatusInfoType:
reason_code: str
additional_info: Optional[str] = None

with_additional_info = StatusInfoType(
reason="Unknown",
additional_info="More details"
)

assert serialize_as_dict(with_additional_info) == {
'reason': 'Unknown',
'additional_info': 'More details',
}

without_additional_info = StatusInfoType(reason="Unknown")

assert serialize_as_dict(with_additional_info) == {
'reason': 'Unknown',
'additional_info': None,
}

"""
serialized = asdict(dataclass)

for field in dataclass.__dataclass_fields__.values():

value = getattr(dataclass, field.name)
if _is_dataclass_instance(value):
serialized[field.name] = serialize_as_dict(value)
continue

if isinstance(value, list):
for item in value:
if _is_dataclass_instance(item):
serialized[field.name] = [serialize_as_dict(item)]

return serialized


def remove_nones(data: Union[List, Dict]) -> Union[List, Dict]:
if isinstance(data, dict):
return {k: remove_nones(v) for k, v in data.items() if v is not None}
Expand Down Expand Up @@ -244,7 +309,7 @@ async def _handle_call(self, msg):

return

temp_response_payload = asdict(response)
temp_response_payload = serialize_as_dict(response)

# Remove nones ensures that we strip out optional arguments
# which were not set and have a default value of None
Expand Down Expand Up @@ -306,7 +371,7 @@ async def call(self, payload, suppress=True, unique_id=None):
CallError.

"""
camel_case_payload = snake_to_camel_case(asdict(payload))
camel_case_payload = snake_to_camel_case(serialize_as_dict(payload))

unique_id = (
unique_id if unique_id is not None else str(self._unique_id_generator())
Expand Down
66 changes: 62 additions & 4 deletions tests/test_charge_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

import pytest

from ocpp.charge_point import camel_to_snake_case, remove_nones, snake_to_camel_case
from ocpp.charge_point import (
camel_to_snake_case,
remove_nones,
serialize_as_dict,
snake_to_camel_case,
)
from ocpp.messages import Call
from ocpp.routing import after, create_route_map, on
from ocpp.v16 import ChargePoint as cp_16
Expand All @@ -11,8 +16,15 @@
from ocpp.v16.datatypes import MeterValue, SampledValue
from ocpp.v16.enums import Action, RegistrationStatus
from ocpp.v201 import ChargePoint as cp_201
from ocpp.v201.call import SetNetworkProfile
from ocpp.v201.datatypes import NetworkConnectionProfileType
from ocpp.v201.call import GetVariables as v201GetVariables
from ocpp.v201.call import SetNetworkProfile as v201SetNetworkProfile
from ocpp.v201.datatypes import (
ComponentType,
EVSEType,
GetVariableDataType,
NetworkConnectionProfileType,
VariableType,
)
from ocpp.v201.enums import OCPPInterfaceType, OCPPTransportType, OCPPVersionType


Expand Down Expand Up @@ -112,7 +124,9 @@ def test_nested_remove_nones():
apn=None,
)

payload = SetNetworkProfile(configuration_slot=1, connection_data=connection_data)
payload = v201SetNetworkProfile(
configuration_slot=1, connection_data=connection_data
)
payload = asdict(payload)

assert expected_payload == remove_nones(payload)
Expand Down Expand Up @@ -233,6 +247,50 @@ def test_remove_nones_with_list_of_strings():
}


def test_serialize_as_dict():
"""
Test recursively serializing a dataclasses as a dictionary.
"""
# Setup
expected = camel_to_snake_case(
{
"getVariableData": [
{
"component": {
"name": "Component",
"instance": None,
"evse": {
"id": 1,
"connectorId": None,
},
},
"variable": {
"name": "Variable",
"instance": None,
},
"attributeType": None,
}
],
"customData": None,
}
)

payload = v201GetVariables(
get_variable_data=[
GetVariableDataType(
component=ComponentType(
name="Component",
evse=EVSEType(id=1),
),
variable=VariableType(name="Variable"),
)
]
)

# Execute / Assert
assert serialize_as_dict(payload) == expected


@pytest.mark.asyncio
async def test_call_unique_id_added_to_handler_args_correctly(connection):
"""
Expand Down