Skip to content

Commit

Permalink
Add update_model_by_attribute
Browse files Browse the repository at this point in the history
update_model_by_attribute allows users to update models that do not have an "id" field.
  • Loading branch information
cblack34 committed Jul 17, 2023
1 parent 4c949f7 commit 5fb4f75
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
27 changes: 26 additions & 1 deletion sqlalchemy_crud/crud.py
@@ -1,5 +1,6 @@
from typing import List
from typing import List, Type

import sqlalchemy
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -70,6 +71,30 @@ def update_model(db: Session, model: Base, model_id: int, schema: dict) -> Base:
return db_model


def update_model_by_attribute(
db: Session,
model: Type[sqlalchemy.orm.decl_api.DeclarativeMeta],
lookup_attribute: str,
lookup_attribute_value,
schema: dict,
):
db_model = get_model_by_attribute(
db=db,
model=model,
attribute=lookup_attribute,
attribute_value=lookup_attribute_value,
)
for key, value in schema.items():
if hasattr(db_model, key):
setattr(db_model, key, value)
else:
raise AttributeError

db.commit()
db.refresh(db_model)
return db_model


def delete_model(db: Session, model: Base, model_id: int) -> None:
db_model = get_model(db=db, model=model, model_id=model_id)
db.delete(db_model)
Expand Down
30 changes: 30 additions & 0 deletions tests/test_crud.py
Expand Up @@ -14,6 +14,7 @@
delete_model,
link_models,
unlink_models,
update_model_by_attribute,
)
from tests.models_for_test import Base, Parent, Child

Expand Down Expand Up @@ -277,6 +278,35 @@ def test_update_model_raise_exception_on_invalid_attribute(self):
schema=dict(name="parent_test_name_1_updated", invalid=1),
)

def test_update_model_by_attribute(self):
self.create_test_data()
self.link_children_to_parents()

model = get_model(db=self.db, model=Parent, model_id=1)
self.assertEqual(model.name, "parent_test_name_1")

model = update_model_by_attribute(
db=self.db,
model=Parent,
lookup_attribute="name",
lookup_attribute_value="parent_test_name_1",
schema=dict(name="parent_test_name_1_updated"),
)
self.assertEqual(model.name, "parent_test_name_1_updated")

model = get_model(db=self.db, model=Parent, model_id=1)
self.assertEqual(model.name, "parent_test_name_1_updated")

def test_update_model_by_attribute_raise_exception_on_invalid_attribute(self):
with self.assertRaises(AttributeError):
update_model_by_attribute(
db=self.db,
model=Parent,
lookup_attribute="name",
lookup_attribute_value="parent_test_name_1",
schema=dict(name="parent_test_name_1_updated", invalid=1),
)

def test_delete_model(self):
self.create_test_data()
self.link_children_to_parents()
Expand Down

0 comments on commit 5fb4f75

Please sign in to comment.