-
Notifications
You must be signed in to change notification settings - Fork 0
/
collision_map.py
288 lines (231 loc) · 13.2 KB
/
collision_map.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
'''A map to record possible obstacles.'''
from math import sin, cos, inf, sqrt
REPR_VERSION_STRING = "v1"
class MapLocation:
'''A location of the map that stores various counts.'''
def __init__(self):
# The number of times an agent has stepped in this location, strong evidence it is clear.
self.stepped_count = 0
# The number of times a scan beam has passed through this location, weak evidence it is clear.
self.missed_count = 0
# The number of times a scan beam has terminated in this location, evidence of an obstacle.
self.hit_count = 0
def __eq__(self, other):
return isinstance(other, MapLocation) and \
self.stepped_count == other.stepped_count and \
self.missed_count == other.missed_count and \
self.hit_count == other.hit_count
def get_discrete_coord(scale, coord):
shifted = coord + (scale / 2)
return int(shifted - shifted % scale)
def v_dot(v1, v2):
return v1[0] * v2[0] + v1[1] * v2[1]
def v_diff(v1, v2):
return (v1[0] - v2[0], v1[1] - v2[1])
class CollisionMap:
'''A map which records possible obstacles given a sensor reading.'''
def __init__(self, collision_map_scale=5, collision_map_max_dist=100, **kwargs):
assert(isinstance(collision_map_scale, int))
assert(collision_map_scale > 0)
self.scale = collision_map_scale
assert(isinstance(collision_map_max_dist, int))
assert(collision_map_max_dist > 0)
self.max_dist = collision_map_max_dist
self.rectangle_tolerance = collision_map_scale / 1000
self.map = {}
@classmethod
def from_string(cls, serialized_map):
'''Creates a new map from a string generated by __repr__.'''
new_map = None
for line_number, line in enumerate(serialized_map.split("\n")):
if line_number == 0:
assert(line == REPR_VERSION_STRING)
elif line_number == 1:
assert(line == "scale,max_dist")
elif line_number == 2:
[scale, max_dist] = line.split(",")
new_map = cls(collision_map_scale=int(scale), collision_map_max_dist=int(max_dist))
elif line_number == 3:
assert(line == "x,y,stepped_count,missed_count,hit_count")
elif line_number > 3:
assert(new_map is not None)
[x, y, stepped_count, missed_count, hit_count] = line.split(",")
line_location = new_map.get_location(int(x), int(y), create = True)
line_location.stepped_count = int(stepped_count)
line_location.missed_count = int(missed_count)
line_location.hit_count = int(hit_count)
return new_map
def get_key(self, x, y):
'''Returns the key into the internal map corresponding to the provided point.'''
shifted_x = x + (self.scale / 2)
shifted_y = y + (self.scale / 2)
return (int(shifted_x - shifted_x % self.scale), int(shifted_y - shifted_y % self.scale))
def get_location(self, x, y, create = False):
'''Retrieves the obstacle information for the given location.'''
key = self.get_key(x, y)
if not key in self.map and not create:
return MapLocation()
if not key in self.map and create:
self.map[key] = MapLocation()
return self.map[key]
def get_neighbor_keys(self, x, y):
'''Get the keys for all 8 adjacent neighbor cells.'''
key_x, key_y = self.get_key(x, y)
neighbors = []
for dx in [-self.scale, 0, self.scale]:
for dy in [-self.scale, 0, self.scale]:
if dx == 0 and dy == 0:
continue
neighbors.append((key_x + dx, key_y + dy))
return neighbors
def record_observations(self, x, y, theta, observations):
'''Record the appropriate counts for each of the given (delta_theta, distance) pairs based out
of the provided current location.'''
current_location_key = self.get_key(x, y)
self.__get_or_insert_location(current_location_key).stepped_count += 1
for delta_theta, distance in observations:
self.__add_line(x, y, theta + delta_theta, distance, current_location_key)
def get_locations_within_rectangle(self, p1, p2, p3, p4):
'''Returns the set of all observed locations partially or fully inside the given rectangle.'''
assert abs(p1[0] - p2[0]) + abs(p1[1] - p2[1]) > self.rectangle_tolerance, "Provided rectangle points are equal (or close enough)."
assert abs(p1[0] - p3[0]) + abs(p1[1] - p3[1]) > self.rectangle_tolerance, "Provided rectangle points are equal (or close enough)."
assert abs(p1[0] - p4[0]) + abs(p1[1] - p4[1]) > self.rectangle_tolerance, "Provided rectangle points are equal (or close enough)."
assert abs(p2[0] - p3[0]) + abs(p2[1] - p3[1]) > self.rectangle_tolerance, "Provided rectangle points are equal (or close enough)."
assert abs(p2[0] - p4[0]) + abs(p2[1] - p4[1]) > self.rectangle_tolerance, "Provided rectangle points are equal (or close enough)."
assert abs(p3[0] - p4[0]) + abs(p3[1] - p4[1]) > self.rectangle_tolerance, "Provided rectangle points are equal (or close enough)."
center_x = (p1[0] + p2[0] + p3[0] + p4[0]) / 4
center_y = (p1[1] + p2[1] + p3[1] + p4[1]) / 4
center_dist = (center_x - p1[0]) ** 2 + (center_y - p1[1]) ** 2
for px, py in [p2, p3, p4]:
p_dist = (center_x - px) ** 2 + (center_y - py) ** 2
assert abs(p_dist - center_dist) < self.rectangle_tolerance, "Provided points do not make a rectangle."
edge12_dist = v_dot(v_diff(p1, p2), v_diff(p1, p2))
diag13_dist = v_dot(v_diff(p1, p3), v_diff(p1, p3))
assert edge12_dist < diag13_dist, "Provided rectangle points are incorrectly ordered."
# Precompute some the rectangle's axes and dimensions.
rectangle_axis_1 = v_diff(p1, p2)
rectangle_axis_1_norm = sqrt(v_dot(rectangle_axis_1, rectangle_axis_1))
rectangle_axis_1_unit = (rectangle_axis_1[0] / rectangle_axis_1_norm,
rectangle_axis_1[1] / rectangle_axis_1_norm)
min_rectangle_axis_1 = min(v_dot(p1, rectangle_axis_1_unit), v_dot(p2, rectangle_axis_1_unit))
max_rectangle_axis_1 = max(v_dot(p1, rectangle_axis_1_unit), v_dot(p2, rectangle_axis_1_unit))
rectangle_axis_2 = v_diff(p1, p4)
rectangle_axis_2_norm = sqrt(v_dot(rectangle_axis_2, rectangle_axis_2))
rectangle_axis_2_unit = (rectangle_axis_2[0] / rectangle_axis_2_norm,
rectangle_axis_2[1] / rectangle_axis_2_norm)
min_rectangle_axis_2 = min(v_dot(p1, rectangle_axis_2_unit), v_dot(p4, rectangle_axis_2_unit))
max_rectangle_axis_2 = max(v_dot(p1, rectangle_axis_2_unit), v_dot(p4, rectangle_axis_2_unit))
# Get the grid-aligned bounding box to check.
min_x, min_y = self.get_key(min(p1[0], p2[0], p3[0], p4[0]), min(p1[1], p2[1], p3[1], p4[1]))
max_x, max_y = self.get_key(max(p1[0], p2[0], p3[0], p4[0]), max(p1[1], p2[1], p3[1], p4[1]))
results = {}
location_count = 0
current_x = min_x
current_y = min_y
while current_y <= max_y:
location_p1 = (current_x - (self.scale / 2), current_y - (self.scale / 2))
location_p2 = (current_x + (self.scale / 2), current_y - (self.scale / 2))
location_p3 = (current_x - (self.scale / 2), current_y + (self.scale / 2))
location_p4 = (current_x + (self.scale / 2), current_y + (self.scale / 2))
# Using the Separating Axis Theorem to check for overlap between the
# rectangle and the location domain. Though I skip projecting the
# rectangle onto the location's axes, since those checks are covered
# by the loop invariants.
min_location_axis_1 = min(v_dot(location_p1, rectangle_axis_1_unit),
v_dot(location_p2, rectangle_axis_1_unit),
v_dot(location_p3, rectangle_axis_1_unit),
v_dot(location_p4, rectangle_axis_1_unit))
max_location_axis_1 = max(v_dot(location_p1, rectangle_axis_1_unit),
v_dot(location_p2, rectangle_axis_1_unit),
v_dot(location_p3, rectangle_axis_1_unit),
v_dot(location_p4, rectangle_axis_1_unit))
if min_location_axis_1 > max_rectangle_axis_1 or min_rectangle_axis_1 > max_location_axis_1:
# This location is not inside the rectangle.
current_x += self.scale
if current_x > max_x:
current_x = min_x
current_y += self.scale
continue
min_location_axis_2 = min(v_dot(location_p1, rectangle_axis_2_unit),
v_dot(location_p2, rectangle_axis_2_unit),
v_dot(location_p3, rectangle_axis_2_unit),
v_dot(location_p4, rectangle_axis_2_unit))
max_location_axis_2 = max(v_dot(location_p1, rectangle_axis_2_unit),
v_dot(location_p2, rectangle_axis_2_unit),
v_dot(location_p3, rectangle_axis_2_unit),
v_dot(location_p4, rectangle_axis_2_unit))
if min_location_axis_2 > max_rectangle_axis_2 or min_rectangle_axis_2 > max_location_axis_2:
# This location is not inside the rectangle.
current_x += self.scale
if current_x > max_x:
current_x = min_x
current_y += self.scale
continue
location_count += 1
if not (current_x, current_y) in self.map:
# This location has not been observed.
current_x += self.scale
if current_x > max_x:
current_x = min_x
current_y += self.scale
continue
results[(current_x, current_y)] = self.get_location(current_x, current_y)
current_x += self.scale
if current_x > max_x:
current_x = min_x
current_y += self.scale
return results, location_count
def __add_line(self, start_x, start_y, theta, distance, current_location_key):
'''Record all spots along the given line as passed through, and records the final spot as hit.'''
start_point_key = self.get_key(start_x, start_y)
end_x = start_x + min(distance, self.max_dist) * cos(theta)
end_y = start_y + min(distance, self.max_dist) * sin(theta)
end_point_key = self.get_key(end_x, end_y)
if distance <= self.max_dist:
self.__get_or_insert_location(end_point_key).hit_count += 1
x, y = start_point_key
current_distance = sqrt((end_x - x) ** 2 + (end_y - y) ** 2)
a = -sin(theta)
b = cos(theta)
c = (start_x * sin(theta)) - (start_y * cos(theta))
while current_distance <= self.max_dist:
if self.get_key(x, y) != start_point_key:
# Don't record the starting point as passed through,
# since it will be recorded as stepped in.
self.get_location(x, y, create = True).missed_count += 1
if (x, y) in self.get_neighbor_keys(end_point_key[0], end_point_key[1]):
break
next_x = None
next_y = None
next_error = inf
for option_x, option_y in self.get_neighbor_keys(x, y):
option_dist = sqrt((end_x - option_x) ** 2 + (end_y - option_y) ** 2)
if option_dist >= current_distance:
# Don't move away from the target!
continue
option_error = abs((a * option_x) + (b * option_y) + c) / sqrt(a ** 2 + b ** 2)
if option_error < next_error:
next_x = option_x
next_y = option_y
next_error = option_error
if next_x is None or next_y is None:
break
x = next_x
y = next_y
current_distance = sqrt((end_x - x) ** 2 + (end_y - y) ** 2)
def __get_or_insert_location(self, key):
'''Retrieves the obstacle information for the given location, but will insert
a new location into the map if it is missing.'''
if not key in self.map:
self.map[key] = MapLocation()
return self.map[key]
def __repr__(self):
repr_str = REPR_VERSION_STRING + "\n"
repr_str += "\n".join(["scale,max_dist",
",".join([str(self.scale), str(self.max_dist)]),
"x,y,stepped_count,missed_count,hit_count"] +
[",".join(
[str(x), str(y), str(value.stepped_count),
str(value.missed_count), str(value.hit_count)]
) for [x, y], value in self.map.items()])
return repr_str