diff --git a/py_utils/orm.py b/py_utils/orm.py index b3cd9bb..0a27492 100644 --- a/py_utils/orm.py +++ b/py_utils/orm.py @@ -1,16 +1,12 @@ from sqlalchemy.exc import IntegrityError, OperationalError -from sqlmodel import Field, Session, create_engine, SQLModel, select -from typing import List, Optional, Type, TypeVar +from sqlmodel import Session, create_engine, SQLModel, select +from typing import List, Type, TypeVar import logging import os import sqlite3 -class DefaultModel(SQLModel): - id: Optional[int] = Field(default=None, primary_key=True) - - _T = TypeVar("_T", bound=SQLModel) @@ -171,13 +167,13 @@ def query_model(self, model: Type[_T]) -> List[_T]: return list(session.exec(select(model)).all()) - def update_model(self, model: _T, pk_field: str = "id", **kwargs) -> _T | None: - """Update model with kwargs provided + def update_model(self, model: _T, values: dict, pk_field: str = "id") -> _T | None: + """Update model with values provided Args: model (Type[SQLModel]): The model class definition + values (dict): A dictionary of the keys with updated values. pk_field (str): The primary key field of the provided model. Defaults to `id`. - **kwargs: A dictionary of the keys with updated values. Returns: The updated entry. @@ -192,20 +188,19 @@ def update_model(self, model: _T, pk_field: str = "id", **kwargs) -> _T | None: try: # Retrieve the object that you want to update object_to_update = session.get(type(model), getattr(model, pk_field)) - # object_to_update = session.query(model.__tablename__).get(getattr(model, pk_field)) if object_to_update is not None: # Update the object with the provided values - for key, value in kwargs.items(): + for key, value in values.items(): setattr(object_to_update, key, value) session.commit() return object_to_update - return None + raise Exception("The model does not exist.") except Exception as e: # Rollback the session in case of any other error session.rollback() - logging.error(f"Error inserting data: {str(e)}") + logging.error(f"Error updating data: {str(e)}") raise e diff --git a/tests/py_utils/test_orm.py b/tests/py_utils/test_orm.py index 4dc1e7e..8d20617 100644 --- a/tests/py_utils/test_orm.py +++ b/tests/py_utils/test_orm.py @@ -95,7 +95,7 @@ def test_update_model(self): actual = Image(**updated_data) created_image = self.db_client.insert_data(image) - updated_image = self.db_client.update_model(created_image, **{"core": "dc"}) + updated_image = self.db_client.update_model(model=created_image, values={"core": created_image.core}) expected = Image(**convert_model_to_dict(updated_image))