Skip to content

Commit

Permalink
Merge pull request #199 from inverted-ai/utils_additions
Browse files Browse the repository at this point in the history
Utils additions
  • Loading branch information
KieranRatcliffeInvertedAI committed Apr 4, 2024
2 parents 0ae77b5 + 74ddbab commit d2877d5
Showing 1 changed file with 126 additions and 5 deletions.
131 changes: 126 additions & 5 deletions invertedai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
import invertedai.api
import invertedai.api.config
from invertedai import error
from invertedai.common import AgentState, AgentAttributes, StaticMapActor, TrafficLightStatesDict, Point
from invertedai.common import AgentState, AgentAttributes, StaticMapActor, TrafficLightStatesDict, Point, RecurrentState
from invertedai.future import to_thread
from invertedai.error import InvertedAIError
from invertedai.api.initialize import InitializeResponse

H_SCALE = 10
text_x_offset = 0
Expand Down Expand Up @@ -442,6 +443,20 @@ def _interpret_response_line(self, result):

return data

@validate_call
def get_default_agent_attributes(agent_count_dict: Dict[str,int]) -> List[AgentAttributes]:
# Function that outputs a list a AgentAttributes with minimal default settings.
# Mainly meant to be used to pad a list of AgentAttributes to send as input to
# initialize(). This list is created by reading a dictionary containing the
# desired agent types with the agent count for each type respectively.

agent_attributes_list = []

for agent_type, agent_count in agent_count_dict.items():
for _ in range(agent_count):
agent_attributes_list.append(AgentAttributes.fromlist([agent_type]))

return agent_attributes_list

