-
-
Notifications
You must be signed in to change notification settings - Fork 21
/
graph_db.py
90 lines (72 loc) · 2.18 KB
/
graph_db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
Neo4J Transactions manager for DB operations
"""
import contextlib
import contextvars # Used for creation of context vars
import logging
import neo4j # Interface with Neo4J
from . import settings # Neo4J settings
from .exceptions import SessionMissingError # Custom exceptions
from .exceptions import TransactionMissingError
log = logging.getLogger(__name__)
DEFAULT_DB = "neo4j"
txn = contextvars.ContextVar("txn")
txn.set(None)
session = contextvars.ContextVar("session")
session.set(None)
@contextlib.asynccontextmanager
async def TransactionCtx():
"""
Transaction context will set global transaction "txn" for the code in context.
Transactions are automatically rollback if an exception occurs within the context.
"""
global txn, session
try:
async with driver.session(database=DEFAULT_DB) as _session:
txn_manager = await _session.begin_transaction()
async with txn_manager as _txn:
txn.set(_txn)
session.set(_session)
yield _txn, _session
finally:
txn.set(None)
session.set(None)
@contextlib.asynccontextmanager
async def database_lifespan():
"""
Context manager for Neo4J database
"""
global driver
uri = settings.uri
driver = neo4j.AsyncGraphDatabase.driver(uri)
try:
yield
finally:
await driver.close()
def get_current_transaction():
"""
Fetches transaction variable in current context to perform DB operations
"""
curr_txn = txn.get()
if curr_txn is None:
raise TransactionMissingError()
return curr_txn
def get_current_session():
"""
Fetches session variable in current context to perform DB operations
"""
curr_session = session.get()
if curr_session is None:
raise SessionMissingError()
return curr_session
@contextlib.contextmanager
def SyncTransactionCtx():
"""
Get a non async session
BEWARE: use it with caution only for edge cases
Normally it should be reserved to background tasks
"""
uri = settings.uri
driver = neo4j.GraphDatabase.driver(uri)
with driver.session(database=DEFAULT_DB) as _session:
yield _session