Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,7 @@ Each `Repository` comes with a set of **typed methods** to perform common CRUD o

______________________________________________________________________

- `find`: Find all records of an entity that match the given filters

______________________________________________________________________

- `get_by_id`: Get a single record by its ID
- `get`: Get a single record by its ID
- `get_batch`: Get all records of an entity that match the given filters
- `get_batch_by_ids`: Get a batch of records by their IDs
- `get_all`: Get all records of an entity
Expand All @@ -142,14 +138,15 @@ ______________________________________________________________________

If you require more flexibility, you may also use the `BaseRepository` which provides more granular operations. The `BaseRepository` provides the following methods:

- `_create`: Create a new record of an entity
- `_create_batch`: Create a batch of records of an entity
- `_update`: Update an entity instance
- `_update_batch`: Update a batch of entity instances with the same values
- `_get`: Get a single record by its ID
- `_get_batch`: Get all records of an entity that match the given filters
- `_delete`: Delete an entity instance
- `_delete_batch`: Delete a batch of entity instances
- `create`: Create a new record of an entity
- `create_batch`: Create a batch of records of an entity
- `update`: Update an entity instance
- `update_batch`: Update a batch of entity instances with the same values
- `get`: Get a single record by its ID
- `get_batch`: Get all records of an entity that match the given filters
- `find`: Find all records of an entity that match the given filters
- `delete`: Delete an entity instance
- `delete_batch`: Delete a batch of entity instances

## Examples

Expand Down
44 changes: 30 additions & 14 deletions sqlmodel_repository/base_repository.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Generic, Optional, Type, TypeVar, get_args
from typing import Generic, List, Optional, Type, TypeVar, get_args

from sqlalchemy.orm import Session
from sqlmodel import col
from structlog import WriteLogger

from sqlmodel_repository.entity import SQLModelEntity
Expand All @@ -14,10 +15,7 @@

class BaseRepository(Generic[GenericEntity], ABC):
"""Abstract base class for all repositories"""

entity: Type[GenericEntity]
logger: WriteLogger
sensitive_attribute_keys: list[str] = []

_default_excluded_keys = ["_sa_instance_state"]

def __init__(self, logger: Optional[WriteLogger] = None, sensitive_attribute_keys: Optional[list[str]] = None):
Expand All @@ -40,7 +38,25 @@ def get_session(self) -> Session:
"""Provides a session to work with"""
raise NotImplementedError

def _update(self, entity: GenericEntity, **kwargs) -> GenericEntity:
def find(self, **kwargs) -> List[GenericEntity]:
"""Get multiple entities with one query by filters

Args:
**kwargs: The filters to apply

Returns:
List[GenericEntity]: The entities that were found in the repository for the given filters
"""
filters = []

for key, value in kwargs.items():
try:
filters.append(col(getattr(self.entity, key)) == value)
except AttributeError as attribute_error:
raise EntityDoesNotPossessAttributeException(f"Entity {self.entity} does not have the attribute {key}") from attribute_error
return self.get_batch(filters=filters)

def update(self, entity: GenericEntity, **kwargs) -> GenericEntity:
"""Updates an entity with the given attributes (keyword arguments) if they are not None

Args:
Expand All @@ -59,7 +75,7 @@ def _update(self, entity: GenericEntity, **kwargs) -> GenericEntity:
session = self.get_session()
self._emit_log("Updating", entities=[entity], **kwargs)

entity = self._get(entity_id=entity.id)
entity = self.get(entity_id=entity.id)

for key, value in kwargs.items():
if value is not None:
Expand All @@ -73,7 +89,7 @@ def _update(self, entity: GenericEntity, **kwargs) -> GenericEntity:

return entity

def _update_batch(self, entities: list[GenericEntity], **kwargs) -> list[GenericEntity]:
def update_batch(self, entities: list[GenericEntity], **kwargs) -> list[GenericEntity]:
"""Updates a list of entities with the given attributes (keyword arguments) if they are not None

