-
Notifications
You must be signed in to change notification settings - Fork 68
/
coil_dataset.py
408 lines (307 loc) · 14.6 KB
/
coil_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
import os
import glob
import traceback
import collections
import sys
import math
import copy
import json
import random
import numpy as np
import torch
import cv2
from torch.utils.data import Dataset
from . import splitter
from . import data_parser
# TODO: Warning, maybe this does not need to be included everywhere.
from configs import g_conf
from coilutils.general import sort_nicely
def parse_remove_configuration(configuration):
"""
Turns the configuration line of sliptting into a name and a set of params.
"""
if configuration is None:
return "None", None
print('conf', configuration)
conf_dict = collections.OrderedDict(configuration)
name = 'remove'
for key in conf_dict.keys():
if key != 'weights' and key != 'boost':
name += '_'
name += key
return name, conf_dict
def get_episode_weather(episode):
with open(os.path.join(episode, 'metadata.json')) as f:
metadata = json.load(f)
print(" WEATHER OF EPISODE ", metadata['weather'])
return int(metadata['weather'])
class CoILDataset(Dataset):
""" The conditional imitation learning dataset"""
def __init__(self, root_dir, transform=None, preload_name=None):
# Setting the root directory for this dataset
self.root_dir = root_dir
# We add to the preload name all the remove labels
if g_conf.REMOVE is not None and g_conf.REMOVE is not "None":
name, self._remove_params = parse_remove_configuration(g_conf.REMOVE)
self.preload_name = preload_name + '_' + name
self._check_remove_function = getattr(splitter, name)
else:
self._check_remove_function = lambda _, __: False
self._remove_params = []
self.preload_name = preload_name
print("preload Name ", self.preload_name)
if self.preload_name is not None and os.path.exists(
os.path.join('_preloads', self.preload_name + '.npy')):
print(" Loading from NPY ")
self.sensor_data_names, self.measurements = np.load(
os.path.join('_preloads', self.preload_name + '.npy'))
print(self.sensor_data_names)
else:
self.sensor_data_names, self.measurements = self._pre_load_image_folders(root_dir)
print("preload Name ", self.preload_name)
self.transform = transform
self.batch_read_number = 0
def __len__(self):
return len(self.measurements)
def __getitem__(self, index):
"""
Get item function used by the dataset loader
returns all the measurements with the desired image.
Args:
index:
Returns:
"""
try:
img_path = os.path.join(self.root_dir,
self.sensor_data_names[index].split('/')[-2],
self.sensor_data_names[index].split('/')[-1])
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
# Apply the image transformation
if self.transform is not None:
boost = 1
img = self.transform(self.batch_read_number * boost, img)
else:
img = img.transpose(2, 0, 1)
img = img.astype(np.float)
img = torch.from_numpy(img).type(torch.FloatTensor)
img = img / 255.
measurements = self.measurements[index].copy()
for k, v in measurements.items():
v = torch.from_numpy(np.asarray([v, ]))
measurements[k] = v.float()
measurements['rgb'] = img
self.batch_read_number += 1
except AttributeError:
print ("Blank IMAGE")
measurements = self.measurements[0].copy()
for k, v in measurements.items():
v = torch.from_numpy(np.asarray([v, ]))
measurements[k] = v.float()
measurements['steer'] = 0.0
measurements['throttle'] = 0.0
measurements['brake'] = 0.0
measurements['rgb'] = np.zeros(3, 88, 200)
return measurements
def is_measurement_partof_experiment(self, measurement_data):
# If the measurement data is not removable is because it is part of this experiment dataa
return not self._check_remove_function(measurement_data, self._remove_params)
def _get_final_measurement(self, speed, measurement_data, angle,
directions, avaliable_measurements_dict):
"""
Function to load the measurement with a certain angle and augmented direction.
Also, it will choose if the brake is gona be present or if acceleration -1,1 is the default.
Returns
The final measurement dict
"""
if angle != 0:
measurement_augmented = self.augment_measurement(copy.copy(measurement_data), angle,
3.6 * speed,
steer_name=avaliable_measurements_dict['steer'])
else:
# We have to copy since it reference a file.
measurement_augmented = copy.copy(measurement_data)
if 'gameTimestamp' in measurement_augmented:
time_stamp = measurement_augmented['gameTimestamp']
else:
time_stamp = measurement_augmented['elapsed_seconds']
final_measurement = {}
# We go for every available measurement, previously tested
# and update for the measurements vec that is used on the training.
for measurement, name_in_dataset in avaliable_measurements_dict.items():
# This is mapping the name of measurement in the target dataset
final_measurement.update({measurement: measurement_augmented[name_in_dataset]})
# Add now the measurements that actually need some kind of processing
final_measurement.update({'speed_module': speed / g_conf.SPEED_FACTOR})
final_measurement.update({'directions': directions})
final_measurement.update({'game_time': time_stamp})
return final_measurement
def _pre_load_image_folders(self, path):
"""
Pre load the image folders for each episode, keep in mind that we only take
the measurements that we think that are interesting for now.
Args
the path for the dataset
Returns
sensor data names: it is a vector with n dimensions being one for each sensor modality
for instance, rgb only dataset will have a single vector with all the image names.
float_data: all the wanted float data is loaded inside a vector, that is a vector
of dictionaries.
"""
episodes_list = glob.glob(os.path.join(path, 'episode_*'))
sort_nicely(episodes_list)
# Do a check if the episodes list is empty
if len(episodes_list) == 0:
raise ValueError("There are no episodes on the training dataset folder %s" % path)
sensor_data_names = []
float_dicts = []
number_of_hours_pre_loaded = 0
# Now we do a check to try to find all the
for episode in episodes_list:
print('Episode ', episode)
available_measurements_dict = data_parser.check_available_measurements(episode)
if number_of_hours_pre_loaded > g_conf.NUMBER_OF_HOURS:
# The number of wanted hours achieved
break
# Get all the measurements from this episode
measurements_list = glob.glob(os.path.join(episode, 'measurement*'))
sort_nicely(measurements_list)
if len(measurements_list) == 0:
print("EMPTY EPISODE")
continue
# A simple count to keep track how many measurements were added this episode.
count_added_measurements = 0
for measurement in measurements_list[:-3]:
data_point_number = measurement.split('_')[-1].split('.')[0]
with open(measurement) as f:
measurement_data = json.load(f)
# depending on the configuration file, we eliminated the kind of measurements
# that are not going to be used for this experiment
# We extract the interesting subset from the measurement dict
speed = data_parser.get_speed(measurement_data)
directions = measurement_data['directions']
final_measurement = self._get_final_measurement(speed, measurement_data, 0,
directions,
available_measurements_dict)
if self.is_measurement_partof_experiment(final_measurement):
float_dicts.append(final_measurement)
rgb = 'CentralRGB_' + data_point_number + '.png'
sensor_data_names.append(os.path.join(episode.split('/')[-1], rgb))
count_added_measurements += 1
# We do measurements for the left side camera
# We convert the speed to KM/h for the augmentation
# We extract the interesting subset from the measurement dict
final_measurement = self._get_final_measurement(speed, measurement_data, -30.0,
directions,
available_measurements_dict)
if self.is_measurement_partof_experiment(final_measurement):
float_dicts.append(final_measurement)
rgb = 'LeftRGB_' + data_point_number + '.png'
sensor_data_names.append(os.path.join(episode.split('/')[-1], rgb))
count_added_measurements += 1
# We do measurements augmentation for the right side cameras
final_measurement = self._get_final_measurement(speed, measurement_data, 30.0,
directions,
available_measurements_dict)
if self.is_measurement_partof_experiment(final_measurement):
float_dicts.append(final_measurement)
rgb = 'RightRGB_' + data_point_number + '.png'
sensor_data_names.append(os.path.join(episode.split('/')[-1], rgb))
count_added_measurements += 1
# Check how many hours were actually added
last_data_point_number = measurements_list[-4].split('_')[-1].split('.')[0]
number_of_hours_pre_loaded += (float(count_added_measurements / 10.0) / 3600.0)
print(" Loaded ", number_of_hours_pre_loaded, " hours of data")
# Make the path to save the pre loaded datasets
if not os.path.exists('_preloads'):
os.mkdir('_preloads')
# If there is a name we saved the preloaded data
if self.preload_name is not None:
np.save(os.path.join('_preloads', self.preload_name), [sensor_data_names, float_dicts])
return sensor_data_names, float_dicts
def augment_directions(self, directions):
if directions == 2.0:
if random.randint(0, 100) < 20:
directions = random.choice([3.0, 4.0, 5.0])
return directions
def augment_steering(self, camera_angle, steer, speed):
"""
Apply the steering physical equation to augment for the lateral cameras steering
Args:
camera_angle: the angle of the camera
steer: the central steering
speed: the speed that the car is going
Returns:
the augmented steering
"""
time_use = 1.0
car_length = 6.0
pos = camera_angle > 0.0
neg = camera_angle <= 0.0
# You should use the absolute value of speed
speed = math.fabs(speed)
rad_camera_angle = math.radians(math.fabs(camera_angle))
val = g_conf.AUGMENT_LATERAL_STEERINGS * (
math.atan((rad_camera_angle * car_length) / (time_use * speed + 0.05))) / 3.1415
steer -= pos * min(val, 0.3)
steer += neg * min(val, 0.3)
steer = min(1.0, max(-1.0, steer))
# print('Angle', camera_angle, ' Steer ', old_steer, ' speed ', speed, 'new steer', steer)
return steer
def augment_measurement(self, measurements, angle, speed, steer_name='steer'):
"""
Augment the steering of a measurement dict
"""
new_steer = self.augment_steering(angle, measurements[steer_name],
speed)
measurements[steer_name] = new_steer
return measurements
def controls_position(self):
return np.where(self.meta_data[:, 0] == b'control')[0][0]
"""
Methods to interact with the dataset attributes that are used for training.
"""
def extract_targets(self, data):
"""
Method used to get to know which positions from the dataset are the targets
for this experiments
Args:
labels: the set of all float data got from the dataset
Returns:
the float data that is actually targets
Raises
value error when the configuration set targets that didn't exist in metadata
"""
targets_vec = []
for target_name in g_conf.TARGETS:
targets_vec.append(data[target_name])
return torch.cat(targets_vec, 1)
def extract_inputs(self, data):
"""
Method used to get to know which positions from the dataset are the inputs
for this experiments
Args:
labels: the set of all float data got from the dataset
Returns:
the float data that is actually targets
Raises
value error when the configuration set targets that didn't exist in metadata
"""
inputs_vec = []
for input_name in g_conf.INPUTS:
inputs_vec.append(data[input_name])
return torch.cat(inputs_vec, 1)
def extract_intentions(self, data):
"""
Method used to get to know which positions from the dataset are the inputs
for this experiments
Args:
labels: the set of all float data got from the dataset
Returns:
the float data that is actually targets
Raises
value error when the configuration set targets that didn't exist in metadata
"""
inputs_vec = []
for input_name in g_conf.INTENTIONS:
inputs_vec.append(data[input_name])
return torch.cat(inputs_vec, 1)