def _get_centers(map_center, height, width, stride):
def check_valid_center(center):
Expand Down Expand Up @@ -486,7 +501,6 @@ def _get_agent_density_per_region(centers,location,agent_density,scaling_factor,

for area_center in iterable_regions:
#Naively check every square within requested area
#TODO: Use heuristics or other methods to (e.g. map polygon, high FOV image, quadtree) to reduce computation time
center_tuple = (area_center.x, area_center.y)
birdview = iai.location_info(
location=location,
Expand Down Expand Up @@ -548,38 +562,50 @@ def area_initialization(
----------
location:
Location name in IAI format.
agent_density:
Maximum agents per 100x100m region to be scaled based on heuristic.
agent_attributes:
Static attributes for all pre-defined agents ONLY. Use the agent_density argument to
specify the number of agents to be sampled.
states_history:
History of pre-defined agent states - the outer list is over time and the inner over agents,
in chronological order, i.e., index 0 is the oldest state and index -1 is the current state.
The order of agents should be the same as in `agent_attributes`.
For best results, provide at least 10 historical states for each agent.
traffic_light_state_history:
History of traffic light states - the list is over time, in chronological order, i.e.
the last element is the current state. If there are traffic lights in the map,
not specifying traffic light state is equivalent to using iai generated light states.
random_seed:
Controls the stochastic aspects of initialization for reproducibility.
map_center:
The x,y coordinate of the center of the area to be initialized.
width:
Distance along the x-axis from the area center to edge of the rectangular area (total
width of the region is 2X the value of this parameter).
height:
Distance along the y-axis from the area center to edge of the rectangular area (total
height of the region is 2X the value of this parameter).
stride:
Distance between the centers of the 100x100m regions.
scaling_factor:
A factor between [0,1] weighting the heuristic for number of agents to spawn in a region.
For example, a value of 0 ignores the heuristic and results in requesting the same number
of agents for all regions.
save_birdviews_to:
If this variable is not None, the birdview images will be saved to this specified path.
display_progress_bar:
If True, a bar is displayed showing the progress of all relevant processes.
Expand Down Expand Up @@ -673,9 +699,9 @@ def inside_fov(center: Point, agent_scope_fov: float, point: Point) -> bool:

try:
all_agents_attributes_in_region = deepcopy(agent_attributes_region_conditional) if agent_attributes_region_conditional is not None else []
for _ in range(num_agents_to_spawn):
# Pad agent attributes list with default values
all_agents_attributes_in_region.append(AgentAttributes.fromlist(["car"]))

padded_agent_attributes = get_default_agent_attributes({"car": num_agents_to_spawn})
all_agents_attributes_in_region.extend(padded_agent_attributes)

# Initialize simulation with an API call
response = iai.initialize(
Expand Down Expand Up @@ -736,6 +762,101 @@ def inside_fov(center: Point, agent_scope_fov: float, point: Point) -> bool:

return response

@validate_call
def iai_conditional_initialize(
location: str,
agent_type_count: Dict[str,int],
location_of_interest: Tuple[float] = (0,0),
recurrent_states: Optional[List[RecurrentState]] = None,
agent_attributes: Optional[List[AgentAttributes]] = None,
states_history: Optional[List[List[AgentState]]] = None,
traffic_light_state_history: Optional[List[TrafficLightStatesDict]] = None,
get_birdview: Optional[bool] = False,
get_infractions: Optional[bool] = False,
random_seed: Optional[int] = None,
api_model_version: Optional[str] = None
):
"""
A utility function to run initialize with conditional agents located at arbitrary distances from the location
of interest. Only agents within a defined distance of the location of interest are passed to initialize as
conditional. Agents outisde of this distance are padded on to the initialize response, including their reccurent
states. Recurrent states must be provided for all agents, otherwise this function behaves like :func:`initialize`.
Please refer to the documentation for :func:`initialize` for more information.
Arguments
----------
location:
Location name in IAI format.
agent_type_count:
A dictionary containing valid AgentType strings as keys mapped to an integer value specifying the desired
number of agents of that type to initialize.
location_of_interest:
Optional coordinates for spawning agents with the given location as center instead of the default map center
See Also
--------
:func:`initialize`
"""

conditional_agent_attributes = []
conditional_agent_states_indexes = []
conditional_recurrent_states = []
outside_agent_states = []
outside_agent_attributes = []
outside_recurrent_states = []

current_agent_states = states_history[-1]
conditional_agent_type_count = deepcopy(agent_type_count)
for i in range(len(current_agent_states)):
agent_state = current_agent_states[i]
dist = math.dist(location_of_interest, (agent_state.center.x, agent_state.center.y))
if dist < AGENT_SCOPE_FOV:
conditional_agent_states_indexes.append(i)
conditional_agent_attributes.append(agent_attributes[i])
conditional_recurrent_states.append(recurrent_states[i])

conditional_agent_type = agent_attributes[i].agent_type
if conditional_agent_type in conditional_agent_type_count:
conditional_agent_type_count[conditional_agent_type] -= 1
if conditional_agent_type_count[conditional_agent_type] <= 0:
del conditional_agent_type_count[conditional_agent_type]

else:
outside_agent_states.append(agent_state)
outside_agent_attributes.append(agent_attributes[i])
outside_recurrent_states.append(recurrent_states[i])

if not conditional_agent_type_count: #The dictionary is empty.
iai.logger.warning("Agent count requirement already satisfied, no new agents initialized.")

padded_agent_attributes = get_default_agent_attributes(conditional_agent_type_count)
conditional_agent_attributes.extend(padded_agent_attributes)

conditional_agent_states = [[]*len(conditional_agent_states_indexes)]
for ts in range(len(conditional_agent_states)):
for agent_index in conditional_agent_states_indexes:
conditional_agent_states[ts].append(states_history[ts][agent_index])

response = invertedai.api.initialize(
location = location,
agent_attributes = conditional_agent_attributes,
states_history = conditional_agent_states,
location_of_interest = location_of_interest,
traffic_light_state_history = traffic_light_state_history,
get_birdview = get_birdview,
get_infractions = get_infractions,
random_seed = random_seed,
api_model_version = api_model_version
)
response.agent_attributes = response.agent_attributes + outside_agent_attributes
response.agent_states = response.agent_states + outside_agent_states
response.recurrent_states = response.recurrent_states + outside_recurrent_states

return response


class APITokenAuth(AuthBase):
def __init__(self, api_token):
self.api_token = api_token
Expand Down

0 comments on commit d2877d5

Please sign in to comment.