/
item.py
298 lines (227 loc) · 10.1 KB
/
item.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
import asyncio
import logging
from delfick_project.norms import sb
from photons_app import helpers as hp
from photons_app.errors import DevicesNotFound, TimedOut
from photons_app.special import SpecialReference
from photons_transport import catch_errors
log = logging.getLogger("photons_transport.targets.item")
class Done:
"""Used to specify when we should close a queue"""
def silence_errors(e):
pass
def choose_source(pkt, source):
"""Used to decide what we use as source for the packet"""
if pkt.actual("source") is not sb.NotSpecified:
return pkt.source
else:
return source
class NoLimit:
"""Used when we don't have a limit semaphore to impose no limit on concurrent access"""
async def __aenter__(self):
pass
async def __aexit__(self, exc_typ, exc, tb):
pass
async def acquire(self):
pass
def release(self):
pass
def locked(self):
return False
no_limit = NoLimit()
class Item:
def __init__(self, parts):
self.parts = parts
if type(self.parts) is not list:
self.parts = [self.parts]
async def run(self, reference, sender, **kwargs):
"""
Entry point to this item, the idea is you create a `script` with the
target and call `run` on the script, which ends up calling this
This is an async generator that yields the results from sending packets to the devices
it is responsible for finding devices and gathering all the responses.
This accepts the following keyword arguments.
broadcast
Whether we are broadcasting these messages or just unicasting directly
to each device
find_timeout
timeout for finding devices
connect_timeout
timeout for connecting to devices
message_timeout
A per message timeout for receiving replies for that message
found
A dictionary of
``{targetHex: (set([(ServiceType, addr), (ServiceType, addr), ...]), broadcastAddr)}``
If this is not provided, one is made for us
accept_found
Accept the found that was given and don't try to change it
error_catcher
A list that errors will be appended to instead of being raised.
Or a callable that takes in the error as an argument.
If this isn't specified then errors are raised after all the received
messages have been yielded.
Note that if there is only one serial that we sent messages to, then
any error is raised as is. Otherwise we raise a
``photons_app.errors.RunErrors``, with all the errors in a list on
the ``errors`` property of the RunErrors exception.
no_retry
If True then the messages being sent will have no automatic retry. This defaults
to False and retry rates are determined by the target you are using.
require_all_devices
Defaults to False. If True then we will not send any messages if we haven't
found all the devices we want to send messages to.
limit
An async context manager used to limit inflight messages. So for each message, we do
.. code-block:: python
async with limit:
send_and_wait_for_reply(message)
For example, an ``asyncio.Semaphore(30)``
Note that if you saying ``target.script(msgs).run(....)`` then limit will be set
to a semaphore with max 30 by default. You may specify just a number and it will turn it
into a semaphore.
"""
if "timeout" in kwargs:
log.warning(hp.lc("Please use message_timeout instead of timeout when calling run"))
with catch_errors(kwargs.get("error_catcher")) as error_catcher:
kwargs["error_catcher"] = error_catcher
broadcast = kwargs.get("broadcast", False)
find_timeout = kwargs.get("find_timeout", 20)
found, serials, missing = await self._find(
kwargs.get("found"), reference, sender, broadcast, find_timeout
)
# Work out what and where to send
# All the packets from here have targets on them
packets = self.make_packets(sender, serials)
# Short cut if nothing to actually send
if not packets:
return
# Determine found
if missing is None and not broadcast:
accept_found = kwargs.get("accept_found") or broadcast
found, missing = await self.search(
sender, found, accept_found, packets, broadcast, find_timeout, kwargs
)
# Complain if we care about having all wanted devices
if not broadcast and kwargs.get("require_all_devices") and missing:
raise DevicesNotFound(missing=missing)
# Write the messages and get results
async for thing in self.write_messages(sender, packets, kwargs):
yield thing
async def _find(self, found, reference, sender, broadcast, timeout):
"""
Turn our reference into serials and a found object and list of missing serials
if reference is not a SpecialReference then we just return it and the found we were given,
otherwise use the special reference to get found and serials where serials includes missing serials
"""
serials = reference
missing = None
if isinstance(reference, SpecialReference):
found, serials = await reference.find(sender, broadcast=broadcast, timeout=timeout)
missing = reference.missing(found)
serials.extend(missing)
if type(serials) is not list:
serials = [serials]
if found is None:
found = sender.found
return found, serials, missing
def simplify_parts(self):
"""
Simplify our parts such that their payloads are bitarrays.
Unless a packet is dynamically created (has a callable field)
in which case, we just return packet as is
"""
ps = []
for p in self.parts:
if p.is_dynamic:
ps.append((p, p))
else:
ps.append((p, p.simplify()))
return ps
def make_packets(self, sender, serials):
"""
Create and fill in the packets from our parts
This means that for each reference and each part we create a clone of
the part with the target set to the reference, complete with a source and
sequence
"""
# Simplify our parts
simplified_parts = self.simplify_parts()
packets = []
for original, p in simplified_parts:
if p.target is sb.NotSpecified:
for serial in serials:
clone = p.clone()
clone.update(
dict(
target=serial,
source=choose_source(clone, sender.source),
sequence=sender.seq(serial),
)
)
packets.append((original, clone))
else:
clone = p.clone()
clone.update(
dict(source=choose_source(clone, sender.source), sequence=sender.seq(p.serial))
)
packets.append((original, clone))
return packets
async def search(self, sender, found, accept_found, packets, broadcast, find_timeout, kwargs):
"""Search for the devices we want to send to"""
serials = list(set([p.serial for _, p in packets if p.target is not None]))
if accept_found or (found and all(serial in found for serial in serials)):
if found is None:
found = sender.found
missing = [serial for serial in serials if serial not in found]
return found, missing
kw = dict(kwargs)
kw["timeout"] = find_timeout
kw["broadcast"] = broadcast
kw["raise_on_none"] = False
return await sender.find_specific_serials(serials, **kw)
async def write_messages(self, sender, packets, kwargs):
"""Send all our packets and collect all the results"""
error_catcher = kwargs["error_catcher"]
async with hp.ResultStreamer(
sender.stop_fut, error_catcher=silence_errors, name="Item::write_messages[streamer]"
) as streamer:
count = 0
for original, packet in packets:
count += 1
await streamer.add_coroutine(
self.do_send(sender, original, packet, kwargs), context=packet
)
streamer.no_more_work()
got = 0
async for result in streamer:
got += 1
if result.successful:
for msg in result.value:
yield msg
else:
exc = result.value
pkt = result.context
if isinstance(exc, asyncio.CancelledError):
hp.add_error(
error_catcher,
TimedOut(
"Message was cancelled",
sent_pkt_type=pkt.pkt_type,
serial=pkt.serial,
source=pkt.source,
sequence=pkt.sequence,
),
)
else:
hp.add_error(error_catcher, exc)
async def do_send(self, sender, original, packet, kwargs):
async with kwargs.get("limit") or no_limit:
return await sender.send_single(
original,
packet,
timeout=kwargs.get("message_timeout", 10),
no_retry=kwargs.get("no_retry", False),
broadcast=kwargs.get("broadcast"),
connect_timeout=kwargs.get("connect_timeout", 10),
)