Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: convert str to iterable (#22) #23

Merged
merged 1 commit into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions supadantic/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
from abc import ABC, abstractmethod
from typing import Any
from copy import copy
from typing import Any, Dict

from pydantic import BaseModel
from pydantic import BaseModel, model_validator
from pydantic._internal._model_construction import ModelMetaclass as PydanticModelMetaclass
from typing_extensions import Self

Expand Down Expand Up @@ -50,3 +52,27 @@ def delete(self: Self) -> None:
if self.id:
db_client = self._get_db_client()
db_client.delete(id=self.id)

@model_validator(mode='before')
def _validate_data_from_supabase(cls, data: Dict[str, Any]) -> Dict[str, Any]:
array_fields = []
result_dict = copy(data)

for key, value in cls.model_json_schema()['properties'].items():
_field_is_array = any(
(
# If field is required, it's possible to get type
value.get('type', None) == 'array',
# If field is optional, it's possible to get type from anyOf array
any(item.get('type', None) == 'array' for item in value.get('anyOf', [])),
)
)

if _field_is_array:
array_fields.append(key)

for key, value in data.items():
if key in array_fields and isinstance(value, str):
result_dict[key] = json.loads(data[key])

return result_dict
4 changes: 4 additions & 0 deletions tests/test_classes/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import List, Tuple

from supadantic.models import BaseSBModel

from .supabase_client import SupabaseClientMock


class ModelMock(BaseSBModel):
name: str
some_optional_list: List[str] | None = None
some_optional_tuple: Tuple[int] | None = None

@classmethod
def _get_table_name(cls) -> str:
Expand Down
14 changes: 12 additions & 2 deletions tests/test_classes/supabase_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import copy
from typing import Any, Dict, Iterable, List

from supadantic.clients.base import BaseClient
Expand All @@ -7,13 +8,22 @@ class SupabaseClientMock(BaseClient):
def __init__(self, table_name: str):
pass

def _get_return_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
result_data = copy(data)

for key, value in data.items():
if type(value) in (list, tuple):
result_data.update({key: str(value)})

return result_data

def insert(self, data: Dict[str, Any]) -> Dict[str, Any]:
data = dict(id=1, **data)
return data
return self._get_return_data(data=data)

def update(self, *, id: int, data: Dict[str, Any]) -> Dict[str, Any]:
data['id'] = id
return data
return self._get_return_data(data=data)

def select(self, *, eq: Dict[str, Any] | None = None, neq: Dict[str, Any] | None = None) -> List[Dict[str, Any]]:
data = [
Expand Down
6 changes: 3 additions & 3 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class TestBaseSBModel:
class TestSave:
def test_create(self):
# Prepare data
test_entity = ModelMock(name='test_name')
test_entity = ModelMock(name='test_name', some_array=['foo', 'bar'])

# Execution
updated_entity = test_entity.save()
Expand All @@ -18,7 +18,7 @@ def test_create(self):

def test_update(self):
# Prepare data
test_entity = ModelMock(id=2, name='test_name')
test_entity = ModelMock(id=2, name='test_name', some_array=['foo', 'bar'])

# Execution
updated_entity = test_entity.save()
Expand All @@ -28,4 +28,4 @@ def test_update(self):
assert updated_entity.name == 'test_name'

def test_objects(self):
assert ModelMock.objects == QSet(model_class=ModelMock)
assert ModelMock.objects == QSet(model_class=ModelMock) # pyright: ignore
Loading