forked from learningequality/ka-lite
-
Notifications
You must be signed in to change notification settings - Fork 2
/
api_client.py
306 lines (241 loc) · 12.8 KB
/
api_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
"""
"""
import re
import json
from .utils import get_serialized_models, save_serialized_models, get_device_counters, deserialize
from .models import *
from ..api_client import BaseClient
from ..devices.api_client import RegistrationClient
from ..devices.models import *
from fle_utils.platforms import get_os_name
class SyncClient(BaseClient):
""" This is for the distributed server, for establishing a client session with
the central server. Over that session, syncing can occur in multiple requests.
Note that in the future, this object may be used to sync
between two distributed servers (i.e. peer-to-peer sync)!"""
session = None
def post(self, path, payload={}, *args, **kwargs):
kwargs.setdefault("timeout", 120)
if self.session and self.session.client_nonce:
payload["client_nonce"] = self.session.client_nonce
return super(SyncClient, self).post(path, payload, *args, **kwargs)
def get(self, path, payload={}, *args, **kwargs):
# Set a high timeout since the server can be heavily loaded at times
kwargs.setdefault("timeout", 120)
if self.session and self.session.client_nonce:
payload["client_nonce"] = self.session.client_nonce
# add a random parameter to ensure the request is not cached
return super(SyncClient, self).get(path, payload, *args, **kwargs)
def start_session(self):
"""A 'session' to exchange data"""
if self.verbose:
print "\nCLIENT: start_session"
if self.session:
self.close_session()
self.session = SyncSession()
if self.verbose:
print "CLIENT: start_session, request #1"
# Request one: validate me as a sessionable partner
(self.session.client_nonce,
self.session.client_device,
data) = self.validate_me_on_server()
# Able to create session
signature = data.get("signature", "")
# Once again, we assume that (currently) the central server's version is >= ours,
# We just store what we can.
own_device = self.session.client_device
session = deserialize(data["session"], src_version=None, dest_version=own_device.get_version()).next().object
self.session.server_nonce = session.server_nonce
self.session.server_device = session.server_device
if not session.verify_server_signature(signature):
raise Exception("Sever session signature did not match.")
if session.client_nonce != self.session.client_nonce:
raise Exception("Client session nonce did not match.")
if session.client_device != self.session.client_device:
raise Exception("Client session device did not match.")
if self.require_trusted and not session.server_device.is_trusted():
raise Exception("The server is not trusted, don't make a session with THAT.")
self.session.verified = True
self.session.timestamp = session.timestamp
self.session.ip = self.parsed_url.netloc
self.session.save()
if self.verbose:
print "CLIENT: start_session, request #2"
# Request two: create your own session, and
# report the result back to me for validation
r = self.post("session/create", {
"client_nonce": self.session.client_nonce,
"client_device": self.session.client_device.pk,
"server_nonce": self.session.server_nonce,
"server_device": self.session.server_device.pk,
"signature": self.session.sign(),
})
if r.status_code == 200:
return "success"
else:
return r
def validate_me_on_server(self, recursive_retry=False):
client_nonce = uuid.uuid4().hex
client_device = Device.get_own_device()
r = self.post("session/create", {
"client_nonce": client_nonce,
"client_device": client_device.pk,
"client_version": client_device.get_version(),
"client_os": get_os_name(),
})
raw_data = r.content
try:
data = json.loads(raw_data)
except ValueError as e:
z = re.search(r'exception_value">([^<]+)<', unicode(raw_data), re.MULTILINE)
if z:
raise Exception("Could not load JSON\n; server error=%s" % z.group(1))
else:
raise Exception("Could not load JSON\n; raw content=%s" % raw_data)
# Happens if the server reports an error
if data.get("error"):
# This happens when a device points to a new central server,
# either because it changed, or because it self-registered.
if not recursive_retry and "Client device matching id could not be found." in data["error"]:
resp = RegistrationClient().register(prove_self=True)
if resp.get("error"):
raise Exception("Error [code=%s]: %s" % (resp.get("code",""), resp.get("error","")))
elif resp.get("code") != "registered":
raise Exception("Unexpected code: '%s'" % resp.get("code",""))
# We seem to have succeeded registering through prove_self;
# let's try to validate again (but without retrying again, lest we loop forever!)
return self.validate_me_on_server(recursive_retry=True)
raise Exception(data.get("error", ""))
return (client_nonce, client_device, data)
def close_session(self):
if self.verbose:
print "\nCLIENT: close_session"
if not self.session:
return
self.post("session/destroy", {
"client_nonce": self.session.client_nonce
})
self.session = None
return "success"
def get_server_device_counters(self):
r = self.get("device/counters")
try:
data = json.loads(r.content or "{}")
except ValueError:
raise Exception("Did not receive proper JSON data from server. URL: {}\n\nActual data:\n\n{}".format(r.url, r.content))
if "error" in data:
raise Exception("Server error in retrieving counters: " + data["error"])
return data.get("device_counters", {})
def get_client_device_counters(self):
return get_device_counters(zone=self.session.client_device.get_zone())
def sync_device_records(self):
if self.verbose:
print "\nCLIENT: sync_device_records"
server_counters = self.get_server_device_counters()
client_counters = self.get_client_device_counters()
if self.verbose:
print client_counters, server_counters
print "COUNTERS: ([D]istributed, [C]entral)"
for device in set(server_counters.keys()).union(client_counters.keys()):
print "\t", device[0:5], "D%d" % client_counters.get(device, 0), "C%d" % server_counters.get(device, 0)
devices_to_download = []
devices_to_upload = []
counters_to_download = {}
counters_to_upload = {}
# loop through the devices we have locally
for device_id in client_counters:
if device_id not in server_counters:
devices_to_upload.append(device_id)
counters_to_upload[device_id] = 0
elif client_counters[device_id] > server_counters[device_id]:
counters_to_upload[device_id] = server_counters[device_id]
# loop through the devices the server has told us about
for device_id in server_counters:
if device_id not in client_counters:
devices_to_download.append(device_id)
counters_to_download[device_id] = 0
elif server_counters[device_id] > client_counters[device_id]:
counters_to_download[device_id] = client_counters[device_id]
if self.verbose:
print "CLIENT: devices_to_upload = %r" % devices_to_upload
print "CLIENT: devices_to_download = %r" % devices_to_download
response = json.loads(self.post("device/download", {"devices": devices_to_download}).content)
# As usual, we're deserializing from the central server, so we assume that what we're getting
# is "smartly" dumbed down for us. We don't need to specify the src_version, as it's
# pre-cleaned for us.
download_results = save_serialized_models(response.get("devices", "[]"), increment_counters=False, verbose=self.verbose)
# BUGFIX(bcipolli) metadata only gets created if models are
# streamed; if a device is downloaded but no models are downloaded,
# metadata does not exist. Let's just force it here.
for device_id in devices_to_download: # force
try:
d = Device.all_objects.get(id=device_id) # even do deleted devices.
except Exception as e:
logging.error("Exception locating device %s for metadata creation: %s" % (device_id, e))
continue
if not d.get_counter_position(): # this would be nonzero if the device sync'd models
d.set_counter_position(counters_to_download[device_id])
self.session.models_downloaded += download_results["saved_model_count"]
self.display_and_count_errors(download_results, context_name="downloading devices")
self.session.save()
# TODO(jamalex): upload local devices as well? only needed once we have P2P syncing
return (counters_to_download, counters_to_upload)
def display_and_count_errors(self, data, context_name="syncing data"):
if "error" in data:
print "Server error(s) in %s: %s" % (context_name, data["error"])
self.session.errors += 1
if "exceptions" in data:
print "Server exceptions(s) in %s: %s" % (context_name, data["exceptions"])
self.session.errors += 1
def sync_models(self):
"""
This method first syncs device counters and device objects, so that the two computers
can determine who has what and, in comparison, what it needs to request.
Then, it uses those device records to partially download and partially upload.
Not all at once--that would be less robust!
Afterwards, it returns summary statistics about what was synced, but no specific
state--this allows it to assume nothing for the next go-around (as this method
is called in a loop elsewhere)
"""
if self.verbose:
print "\nCLIENT: sync_models"
counters_to_download, counters_to_upload = self.sync_device_records()
# Download (but prepare for errors--both thrown and unthrown!)
download_results = {
"saved_model_count" : 0,
"unsaved_model_count" : 0,
}
try:
if self.verbose:
print "CLIENT: sync_models, downloading"
response = json.loads(self.post("models/download", {"device_counters": counters_to_download}).content)
# As usual, we're deserializing from the central server, so we assume that what we're getting
# is "smartly" dumbed down for us. We don't need to specify the src_version, as it's
# pre-cleanaed for us.
download_results.update(save_serialized_models(response.get("models", "[]"), verbose=self.verbose))
self.session.models_downloaded += download_results["saved_model_count"]
self.display_and_count_errors(download_results, context_name="downloading models")
except Exception as e:
print "Exception downloading models (in api_client): %s, %s, %s" % (e.__class__.__name__, e.message, e.args)
download_results["error"] = e
self.session.errors += 1
# Upload (but prepare for errors--both thrown and unthrown!)
upload_results = {
"saved_model_count" : 0,
"unsaved_model_count" : 0,
}
try:
if self.verbose:
print "CLIENT: sync_models, uploading"
# By not specifying a dest_version, we're sending everything.
# Again, this is OK because we're sending to the central server.
response = self.post("models/upload", {"models": get_serialized_models(counters_to_upload, verbose=self.verbose)})
upload_results.update(json.loads(response.content))
self.session.models_uploaded += upload_results["saved_model_count"]
self.display_and_count_errors(upload_results, context_name="uploading models")
except Exception as e:
print "Exception uploading models (in api_client): %s, %s, %s" % (e.__class__.__name__, e.message, e.args)
upload_results["error"] = e
self.session.errors += 1
self.session.save()
return {"download_results": download_results, "upload_results": upload_results}