Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions codex-rs/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions codex-rs/codex-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ anyhow = { workspace = true }
assert_matches = { workspace = true }
pretty_assertions = { workspace = true }
tokio-test = { workspace = true }
wiremock = { workspace = true }
reqwest = { workspace = true }

[lints]
workspace = true
1 change: 1 addition & 0 deletions codex-rs/codex-api/src/endpoint/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod chat;
pub mod compact;
pub mod models;
pub mod responses;
mod streaming;
216 changes: 216 additions & 0 deletions codex-rs/codex-api/src/endpoint/models.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
use crate::auth::AuthProvider;
use crate::auth::add_auth_headers;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::telemetry::run_with_request_telemetry;
use codex_client::HttpTransport;
use codex_client::RequestTelemetry;
use codex_protocol::openai_models::ModelsResponse;
use http::HeaderMap;
use http::Method;
use std::sync::Arc;

pub struct ModelsClient<T: HttpTransport, A: AuthProvider> {
transport: T,
provider: Provider,
auth: A,
request_telemetry: Option<Arc<dyn RequestTelemetry>>,
}

impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
Self {
transport,
provider,
auth,
request_telemetry: None,
}
}

pub fn with_telemetry(mut self, request: Option<Arc<dyn RequestTelemetry>>) -> Self {
self.request_telemetry = request;
self
}

fn path(&self) -> &'static str {
"models"
}

pub async fn list_models(
&self,
client_version: &str,
extra_headers: HeaderMap,
) -> Result<ModelsResponse, ApiError> {
let builder = || {
let mut req = self.provider.build_request(Method::GET, self.path());
req.headers.extend(extra_headers.clone());

let separator = if req.url.contains('?') { '&' } else { '?' };
req.url = format!("{}{}client_version={client_version}", req.url, separator);

add_auth_headers(&self.auth, req)
};

let resp = run_with_request_telemetry(
self.provider.retry.to_policy(),
self.request_telemetry.clone(),
builder,
|req| self.transport.execute(req),
)
.await?;

serde_json::from_slice::<ModelsResponse>(&resp.body).map_err(|e| {
ApiError::Stream(format!(
"failed to decode models response: {e}; body: {}",
String::from_utf8_lossy(&resp.body)
))
})
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::provider::RetryConfig;
use crate::provider::WireApi;
use async_trait::async_trait;
use codex_client::Request;
use codex_client::Response;
use codex_client::StreamResponse;
use codex_client::TransportError;
use http::HeaderMap;
use http::StatusCode;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;

#[derive(Clone, Default)]
struct CapturingTransport {
last_request: Arc<Mutex<Option<Request>>>,
body: Arc<ModelsResponse>,
}

#[async_trait]
impl HttpTransport for CapturingTransport {
async fn execute(&self, req: Request) -> Result<Response, TransportError> {
*self.last_request.lock().unwrap() = Some(req);
let body = serde_json::to_vec(&*self.body).unwrap();
Ok(Response {
status: StatusCode::OK,
headers: HeaderMap::new(),
body: body.into(),
})
}

async fn stream(&self, _req: Request) -> Result<StreamResponse, TransportError> {
Err(TransportError::Build("stream should not run".to_string()))
}
}

#[derive(Clone, Default)]
struct DummyAuth;

impl AuthProvider for DummyAuth {
fn bearer_token(&self) -> Option<String> {
None
}
}

fn provider(base_url: &str) -> Provider {
Provider {
name: "test".to_string(),
base_url: base_url.to_string(),
query_params: None,
wire: WireApi::Responses,
headers: HeaderMap::new(),
retry: RetryConfig {
max_attempts: 1,
base_delay: Duration::from_millis(1),
retry_429: false,
retry_5xx: true,
retry_transport: true,
},
stream_idle_timeout: Duration::from_secs(1),
}
}

#[tokio::test]
async fn appends_client_version_query() {
let response = ModelsResponse { models: Vec::new() };

let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
};

let client = ModelsClient::new(
transport.clone(),
provider("https://example.com/api/codex"),
DummyAuth,
);

let result = client
.list_models("0.99.0", HeaderMap::new())
.await
.expect("request should succeed");

assert_eq!(result.models.len(), 0);

let url = transport
.last_request
.lock()
.unwrap()
.as_ref()
.unwrap()
.url
.clone();
assert_eq!(
url,
"https://example.com/api/codex/models?client_version=0.99.0"
);
}

