Skip to content

Commit

Permalink
Change: Shrink the signature of all Model classes
Browse files Browse the repository at this point in the history
Remove the static model creation methods  `_get_value` and
`_get_value_from_model_field_cls` from the Model class and refactor them
to functions.

This methods are just used during model creation and afterwards should
not be available on the Model instance.
  • Loading branch information
bjoernricks committed Nov 17, 2023
1 parent 11e8cdd commit b1d96b0
Showing 1 changed file with 53 additions and 61 deletions.
114 changes: 53 additions & 61 deletions pontos/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,71 +70,63 @@ class ModelAttribute:
"""


def _get_value_from_model_field_cls(
model_field_cls: Type[Any], value: Any
) -> Any:
if isclass(model_field_cls) and issubclass(model_field_cls, Model):
value = model_field_cls.from_dict(value)
elif isclass(model_field_cls) and issubclass(model_field_cls, datetime):
# Only Python 3.11 supports sufficient formats in
# datetime.fromisoformat. Therefore we have to use dateutil here.
value = dateparser.isoparse(value)
# the iso format may not contain UTC data or a UTC offset
# this means it is considered local time (Python calls this "naive"
# datetime) and can't really be compared to other times. maybe we
# should always assume UTC for these formats.
# This could be done the following:
# if not value.tzinfo:
# value = value.replace(tzinfo=timezone.utc)
elif isclass(model_field_cls) and issubclass(model_field_cls, date):
value = date.fromisoformat(value)
elif get_origin(model_field_cls) == list:
model_field_cls = get_args(model_field_cls)[0]
value = _get_value_from_model_field_cls(model_field_cls, value)
elif get_origin(model_field_cls) == dict:
model_field_cls = dict
value = _get_value_from_model_field_cls(model_field_cls, value)
elif get_origin(model_field_cls) == Union:
possible_types = get_args(model_field_cls)
current_type = type(value)
if current_type in possible_types:
model_field_cls = current_type
else:
# currently Unions should not contain Models. this would require
# to iterate over the possible type, check if it is a Model
# class and try to create an instance of this class until it
# fits. For now just fallback to first type
model_field_cls = possible_types[0]

value = _get_value_from_model_field_cls(model_field_cls, value)
else:
if isinstance(value, dict):
value = model_field_cls(**value)
else:
value = model_field_cls(value)
return value


def _get_value(model_field_cls: Type[Any], value: Any) -> Any:
if model_field_cls:
value = _get_value_from_model_field_cls(model_field_cls, value)
return value


@dataclass(init=False)
class Model:
"""
Base class for models
"""

@staticmethod
def _get_value_from_model_field_cls(
model_field_cls: Type[Any], value: Any
) -> Any:
if isclass(model_field_cls) and issubclass(model_field_cls, Model):
value = model_field_cls.from_dict(value)
elif isclass(model_field_cls) and issubclass(model_field_cls, datetime):
# Only Python 3.11 supports sufficient formats in
# datetime.fromisoformat. Therefore we have to use dateutil here.
value = dateparser.isoparse(value)
# the iso format may not contain UTC data or a UTC offset
# this means it is considered local time (Python calls this "naive"
# datetime) and can't really be compared to other times. maybe we
# should always assume UTC for these formats.
# This could be done the following:
# if not value.tzinfo:
# value = value.replace(tzinfo=timezone.utc)
elif isclass(model_field_cls) and issubclass(model_field_cls, date):
value = date.fromisoformat(value)
elif get_origin(model_field_cls) == list:
model_field_cls = get_args(model_field_cls)[0]
value = Model._get_value_from_model_field_cls(
model_field_cls, value
)
elif get_origin(model_field_cls) == dict:
model_field_cls = dict
value = Model._get_value_from_model_field_cls(
model_field_cls, value
)
elif get_origin(model_field_cls) == Union:
possible_types = get_args(model_field_cls)
current_type = type(value)
if current_type in possible_types:
model_field_cls = current_type
else:
# currently Unions should not contain Models. this would require
# to iterate over the possible type, check if it is a Model
# class and try to create an instance of this class until it
# fits. For now just fallback to first type
model_field_cls = possible_types[0]

value = Model._get_value_from_model_field_cls(
model_field_cls, value
)
else:
if isinstance(value, dict):
value = model_field_cls(**value)
else:
value = model_field_cls(value)
return value

@staticmethod
def _get_value(model_field_cls: Type[Any], value: Any) -> Any:
if model_field_cls:
value = Model._get_value_from_model_field_cls(
model_field_cls, value
)
return value

@classmethod
def from_dict(cls, data: Dict[str, Any]):
"""
Expand Down Expand Up @@ -163,10 +155,10 @@ def from_dict(cls, data: Dict[str, Any]):
try:
if isinstance(value, list):
model_field_cls = type_hints.get(name)
value = [cls._get_value(model_field_cls, v) for v in value] # type: ignore # pylint: disable=line-too-long # noqa: E501,PLW2901
value = [_get_value(model_field_cls, v) for v in value] # type: ignore # pylint: disable=line-too-long # noqa: E501,PLW2901
elif value is not None:
model_field_cls = type_hints.get(name)
value = cls._get_value(model_field_cls, value) # type: ignore # pylint: disable=line-too-long # noqa: E501,PLW2901
value = _get_value(model_field_cls, value) # type: ignore # pylint: disable=line-too-long # noqa: E501,PLW2901
except (ValueError, TypeError) as e:
raise ModelError(
f"Error while creating {cls.__name__} model. Could not set "
Expand Down

0 comments on commit b1d96b0

Please sign in to comment.