-
-
Notifications
You must be signed in to change notification settings - Fork 82
/
model_observer.py
196 lines (153 loc) · 6.14 KB
/
model_observer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import threading
import warnings
from collections import defaultdict
from enum import Enum
from functools import partial
from typing import Type, Dict, Any, Set, overload
from uuid import uuid4
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.db import transaction
from django.db.models import Model
from django.db.models.signals import post_delete, post_save, post_init
from djangochannelsrestframework.consumers import AsyncAPIConsumer
from djangochannelsrestframework.observer.base_observer import BaseObserver
"""
1) not in transaction
2) in a simple transation with one operations
3) in a simple transation with mutliple operations
4) was in a transation but rolled back then outside of a transation saved
5) was in a transatino but rolled back then inside a new one saved
6) in a transations savepoint that is never saved
7) in a transation and then in a save point
On each model instance we add a `__observers = {self(id): {tracking info}}`
"""
class Action(Enum):
CREATE = "create"
UPDATE = "update"
DELETE = "delete"
class UnsupportedWarning(Warning):
"""
"""
class ModelObserverInstanceState:
# this is set when the instance is created
current_groups: Set[str] = set()
class ModelObserver(BaseObserver):
def __init__(self, func, model_cls: Type[Model], **kwargs):
super().__init__(func)
self._model_cls = None
self.model_cls = model_cls # type: Type[Model]
self.id = uuid4()
@property
def model_cls(self) -> Type[Model]:
return self._model_cls
@model_cls.setter
def model_cls(self, value: Type[Model]):
was_none = self._model_cls is None
self._model_cls = value
if self._model_cls is not None and was_none:
self._connect()
def _connect(self):
"""
Connect the signal listing.
"""
# this is used to capture the current state for the model
post_init.connect(
self.post_init_receiver, sender=self.model_cls, dispatch_uid=id(self)
)
post_save.connect(
self.post_save_receiver, sender=self.model_cls, dispatch_uid=id(self)
)
post_delete.connect(
self.post_delete_receiver, sender=self.model_cls, dispatch_uid=id(self)
)
def post_init_receiver(self, instance: Model, **kwargs):
if instance.pk is None:
current_groups = set()
else:
current_groups = set(self.group_names_for_signal(instance=instance))
self.get_observer_state(instance).current_groups = current_groups
def get_observer_state(self, instance: Model) -> ModelObserverInstanceState:
# use a thread local dict to be safe...
if not hasattr(instance._state, "_thread_local_observers"):
instance._state._thread_local_observers = defaultdict(
ModelObserverInstanceState
)
return instance._state._thread_local_observers[self.id]
def post_save_receiver(self, instance: Model, created: bool, **kwargs):
"""
Handle the post save.
"""
if created:
self.database_event(instance, Action.CREATE)
else:
self.database_event(instance, Action.UPDATE)
def post_delete_receiver(self, instance: Model, **kwargs):
self.database_event(instance, Action.DELETE)
def database_event(self, instance: Model, action: Action):
connection = transaction.get_connection()
if connection.in_atomic_block:
if len(connection.savepoint_ids) > 0:
warnings.warn(
"Model observation with save points is unsupported and will"
" result in unexpected beauvoir.",
UnsupportedWarning,
)
connection.on_commit(partial(self.post_change_receiver, instance, action))
def post_change_receiver(self, instance: Model, action: Action, **kwargs):
"""
Triggers the old_binding to possibly send to its group.
"""
old_group_names = self.get_observer_state(instance).current_groups
if action == Action.DELETE:
new_group_names = set()
else:
new_group_names = set(self.group_names_for_signal(instance=instance))
self.get_observer_state(instance).current_groups = new_group_names
# if post delete, new_group_names should be []
# Django DDP had used the ordering of DELETE, UPDATE then CREATE for good reasons.
self.send_messages(
instance, old_group_names - new_group_names, Action.DELETE, **kwargs
)
# the object has been updated so that its groups are not the same.
self.send_messages(
instance, old_group_names & new_group_names, Action.UPDATE, **kwargs
)
#
self.send_messages(
instance, new_group_names - old_group_names, Action.CREATE, **kwargs
)
def send_messages(
self, instance: Model, group_names: Set[str], action: Action, **kwargs
):
if not group_names:
return
message = self.serialize(instance, action, **kwargs)
channel_layer = get_channel_layer()
for group_name in group_names:
async_to_sync(channel_layer.group_send)(group_name, message)
def group_names(self, *args, **kwargs):
# one channel for all updates.
yield "{}-{}-model-{}".format(
self._uuid, self.func.__name__.replace("_", "."), self.model_label,
)
def serialize(self, instance, action, **kwargs) -> Dict[str, Any]:
message = {}
if self._serializer:
message = self._serializer(self, instance, action, **kwargs)
else:
message["pk"] = instance.pk
message["type"] = self.func.__name__.replace("_", ".")
message["action"] = action.value
return message
@property
def model_label(self):
model_label = (
"{}.{}".format(
self.model_cls._meta.app_label.lower(),
self.model_cls._meta.object_name.lower(),
)
.lower()
.replace("_", ".")
)
return model_label