#[tokio::test]
async fn parses_models_response() {
let response = ModelsResponse {
models: vec![
serde_json::from_value(json!({
"slug": "gpt-test",
"display_name": "gpt-test",
"description": "desc",
"default_reasoning_level": "medium",
"supported_reasoning_levels": ["low", "medium", "high"],
"shell_type": "shell_command",
"visibility": "list",
"minimal_client_version": [0, 99, 0],
"supported_in_api": true,
"priority": 1
}))
.unwrap(),
],
};

let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
};

let client = ModelsClient::new(
transport,
provider("https://example.com/api/codex"),
DummyAuth,
);

let result = client
.list_models("0.99.0", HeaderMap::new())
.await
.expect("request should succeed");

assert_eq!(result.models.len(), 1);
assert_eq!(result.models[0].slug, "gpt-test");
assert_eq!(result.models[0].supported_in_api, true);
assert_eq!(result.models[0].priority, 1);
}
}
1 change: 1 addition & 0 deletions codex-rs/codex-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub use crate::common::create_text_param_for_request;
pub use crate::endpoint::chat::AggregateStreamExt;
pub use crate::endpoint::chat::ChatClient;
pub use crate::endpoint::compact::CompactClient;
pub use crate::endpoint::models::ModelsClient;
pub use crate::endpoint::responses::ResponsesClient;
pub use crate::endpoint::responses::ResponsesOptions;
pub use crate::error::ApiError;
Expand Down
100 changes: 100 additions & 0 deletions codex-rs/codex-api/tests/models_integration.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use codex_api::AuthProvider;
use codex_api::ModelsClient;
use codex_api::provider::Provider;
use codex_api::provider::RetryConfig;
use codex_api::provider::WireApi;
use codex_client::ReqwestTransport;
use codex_protocol::openai_models::ClientVersion;
use codex_protocol::openai_models::ModelInfo;
use codex_protocol::openai_models::ModelVisibility;
use codex_protocol::openai_models::ModelsResponse;
use codex_protocol::openai_models::ReasoningLevel;
use codex_protocol::openai_models::ShellType;
use http::HeaderMap;
use http::Method;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
use wiremock::matchers::method;
use wiremock::matchers::path;

#[derive(Clone, Default)]
struct DummyAuth;

impl AuthProvider for DummyAuth {
fn bearer_token(&self) -> Option<String> {
None
}
}

fn provider(base_url: &str) -> Provider {
Provider {
name: "test".to_string(),
base_url: base_url.to_string(),
query_params: None,
wire: WireApi::Responses,
headers: HeaderMap::new(),
retry: RetryConfig {
max_attempts: 1,
base_delay: std::time::Duration::from_millis(1),
retry_429: false,
retry_5xx: true,
retry_transport: true,
},
stream_idle_timeout: std::time::Duration::from_secs(1),
}
}

#[tokio::test]
async fn models_client_hits_models_endpoint() {
let server = MockServer::start().await;
let base_url = format!("{}/api/codex", server.uri());

let response = ModelsResponse {
models: vec![ModelInfo {
slug: "gpt-test".to_string(),
display_name: "gpt-test".to_string(),
description: Some("desc".to_string()),
default_reasoning_level: ReasoningLevel::Medium,
supported_reasoning_levels: vec![
ReasoningLevel::Low,
ReasoningLevel::Medium,
ReasoningLevel::High,
],
shell_type: ShellType::ShellCommand,
visibility: ModelVisibility::List,
minimal_client_version: ClientVersion(0, 1, 0),
supported_in_api: true,
priority: 1,
}],
};

Mock::given(method("GET"))
.and(path("/api/codex/models"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("content-type", "application/json")
.set_body_json(&response),
)
.mount(&server)
.await;

let transport = ReqwestTransport::new(reqwest::Client::new());
let client = ModelsClient::new(transport, provider(&base_url), DummyAuth);

let result = client
.list_models("0.1.0", HeaderMap::new())
.await
.expect("models request should succeed");

assert_eq!(result.models.len(), 1);
assert_eq!(result.models[0].slug, "gpt-test");

let received = server
.received_requests()
.await
.expect("should capture requests");
assert_eq!(received.len(), 1);
assert_eq!(received[0].method, Method::GET.as_str());
assert_eq!(received[0].url.path(), "/api/codex/models");
}
Loading
Loading