Skip to content

Commit

Permalink
Merge pull request #142 from inverted-ai/model-version
Browse files Browse the repository at this point in the history
Handle model_version
  • Loading branch information
Ruishenl committed Nov 24, 2023
2 parents 17dcad5 + b5700e7 commit 6a37db5
Show file tree
Hide file tree
Showing 12 changed files with 160 additions and 15 deletions.
14 changes: 11 additions & 3 deletions invertedai/api/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class DriveResponse(BaseModel):
is_inside_supported_area: List[
bool
] #: For each agent, indicates whether the predicted state is inside supported area.

model_version: str # Model version used for this API call

@validate_arguments
def drive(
Expand All @@ -55,6 +55,7 @@ def drive(
rendering_fov: Optional[float] = None,
get_infractions: bool = False,
random_seed: Optional[int] = None,
model_version: Optional[str] = None
) -> DriveResponse:
"""
Parameters
Expand Down Expand Up @@ -99,6 +100,8 @@ def drive(
random_seed:
Controls the stochastic aspects of agent behavior for reproducibility.
model_version:
Optionally specify the version of the model. If None is passed which is by default, the best model will be used.
See Also
--------
:func:`initialize`
Expand Down Expand Up @@ -144,7 +147,8 @@ def _tolist(input_data: List):
get_infractions=get_infractions,
random_seed=random_seed,
rendering_center=rendering_center,
rendering_fov=rendering_fov
rendering_fov=rendering_fov,
model_version=model_version
)
start = time.time()
timeout = TIMEOUT
Expand All @@ -170,6 +174,7 @@ def _tolist(input_data: List):
if response["infraction_indicators"]
else [],
is_inside_supported_area=response["is_inside_supported_area"],
model_version=response["model_version"]
)

return response
Expand All @@ -193,6 +198,7 @@ async def async_drive(
rendering_fov: Optional[float] = None,
get_infractions: bool = False,
random_seed: Optional[int] = None,
model_version: Optional[str] = None
) -> DriveResponse:
"""
A light async version of :func:`drive`
Expand All @@ -216,7 +222,8 @@ def _tolist(input_data: List):
get_infractions=get_infractions,
random_seed=random_seed,
rendering_center=rendering_center,
rendering_fov=rendering_fov
rendering_fov=rendering_fov,
model_version=model_version,
)
response = await iai.session.async_request(model="drive", data=model_inputs)

Expand All @@ -237,6 +244,7 @@ def _tolist(input_data: List):
if response["infraction_indicators"]
else [],
is_inside_supported_area=response["is_inside_supported_area"],
model_version=response["model_version"]
)

return response
10 changes: 10 additions & 0 deletions invertedai/api/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class InitializeResponse(BaseModel):
infractions: Optional[
List[InfractionIndicators]
] #: If `get_infractions` was set, they are returned here.
model_version: str # Model version used for this API call


@validate_arguments
Expand All @@ -56,6 +57,7 @@ def initialize(
get_infractions: bool = False,
agent_count: Optional[int] = None,
random_seed: Optional[int] = None,
model_version: Optional[str] = None # Model version used for this API call
) -> InitializeResponse:
"""
Initializes a simulation in a given location.
Expand Down Expand Up @@ -104,6 +106,9 @@ def initialize(
random_seed:
Controls the stochastic aspects of initialization for reproducibility.
model_version:
Optionally specify the version of the model. If None is passed which is by default, the best model will be used.
See Also
--------
:func:`drive`
Expand Down Expand Up @@ -152,6 +157,7 @@ def initialize(
location_of_interest=location_of_interest,
get_infractions=get_infractions,
random_seed=random_seed,
model_version=model_version
)
start = time.time()
timeout = TIMEOUT
Expand Down Expand Up @@ -182,6 +188,7 @@ def initialize(
]
if response["infraction_indicators"]
else [],
model_version=response["model_version"]
)
return response
except TryAgain as e:
Expand All @@ -203,6 +210,7 @@ async def async_initialize(
get_infractions: bool = False,
agent_count: Optional[int] = None,
random_seed: Optional[int] = None,
model_version: Optional[str] = None
) -> InitializeResponse:
"""
The async version of :func:`initialize`
Expand All @@ -228,6 +236,7 @@ async def async_initialize(
location_of_interest=location_of_interest,
get_infractions=get_infractions,
random_seed=random_seed,
model_version=model_version
)

