-
Notifications
You must be signed in to change notification settings - Fork 0
/
sqlalchemy.py
128 lines (94 loc) · 4.08 KB
/
sqlalchemy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from __future__ import annotations
from functools import wraps
from typing import TypeVar
from sqlalchemy import MetaData, select
from sqlalchemy.engine import Row
from sqlalchemy.orm import (
sessionmaker,
declarative_base,
Session as _Session,
DeclarativeMeta,
)
from sqlalchemy.sql import Select
from .named import NamedPropertiesMeta
# TODO proper type annotations for Select (mb python3.11's Self type)
# noinspection PyUnresolvedReferences
class Session(_Session):
def get_first(self, stmt: Select) -> object | None:
return self.execute(stmt).scalars().first()
def get_first_row(self, stmt: Select) -> Row:
return self.execute(stmt).first()
def get_all(self, stmt: Select) -> list[object]:
return self.execute(stmt).scalars().all()
def get_all_rows(self, stmt: Select) -> list[Row]:
return self.execute(stmt).all()
def get_paginated(self, stmt: Select, offset: int, limit: int) -> list[object]:
return self.get_all(stmt.offset(offset).limit(limit))
def get_paginated_rows(self, stmt: Select, offset: int, limit: int) -> list[Row]:
return self.get_all_rows(stmt.offset(offset).limit(limit))
class Sessionmaker(sessionmaker):
def with_begin(self, function):
"""Wraps the function with Session.begin() and passes session object to the decorated function"""
@wraps(function)
def with_begin_inner(*args, **kwargs):
if "session" in kwargs:
return function(*args, **kwargs)
with self.begin() as session:
kwargs["session"] = session
return function(*args, **kwargs)
return with_begin_inner
def with_autocommit(self, function):
"""Wraps the function with Session.begin() for automatic commits after the decorated function"""
@wraps(function)
def with_autocommit_inner(*args, **kwargs):
with self.begin() as _:
return function(*args, **kwargs)
return with_autocommit_inner
t = TypeVar("t", bound="ModBase")
class ModBaseMeta(NamedPropertiesMeta, DeclarativeMeta):
pass
class ModBase: # TODO remove session usages?
@classmethod
def create(cls: type[t], session: Session, **kwargs) -> t:
entry = cls(**kwargs)
session.add(entry)
session.flush()
return entry
@classmethod
def select_by_kwargs(cls, *order_by, **kwargs) -> Select:
if len(order_by) == 0:
return select(cls).filter_by(**kwargs)
return select(cls).filter_by(**kwargs).order_by(*order_by)
@classmethod
def find_first_by_kwargs(cls: type[t], session, *order_by, **kwargs) -> t | None:
return session.get_first(cls.select_by_kwargs(*order_by, **kwargs))
@classmethod
def find_first_row_by_kwargs(cls, session, *order_by, **kwargs) -> Row | None:
return session.get_first_row(cls.select_by_kwargs(*order_by, **kwargs))
@classmethod
def find_all_by_kwargs(cls: type[t], session, *order_by, **kwargs) -> list[t]:
return session.get_all(cls.select_by_kwargs(*order_by, **kwargs))
@classmethod
def find_all_rows_by_kwargs(cls, session, *order_by, **kwargs) -> list[Row]:
return session.get_all_rows(cls.select_by_kwargs(*order_by, **kwargs))
@classmethod
def find_paginated_by_kwargs(
cls: type[t], session, offset: int, limit: int, *order_by, **kwargs
) -> list[t]:
return session.get_paginated(
cls.select_by_kwargs(*order_by, **kwargs), offset, limit
)
@classmethod
def find_paginated_rows_by_kwargs(
cls, session, offset: int, limit: int, *order_by, **kwargs
) -> list[Row]:
return session.get_paginated_rows(
cls.select_by_kwargs(*order_by, **kwargs), offset, limit
)
# TODO find_by_... with reflection or metaclasses
def delete(self, session: Session) -> None:
session.delete(self)
session.flush()
def create_base(meta: MetaData) -> type[ModBase]:
# noinspection PyTypeChecker
return declarative_base(metadata=meta, cls=ModBase, metaclass=ModBaseMeta)