forked from deepdrive/deepdrive-api
/
server.py
351 lines (314 loc) · 13.4 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
from __future__ import (absolute_import, division,
print_function, unicode_literals)
import json
import simplejson
import time
from deepdrive_api.client import get_action
from future.builtins import (dict, input, str)
import zmq
import pyarrow
from gym import spaces
import deepdrive_api.constants as c
import deepdrive_api.methods as m
from deepdrive_api import logs
log = logs.get_log(__name__)
CONN_STRING = "tcp://*:%s" % c.API_PORT
BLACKLIST_PARAMS = [
# We are the server, so the sim is always local to us,
# remote to a client somewhere
'is_remote_client',
# Distributed tf sessions are not implemented and probably
# wouldn't be passed this way anyway. This param is just
# for sharing local tf sessions on the same GPU.
'sess',
]
CHALLENGE_BLACKLIST_PARAMS = {
'env_id': 'Only one gym env',
'max_steps': 'Evaluation duration is standard across submissions',
'use_sim_start_command': 'Cannot pass parameters to Unreal',
# TODO: Step timeout and variable step duration less than threshold
'fps': 'Step duration is capped',
'driving_style': 'Modifies reward function',
'enable_traffic': 'Changes difficulty of scenario',
'ego_mph': 'Used by in-game throttle PID, '
'submissions must control their own throttle',
}
class Server(object):
"""Deepdrive server process that runs on same machine as Unreal Engine.
self.sim is a OpenAI gym environment factory which creates a new gym
environment on start().
Simple ZMQ / pyarrow server that runs the deepdrive gym environment locally,
which communicates with Unreal locally via shared mem and localhost.
"""
def __init__(self, sim, json_mode: bool = False, sim_args: dict = None):
"""
:param sim: sim is a module reference to deepdrive.sim, i.e.
https://github.com/deepdrive/deepdrive/tree/e114f9f053afe20d5a1478167d3f3c1f180fd279/sim
Yes, this is a circular runtime reference and does not allow
servers to be written in other languages, but I wanted to
keep the client and server implementations together so client
implementations in other languages would have be able to reference
everything here in one place.
:param json_mode: Allows sending / receiving all data in json to avoid
dependency on pyarrow. Sensor data will be omitted in this case.
:param sim: Sim args configured on the server side. This precludes
clients from configuring the environment for situations where
some standardized sim is expected, i.e. leaderboard evals,
challenges, etc...
"""
self.sim = sim
self.sim_args = sim_args
self.json_mode = json_mode
self.socket = None
self.context = None
self.env = None
self.serialization_errors = set()
# Once set, client gets a few seconds to close, then we force close
self.should_close_time: float = 0
def create_socket(self):
if self.socket is not None:
log.info('Closed server socket')
self.socket.close()
if self.context is not None:
log.info('Destroyed context')
self.context.destroy()
self.context = zmq.Context()
socket = self.context.socket(zmq.PAIR)
# socket.RCVTIMEO = c.API_TIMEOUT_MS
# socket.SNDTIMEO = c.API_TIMEOUT_MS
socket.bind(CONN_STRING)
self.socket = socket
return socket
def run(self):
self.create_socket()
log.info('Environment server started at %s', CONN_STRING)
done = False
while not done:
try:
done = self.dispatch()
except zmq.error.Again:
log.info('Waiting for client')
self.create_socket()
def dispatch(self):
"""
Waits for a message from the client, deserializes, routes to the
appropriate method, and sends a serialized response.
"""
if self.json_mode:
msg = self.socket.recv_json()
if not msg:
log.error('Received empty message, skipping')
return
method, args, kwargs = msg['method'], msg['args'], msg['kwargs']
else:
msg = self.socket.recv()
if not msg:
log.error('Received empty message, skipping')
return
method, args, kwargs = pyarrow.deserialize(msg)
done = False
close_resp = dict(closed_sim=True)
if self.env is None and method != m.START:
resp = 'No environment started, please send start request'
log.error('Client sent request with no environment started')
elif method == m.CLOSE:
self.env.close()
resp = close_resp
done = True
elif self.env is not None and self.env.unwrapped.should_close:
if self.should_close_time == 0:
self.should_close_time = time.time() + 3
elif time.time() > self.should_close_time:
self.env.close()
done = True
if method == m.STEP:
obs, reward, done, info = None, 0, True, {'closed': True}
if self.json_mode:
resp = self.get_json_step_response(obs, reward, done, info)
else:
resp = obs, reward, done, info
else:
resp = close_resp
elif method == m.START:
resp = self.handle_start_sim_request(kwargs)
elif method == m.STEP:
if self.json_mode:
action = get_action(**kwargs)
else:
action = args[0]
resp = self.get_step_response(action)
elif method == m.RESET:
resp = dict(reset_response=self.env.reset())
elif method == m.ACTION_SPACE or method == m.OBSERVATION_SPACE:
resp = self.serialize_space(self.env.action_space)
elif method == m.REWARD_RANGE:
resp = self.env.reward_range
elif method == m.METADATA:
resp = self.env.metadata
elif method == m.CHANGE_CAMERAS:
resp = self.env.unwrapped.change_cameras(*args, **kwargs)
else:
log.error('Invalid API method')
resp = 'Invalid API method'
serialized = self.serialize(resp)
if serialized is None:
raise RuntimeError('Could not serialize response. '
'Check above for details')
if self.json_mode:
self.socket.send_string(serialized)
else:
self.socket.send(serialized.to_buffer())
return done
def get_step_response(self, action):
resp = self.env.step(action)
if self.json_mode:
obs, reward, done, info = resp
if obs:
obs = self.get_filtered_observation(obs)
else:
obs = None
resp = self.get_json_step_response(obs, reward, done, info)
return resp
@staticmethod
def get_json_step_response(obs, reward, done, info):
resp = dict(
observation=obs,
reward=reward,
done=done,
info=info,
)
return resp
def handle_start_sim_request(self, kwargs):
if self.sim_args is not None:
sim_args = self.sim_args
server_type = 'locally_configured'
if 'path_follower' in kwargs and \
kwargs['path_follower'] and 'map' in kwargs and \
sim_args['map'] != '':
# Hack to deal with release / request bug in sim on new maps
sim_args['path_follower'] = kwargs['path_follower']
else:
sim_args = kwargs
server_type = 'remotely_configured'
self.remove_blacklisted_params(kwargs)
self.env = self.sim.start(**sim_args)
ret = dict(server_started=dict(type=server_type))
return ret
def serialize(self, resp):
if self.json_mode:
ret = simplejson.dumps(resp, ignore_nan=True)
else:
ret = self.serialize_pyarrow(resp)
return ret
@staticmethod
def get_filtered_observation(obs):
coll = obs['last_collision']
filtered = dict(
accerlation=obs['acceleration'].tolist(),
angular_acceleration=obs['angular_acceleration'].tolist(),
angular_velocity=obs['angular_velocity'].tolist(),
brake=obs['brake'],
# Skipping cameras for now (base64??)
capture_timestamp=obs['capture_timestamp'],
dimension=obs['dimension'].tolist(),
distance_along_route=obs['distance_along_route'],
distance_to_center_of_lane=obs['distance_to_center_of_lane'],
distance_to_next_agent=obs['distance_to_next_agent'],
distance_to_next_opposing_agent=obs[
'distance_to_next_opposing_agent'],
distance_to_prev_agent=obs['distance_to_prev_agent'],
episode_return=obs['episode_return'],
forward_vector=obs['forward_vector'].tolist(),
handbrake=obs['handbrake'],
is_game_driving=obs['is_game_driving'],
is_passing=obs['is_passing'],
is_resetting=obs['is_resetting'],
lap_number=obs['lap_number'],
last_collision=dict(
collidee_velocity=coll['collidee_velocity'].tolist(),
collision_location=coll['collision_normal'].tolist(),
collision_normal=coll['collision_normal'].tolist(),
time_since_last_collision=coll['time_since_last_collision'],
time_stamp=coll['time_stamp'],
time_utc=coll['time_utc'],
),
position=obs['position'].tolist(),
right_vector=obs['right_vector'].tolist(),
rotation=obs['rotation'].tolist(),
route_length=obs['route_length'],
scenario_finished=obs['scenario_finished'],
speed=obs['speed'],
steering=obs['steering'],
throttle=obs['throttle'],
up_vector=obs['up_vector'].tolist(),
velocity=obs['velocity'].tolist(),
world=obs['world'],
)
return filtered
def serialize_pyarrow(self, resp):
serialized = None
while serialized is None:
try:
serialized = pyarrow.serialize(resp)
except pyarrow.lib.SerializationCallbackError as e:
msg = str(e)
self.remove_unserializeables(resp, msg)
return serialized
def remove_unserializeables(self, x, msg):
"""
Make an object serializeable by pyarrow after an error by checking for the type
in msg. Pyarrow doesn't have a great API for serializable types, so doing this as a
stop gap for now.
We should avoid sending unserializable data to pyarrow, but at the same time not
totally fail when we do. Errors will be printed when unserializable data is first
encountered, so that we can go back and remove when it's reasonable.
This will not remove a list or tuple item, but will recursively search through
lists and tuples for dicts with unserializeable values.
:param x: Object from which to remove elements that pyarrow cannot serialize
:param msg: The error message returned by pyarrow during serizialization
:return:
"""
if isinstance(x, dict):
for k, v in x.items():
value_type = str(type(v))
if value_type in msg:
if value_type not in self.serialization_errors:
self.serialization_errors.add(value_type)
log.warning('Unserializable type %s Not sending to '
'client!', value_type)
x[k] = '[REMOVED!] %s was not serializable on server. ' \
'Avoid sending unserializable data for best ' \
'performance.' % value_type
if isinstance(v, dict) or isinstance(v, list) or \
isinstance(v, tuple):
# No __iter__ as numpy arrays are too big for this
self.remove_unserializeables(v, msg)
elif isinstance(x, tuple) or isinstance(x, list):
for e in x:
self.remove_unserializeables(e, msg)
@staticmethod
def remove_blacklisted_params(kwargs):
for key in list(kwargs):
if key in BLACKLIST_PARAMS:
log.warning('Removing {key} from sim start args, not'
' relevant to remote clients'.format(key=key))
del kwargs[key]
@staticmethod
def serialize_space(space):
space_type = type(space)
if space_type == spaces.Box:
resp = {'type': str(space_type),
'low': space.low,
'high': space.high,
'dtype': str(space.dtype)
}
else:
raise RuntimeError('Space of type "%s" value "%r" not supported'
% (str(space_type), space))
return resp
def start(sim, json_mode=False, sim_path=None, sim_args: dict = None):
from deepdrive_api import utils
if sim_path is not None:
utils.check_pyarrow_compatibility(sim_path)
server = Server(sim=sim, json_mode=json_mode, sim_args=sim_args)
server.run()