Skip to content

Commit

Permalink
Merge 20131bb into 08ce9a8
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanlct committed Jul 6, 2019
2 parents 08ce9a8 + 20131bb commit d7237bc
Show file tree
Hide file tree
Showing 4 changed files with 412 additions and 455 deletions.
3 changes: 1 addition & 2 deletions flow/core/kernel/scenario/traci.py
Expand Up @@ -679,8 +679,7 @@ def generate_cfg(self, net_params, traffic_lights, routes):
else:
show_detector = {'key': 'show-detectors', 'value': 'false'}

# FIXME(ak): add abstract method
nodes = self.specify_tll(net_params)
nodes = self._inner_nodes # nodes where there's traffic lights
tll = []
for node in nodes:
tll.append({
Expand Down
78 changes: 37 additions & 41 deletions flow/envs/green_wave_env.py
Expand Up @@ -92,7 +92,6 @@ def __init__(self, env_params, sim_params, scenario, simulator='traci'):
'velocities': np.zeros((self.steps, self.k.vehicle.num_vehicles)),
'positions': np.zeros((self.steps, self.k.vehicle.num_vehicles))
}
self.node_mapping = scenario.get_node_mapping()

# Keeps track of the last time the traffic lights in an intersection
# were allowed to change (the last time the lights were allowed to
Expand Down Expand Up @@ -249,27 +248,22 @@ def compute_reward(self, rl_actions, **kwargs):
# ===============================

def get_distance_to_intersection(self, veh_ids):
"""Return the distance from vehicle(s) to the next intersection.
Determines the smallest distance from the current vehicle's position
to any of the intersections.
"""Determine the distance from a vehicle to its next intersection.
Parameters
----------
veh_ids : str
vehicle identifier
veh_ids : str or str list
vehicle(s) identifier(s)
Returns
-------
tup
1st element: distance to closest intersection
2nd element: intersection ID (which also specifies which side of
the intersection the vehicle will be arriving at)
float (or float list)
distance to closest intersection
"""
if isinstance(veh_ids, list):
return [self.find_intersection_dist(veh_id) for veh_id in veh_ids]
else:
return self.find_intersection_dist(veh_ids)
return [self.get_distance_to_intersection(veh_id)
for veh_id in veh_ids]
return self.find_intersection_dist(veh_ids)

def find_intersection_dist(self, veh_id):
"""Return distance from intersection.
Expand Down Expand Up @@ -392,35 +386,37 @@ def _reroute_if_final_edge(self, veh_id):
pos="0",
speed="max")

# FIXME it doesn't make sense to pass a list of edges since the function
# returns a flattened list with no padding, so we would lose information
def k_closest_to_intersection(self, edges, k):
"""Return the vehicle IDs of k closest vehicles to an intersection.
Return the veh_id of the k closest vehicles to an intersection for
each edge. Performs no check on whether or not edge is going toward an
intersection or not. Does no padding
For each edge in edges, return the ids (veh_id) of the k vehicles
in edge that are closest to an intersection (the intersection they
are heading towards).
- Performs no check on whether or not edge is going towards an
intersection or not.
- Does no padding if there are less than k vehicles on an edge.
"""
if k < 0:
raise IndexError("k must be greater than 0")
dists = []

def sort_lambda(veh_id):
return self.get_distance_to_intersection(veh_id)
raise ValueError("Function k_closest_to_intersection called with"
"parameter k={}, but k should be non-negative"
.format(k))

if isinstance(edges, list):
for edge in edges:
vehicles = self.k.vehicle.get_ids_by_edge(edge)
dist = sorted(
vehicles,
key=sort_lambda
)
dists += dist[:k]
else:
vehicles = self.k.vehicle.get_ids_by_edge(edges)
dist = sorted(
vehicles,
key=lambda veh_id: self.get_distance_to_intersection(veh_id))
dists += dist[:k]
return dists
ids = [self.k_closest_to_intersection(edge, k) for edge in edges]
# flatten the list before returning it
return [veh_id for sublist in ids for veh_id in sublist]

# get the ids of all the vehicles on the edge 'edges' ordered by
# increasing distance to intersection
veh_ids_ordered = sorted(
self.k.vehicle.get_ids_by_edge(edges),
key=self.get_distance_to_intersection)

# return the ids of the k vehicles closest to the intersection
return veh_ids_ordered[:k]


class PO_TrafficLightGridEnv(TrafficLightGridEnv):
Expand Down Expand Up @@ -508,7 +504,7 @@ def get_state(self):
grid_array["inner_length"])
all_observed_ids = []

for node, edges in self.scenario.get_node_mapping():
for _, edges in self.scenario.node_mapping:
for edge in edges:
observed_ids = \
self.k_closest_to_intersection(edge, self.num_observed)
Expand All @@ -522,13 +518,13 @@ def get_state(self):
]
dist_to_intersec += [
(self.k.scenario.edge_length(
self.k.vehicle.get_edge(veh_id))
- self.k.vehicle.get_position(veh_id)) / max_dist
self.k.vehicle.get_edge(veh_id)) -
self.k.vehicle.get_position(veh_id)) / max_dist
for veh_id in observed_ids
]
edge_number += \
[self._convert_edge(self.k.vehicle.get_edge(veh_id))
/ (self.k.scenario.network.num_edges - 1)
[self._convert_edge(self.k.vehicle.get_edge(veh_id)) /
(self.k.scenario.network.num_edges - 1)
for veh_id in observed_ids]

if len(observed_ids) < self.num_observed:
Expand Down

0 comments on commit d7237bc

Please sign in to comment.