response = await iai.session.async_request(model="initialize", data=model_inputs)
Expand Down Expand Up @@ -255,5 +264,6 @@ async def async_initialize(
]
if response["infraction_indicators"]
else [],
model_version=response["model_version"]
)
return response
18 changes: 18 additions & 0 deletions invertedai_cpp/invertedai/drive_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ DriveRequest::DriveRequest(const std::string &body_str) {
this->body_json_["random_seed"].is_number_integer()
? std::optional<int>{this->body_json_["random_seed"].get<int>()}
: std::nullopt;
this->model_version_ = this->body_json_["model_version"].is_null()
? std::nullopt
: std::optional<std::string>{
this->body_json_["model_version"]};
}

void DriveRequest::refresh_body_json_() {
Expand Down Expand Up @@ -96,6 +100,11 @@ void DriveRequest::refresh_body_json_() {
} else {
this->body_json_["random_seed"] = nullptr;
}
if (this->model_version_.has_value()) {
this->body_json_["model_version"] = this->model_version_.value();
} else {
this->body_json_["model_version"] = nullptr;
}
}

void DriveRequest::update(const InitializeResponse &init_res) {
Expand Down Expand Up @@ -145,6 +154,11 @@ DriveRequest::rendering_center() const {
return this->rendering_center_;
}

std::optional<std::string>
DriveRequest::model_version() const {
return this->model_version_;
}

std::optional<int> DriveRequest::random_seed() const {
return this->random_seed_;
}
Expand Down Expand Up @@ -194,4 +208,8 @@ void DriveRequest::set_random_seed(std::optional<int> random_seed) {
this->random_seed_ = random_seed;
}

void DriveRequest::set_model_version(std::optional<std::string> model_version) {
this->model_version_ = model_version;
}

} // namespace invertedai
9 changes: 9 additions & 0 deletions invertedai_cpp/invertedai/drive_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class DriveRequest {
std::optional<int> random_seed_;
std::optional<double> rendering_fov_;
std::optional<std::pair<double, double>> rendering_center_;
std::optional<std::string> model_version_;
json body_json_;

void refresh_body_json_();
Expand Down Expand Up @@ -94,6 +95,10 @@ class DriveRequest {
* Get random_seed.
*/
std::optional<int> random_seed() const;
/**
* Get model version.
*/
std::optional<std::string> model_version() const;

// setters
/**
Expand Down Expand Up @@ -149,6 +154,10 @@ class DriveRequest {
* for reproducibility.
*/
void set_random_seed(std::optional<int> random_seed);
/**
* Set model version. If None is passed which is by default, the best model will be used.
*/
void set_model_version(std::optional<std::string> model_version);
};

} // namespace invertedai
Expand Down
9 changes: 9 additions & 0 deletions invertedai_cpp/invertedai/drive_response.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ DriveResponse::DriveResponse(const std::string &body_str) {
element[2]};
this->infraction_indicators_.push_back(infraction_indicator);
}
this->model_version_.clear();
this->model_version_ = body_json_["model_version"];
}

void DriveResponse::refresh_body_json_() {
Expand Down Expand Up @@ -69,6 +71,8 @@ void DriveResponse::refresh_body_json_() {
infraction_indicator.wrong_way};
this->body_json_["infraction_indicators"].push_back(element);
}
this->model_version_.clear();
this->model_version_ = body_json_["model_version"];
}

