Skip to content

Commit

Permalink
Rename internal base model attribute model_version to api_model_versi…
Browse files Browse the repository at this point in the history
…on to avoid violating BaseModel protected namespace
  • Loading branch information
Ruishenl committed Jan 26, 2024
1 parent 454d218 commit 23e7ad6
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 18 deletions.
15 changes: 8 additions & 7 deletions invertedai/api/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class DriveResponse(BaseModel):
] #: For each agent, indicates whether the predicted state is inside supported area.
traffic_lights_states: Optional[TrafficLightStatesDict] #: Traffic light states for the full map, each key-value pair corresponds to one particular traffic light.
light_recurrent_states: Optional[LightRecurrentStates] #: Light recurrent states for the full map, each element corresponds to one light group.
api_model_version: str # Model version used for this API call


@validate_arguments
Expand All @@ -60,7 +61,7 @@ def drive(
rendering_fov: Optional[float] = None,
get_infractions: bool = False,
random_seed: Optional[int] = None,
model_version: Optional[str] = None
api_model_version: Optional[str] = None
) -> DriveResponse:
"""
Parameters
Expand Down Expand Up @@ -114,7 +115,7 @@ def drive(
random_seed:
Controls the stochastic aspects of agent behavior for reproducibility.
model_version:
api_model_version:
Optionally specify the version of the model. If None is passed which is by default, the best model will be used.
See Also
--------
Expand Down Expand Up @@ -160,7 +161,7 @@ def _tolist(input_data: List):
random_seed=random_seed,
rendering_center=rendering_center,
rendering_fov=rendering_fov,
model_version=model_version
model_version=api_model_version
)
start = time.time()
timeout = TIMEOUT
Expand All @@ -186,7 +187,7 @@ def _tolist(input_data: List):
if response["infraction_indicators"]
else [],
is_inside_supported_area=response["is_inside_supported_area"],
model_version=response["model_version"],
api_model_version=response["model_version"],
traffic_lights_states=response["traffic_lights_states"]
if response["traffic_lights_states"] is not None
else None,
Expand Down Expand Up @@ -219,7 +220,7 @@ async def async_drive(
rendering_fov: Optional[float] = None,
get_infractions: bool = False,
random_seed: Optional[int] = None,
model_version: Optional[str] = None
api_model_versioin: Optional[str] = None
) -> DriveResponse:
"""
A light async version of :func:`drive`
Expand All @@ -245,7 +246,7 @@ def _tolist(input_data: List):
random_seed=random_seed,
rendering_center=rendering_center,
rendering_fov=rendering_fov,
model_version=model_version,
model_version=api_model_versioin,
)
response = await iai.session.async_request(model="drive", data=model_inputs)

Expand All @@ -266,7 +267,7 @@ def _tolist(input_data: List):
if response["infraction_indicators"]
else [],
is_inside_supported_area=response["is_inside_supported_area"],
model_version=response["model_version"],
api_model_version=response["model_version"],
traffic_lights_states=response["traffic_lights_states"]
if response["traffic_lights_states"] is not None
else None,
Expand Down
16 changes: 8 additions & 8 deletions invertedai/api/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class InitializeResponse(BaseModel):
] #: If `get_infractions` was set, they are returned here.
traffic_lights_states: Optional[TrafficLightStatesDict] #: Traffic light states for the full map, each key-value pair corresponds to one particular traffic light.
light_recurrent_states: Optional[LightRecurrentStates] #: Light recurrent states for the full map, each element corresponds to one light group.
model_version: str # Model version used for this API call
api_model_version: str # Model version used for this API call


@validate_arguments
Expand All @@ -61,7 +61,7 @@ def initialize(
get_infractions: bool = False,
agent_count: Optional[int] = None,
random_seed: Optional[int] = None,
model_version: Optional[str] = None # Model version used for this API call
api_model_version: Optional[str] = None # Model version used for this API call
) -> InitializeResponse:
"""
Initializes a simulation in a given location, using a combination of **user-defined** and **sampled** agents.
Expand Down Expand Up @@ -116,7 +116,7 @@ def initialize(
random_seed:
Controls the stochastic aspects of initialization for reproducibility.
model_version:
api_model_version:
Optionally specify the version of the model. If None is passed which is by default, the best model will be used.
See Also
Expand Down Expand Up @@ -160,7 +160,7 @@ def initialize(
location_of_interest=location_of_interest,
get_infractions=get_infractions,
random_seed=random_seed,
model_version=model_version
model_version=api_model_version
)
start = time.time()
timeout = TIMEOUT
Expand All @@ -186,7 +186,7 @@ def initialize(
]
if response["infraction_indicators"]
else [],
model_version=response["model_version"],
api_model_version=response["model_version"],
traffic_lights_states=response["traffic_lights_states"]
if response["traffic_lights_states"] is not None
else None,
Expand Down Expand Up @@ -217,7 +217,7 @@ async def async_initialize(
get_infractions: bool = False,
agent_count: Optional[int] = None,
random_seed: Optional[int] = None,
model_version: Optional[str] = None
api_model_version: Optional[str] = None
) -> InitializeResponse:
"""
The async version of :func:`initialize`
Expand All @@ -237,7 +237,7 @@ async def async_initialize(
location_of_interest=location_of_interest,
get_infractions=get_infractions,
random_seed=random_seed,
model_version=model_version
model_version=api_model_version
)

response = await iai.session.async_request(model="initialize", data=model_inputs)
Expand Down Expand Up @@ -265,7 +265,7 @@ async def async_initialize(
]
if response["infraction_indicators"]
else [],
model_version=response["model_version"],
api_model_version=response["model_version"],
traffic_lights_states=response["traffic_lights_states"]
if response["traffic_lights_states"] is not None
else None,
Expand Down
4 changes: 2 additions & 2 deletions invertedai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ async def reinit(reinitialize_agent_state, reinitialize_agent_attrs, area_center
file_path = f"{birdview_path}-{(area_center.x, area_center.y)}.jpg"
response.birdview.decode_and_save(file_path)

return valid_agent_state, valid_agent_attrs, valid_agent_rs, response.model_version
return valid_agent_state, valid_agent_attrs, valid_agent_rs, response.api_model_version

stride = initialize_fov / 2

Expand Down Expand Up @@ -404,7 +404,7 @@ async def reinit(reinitialize_agent_state, reinitialize_agent_attrs, area_center
recurrent_states=new_recurrent_states,
agent_states=new_agent_state,
agent_attributes=new_attributes,
model_version=model_version
api_model_version=model_version
)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "invertedai"
version = "0.0.14dev0"
version = "0.0.14dev1"
description = "Client SDK for InvertedAI"
authors = ["Inverted AI <info@inverted.ai>"]
readme = "README.md"
Expand Down

0 comments on commit 23e7ad6

Please sign in to comment.