Skip to content

Commit

Permalink
feat: Handle nested models more gracefully; add per-field data prep
Browse files Browse the repository at this point in the history
  • Loading branch information
eykd committed Feb 21, 2024
1 parent f258be4 commit 50d6666
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 10 deletions.
10 changes: 8 additions & 2 deletions src/steerage/repositories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,13 @@ def transform_data_to_entity(self, data: Mapping) -> TEntity:

def prepare_data_for_entity(self, data: Mapping) -> Mapping:
"""Template method: transform stored data into entity-ready data."""
return data
out = {}
for key, value in data.items():
prepare = getattr(self, f'prepare_{key}', None)
if prepare is not None:
value = prepare(value, data)
out[key] = value
return out


@dataclass
Expand Down Expand Up @@ -415,7 +421,7 @@ async def update_attrs(self, id: UUIDorStr, **kwargs) -> None:
inserted will raise `NotFound`.
"""
entity = await self.get(id)
entity = entity.model_copy(update=kwargs)
entity = entity.model_copy(update=kwargs, deep=True)
await self.update(entity)


Expand Down
2 changes: 1 addition & 1 deletion src/steerage/repositories/memdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def run_update_query(self, **kwargs) -> int:
"""Run this as an update query against the backend."""
count = 0
async for entity in self:
new_entity = entity.model_copy(update=kwargs)
new_entity = entity.model_copy(update=self.prepare_data_for_entity(kwargs))
self._upsert(self.transform_entity_to_data(new_entity))
count += 1
return count
Expand Down
2 changes: 1 addition & 1 deletion src/steerage/repositories/shelvedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def run_update_query(self, **kwargs) -> int:
"""Run this as an update query against the backend."""
count = 0
async for entity in self:
new_entity = entity.model_copy(update=kwargs)
new_entity = entity.model_copy(update=self.prepare_data_for_entity(kwargs))
key = self._get_key(new_entity.id)
self._upsert(key, new_entity.model_dump())
count += 1
Expand Down
59 changes: 53 additions & 6 deletions tests/test_repository_implementations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# ruff: noqa: D100, D101, D102, D103
from collections.abc import Mapping
from datetime import datetime, timedelta
from typing import Any
from uuid import UUID, uuid5

import factory
Expand All @@ -12,7 +14,7 @@
from pydantic import BaseModel, ConfigDict, Field
from pydantic.types import AwareDatetime

from steerage.repositories.base import AbstractEntityRepository
from steerage.repositories.base import AbstractEntityRepository, AbstractBaseQuery
from steerage.repositories.memdb import (
AbstractInMemoryQuery,
AbstractInMemoryRepository,
Expand All @@ -35,17 +37,30 @@
NAMESPACE = UUID("dbe2dff9-122e-4718-924f-710073c33b53")


class SubEntity(BaseModel):
bar: str

model_config = ConfigDict(frozen=True)


class Entity(BaseModel):
id: UUID
foo: str
num: int
is_odd: bool
oddish: bool | None
sub: SubEntity | None
created_at: AwareDatetime = Field(default_factory=utcnow)
finished_at: AwareDatetime | None = None

model_config = ConfigDict(frozen=True)

class SubEntityFactory(factory.Factory):
class Meta:
model = SubEntity

bar = 'blah'


class EntityFactory(factory.Factory):
class Meta:
Expand All @@ -56,10 +71,18 @@ class Meta:
is_odd = factory.Sequence(lambda n: bool(n % 2))
foo = factory.LazyAttribute(lambda obj: f"baz{obj.num}" if obj.is_odd else f"bar{obj.num}")
oddish = factory.LazyAttribute(lambda obj: True if obj.is_odd else None)
sub = factory.SubFactory(SubEntityFactory)
created_at = factory.Sequence(lambda n: datetime(2023, 12, 15, 12, 0, tzinfo=pytz.UTC) - timedelta(hours=n))


class InMemoryEntityQuery(AbstractInMemoryQuery):
class AbstractEntityQuery(AbstractBaseQuery):
def prepare_sub(self, value: Mapping, data: Mapping) -> SubEntity:
if value is not None:
value = SubEntity.model_construct(**value)
return value


class InMemoryEntityQuery(AbstractEntityQuery, AbstractInMemoryQuery):
table_name: str = "entities"
entity_class = Entity

Expand All @@ -70,7 +93,7 @@ class InMemoryEntityRepository(AbstractInMemoryRepository):
query_class = InMemoryEntityQuery


class ShelveEntityQuery(AbstractShelveQuery):
class ShelveEntityQuery(AbstractEntityQuery, AbstractShelveQuery):
table_name: str = "entities"
entity_class = Entity

Expand All @@ -90,15 +113,28 @@ class ShelveEntityRepository(AbstractShelveRepository):
sa.Column("num", sa.Integer),
sa.Column("is_odd", sa.Boolean),
sa.Column("oddish", sa.Boolean, nullable=True),
sa.Column("sub_bar", sa.String, nullable=True),
sa.Column("created_at", AwareDateTime, nullable=True),
sa.Column("finished_at", AwareDateTime, nullable=True),
)


class SQLEntityQuery(AbstractSQLQuery):
class SQLEntityQuery(AbstractEntityQuery, AbstractSQLQuery):
table = ENTITY_TABLE
entity_class = Entity

def transform_entity_to_data(self, entity: Entity) -> dict[str, Any]:
data = entity.model_dump()
sub_bar = data.pop("sub")
if sub_bar is not None:
data["sub_bar"] = sub_bar["bar"]
return data

def prepare_data_for_entity(self, data: Mapping) -> Mapping:
sub_bar = data.pop("sub_bar")
data['sub'] = {'bar': sub_bar}
return super().prepare_data_for_entity(data)


class SQLEntityRepository(AbstractSQLRepository):
entity_class = Entity
Expand Down Expand Up @@ -191,6 +227,17 @@ async def test_it_should_update_an_entity(self, repo: AbstractEntityRepository,

assert result.foo == "blah"

async def test_it_should_update_a_subentity(self, repo: AbstractEntityRepository, stored_entity: Entity):
entity = stored_entity.model_copy(update={"sub": SubEntity(bar="banana")})
async with repo:
await repo.update(entity)
await repo.commit()

async with repo:
result = await repo.get(stored_entity.id)

assert result.sub.bar == "banana"

async def test_it_should_fail_to_update_a_nonexistent_entity(
self, repo: AbstractEntityRepository, entity: Entity
):
Expand Down Expand Up @@ -539,13 +586,13 @@ async def test_it_should_repr_a_short_cached_query(self, repo, stored_entities):
query = repo.objects.filter(num__lt=2).order_by("num")
await alist(query)
result = repr(query)
expected = "<InMemoryEntityQuery [Entity(id=UUID('de509355-5376-5405-a36d-91caed2ba8d1'), foo='bar0', num=0, is_odd=False, oddish=None, created_at=datetime.datetime(2023, 12, 15, 12, 0, tzinfo=<UTC>), finished_at=None), Entity(id=UUID('8db9b404-f276-5674-8006-12b74a8c62e3'), foo='baz1', num=1, is_odd=True, oddish=True, created_at=datetime.datetime(2023, 12, 15, 11, 0, tzinfo=<UTC>), finished_at=None)]>"
expected = "<InMemoryEntityQuery [Entity(id=UUID('de509355-5376-5405-a36d-91caed2ba8d1'), foo='bar0', num=0, is_odd=False, oddish=None, sub=SubEntity(bar='blah'), created_at=datetime.datetime(2023, 12, 15, 12, 0, tzinfo=<UTC>), finished_at=None), Entity(id=UUID('8db9b404-f276-5674-8006-12b74a8c62e3'), foo='baz1', num=1, is_odd=True, oddish=True, sub=SubEntity(bar='blah'), created_at=datetime.datetime(2023, 12, 15, 11, 0, tzinfo=<UTC>), finished_at=None)]>"
assert result == expected

async def test_it_should_repr_a_long_cached_query(self, repo, stored_entities):
async with repo:
query = repo.objects.order_by("num").all()
await alist(query)
result = repr(query)
expected = "<InMemoryEntityQuery [Entity(id=UUID('de509355-5376-5405-a36d-91caed2ba8d1'), foo='bar0', num=0, is_odd=False, oddish=None, created_at=datetime.datetime(2023, 12, 15, 12, 0, tzinfo=<UTC>), finished_at=None), Entity(id=UUID('8db9b404-f276-5674-8006-12b74a8c62e3'), foo='baz1', num=1, is_odd=True, oddish=True, created_at=datetime.datetime(2023, 12, 15, 11, 0, tzinfo=<UTC>), finished_at=None), Entity(id=UUID('745da407-8c19-59d1-9a0e-8be54c7ac605'), foo='bar2', num=2, is_odd=False, oddish=None, created_at=datetime.datetime(2023, 12, 15, 10, 0, tzinfo=<UTC>), finished_at=None), '...(remaining elements truncated)...']>"
expected = "<InMemoryEntityQuery [Entity(id=UUID('de509355-5376-5405-a36d-91caed2ba8d1'), foo='bar0', num=0, is_odd=False, oddish=None, sub=SubEntity(bar='blah'), created_at=datetime.datetime(2023, 12, 15, 12, 0, tzinfo=<UTC>), finished_at=None), Entity(id=UUID('8db9b404-f276-5674-8006-12b74a8c62e3'), foo='baz1', num=1, is_odd=True, oddish=True, sub=SubEntity(bar='blah'), created_at=datetime.datetime(2023, 12, 15, 11, 0, tzinfo=<UTC>), finished_at=None), Entity(id=UUID('745da407-8c19-59d1-9a0e-8be54c7ac605'), foo='bar2', num=2, is_odd=False, oddish=None, sub=SubEntity(bar='blah'), created_at=datetime.datetime(2023, 12, 15, 10, 0, tzinfo=<UTC>), finished_at=None), '...(remaining elements truncated)...']>"
assert result == expected

0 comments on commit 50d6666

Please sign in to comment.