/
_repository.py
72 lines (57 loc) · 2.36 KB
/
_repository.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# -*- test-case-name: dbxs.test.test_access.AccessTestCase.test_repository -*-
"""
A repository combines a collection of accessors.
"""
from __future__ import annotations
import sys
from contextlib import asynccontextmanager
from inspect import signature
from typing import AsyncContextManager, AsyncIterator, Callable, TypeVar
from ._access import accessor
from .async_dbapi import AsyncConnectable, transaction
T = TypeVar("T")
def repository(
repositoryType: type[T],
) -> Callable[[AsyncConnectable], AsyncContextManager[T]]:
"""
A L{repository} combines management of a transaction with management of a
"repository", which is a collection of L{accessor}s and a contextmanager
that manages a transaction. This is easier to show with an example than a
description::
class Users(Protocol):
@query(sql="...", load=one(User))
def getUserByID(self, id: UserID) -> User: ...
class Posts(Protocol):
@query(sql="...", load=many(Post))
def getPostsFromUser(self, id: UserID) -> AsyncIterator[Posts]: ...
@dataclass
class BlogDB:
users: Users
posts: Posts
blogRepository = repository(BlogDB)
# ...
async def userAndPosts(pool: AsyncConnectable, id: UserID) -> str:
async with blogRepository(pool) as blog:
user = await blog.users.getUserByID(id)
posts = await blog.posts.getPostsFromUser(posts)
# transaction commits here
"""
sig = signature(repositoryType)
accessors = {}
for name, parameter in sig.parameters.items(): # pragma: no branch
annotation = parameter.annotation
# It would be nicer to do this with signature(..., eval_str=True), but
# that's not available until we require python>=3.10
if isinstance(annotation, str): # pragma: no branch
annotation = eval(
annotation, sys.modules[repositoryType.__module__].__dict__
)
accessors[name] = accessor(annotation)
@asynccontextmanager
async def transactify(acxn: AsyncConnectable) -> AsyncIterator[T]:
kw = {}
async with transaction(acxn) as aconn:
for name in accessors: # pragma: no branch
kw[name] = accessors[name](aconn)
yield repositoryType(**kw)
return transactify