diff --git a/docs/source/cppapi/cpp-blame.md b/docs/source/cppapi/cpp-blame.md new file mode 100644 index 0000000..3d4fc76 --- /dev/null +++ b/docs/source/cppapi/cpp-blame.md @@ -0,0 +1,17 @@ +# BLAME (C++) + + +```{eval-rst} +.. doxygenfunction:: invertedai::blame + :project: InvertedAI-CPP +``` + +--- +```{eval-rst} +.. doxygenclass:: invertedai::BlameRequest + :members: + :undoc-members: +.. doxygenclass:: invertedai::BlameResponse + :members: + :undoc-members: +``` diff --git a/docs/source/cppapi/index.md b/docs/source/cppapi/index.md index 1769929..8d2aee6 100644 --- a/docs/source/cppapi/index.md +++ b/docs/source/cppapi/index.md @@ -7,6 +7,7 @@ accessed directly. Below are the key functions of the library, along with some c ```{toctree} :maxdepth: 1 +cpp-blame cpp-drive cpp-initialize cpp-location-info diff --git a/docs/source/pythonapi/index.md b/docs/source/pythonapi/index.md index d4f90e6..2a7307a 100644 --- a/docs/source/pythonapi/index.md +++ b/docs/source/pythonapi/index.md @@ -7,13 +7,13 @@ accessed directly. Below are the key functions of the library, along with some c ```{toctree} :maxdepth: 1 +sdk-blame sdk-drive sdk-initialize -sdk-location-info sdk-light -sdk-blame -sdk-simulation +sdk-location-info sdk-common +sdk-simulation sdk-env-var ``` diff --git a/examples/carla/region_drive.py b/examples/carla/region_drive.py index 7b93cee..ba7fbc4 100644 --- a/examples/carla/region_drive.py +++ b/examples/carla/region_drive.py @@ -1,6 +1,9 @@ +import sys +sys.path.append('../') + import argparse import invertedai as iai -from invertedai.simulation.simulator import Simulation, SimulationConfig +from simulation.simulator import Simulation, SimulationConfig import pathlib import pygame from tqdm import tqdm @@ -17,7 +20,7 @@ parser.add_argument("-cap", "--quadtree_capacity", type=int, default=15) parser.add_argument("-ad", "--agent_density", type=int, default=10) parser.add_argument("-ri", "--re_initialization", type=int, default=30) -parser.add_argument("-len", "--simulation_length", type=int, default=10000) +parser.add_argument("-len", "--simulation_length", type=int, default=600) args = parser.parse_args() @@ -29,12 +32,12 @@ cfg = SimulationConfig(location=args.location, map_center=(response.map_center.x, response.map_center.y), map_fov=response.map_fov, rendered_static_map=rendered_static_map, - map_width=response.map_fov+200, map_height=response.map_fov+200, agent_density=args.agent_density, + map_width=response.map_fov, map_height=response.map_fov, agent_density=args.agent_density, initialize_stride=50, quadtree_capacity=args.quadtree_capacity, re_initialization_period=args.re_initialization) simulation = Simulation(cfg=cfg) -fps = 100 +fps = 60 clock = pygame.time.Clock() run = True start = perf_counter() diff --git a/examples/simulation/regions.py b/examples/simulation/regions.py index 9058eaa..8039a8a 100644 --- a/examples/simulation/regions.py +++ b/examples/simulation/regions.py @@ -2,7 +2,7 @@ from simulation.utils import Rectangle, RE_INITIALIZATION_PERIOD, DEBUG from typing import List, Optional, Callable from random import randint -from invertedai import drive, async_drive +from invertedai.api.drive import drive, async_drive from simulation.car import Car diff --git a/invertedai/api/drive.py b/invertedai/api/drive.py index 40acf9f..da457c7 100644 --- a/invertedai/api/drive.py +++ b/invertedai/api/drive.py @@ -41,7 +41,7 @@ class DriveResponse(BaseModel): is_inside_supported_area: List[ bool ] #: For each agent, indicates whether the predicted state is inside supported area. - + model_version: str # Model version used for this API call @validate_arguments def drive( @@ -55,6 +55,7 @@ def drive( rendering_fov: Optional[float] = None, get_infractions: bool = False, random_seed: Optional[int] = None, + model_version: Optional[str] = None ) -> DriveResponse: """ Parameters @@ -99,6 +100,8 @@ def drive( random_seed: Controls the stochastic aspects of agent behavior for reproducibility. + model_version: + Optionally specify the version of the model. If None is passed which is by default, the best model will be used. See Also -------- :func:`initialize` @@ -144,7 +147,8 @@ def _tolist(input_data: List): get_infractions=get_infractions, random_seed=random_seed, rendering_center=rendering_center, - rendering_fov=rendering_fov + rendering_fov=rendering_fov, + model_version=model_version ) start = time.time() timeout = TIMEOUT @@ -170,6 +174,7 @@ def _tolist(input_data: List): if response["infraction_indicators"] else [], is_inside_supported_area=response["is_inside_supported_area"], + model_version=response["model_version"] ) return response @@ -193,6 +198,7 @@ async def async_drive( rendering_fov: Optional[float] = None, get_infractions: bool = False, random_seed: Optional[int] = None, + model_version: Optional[str] = None ) -> DriveResponse: """ A light async version of :func:`drive` @@ -216,7 +222,8 @@ def _tolist(input_data: List): get_infractions=get_infractions, random_seed=random_seed, rendering_center=rendering_center, - rendering_fov=rendering_fov + rendering_fov=rendering_fov, + model_version=model_version, ) response = await iai.session.async_request(model="drive", data=model_inputs) @@ -237,6 +244,7 @@ def _tolist(input_data: List): if response["infraction_indicators"] else [], is_inside_supported_area=response["is_inside_supported_area"], + model_version=response["model_version"] ) return response diff --git a/invertedai/api/initialize.py b/invertedai/api/initialize.py index e9e0702..af5db71 100644 --- a/invertedai/api/initialize.py +++ b/invertedai/api/initialize.py @@ -41,6 +41,7 @@ class InitializeResponse(BaseModel): infractions: Optional[ List[InfractionIndicators] ] #: If `get_infractions` was set, they are returned here. + model_version: str # Model version used for this API call @validate_arguments @@ -56,6 +57,7 @@ def initialize( get_infractions: bool = False, agent_count: Optional[int] = None, random_seed: Optional[int] = None, + model_version: Optional[str] = None # Model version used for this API call ) -> InitializeResponse: """ Initializes a simulation in a given location. @@ -104,6 +106,9 @@ def initialize( random_seed: Controls the stochastic aspects of initialization for reproducibility. + model_version: + Optionally specify the version of the model. If None is passed which is by default, the best model will be used. + See Also -------- :func:`drive` @@ -152,6 +157,7 @@ def initialize( location_of_interest=location_of_interest, get_infractions=get_infractions, random_seed=random_seed, + model_version=model_version ) start = time.time() timeout = TIMEOUT @@ -182,6 +188,7 @@ def initialize( ] if response["infraction_indicators"] else [], + model_version=response["model_version"] ) return response except TryAgain as e: @@ -203,6 +210,7 @@ async def async_initialize( get_infractions: bool = False, agent_count: Optional[int] = None, random_seed: Optional[int] = None, + model_version: Optional[str] = None ) -> InitializeResponse: """ The async version of :func:`initialize` @@ -228,6 +236,7 @@ async def async_initialize( location_of_interest=location_of_interest, get_infractions=get_infractions, random_seed=random_seed, + model_version=model_version ) response = await iai.session.async_request(model="initialize", data=model_inputs) @@ -255,5 +264,6 @@ async def async_initialize( ] if response["infraction_indicators"] else [], + model_version=response["model_version"] ) return response diff --git a/invertedai_cpp/Dockerfile b/invertedai_cpp/Dockerfile index e83b95f..80296d5 100644 --- a/invertedai_cpp/Dockerfile +++ b/invertedai_cpp/Dockerfile @@ -1,4 +1,4 @@ -FROM ubuntu:22.10 +FROM ubuntu:22.04 ENV DEBIAN_FRONTEND noninteractive RUN apt-get update && apt-get install -y \ && apt-get install -y build-essential \ diff --git a/invertedai_cpp/examples/BUILD b/invertedai_cpp/examples/BUILD index 2089c2f..1f1afd2 100644 --- a/invertedai_cpp/examples/BUILD +++ b/invertedai_cpp/examples/BUILD @@ -30,3 +30,18 @@ cc_binary( "@opencv", ], ) + +cc_binary( + name = "fps_control_demo", + srcs = ["fps_control_demo.cc"], + data = [ + "drive_body.json", + "initialize_body.json", + "location_info_body.json", + ], + deps = [ + "//invertedai:api", + "@boost//:beast", + "@opencv", + ], +) \ No newline at end of file diff --git a/invertedai_cpp/examples/conditional_initialize_body.json b/invertedai_cpp/examples/conditional_initialize_body.json new file mode 100644 index 0000000..4e2e3e7 --- /dev/null +++ b/invertedai_cpp/examples/conditional_initialize_body.json @@ -0,0 +1,26 @@ +{ + "location": "iai:ubc_roundabout", + "num_agents_to_spawn": 10, + "states_history": null, + "agent_attributes": null, + "traffic_light_state_history": null, + "get_birdview": true, + "get_infractions": false, + "random_seed": null, + "location_of_interest": null, + "conditional_agent_states": [ + [ + -4.79, + -18.32, + -1.35, + 6.62 + ] + ], + "conditional_agent_attributes": [ + [ + 4.4, + 1.94, + 1.41 + ] + ] +} diff --git a/invertedai_cpp/examples/drive_body.json b/invertedai_cpp/examples/drive_body.json index d26b7ba..b7f9a51 100644 --- a/invertedai_cpp/examples/drive_body.json +++ b/invertedai_cpp/examples/drive_body.json @@ -1,5 +1,5 @@ { - "location": "canada:vancouver:ubc_roundabout", + "location": "iai:ubc_roundabout", "agent_states": [ [ 33.8, diff --git a/invertedai_cpp/examples/fps_control_demo.cc b/invertedai_cpp/examples/fps_control_demo.cc new file mode 100644 index 0000000..3e4a580 --- /dev/null +++ b/invertedai_cpp/examples/fps_control_demo.cc @@ -0,0 +1,292 @@ +#include +#include +#include +#include +#include + +#define _USE_MATH_DEFINES +#include //includes std::async library +#include +#include +#include + +#include +#include +#include + +#include "../invertedai/api.h" +#include "../invertedai/data_utils.h" + +using tcp = net::ip::tcp; // from +using json = nlohmann::json; // from + +const unsigned int IAI_FPS = 10; + +struct EgoAgentInput { + std::vector ego_states; + std::vector ego_attributes; +}; + +json get_ego_log(const std::string& file_path) { + std::ifstream ego_log_file(file_path); + json example_ego_agent_log = json::parse(ego_log_file); + return example_ego_agent_log; +} + +EgoAgentInput get_ego_agents(const json& example_ego_agent_log){ + EgoAgentInput output_struct; + + invertedai::AgentState current_ego_state; + current_ego_state.x = example_ego_agent_log.at(0)["json"]["agent_states"][0][0]; + current_ego_state.y = example_ego_agent_log.at(0)["json"]["agent_states"][0][1]; + current_ego_state.orientation = example_ego_agent_log.at(0)["json"]["agent_states"][0][2]; + current_ego_state.speed = example_ego_agent_log.at(0)["json"]["agent_states"][0][3]; + output_struct.ego_states = {current_ego_state}; + + invertedai::AgentAttributes ego_attributes; + ego_attributes.length = example_ego_agent_log.at(0)["json"]["agent_attributes"][0][0]; + ego_attributes.width = example_ego_agent_log.at(0)["json"]["agent_attributes"][0][1]; + ego_attributes.rear_axis_offset = example_ego_agent_log.at(0)["json"]["agent_attributes"][0][2]; + output_struct.ego_attributes = {ego_attributes}; + + return output_struct; +} + +std::vector> linear_interpolation(const std::vector& current_agent_states, const std::vector& next_agent_states, const int number_steps) { + if (current_agent_states.size() != next_agent_states.size()) { + throw std::runtime_error("Size of vector arguements for interpolation does not match."); + } + + std::vector> interpolated_states; + double number_steps_div = (double)number_steps; + + for (int tt = 0; tt < number_steps; tt++) { + std::vector timestep_states; + for(int i = 0; i < (int)next_agent_states.size(); i++){ + invertedai::AgentState agent_state_interpolated; + invertedai::AgentState agent_state_current = current_agent_states[i]; + invertedai::AgentState agent_state_next = next_agent_states[i]; + + agent_state_interpolated.x = agent_state_current.x + (agent_state_next.x - agent_state_current.x)*(tt/number_steps_div); + agent_state_interpolated.y = agent_state_current.y + (agent_state_next.y - agent_state_current.y)*(tt/number_steps_div); + agent_state_interpolated.orientation = agent_state_current.orientation + (agent_state_next.orientation - agent_state_current.orientation)*(tt/number_steps_div); + agent_state_interpolated.speed = agent_state_current.speed + (agent_state_next.speed - agent_state_current.speed)*(tt/number_steps_div); + timestep_states.push_back(agent_state_interpolated); + } + interpolated_states.push_back(timestep_states); + } + + return interpolated_states; +} + +double angle_wrap(double angle) { + //Assume angles are given in radians + const int MAX_ITERATIONS = 1000; + int num_iteration = 0; + while ((angle > M_PI) || (angle < -M_PI)) { + if (angle > M_PI) { + angle -= 2*M_PI; + } + else if (angle < -M_PI) { + angle += 2*M_PI; + } + num_iteration += 1; + if (num_iteration >= MAX_ITERATIONS) { + throw std::runtime_error("Exceeded maximum allowable iterations."); + } + } + + return angle; +} + +double get_angle_difference(double a, double b){ + //Assume angles are wrapped + //Assume angles are in radians + //Assume the equation is a - b + double sub = a -b; + if (sub > M_PI) { + sub = 2*M_PI - sub; + } + else if (sub < -M_PI) { + sub = -2*M_PI + sub; + } + + return sub; +} + +std::vector extrapolate_ego_agents(const std::vector& current_ego_states, const std::vector& previous_ego_states){ + std::vector estimated_ego_states; + for(int ag = 0; ag < (int)current_ego_states.size(); ag++){ + invertedai::AgentState ego_agent; + double angle_difference = get_angle_difference(angle_wrap(current_ego_states[ag].orientation),angle_wrap(previous_ego_states[ag].orientation)); + ego_agent.orientation = current_ego_states[ag].orientation + angle_difference; //Assume constant angular velocity over time period from previous time period + ego_agent.orientation = angle_wrap(ego_agent.orientation); + ego_agent.speed = current_ego_states[ag].speed + (current_ego_states[ag].speed - previous_ego_states[ag].speed); //Assume constant acceleration over time period from previous time period + + //Estimate future position of agent based on estimated average speed between start and beginning of time period + double avg_speed = (current_ego_states[ag].speed + ego_agent.speed)/2; + ego_agent.x = current_ego_states[ag].x + avg_speed*sin(current_ego_states[ag].orientation)/IAI_FPS; + ego_agent.y = current_ego_states[ag].y + avg_speed*cos(current_ego_states[ag].orientation)/IAI_FPS; + + estimated_ego_states.push_back(ego_agent); + } + + return estimated_ego_states; +} + +std::vector split_npc_and_ego_states(std::vector& combined_agent_vector, const int num_ego_agents){ + //Returns the vector of ego states + //Assume the ego agents are at the beginning of the vector + + std::vector ego_states; + for (int a = 0; a < num_ego_agents; a++) { + ego_states.push_back(combined_agent_vector.front()); + combined_agent_vector.erase(combined_agent_vector.begin()); + } + + return ego_states; + +} + + +// usage: ./fps_control_demo $location $agent_num $timestep $api_key $FPS +int main(int argc, char **argv) { + try { + const std::string location(argv[1]); + const unsigned int agent_num = std::stoi(argv[2]); + const unsigned int timestep = std::stoi(argv[3]); + const std::string api_key(argv[4]); + const unsigned int FPS = std::stoi(argv[5]); + if (FPS % IAI_FPS != 0) { + throw std::invalid_argument("FPS argument must be a multiple of 10."); + } + int NUM_INTERP_STEPS = (int) FPS/IAI_FPS; + + net::io_context ioc; + ssl::context ctx(ssl::context::tlsv12_client); + // configure connection setting + invertedai::Session session(ioc, ctx); + session.set_api_key(api_key); + session.connect(); + + // construct request for getting information about the location + invertedai::LocationInfoRequest loc_info_req( + invertedai::read_file("examples/location_info_body.json")); + loc_info_req.set_location(location); + + // get response of location information + invertedai::LocationInfoResponse loc_info_res = + invertedai::location_info(loc_info_req, &session); + + // use opencv to decode and save the bird's eye view image of the simulation + auto image = cv::imdecode(loc_info_res.birdview_image(), cv::IMREAD_COLOR); + cv::cvtColor(image, image, cv::COLOR_BGR2RGB); + int frame_width = image.rows; + int frame_height = image.cols; + cv::VideoWriter video("iai-demo.avi", + cv::VideoWriter::fourcc('M', 'J', 'P', 'G'), 10, + cv::Size(frame_width, frame_height)); + + + ////////////////////////////////////////////////////////////////////////////// + //REPLACE THIS BLOCK OF CODE WITH YOUR OWN EGO AGENT MODEL + //Get ego agent initial state and attributes to set into initial conditions + json example_ego_agent_log = get_ego_log("examples/ubc_roundabout_ego_agent_log.json"); + EgoAgentInput ego_agent_struct = get_ego_agents(example_ego_agent_log); + std::vector current_ego_states = ego_agent_struct.ego_states; + std::vector all_ego_attributes = ego_agent_struct.ego_attributes; + ////////////////////////////////////////////////////////////////////////////// + const int NUMBER_EGO_AGENTS = (int)all_ego_attributes.size(); + + // construct request for initializing the simulation (placing NPCs on the map) + invertedai::InitializeRequest init_req(invertedai::read_file("examples/conditional_initialize_body.json")); + // set the location + init_req.set_location(location); + // set the number of agents + init_req.set_num_agents_to_spawn(agent_num); + std::vector> current_ego_states_history = {current_ego_states}; + init_req.set_states_history(current_ego_states_history); + init_req.set_agent_attributes(all_ego_attributes); + + // get the response of simulation initialization + invertedai::InitializeResponse init_res = + invertedai::initialize(init_req, &session); + std::vector current_agent_states = init_res.agent_states(); //Should be NPC states only + split_npc_and_ego_states(current_agent_states,NUMBER_EGO_AGENTS); + + // construct request for stepping the simulation (driving the NPCs) + invertedai::DriveRequest drive_req( + invertedai::read_file("examples/drive_body.json")); + drive_req.set_location(location); + drive_req.update(init_res); + + //Acquire the next drive states based on the current NPC and ego agent states + invertedai::DriveResponse next_drive_res = invertedai::drive(drive_req, &session); + std::vector next_agent_states = next_drive_res.agent_states(); //Should be NPC states only + split_npc_and_ego_states(next_agent_states,NUMBER_EGO_AGENTS); + drive_req.update(next_drive_res); + + for (int t = 0; t < (int)timestep; t++) { + std::vector drive_states = drive_req.agent_states(); + std::vector previous_ego_states = split_npc_and_ego_states(drive_states,NUMBER_EGO_AGENTS); + std::vector estimated_ego_states = extrapolate_ego_agents(current_ego_states, previous_ego_states); + estimated_ego_states.insert(estimated_ego_states.end(), drive_states.begin(), drive_states.end()); //Concatenate ego states to NPC states + drive_req.set_agent_states(estimated_ego_states); + + //Get the future IAI agent states while stepping through the higher FPS timesteps + std::future drive_res = std::async (std::launch::async,invertedai::drive,std::ref(drive_req),&session); + std::vector> interpolated_states = linear_interpolation(current_agent_states,next_agent_states,NUM_INTERP_STEPS); + + ////////////////////////////////////////////////////////////////////////////// + //REPLACE THIS BLOCK OF CODE WITH YOUR OWN EGO AGENT MODEL + //Example acquiring the ego agent states between current and next IAI timestep + invertedai::AgentState next_ego_state; + next_ego_state.x = example_ego_agent_log.at(t+1)["cars"][0]["x"]; + next_ego_state.y = example_ego_agent_log.at(t+1)["cars"][0]["y"]; + next_ego_state.orientation = example_ego_agent_log.at(t+1)["cars"][0]["orientation"]; + std::vector next_ego_states = {next_ego_state}; + std::vector> ego_states = linear_interpolation(current_ego_states,next_ego_states,NUM_INTERP_STEPS); + + //Show the time steps between IAI time steps + std::cout << "Time step: " << t << std::endl; + for(int i = 0; i < NUM_INTERP_STEPS; i++) { + std::cout << "Sub time step: " << i << std::endl; + for(int j = 0; j < (int)ego_states[i].size(); j++){ + std::vector timestep_states = ego_states[i]; + std::cout << "Ego Agent State " << j << ": [x: " << timestep_states[j].x << ", y: " << timestep_states[j].y << ", orientation: " << timestep_states[j].orientation << "]" << std::endl; + } + for(int j = 0; j < (int)interpolated_states[i].size(); j++){ + std::vector timestep_states = interpolated_states[i]; + std::cout << "NPC Agent State " << j << ": [x: " << timestep_states[j].x << ", y: " << timestep_states[j].y << ", orientation: " << timestep_states[j].orientation << "]" << std::endl; + } + } + + for(int ag = 0; ag < (int)current_ego_states.size(); ag++) { + double speed = sqrt(std::pow(next_ego_states[ag].x-current_ego_states[ag].x,2) + std::pow(next_ego_states[ag].y-current_ego_states[ag].y,2))*IAI_FPS; + current_ego_states[ag] = next_ego_states[ag]; + current_ego_states[ag].speed = speed; + } + ////////////////////////////////////////////////////////////////////////////// + + current_agent_states = next_agent_states; + next_drive_res = drive_res.get(); + + auto image = cv::imdecode(next_drive_res.birdview(), cv::IMREAD_COLOR); + cv::cvtColor(image, image, cv::COLOR_BGR2RGB); + video.write(image); + + next_agent_states = next_drive_res.agent_states(); + next_agent_states.pop_back(); //Remove the ego agent from these states + drive_req.update(next_drive_res); + + } + video.release(); + + } catch (std::exception const &e) { + std::cerr << "Error: " << e.what() << std::endl; + return EXIT_FAILURE; + + } + return EXIT_SUCCESS; + +} diff --git a/invertedai_cpp/examples/initialize_body.json b/invertedai_cpp/examples/initialize_body.json index e9f46ae..460ea84 100755 --- a/invertedai_cpp/examples/initialize_body.json +++ b/invertedai_cpp/examples/initialize_body.json @@ -1,5 +1,5 @@ { - "location": "canada:vancouver:ubc_roundabout", + "location": "iai:ubc_roundabout", "num_agents_to_spawn": 10, "states_history": null, "agent_attributes": null, diff --git a/invertedai_cpp/examples/location_info_body.json b/invertedai_cpp/examples/location_info_body.json index 3df61e9..5adac13 100644 --- a/invertedai_cpp/examples/location_info_body.json +++ b/invertedai_cpp/examples/location_info_body.json @@ -1,5 +1,5 @@ { - "location": "canada:vancouver:ubc_roundabout", + "location": "iai:ubc_roundabout", "include_map_source": false, "rendering_fov": null, "rendering_center": null diff --git a/invertedai_cpp/examples/ubc_roundabout_ego_agent_log.json b/invertedai_cpp/examples/ubc_roundabout_ego_agent_log.json new file mode 100644 index 0000000..bb07aab --- /dev/null +++ b/invertedai_cpp/examples/ubc_roundabout_ego_agent_log.json @@ -0,0 +1,1134 @@ +[ + { + "message": "initialized cars", + "json": { + "agent_states": [ + [ + -4.79, + -18.32, + -1.35, + 6.62 + ] + ], + "agent_attributes": [ + [ + 4.4, + 1.94, + 1.41 + ] + ] + } + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": -4.65, + "y": -18.96, + "orientation": -1.36 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": -4.52, + "y": -19.61, + "orientation": -1.37 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": -4.37, + "y": -20.25, + "orientation": -1.35 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": -4.18, + "y": -20.88, + "orientation": -1.32 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": -3.95, + "y": -21.5, + "orientation": -1.27 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": -3.7, + "y": -22.12, + "orientation": -1.23 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": -3.44, + "y": -22.73, + "orientation": -1.2 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": -3.16, + "y": -23.34, + "orientation": -1.17 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": -2.84, + "y": -23.93, + "orientation": -1.12 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": -2.5, + "y": -24.48, + "orientation": -1.07 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": -2.16, + "y": -25.03, + "orientation": -1.05 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": -1.81, + "y": -25.58, + "orientation": -1.03 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": -1.45, + "y": -26.12, + "orientation": -1.01 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": -1.08, + "y": -26.65, + "orientation": -0.99 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": -0.7, + "y": -27.17, + "orientation": -0.97 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": -0.3, + "y": -27.68, + "orientation": -0.94 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 0.12, + "y": -28.18, + "orientation": -0.91 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 0.56, + "y": -28.67, + "orientation": -0.87 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 1.02, + "y": -29.15, + "orientation": -0.84 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 1.49, + "y": -29.63, + "orientation": -0.82 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 1.98, + "y": -30.1, + "orientation": -0.79 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 2.49, + "y": -30.55, + "orientation": -0.76 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 3.03, + "y": -30.97, + "orientation": -0.71 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 3.58, + "y": -31.37, + "orientation": -0.67 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 4.15, + "y": -31.74, + "orientation": -0.63 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 4.73, + "y": -32.1, + "orientation": -0.59 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 5.31, + "y": -32.46, + "orientation": -0.57 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 5.89, + "y": -32.82, + "orientation": -0.56 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 6.48, + "y": -33.18, + "orientation": -0.55 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 7.1, + "y": -33.54, + "orientation": -0.54 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 7.73, + "y": -33.89, + "orientation": -0.53 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 8.37, + "y": -34.23, + "orientation": -0.51 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 9.01, + "y": -34.56, + "orientation": -0.5 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 9.61, + "y": -34.87, + "orientation": -0.49 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 10.21, + "y": -35.18, + "orientation": -0.48 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 10.78, + "y": -35.44, + "orientation": -0.46 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 11.34, + "y": -35.7, + "orientation": -0.45 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 11.88, + "y": -35.95, + "orientation": -0.44 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 12.42, + "y": -36.21, + "orientation": -0.44 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 12.97, + "y": -36.47, + "orientation": -0.44 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 13.54, + "y": -36.75, + "orientation": -0.45 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 14.13, + "y": -37.05, + "orientation": -0.46 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 14.74, + "y": -37.37, + "orientation": -0.47 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 15.36, + "y": -37.71, + "orientation": -0.48 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 16, + "y": -38.07, + "orientation": -0.5 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 16.63, + "y": -38.44, + "orientation": -0.52 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 17.22, + "y": -38.79, + "orientation": -0.53 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 17.81, + "y": -39.17, + "orientation": -0.55 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 18.43, + "y": -39.58, + "orientation": -0.56 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 19.07, + "y": -40.01, + "orientation": -0.58 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 19.7, + "y": -40.45, + "orientation": -0.6 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 20.31, + "y": -40.9, + "orientation": -0.62 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 20.91, + "y": -41.36, + "orientation": -0.64 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 21.46, + "y": -41.8, + "orientation": -0.66 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 21.95, + "y": -42.24, + "orientation": -0.69 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 22.41, + "y": -42.64, + "orientation": -0.7 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 22.9, + "y": -43.06, + "orientation": -0.7 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 23.43, + "y": -43.51, + "orientation": -0.7 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 23.94, + "y": -43.98, + "orientation": -0.72 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 24.35, + "y": -44.47, + "orientation": -0.79 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 24.77, + "y": -44.9, + "orientation": -0.79 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 25.19, + "y": -45.34, + "orientation": -0.8 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 25.63, + "y": -45.79, + "orientation": -0.8 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 26.07, + "y": -46.21, + "orientation": -0.79 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 26.45, + "y": -46.62, + "orientation": -0.8 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 26.8, + "y": -47.01, + "orientation": -0.81 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 27.16, + "y": -47.39, + "orientation": -0.81 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 27.53, + "y": -47.78, + "orientation": -0.81 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 27.92, + "y": -48.21, + "orientation": -0.82 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 28.33, + "y": -48.67, + "orientation": -0.83 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 28.77, + "y": -49.17, + "orientation": -0.84 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 29.22, + "y": -49.7, + "orientation": -0.85 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 29.66, + "y": -50.23, + "orientation": -0.86 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 30.07, + "y": -50.72, + "orientation": -0.87 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 30.44, + "y": -51.19, + "orientation": -0.88 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 30.77, + "y": -51.63, + "orientation": -0.9 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 31.1, + "y": -52.07, + "orientation": -0.91 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 31.45, + "y": -52.54, + "orientation": -0.92 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 31.82, + "y": -53.05, + "orientation": -0.93 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 32.2, + "y": -53.59, + "orientation": -0.94 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 32.6, + "y": -54.16, + "orientation": -0.95 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 33, + "y": -54.73, + "orientation": -0.95 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 33.4, + "y": -55.28, + "orientation": -0.95 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 33.77, + "y": -55.8, + "orientation": -0.95 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 34.14, + "y": -56.32, + "orientation": -0.95 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 34.51, + "y": -56.84, + "orientation": -0.95 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 34.91, + "y": -57.4, + "orientation": -0.95 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 35.33, + "y": -58, + "orientation": -0.95 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 35.75, + "y": -58.62, + "orientation": -0.96 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 36.16, + "y": -59.24, + "orientation": -0.97 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 36.55, + "y": -59.83, + "orientation": -0.98 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 36.95, + "y": -60.43, + "orientation": -0.98 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 37.38, + "y": -61.04, + "orientation": -0.97 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 37.8, + "y": -61.66, + "orientation": -0.97 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 38.18, + "y": -62.26, + "orientation": -0.99 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 38.56, + "y": -62.84, + "orientation": -0.99 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 38.94, + "y": -63.4, + "orientation": -0.98 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 39.3, + "y": -63.93, + "orientation": -0.97 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 39.65, + "y": -64.46, + "orientation": -0.97 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 39.98, + "y": -64.97, + "orientation": -0.98 + } + ] + }, + { + "message": "new car positions after DRIVE call", + "cars": [ + { + "id": 0, + "x": 40.34, + "y": -65.51, + "orientation": -0.98 + } + ] + } +] + diff --git a/invertedai_cpp/invertedai/drive_request.cc b/invertedai_cpp/invertedai/drive_request.cc index c115123..d6e55c9 100644 --- a/invertedai_cpp/invertedai/drive_request.cc +++ b/invertedai_cpp/invertedai/drive_request.cc @@ -47,6 +47,10 @@ DriveRequest::DriveRequest(const std::string &body_str) { this->body_json_["random_seed"].is_number_integer() ? std::optional{this->body_json_["random_seed"].get()} : std::nullopt; + this->model_version_ = this->body_json_["model_version"].is_null() + ? std::nullopt + : std::optional{ + this->body_json_["model_version"]}; } void DriveRequest::refresh_body_json_() { @@ -96,6 +100,11 @@ void DriveRequest::refresh_body_json_() { } else { this->body_json_["random_seed"] = nullptr; } + if (this->model_version_.has_value()) { + this->body_json_["model_version"] = this->model_version_.value(); + } else { + this->body_json_["model_version"] = nullptr; + } } void DriveRequest::update(const InitializeResponse &init_res) { @@ -145,6 +154,11 @@ DriveRequest::rendering_center() const { return this->rendering_center_; } +std::optional +DriveRequest::model_version() const { + return this->model_version_; +} + std::optional DriveRequest::random_seed() const { return this->random_seed_; } @@ -194,4 +208,8 @@ void DriveRequest::set_random_seed(std::optional random_seed) { this->random_seed_ = random_seed; } +void DriveRequest::set_model_version(std::optional model_version) { + this->model_version_ = model_version; +} + } // namespace invertedai diff --git a/invertedai_cpp/invertedai/drive_request.h b/invertedai_cpp/invertedai/drive_request.h index 171eadb..188dd07 100644 --- a/invertedai_cpp/invertedai/drive_request.h +++ b/invertedai_cpp/invertedai/drive_request.h @@ -27,6 +27,7 @@ class DriveRequest { std::optional random_seed_; std::optional rendering_fov_; std::optional> rendering_center_; + std::optional model_version_; json body_json_; void refresh_body_json_(); @@ -94,6 +95,10 @@ class DriveRequest { * Get random_seed. */ std::optional random_seed() const; + /** + * Get model version. + */ + std::optional model_version() const; // setters /** @@ -149,6 +154,10 @@ class DriveRequest { * for reproducibility. */ void set_random_seed(std::optional random_seed); + /** + * Set model version. If None is passed which is by default, the best model will be used. + */ + void set_model_version(std::optional model_version); }; } // namespace invertedai diff --git a/invertedai_cpp/invertedai/drive_response.cc b/invertedai_cpp/invertedai/drive_response.cc index 5158b7d..2064d2c 100644 --- a/invertedai_cpp/invertedai/drive_response.cc +++ b/invertedai_cpp/invertedai/drive_response.cc @@ -35,6 +35,8 @@ DriveResponse::DriveResponse(const std::string &body_str) { element[2]}; this->infraction_indicators_.push_back(infraction_indicator); } + this->model_version_.clear(); + this->model_version_ = body_json_["model_version"]; } void DriveResponse::refresh_body_json_() { @@ -69,6 +71,8 @@ void DriveResponse::refresh_body_json_() { infraction_indicator.wrong_way}; this->body_json_["infraction_indicators"].push_back(element); } + this->model_version_.clear(); + this->model_version_ = body_json_["model_version"]; } std::string DriveResponse::body_str() { @@ -96,6 +100,11 @@ std::vector DriveResponse::infraction_indicators() const { return this->infraction_indicators_; } +std::string +DriveResponse::model_version() const { + return this->model_version_; +} + void DriveResponse::set_agent_states( const std::vector &agent_states) { this->agent_states_ = agent_states; diff --git a/invertedai_cpp/invertedai/drive_response.h b/invertedai_cpp/invertedai/drive_response.h index c8f09a5..cb2a457 100644 --- a/invertedai_cpp/invertedai/drive_response.h +++ b/invertedai_cpp/invertedai/drive_response.h @@ -18,6 +18,7 @@ class DriveResponse { std::vector> recurrent_states_; std::vector birdview_; std::vector infraction_indicators_; + std::string model_version_; json body_json_; void refresh_body_json_(); @@ -54,6 +55,10 @@ class DriveResponse { * If get_infractions was set, they are returned here. */ std::vector infraction_indicators() const; + /** + * Get model version. + */ + std::string model_version() const; // setters /** diff --git a/invertedai_cpp/invertedai/initialize_request.cc b/invertedai_cpp/invertedai/initialize_request.cc index 846382a..8a66453 100644 --- a/invertedai_cpp/invertedai/initialize_request.cc +++ b/invertedai_cpp/invertedai/initialize_request.cc @@ -67,6 +67,10 @@ InitializeRequest::InitializeRequest(const std::string &body_str) { this->body_json_["random_seed"].is_number_integer() ? std::optional{this->body_json_["random_seed"].get()} : std::nullopt; + this->model_version_ = this->body_json_["model_version"].is_null() + ? std::nullopt + : std::optional{ + this->body_json_["model_version"]}; } void InitializeRequest::refresh_body_json_() { @@ -118,6 +122,11 @@ void InitializeRequest::refresh_body_json_() { } else { this->body_json_["random_seed"] = nullptr; } + if (this->model_version_.has_value()) { + this->body_json_["model_version"] = this->model_version_.value(); + } else { + this->body_json_["model_version"] = nullptr; + } }; std::string InitializeRequest::body_str() { @@ -172,6 +181,11 @@ std::optional InitializeRequest::random_seed() const { return this->random_seed_; } +std::optional +InitializeRequest::model_version() const { + return this->model_version_; +} + void InitializeRequest::set_location(const std::string &location) { this->location_ = location; } @@ -227,4 +241,8 @@ void InitializeRequest::set_random_seed(std::optional random_seed) { this->random_seed_ = random_seed; } +void InitializeRequest::set_model_version(std::optional model_version) { + this->model_version_ = model_version; +} + } // namespace invertedai diff --git a/invertedai_cpp/invertedai/initialize_request.h b/invertedai_cpp/invertedai/initialize_request.h index d256db0..e4a2572 100644 --- a/invertedai_cpp/invertedai/initialize_request.h +++ b/invertedai_cpp/invertedai/initialize_request.h @@ -25,6 +25,7 @@ class InitializeRequest { bool get_infractions_; std::optional agent_count_; std::optional random_seed_; + std::optional model_version_; json body_json_; void refresh_body_json_(); @@ -90,6 +91,10 @@ class InitializeRequest { * for reproducibility. */ std::optional random_seed() const; + /** + * Get model version. + */ + std::optional model_version() const; // setters /** @@ -158,6 +163,10 @@ class InitializeRequest { * for reproducibility. */ void set_random_seed(std::optional random_seed); + /** + * Set model version. If None is passed which is by default, the best model will be used. + */ + void set_model_version(std::optional model_version); }; } // namespace invertedai diff --git a/invertedai_cpp/invertedai/initialize_response.cc b/invertedai_cpp/invertedai/initialize_response.cc index c06e0de..3044712 100644 --- a/invertedai_cpp/invertedai/initialize_response.cc +++ b/invertedai_cpp/invertedai/initialize_response.cc @@ -36,6 +36,8 @@ InitializeResponse::InitializeResponse(const std::string &body_str) { element[2]}; this->infraction_indicators_.push_back(infraction_indicator); } + this->model_version_.clear(); + this->model_version_ = body_json_["model_version"]; } void InitializeResponse::refresh_body_json_() { @@ -72,6 +74,8 @@ void InitializeResponse::refresh_body_json_() { infraction_indicator.wrong_way}; this->body_json_["infraction_indicators"].push_back(element); } + this->model_version_.clear(); + this->model_version_ = body_json_["model_version"]; } std::string InitializeResponse::body_str() { @@ -100,6 +104,11 @@ InitializeResponse::infraction_indicators() const { return this->infraction_indicators_; } +std::string +InitializeResponse::model_version() const { + return this->model_version_; +} + void InitializeResponse::set_agent_states( const std::vector &agent_states) { this->agent_states_ = agent_states; diff --git a/invertedai_cpp/invertedai/initialize_response.h b/invertedai_cpp/invertedai/initialize_response.h index 2a8b6d3..0fc0690 100644 --- a/invertedai_cpp/invertedai/initialize_response.h +++ b/invertedai_cpp/invertedai/initialize_response.h @@ -17,6 +17,7 @@ class InitializeResponse { std::vector> recurrent_states_; std::vector birdview_; std::vector infraction_indicators_; + std::string model_version_; json body_json_; void refresh_body_json_(); @@ -53,6 +54,10 @@ class InitializeResponse { * If get_infractions was set, they are returned here. */ std::vector infraction_indicators() const; + /** + * Get model version. + */ + std::string model_version() const; // setters /** diff --git a/invertedai_cpp/invertedai/session.cc b/invertedai_cpp/invertedai/session.cc index 2f7a142..9492643 100644 --- a/invertedai_cpp/invertedai/session.cc +++ b/invertedai_cpp/invertedai/session.cc @@ -36,19 +36,46 @@ void Session::set_api_key(const std::string &api_key) { } void Session::connect() { - if (!SSL_set_tlsext_host_name(this->stream_.native_handle(), this->host_)) { + auto const results = this->resolver_.resolve(this->host_, this->port_); + if (!local_mode){ + if (!SSL_set_tlsext_host_name(this->ssl_stream_.native_handle(), this->host_)) { beast::error_code ec{static_cast(::ERR_get_error()), net::error::get_ssl_category()}; throw beast::system_error{ec}; } - auto const results = this->resolver_.resolve(this->host_, this->port_); - beast::get_lowest_layer(this->stream_).connect(results); - this->stream_.handshake(ssl::stream_base::client); + beast::get_lowest_layer(this->ssl_stream_).connect(results); + this->ssl_stream_.handshake(ssl::stream_base::client); + } + else{ + this->tcp_stream_.connect(results); + } } void Session::shutdown() { beast::error_code ec; - this->stream_.shutdown(ec); + if (local_mode){ + // Shutdown the connection + this->tcp_stream_.socket().shutdown(tcp::socket::shutdown_both, ec); + if(ec) { + std::cerr << "Shutdown error: " << ec.message() << "\n"; + throw beast::system_error{ec}; + } + + // Close the socket + this->tcp_stream_.socket().close(ec); + if(ec) { + std::cerr << "Close error: " << ec.message() << "\n"; + throw beast::system_error{ec}; + } + } + else{ + this->ssl_stream_.shutdown(ec); + if(ec) { + std::cerr << "Shutdown error: " << ec.message() << "\n"; + throw beast::system_error{ec}; + } + + } if (ec == net::error::eof) { ec = {}; } @@ -60,7 +87,7 @@ void Session::shutdown() { const std::string Session::request(const std::string &mode, const std::string &body_str, const std::string &url_query_string) { - std::string target = "/v0/aws/m1/" + mode + url_query_string; + std::string target = subdomain + mode + url_query_string; http::request req{ mode == "location_info" ? http::verb::get : http::verb::post, @@ -77,11 +104,25 @@ const std::string Session::request(const std::string &mode, req.body() = body_str; req.prepare_payload(); - http::write(this->stream_, req); + if (local_mode){ + http::write(this->tcp_stream_, req); + + } + else { + http::write(this->ssl_stream_, req); + } + beast::flat_buffer buffer; http::response res; beast::error_code ec; - http::read(this->stream_, buffer, res, ec); + if (local_mode){ + http::read(this->tcp_stream_, buffer, res, ec); + + } + else{ + http::read(this->ssl_stream_, buffer, res, ec); + + } if (!(res.result() == http::status::ok)) { throw std::runtime_error( "response status: " + std::to_string(res.result_int()) + "\nbody:\n" + diff --git a/invertedai_cpp/invertedai/session.h b/invertedai_cpp/invertedai/session.h index a68db74..fec96cc 100644 --- a/invertedai_cpp/invertedai/session.h +++ b/invertedai_cpp/invertedai/session.h @@ -18,16 +18,20 @@ class Session { private: std::string api_key_; tcp::resolver resolver_; - beast::ssl_stream stream_; + beast::ssl_stream ssl_stream_ ; + beast::tcp_stream tcp_stream_ ; const char *debug_mode = std::getenv("DEBUG"); + const char *iai_dev = std::getenv("IAI_DEV"); + const bool local_mode = iai_dev && (std::string(iai_dev) == "1" || std::string(iai_dev) == "True"); public: - const char *host_ = "api.inverted.ai"; - const char *port_ = "443"; + const char* host_ = local_mode ? "localhost" : "api.inverted.ai"; + const char* port_ = local_mode ? "8000" : "443"; + const char *subdomain = local_mode ? "/" : "/v0/aws/m1/";; const int version_ = 11; explicit Session(net::io_context &ioc, ssl::context &ctx) - : resolver_(ioc), stream_(ioc, ctx){}; + : resolver_(ioc), ssl_stream_(ioc, ctx), tcp_stream_(ioc){}; /** * Set your own api key here.