Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions py_utils/orm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlmodel import Field, Session, create_engine, SQLModel, select
from typing import List, Optional, Type, TypeVar
from sqlmodel import Session, create_engine, SQLModel, select
from typing import List, Type, TypeVar

import logging
import os
import sqlite3


class DefaultModel(SQLModel):
id: Optional[int] = Field(default=None, primary_key=True)


_T = TypeVar("_T", bound=SQLModel)


Expand Down Expand Up @@ -171,13 +167,13 @@ def query_model(self, model: Type[_T]) -> List[_T]:

return list(session.exec(select(model)).all())

def update_model(self, model: _T, pk_field: str = "id", **kwargs) -> _T | None:
"""Update model with kwargs provided
def update_model(self, model: _T, values: dict, pk_field: str = "id") -> _T | None:
"""Update model with values provided

Args:
model (Type[SQLModel]): The model class definition
values (dict): A dictionary of the keys with updated values.
pk_field (str): The primary key field of the provided model. Defaults to `id`.
**kwargs: A dictionary of the keys with updated values.

Returns:
The updated entry.
Expand All @@ -192,20 +188,19 @@ def update_model(self, model: _T, pk_field: str = "id", **kwargs) -> _T | None:
try:
# Retrieve the object that you want to update
object_to_update = session.get(type(model), getattr(model, pk_field))
# object_to_update = session.query(model.__tablename__).get(getattr(model, pk_field))
if object_to_update is not None:

# Update the object with the provided values
for key, value in kwargs.items():
for key, value in values.items():
setattr(object_to_update, key, value)

session.commit()

return object_to_update

return None
raise Exception("The model does not exist.")
except Exception as e:
# Rollback the session in case of any other error
session.rollback()
logging.error(f"Error inserting data: {str(e)}")
logging.error(f"Error updating data: {str(e)}")
raise e
2 changes: 1 addition & 1 deletion tests/py_utils/test_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_update_model(self):
actual = Image(**updated_data)

created_image = self.db_client.insert_data(image)
updated_image = self.db_client.update_model(created_image, **{"core": "dc"})
updated_image = self.db_client.update_model(model=created_image, values={"core": created_image.core})

expected = Image(**convert_model_to_dict(updated_image))

Expand Down