Args:
Expand Down Expand Up @@ -102,7 +118,7 @@ def _update_batch(self, entities: list[GenericEntity], **kwargs) -> list[Generic

return entities

def _get(self, entity_id: int) -> GenericEntity:
def get(self, entity_id: int) -> GenericEntity:
"""Retrieves an entity from the database with the specified ID.

Args:
Expand All @@ -123,7 +139,7 @@ def _get(self, entity_id: int) -> GenericEntity:
return result

# pylint: disable=dangerous-default-value
def _get_batch(self, filters: Optional[list] = None) -> list[GenericEntity]:
def get_batch(self, filters: Optional[list] = None) -> list[GenericEntity]:
"""Retrieves a list of entities from the database that match the specified filters.

Args:
Expand All @@ -141,7 +157,7 @@ def _get_batch(self, filters: Optional[list] = None) -> list[GenericEntity]:
result = session.query(self.entity).filter(*filters).all()
return result

def _create(self, entity: GenericEntity) -> GenericEntity:
def create(self, entity: GenericEntity) -> GenericEntity:
"""Adds a new entity to the database.

Args:
Expand All @@ -165,7 +181,7 @@ def _create(self, entity: GenericEntity) -> GenericEntity:
session.rollback()
raise CouldNotCreateEntityException from exception

def _create_batch(self, entities: list[GenericEntity]) -> list[GenericEntity]:
def create_batch(self, entities: list[GenericEntity]) -> list[GenericEntity]:
"""Adds a batch of new entities to the database.

Args:
Expand All @@ -190,7 +206,7 @@ def _create_batch(self, entities: list[GenericEntity]) -> list[GenericEntity]:
session.refresh(entity)
return entities

def _delete(self, entity: GenericEntity) -> GenericEntity:
def delete(self, entity: GenericEntity) -> GenericEntity:
"""Deletes an entity from the database.

Args:
Expand All @@ -213,7 +229,7 @@ def _delete(self, entity: GenericEntity) -> GenericEntity:
session.rollback()
raise CouldNotDeleteEntityException from exception

def _delete_batch(self, entities: list[GenericEntity]) -> None:
def delete_batch(self, entities: list[GenericEntity]) -> None:
"""Deletes a batch of entities from the database.

Args:
Expand Down
112 changes: 5 additions & 107 deletions sqlmodel_repository/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,74 +5,13 @@

from sqlmodel_repository.base_repository import BaseRepository
from sqlmodel_repository.entity import SQLModelEntity
from sqlmodel_repository.exceptions import EntityDoesNotPossessAttributeException

GenericEntity = TypeVar("GenericEntity", bound=SQLModelEntity)


class Repository(BaseRepository[GenericEntity], ABC):
"""Abstract base class for repository implementations"""

def create(self, entity: GenericEntity) -> GenericEntity:
"""Creates an entity to the repository

Args:
entity (GenericEntity): The entity to add

Returns:
GenericEntity: The added entity

Raises:
CouldNotCreateEntityException: If the entity could not be created
"""
return self._create(entity=entity)

def create_batch(self, entities: list[GenericEntity]) -> list[GenericEntity]:
"""Creates an entity to the repository

Args:
entities (list[GenericEntity]): The entities to add

Returns:
list[GenericEntity]: The added entities

Raises:
CouldNotCreateEntityException: If the entity could not be created
"""
return self._create_batch(entities=entities)

def get_by_id(self, entity_id: int) -> GenericEntity:
"""Get an entity by ID

Args:
entity_id (int): The ID of the entity

Returns:
GenericEntity: The entity

Raises:
NoResultFound: If no entity was found
"""
return self._get(entity_id=entity_id)

def find(self, **kwargs) -> List[GenericEntity]:
"""Get multiple entities with one query by filters

Args:
**kwargs: The filters to apply

Returns:
List[GenericEntity]: The entities that were found in the repository for the given filters
"""
filters = []

for key, value in kwargs.items():
try:
filters.append(col(getattr(self.entity, key)) == value)
except AttributeError as attribute_error:
raise EntityDoesNotPossessAttributeException(f"Entity {self.entity} does not have the attribute {key}") from attribute_error
return self._get_batch(filters=filters)

def get_batch_by_ids(self, entity_ids: list[int]) -> List[GenericEntity]:
"""Get multiple entities with one query by IDs

Expand All @@ -83,40 +22,15 @@ def get_batch_by_ids(self, entity_ids: list[int]) -> List[GenericEntity]:
List[GenericEntity]: The entities that were found in the repository for the given IDs
"""
filters = [col(self._entity_class().id).in_(entity_ids)]
return self._get_batch(filters=filters)
return self.get_batch(filters=filters)

def get_all(self) -> List[GenericEntity]:
"""Get all entities of the repository

Returns:
List[GenericEntity]: All entities that were found in the repository
"""
return self._get_batch()

# noinspection PyShadowingBuiltins
def update(self, entity: GenericEntity, **kwargs) -> GenericEntity:
"""Update an entity

Args:
entity (GenericEntity): Entity to update
**kwargs: Any new values

Returns:
GenericEntity: The updated entity
"""
return self._update(entity=entity, **kwargs)

def update_batch(self, entities: list[GenericEntity], **kwargs) -> list[GenericEntity]:
"""Update multiple entities with the same target values

Args:
entities (list[GenericEntity]): Entities to update
**kwargs: Any new values

Returns:
list[GenericEntity]: The updated entities
"""
return self._update_batch(entities=entities, **kwargs)
return self.get_batch()

def update_batch_by_ids(self, entity_ids: list[int], **kwargs) -> list[GenericEntity]:
"""Update multiple entities with the same target values
Expand All @@ -129,7 +43,7 @@ def update_batch_by_ids(self, entity_ids: list[int], **kwargs) -> list[GenericEn
list[GenericEntity]: The updated entities
"""
entities = self.get_batch_by_ids(entity_ids=entity_ids)
return self._update_batch(entities=entities, **kwargs)
return self.update_batch(entities=entities, **kwargs)

# noinspection PyShadowingBuiltins
def update_by_id(self, entity_id: int, **kwargs) -> GenericEntity:
Expand All @@ -142,17 +56,9 @@ def update_by_id(self, entity_id: int, **kwargs) -> GenericEntity:
Returns:
GenericEntity: The updated entity
"""
entity_to_update = self.get_by_id(entity_id=entity_id)
entity_to_update = self.get(entity_id=entity_id)
return self.update(entity=entity_to_update, **kwargs)

def delete(self, entity: GenericEntity) -> None:
"""Delete an entity

Args:
entity (GenericEntity): Entity to delete
"""
self._delete(entity=entity)

def delete_by_id(self, entity_id: int) -> None:
"""Delete an entity by entity_id

Expand All @@ -162,17 +68,9 @@ def delete_by_id(self, entity_id: int) -> None:
Raises:
NoResultFound: If no entity was found
"""
entity_to_delete = self.get_by_id(entity_id=entity_id)
entity_to_delete = self.get(entity_id=entity_id)
self.delete(entity=entity_to_delete)

def delete_batch(self, entities: List[GenericEntity]):
"""Delete multiple entities with one commit

Args:
entities (List[GenericEntity]): Entities to delete
"""
self._delete_batch(entities=entities)

def delete_batch_by_ids(self, entity_ids: List[int]):
"""Delete an entity by entity_id

Expand Down
Loading