/
publish.py
132 lines (103 loc) · 4.13 KB
/
publish.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from __future__ import annotations
import asyncio
from collections import defaultdict
from collections.abc import MutableMapping
from dask.utils import stringify
from distributed.utils import log_errors
class PublishExtension:
"""An extension for the scheduler to manage collections
* publish_list
* publish_put
* publish_get
* publish_delete
"""
def __init__(self, scheduler):
self.scheduler = scheduler
self.datasets = dict()
handlers = {
"publish_list": self.list,
"publish_put": self.put,
"publish_get": self.get,
"publish_delete": self.delete,
"publish_wait_flush": self.flush_wait,
}
stream_handlers = {
"publish_flush_batched_send": self.flush_receive,
}
self.scheduler.handlers.update(handlers)
self.scheduler.stream_handlers.update(stream_handlers)
self._flush_received = defaultdict(asyncio.Event)
def flush_receive(self, uid, **kwargs):
self._flush_received[uid].set()
async def flush_wait(self, uid):
await self._flush_received[uid].wait()
@log_errors
def put(self, keys=None, data=None, name=None, override=False, client=None):
if not override and name in self.datasets:
raise KeyError("Dataset %s already exists" % name)
self.scheduler.client_desires_keys(keys, f"published-{stringify(name)}")
self.datasets[name] = {"data": data, "keys": keys}
return {"status": "OK", "name": name}
@log_errors
def delete(self, name=None):
out = self.datasets.pop(name, {"keys": []})
self.scheduler.client_releases_keys(out["keys"], f"published-{stringify(name)}")
@log_errors
def list(self, *args):
return list(sorted(self.datasets.keys(), key=str))
@log_errors
def get(self, name=None, client=None):
return self.datasets.get(name, None)
class Datasets(MutableMapping):
"""A dict-like wrapper around :class:`Client` dataset methods.
Parameters
----------
client : distributed.client.Client
"""
__slots__ = ("_client",)
def __init__(self, client):
self._client = client
def __getitem__(self, key):
# When client is asynchronous, it returns a coroutine
return self._client.get_dataset(key)
def __setitem__(self, key, value):
if self._client.asynchronous:
# 'await obj[key] = value' is not supported by Python as of 3.8
raise TypeError(
"Can't use 'client.datasets[name] = value' when client is "
"asynchronous; please use 'client.publish_dataset(name=value)' instead"
)
self._client.publish_dataset(value, name=key)
def __delitem__(self, key):
if self._client.asynchronous:
# 'await del obj[key]' is not supported by Python as of 3.8
raise TypeError(
"Can't use 'del client.datasets[name]' when client is asynchronous; "
"please use 'client.unpublish_dataset(name)' instead"
)
return self._client.unpublish_dataset(key)
def __iter__(self):
if self._client.asynchronous:
raise TypeError(
"Can't invoke iter() or 'for' on client.datasets when client is "
"asynchronous; use 'async for' instead"
)
yield from self._client.list_datasets()
def __aiter__(self):
if not self._client.asynchronous:
raise TypeError(
"Can't invoke 'async for' on client.datasets when client is "
"synchronous; use iter() or 'for' instead"
)
async def _():
for key in await self._client.list_datasets():
yield key
return _()
def __len__(self):
if self._client.asynchronous:
# 'await len(obj)' is not supported by Python as of 3.8
raise TypeError(
"Can't use 'len(client.datasets)' when client is asynchronous; "
"please use 'len(await client.list_datasets())' instead"
)
return len(self._client.list_datasets())