-
Notifications
You must be signed in to change notification settings - Fork 5
/
main.py
93 lines (75 loc) · 2.89 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import time
import logging
import traceback
import asyncio
import signal
import numpy as np
import zmq
from zmq.asyncio import Context
from commons.common_zmq import recv_array_with_json, initialize_subscriber, initialize_publisher
from commons.configuration_manager import ConfigurationManager
# from src.utilities.transformer import Transformer
from utilities.recorder import Recorder
async def main(context: Context):
config_manager = ConfigurationManager()
conf = config_manager.config
# transformer = Transformer(conf)
recorder = Recorder(conf)
data_queue = context.socket(zmq.SUB)
controls_queue = context.socket(zmq.PUB)
control_mode = conf.control_mode
dagger_training_enabled = conf.dagger_training_enabled
dagger_epoch_size = conf.dagger_epoch_size
try:
mem_slice_frames = []
mem_slice_numerics = []
data_count = 0
await initialize_subscriber(data_queue, conf.data_queue_port)
await initialize_publisher(controls_queue, conf.controls_queue_port)
while True:
frame, data = await recv_array_with_json(queue=data_queue)
telemetry, expert_action = data
if frame is None or telemetry is None or expert_action is None:
logging.info("None data")
continue
try:
next_controls = expert_action.copy()
time.sleep(0.01)
recorder.record_full(frame, telemetry, expert_action, next_controls)
controls_queue.send_json(next_controls)
except Exception as ex:
print("Sending exception: {}".format(ex))
traceback.print_tb(ex.__traceback__)
except Exception as ex:
print("Exception: {}".format(ex))
traceback.print_tb(ex.__traceback__)
finally:
data_queue.close()
controls_queue.close()
if recorder is not None:
recorder.save_session_with_expert()
def cancel_tasks(loop):
for task in asyncio.Task.all_tasks(loop):
task.cancel()
def signal_cancel_tasks(*args):
loop = asyncio.get_event_loop()
for task in asyncio.Task.all_tasks(loop):
task.cancel()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
loop = asyncio.get_event_loop()
# not implemented in Windows:
# loop.add_signal_handler(signal.SIGINT, cancel_tasks, loop)
# loop.add_signal_handler(signal.SIGTERM, cancel_tasks, loop)
# alternative
signal.signal(signal.SIGINT, signal_cancel_tasks)
signal.signal(signal.SIGTERM, signal_cancel_tasks)
context = zmq.asyncio.Context()
try:
loop.run_until_complete(main(context))
except Exception as ex:
logging.error("Base interruption: {}".format(ex))
traceback.print_tb(ex.__traceback__)
finally:
loop.close()
context.destroy()