Skip to content

Commit

Permalink
Merge 86260ad into 3beeeb2
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanlct committed Jul 18, 2019
2 parents 3beeeb2 + 86260ad commit 2aa3f3b
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 32 deletions.
114 changes: 89 additions & 25 deletions flow/envs/green_wave_env.py
Expand Up @@ -433,37 +433,101 @@ def _reroute_if_final_edge(self, veh_id):
pos="0",
speed="max")

# FIXME it doesn't make sense to pass a list of edges since the function
# returns a flattened list with no padding, so we would lose information
def k_closest_to_intersection(self, edges, k):
"""Return the vehicle IDs of k closest vehicles to an intersection.
For each edge in edges, return the ids (veh_id) of the k vehicles
in edge that are closest to an intersection (the intersection they
are heading towards).
- Performs no check on whether or not edge is going towards an
intersection or not.
- Does no padding if there are less than k vehicles on an edge.
def get_closest_to_intersection(self, edges, num_closest, padding=False):
"""Return the IDs of the vehicles that are closest to an intersection.
For each edge in edges, return the IDs (veh_id) of the num_closest
vehicles in edge that are closest to an intersection (the intersection
they are heading towards).
This function performs no check on whether or not edges are going
towards an intersection or not, it just gets the vehicles that are
closest to the end of their edges.
If there are less than num_closest vehicles on an edge, the function
performs padding by adding empty strings "" instead of vehicle ids if
the padding parameter is set to True.
Parameters
----------
edges : str | str list
ID of an edge or list of edge IDs.
num_closest : int (> 0)
Number of vehicles to consider on each edge.
padding : bool (default False)
If there are less than num_closest vehicles on an edge, perform
padding by adding empty strings "" instead of vehicle ids if the
padding parameter is set to True (note: leaving padding to False
while passing a list of several edges as parameter can lead to
information loss since you will not know which edge, if any,
contains less than num_closest vehicles).
Usage
-----
For example, consider the following network, composed of 4 edges
whose ids are "edge0", "edge1", "edge2" and "edge3", the numbers
being vehicles all headed towards intersection x. The ID of the vehicle
with number n is "veh{n}" (edge "veh0", "veh1"...).
edge1
| |
| 7 |
| 8 |
-------------| |-------------
edge0 1 2 3 4 5 6 x edge2
-------------| |-------------
| 9 |
| 10|
| 11|
edge3
And consider the following example calls on the previous network:
>>> get_closest_to_intersection("edge0", 4)
["veh6", "veh5", "veh4", "veh3"]
>>> get_closest_to_intersection("edge0", 8)
["veh6", "veh5", "veh4", "veh3", "veh2", "veh1"]
>>> get_closest_to_intersection("edge0", 8, padding=True)
["veh6", "veh5", "veh4", "veh3", "veh2", "veh1", "", ""]
>>> get_closest_to_intersection(["edge0", "edge1", "edge2", "edge3"],
3, padding=True)
["veh6", "veh5", "veh4", "veh8", "veh7", "", "", "", "", "veh9",
"veh10", "veh11"]
Returns
-------
str list
If n is the number of edges given as parameters, then the returned
list contains n * num_closest vehicle IDs.
Raises
------
ValueError
if num_closest <= 0
"""
if k < 0:
raise ValueError("Function k_closest_to_intersection called with"
"parameter k={}, but k should be non-negative"
.format(k))
if num_closest <= 0:
raise ValueError("Function get_closest_to_intersection called with"
"parameter num_closest={}, but num_closest should"
"be positive".format(num_closest))

if isinstance(edges, list):
ids = [self.k_closest_to_intersection(edge, k) for edge in edges]
# flatten the list before returning it
ids = [self.get_closest_to_intersection(edge, num_closest)
for edge in edges]
# flatten the list and return it
return [veh_id for sublist in ids for veh_id in sublist]

# get the ids of all the vehicles on the edge 'edges' ordered by
# increasing distance to intersection
veh_ids_ordered = sorted(
self.k.vehicle.get_ids_by_edge(edges),
key=self.get_distance_to_intersection)
# increasing distance to end of edge (intersection)
veh_ids_ordered = sorted(self.k.vehicle.get_ids_by_edge(edges),
key=self.get_distance_to_intersection)

# return the ids of the k vehicles closest to the intersection
return veh_ids_ordered[:k]
# return the ids of the num_closest vehicles closest to the
# intersection, potentially with ""-padding.
pad_lst = [""] * (num_closest - len(veh_ids_ordered))
return veh_ids_ordered[:num_closest] + (pad_lst if padding else [])


class PO_TrafficLightGridEnv(TrafficLightGridEnv):
Expand Down Expand Up @@ -554,7 +618,7 @@ def get_state(self):
for _, edges in self.scenario.node_mapping:
for edge in edges:
observed_ids = \
self.k_closest_to_intersection(edge, self.num_observed)
self.get_closest_to_intersection(edge, self.num_observed)
all_observed_ids += observed_ids

# check which edges we have so we can always pad in the right
Expand Down
2 changes: 1 addition & 1 deletion flow/multiagent_envs/traffic_light_grid.py
Expand Up @@ -119,7 +119,7 @@ def get_state(self):
local_edge_numbers = []
for edge in edges:
observed_ids = \
self.k_closest_to_intersection(edge, self.num_observed)
self.get_closest_to_intersection(edge, self.num_observed)
all_observed_ids.append(observed_ids)

# check which edges we have so we can always pad in the right
Expand Down
15 changes: 9 additions & 6 deletions tests/fast_tests/test_traffic_lights.py
Expand Up @@ -191,23 +191,26 @@ def test_k_closest(self):

# get the node mapping for node center0
c0_edges = node_mapping[0][1]
k_closest = self.env.k_closest_to_intersection(c0_edges, 3)
k_closest = self.env.get_closest_to_intersection(c0_edges, 3)

# check bot, right, top, left in that order
self.assertEqual(
self.env.k_closest_to_intersection(c0_edges[0], 3), k_closest[0:2])
self.env.get_closest_to_intersection(c0_edges[0], 3),
k_closest[0:2])
self.assertEqual(
self.env.k_closest_to_intersection(c0_edges[1], 3), k_closest[2:4])
self.env.get_closest_to_intersection(c0_edges[1], 3),
k_closest[2:4])
self.assertEqual(
len(self.env.k_closest_to_intersection(c0_edges[2], 3)), 0)
len(self.env.get_closest_to_intersection(c0_edges[2], 3)), 0)
self.assertEqual(
self.env.k_closest_to_intersection(c0_edges[3], 3), k_closest[4:6])
self.env.get_closest_to_intersection(c0_edges[3], 3),
k_closest[4:6])

for veh_id in k_closest:
self.assertTrue(self.env.k.vehicle.get_edge(veh_id) in c0_edges)

with self.assertRaises(ValueError):
self.env.k_closest_to_intersection(c0_edges, -1)
self.env.get_closest_to_intersection(c0_edges, -1)


class TestItRuns(unittest.TestCase):
Expand Down

0 comments on commit 2aa3f3b

Please sign in to comment.