std::string DriveResponse::body_str() {
Expand Down Expand Up @@ -96,6 +100,11 @@ std::vector<InfractionIndicator> DriveResponse::infraction_indicators() const {
return this->infraction_indicators_;
}

std::string
DriveResponse::model_version() const {
return this->model_version_;
}

void DriveResponse::set_agent_states(
const std::vector<AgentState> &agent_states) {
this->agent_states_ = agent_states;
Expand Down
5 changes: 5 additions & 0 deletions invertedai_cpp/invertedai/drive_response.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class DriveResponse {
std::vector<std::vector<double>> recurrent_states_;
std::vector<unsigned char> birdview_;
std::vector<InfractionIndicator> infraction_indicators_;
std::string model_version_;
json body_json_;

void refresh_body_json_();
Expand Down Expand Up @@ -54,6 +55,10 @@ class DriveResponse {
* If get_infractions was set, they are returned here.
*/
std::vector<InfractionIndicator> infraction_indicators() const;
/**
* Get model version.
*/
std::string model_version() const;

// setters
/**
Expand Down
18 changes: 18 additions & 0 deletions invertedai_cpp/invertedai/initialize_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ InitializeRequest::InitializeRequest(const std::string &body_str) {
this->body_json_["random_seed"].is_number_integer()
? std::optional<int>{this->body_json_["random_seed"].get<int>()}
: std::nullopt;
this->model_version_ = this->body_json_["model_version"].is_null()
? std::nullopt
: std::optional<std::string>{
this->body_json_["model_version"]};
}

void InitializeRequest::refresh_body_json_() {
Expand Down Expand Up @@ -118,6 +122,11 @@ void InitializeRequest::refresh_body_json_() {
} else {
this->body_json_["random_seed"] = nullptr;
}
if (this->model_version_.has_value()) {
this->body_json_["model_version"] = this->model_version_.value();
} else {
this->body_json_["model_version"] = nullptr;
}
};

std::string InitializeRequest::body_str() {
Expand Down Expand Up @@ -172,6 +181,11 @@ std::optional<int> InitializeRequest::random_seed() const {
return this->random_seed_;
}

std::optional<std::string>
InitializeRequest::model_version() const {
return this->model_version_;
}

void InitializeRequest::set_location(const std::string &location) {
this->location_ = location;
}
Expand Down Expand Up @@ -227,4 +241,8 @@ void InitializeRequest::set_random_seed(std::optional<int> random_seed) {
this->random_seed_ = random_seed;
}

void InitializeRequest::set_model_version(std::optional<std::string> model_version) {
this->model_version_ = model_version;
}

} // namespace invertedai
9 changes: 9 additions & 0 deletions invertedai_cpp/invertedai/initialize_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class InitializeRequest {
bool get_infractions_;
std::optional<int> agent_count_;
std::optional<int> random_seed_;
std::optional<std::string> model_version_;
json body_json_;

void refresh_body_json_();
Expand Down Expand Up @@ -90,6 +91,10 @@ class InitializeRequest {
* for reproducibility.
*/
std::optional<int> random_seed() const;
/**
* Get model version.
*/
std::optional<std::string> model_version() const;

// setters
/**
Expand Down Expand Up @@ -158,6 +163,10 @@ class InitializeRequest {
* for reproducibility.
*/
void set_random_seed(std::optional<int> random_seed);
/**
* Set model version. If None is passed which is by default, the best model will be used.
*/
void set_model_version(std::optional<std::string> model_version);
};

} // namespace invertedai
Expand Down
9 changes: 9 additions & 0 deletions invertedai_cpp/invertedai/initialize_response.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ InitializeResponse::InitializeResponse(const std::string &body_str) {
element[2]};
this->infraction_indicators_.push_back(infraction_indicator);
}
this->model_version_.clear();
this->model_version_ = body_json_["model_version"];
}

void InitializeResponse::refresh_body_json_() {
Expand Down Expand Up @@ -72,6 +74,8 @@ void InitializeResponse::refresh_body_json_() {
infraction_indicator.wrong_way};
this->body_json_["infraction_indicators"].push_back(element);
}
this->model_version_.clear();
this->model_version_ = body_json_["model_version"];
}

std::string InitializeResponse::body_str() {
Expand Down Expand Up @@ -100,6 +104,11 @@ InitializeResponse::infraction_indicators() const {
return this->infraction_indicators_;
}

std::string
InitializeResponse::model_version() const {
return this->model_version_;
}

void InitializeResponse::set_agent_states(
const std::vector<AgentState> &agent_states) {
this->agent_states_ = agent_states;
Expand Down
5 changes: 5 additions & 0 deletions invertedai_cpp/invertedai/initialize_response.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class InitializeResponse {
std::vector<std::vector<double>> recurrent_states_;
std::vector<unsigned char> birdview_;
std::vector<InfractionIndicator> infraction_indicators_;
std::string model_version_;
json body_json_;

void refresh_body_json_();
Expand Down Expand Up @@ -53,6 +54,10 @@ class InitializeResponse {
* If get_infractions was set, they are returned here.
*/
std::vector<InfractionIndicator> infraction_indicators() const;
/**
* Get model version.
*/
std::string model_version() const;

// setters
/**
Expand Down

0 comments on commit 6a37db5

Please sign in to comment.