Skip to content

Commit

Permalink
fix: accept None in model_from_dict to allow setting model fields… (
Browse files Browse the repository at this point in the history
#141)

Co-authored-by: Cody Fincher <204685+cofin@users.noreply.github.com>
Co-authored-by: Cody Fincher <cody.fincher@gmail.com>
  • Loading branch information
3 people committed Apr 2, 2024
1 parent 38c029e commit c791bed
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 62 deletions.
10 changes: 5 additions & 5 deletions advanced_alchemy/repository/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def get_instrumented_attr(model: type[ModelProtocol], key: str | InstrumentedAtt

def model_from_dict(model: ModelT, **kwargs: Any) -> ModelT:
"""Return ORM Object from Dictionary."""
data = {}
for column_name in model.__mapper__.columns.keys(): # noqa: SIM118
column_val = kwargs.get(column_name, None)
if column_val is not None:
data[column_name] = column_val
data = {
column_name: kwargs[column_name]
for column_name in model.__mapper__.columns.keys() # noqa: SIM118
if column_name in kwargs
}
return model(**data) # type: ignore # noqa: PGH003
8 changes: 2 additions & 6 deletions advanced_alchemy/repository/memory/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,14 +400,10 @@ async def add_many(self, data: list[ModelT], **_: Any) -> list[ModelT]:

async def update(self, data: ModelT, **_: Any) -> ModelT:
self._find_or_raise_not_found(self.__collection__().key(data))
self.__collection__().update(data)
return data
return self.__collection__().update(data)

async def update_many(self, data: list[ModelT], **_: Any) -> list[ModelT]:
for obj in data:
if obj in self.__collection__():
self.__collection__().update(obj)
return data
return [self.__collection__().update(obj) for obj in data if obj in self.__collection__()]

async def delete(self, item_id: Any, **_: Any) -> ModelT:
try:
Expand Down
8 changes: 2 additions & 6 deletions advanced_alchemy/repository/memory/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,14 +402,10 @@ def add_many(self, data: list[ModelT], **_: Any) -> list[ModelT]:

def update(self, data: ModelT, **_: Any) -> ModelT:
self._find_or_raise_not_found(self.__collection__().key(data))
self.__collection__().update(data)
return data
return self.__collection__().update(data)

def update_many(self, data: list[ModelT], **_: Any) -> list[ModelT]:
for obj in data:
if obj in self.__collection__():
self.__collection__().update(obj)
return data
return [self.__collection__().update(obj) for obj in data if obj in self.__collection__()]

def delete(self, item_id: Any, **_: Any) -> ModelT:
try:
Expand Down
86 changes: 42 additions & 44 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
[project]
authors = [
{name = "Cody Fincher", email = "cody.fincher@gmail.com"},
{name = "Peter Schutt", email = "peter.github@proton.me"},
{name = "Janek Nouvertné", email = "j.a.nouvertne@posteo.de"},
{name = "Jacob Coffee", email = "jacob@z7x.org"},
{ name = "Cody Fincher", email = "cody.fincher@gmail.com" },
{ name = "Peter Schutt", email = "peter.github@proton.me" },
{ name = "Janek Nouvertné", email = "j.a.nouvertne@posteo.de" },
{ name = "Jacob Coffee", email = "jacob@z7x.org" },
]
classifiers = [
"Development Status :: 3 - Alpha",
Expand Down Expand Up @@ -32,7 +32,7 @@ dependencies = [
]
description = "Ready-to-go SQLAlchemy concoctions."
keywords = ["sqlalchemy", "alembic", "litestar", "sanic", "fastapi", "flask"]
license = {text = "MIT"}
license = { text = "MIT" }
name = "advanced_alchemy"
readme = "README.md"
requires-python = ">=3.8"
Expand All @@ -48,9 +48,7 @@ Issue = "https://github.com/jolt-org/advanced-alchemy/issues/"
Source = "https://github.com/jolt-org/advanced-alchemy"

[project.optional-dependencies]
uuid = [
"uuid-utils>=0.6.1",
]
uuid = ["uuid-utils>=0.6.1"]
[build-system]
build-backend = "hatchling.build"
requires = ["hatchling"]
Expand Down Expand Up @@ -97,15 +95,15 @@ extensions = [
"flask-sqlalchemy>=3.1.1",
]
linting = [
"pre-commit>=3.4.0",
"black>=23.7.0",
"mypy>=1.5.1",
"ruff>=0.0.287",
"asyncpg-stubs",
"types-pytest-lazy-fixture",
"types-click",
"types-pyyaml",
"pyright>=1.1.350",
"pre-commit>=3.4.0",
"black>=23.7.0",
"mypy>=1.5.1",
"ruff>=0.0.287",
"asyncpg-stubs",
"types-pytest-lazy-fixture",
"types-click",
"types-pyyaml",
"pyright>=1.1.350",
]
test = [
"pytest>=7.4.1,<8.0.0",
Expand Down Expand Up @@ -191,32 +189,32 @@ target-version = "py38"
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
fixable = ["ALL"]
ignore = [
"A003", # flake8-builtins - class attribute {name} is shadowing a python builtin
"B010", # flake8-bugbear - do not call setattr with a constant attribute value
"D100", # pydocstyle - missing docstring in public module
"D101", # pydocstyle - missing docstring in public class
"D102", # pydocstyle - missing docstring in public method
"D103", # pydocstyle - missing docstring in public function
"D104", # pydocstyle - missing docstring in public package
"D105", # pydocstyle - missing docstring in magic method
"D106", # pydocstyle - missing docstring in public nested class
"D107", # pydocstyle - missing docstring in __init__
"D202", # pydocstyle - no blank lines allowed after function docstring
"D205", # pydocstyle - 1 blank line required between summary line and description
"D415", # pydocstyle - first line should end with a period, question mark, or exclamation point
"E501", # pycodestyle line too long, handled by black
"A003", # flake8-builtins - class attribute {name} is shadowing a python builtin
"B010", # flake8-bugbear - do not call setattr with a constant attribute value
"D100", # pydocstyle - missing docstring in public module
"D101", # pydocstyle - missing docstring in public class
"D102", # pydocstyle - missing docstring in public method
"D103", # pydocstyle - missing docstring in public function
"D104", # pydocstyle - missing docstring in public package
"D105", # pydocstyle - missing docstring in magic method
"D106", # pydocstyle - missing docstring in public nested class
"D107", # pydocstyle - missing docstring in __init__
"D202", # pydocstyle - no blank lines allowed after function docstring
"D205", # pydocstyle - 1 blank line required between summary line and description
"D415", # pydocstyle - first line should end with a period, question mark, or exclamation point
"E501", # pycodestyle line too long, handled by black
"PLW2901", # pylint - for loop variable overwritten by assignment target
"RUF012", # Ruff-specific rule - annotated with classvar
"RUF012", # Ruff-specific rule - annotated with classvar
"ANN401",
"ANN102",
"ANN101",
"FBT",
"PLR0913", # too many arguments
"PT",
"TD",
"ARG002", # ignore for now; investigate
"ARG002", # ignore for now; investigate
"PERF203", # ignore for now; investigate
"PD011", # pandas
"PD011", # pandas
]
select = ["ALL"]

Expand Down Expand Up @@ -415,16 +413,16 @@ trim = true

[tool.git-cliff.git]
commit_parsers = [
{message = "^feat", group = "Features"},
{message = "^fix", group = "Bug Fixes"},
{message = "^doc", group = "Documentation"},
{message = "^perf", group = "Performance"},
{message = "^refactor", group = "Refactor"},
{message = "^style", group = "Styling"},
{message = "^test", group = "Testing"},
{message = "^chore\\(release\\): prepare for", skip = true},
{message = "^chore", group = "Miscellaneous Tasks"},
{body = ".*security", group = "Security"},
{ message = "^feat", group = "Features" },
{ message = "^fix", group = "Bug Fixes" },
{ message = "^doc", group = "Documentation" },
{ message = "^perf", group = "Performance" },
{ message = "^refactor", group = "Refactor" },
{ message = "^style", group = "Styling" },
{ message = "^test", group = "Testing" },
{ message = "^chore\\(release\\): prepare for", skip = true },
{ message = "^chore", group = "Miscellaneous Tasks" },
{ body = ".*security", group = "Security" },
]
conventional_commits = true
filter_commits = false
Expand Down
24 changes: 23 additions & 1 deletion tests/integration/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import asyncio
import contextlib
import os
from datetime import datetime, timedelta, timezone
from datetime import date, datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Dict, Generator, Iterator, List, Literal, Type, Union, cast
from unittest.mock import NonCallableMagicMock, create_autospec
from uuid import UUID
Expand Down Expand Up @@ -1666,6 +1666,8 @@ async def test_lazy_load(
item_model: ItemModel,
tag_model: TagModel,
) -> None:
if getattr(tag_repo, "__collection__", None) is not None:
pytest.skip("Skipping lazy load testing on Mock repositories.")
tag_obj = await maybe_async(tag_repo.add(tag_model(name="A new tag")))
assert tag_obj
new_items = await maybe_async(
Expand Down Expand Up @@ -1921,6 +1923,26 @@ async def test_service_update_method_no_item_id(author_service: AuthorService, f
assert updated_obj.name == obj.name


async def test_service_update_method_data_is_dict(author_service: AuthorService, first_author_id: Any) -> None:
new_date = datetime.date(datetime.now())
updated_obj = await maybe_async(
author_service.update(item_id=first_author_id, data={"dob": new_date}),
)
assert updated_obj.dob == new_date
# ensure the other fields are not affected
assert updated_obj.name == "Agatha Christie"


async def test_service_update_method_data_is_dict_with_none_value(
author_service: AuthorService,
first_author_id: Any,
) -> None:
updated_obj = await maybe_async(author_service.update(item_id=first_author_id, data={"dob": None}))
assert cast(Union[date, None], updated_obj.dob) is None
# ensure the other fields are not affected
assert updated_obj.name == "Agatha Christie"


async def test_service_update_method_instrumented_attribute(
author_service: AuthorService,
first_author_id: Any,
Expand Down

0 comments on commit c791bed

Please sign in to comment.