Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions codex-rs/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions codex-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ members = [
"keyring-store",
"file-search",
"linux-sandbox",
"lmstudio",
"login",
"mcp-server",
"mcp-types",
Expand Down Expand Up @@ -69,6 +70,7 @@ codex-file-search = { path = "file-search" }
codex-git = { path = "utils/git" }
codex-keyring-store = { path = "keyring-store" }
codex-linux-sandbox = { path = "linux-sandbox" }
codex-lmstudio = { path = "lmstudio" }
codex-login = { path = "login" }
codex-mcp-server = { path = "mcp-server" }
codex-ollama = { path = "ollama" }
Expand Down
2 changes: 2 additions & 0 deletions codex-rs/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ workspace = true
clap = { workspace = true, features = ["derive", "wrap_help"], optional = true }
codex-app-server-protocol = { workspace = true }
codex-core = { workspace = true }
codex-lmstudio = { workspace = true }
codex-ollama = { workspace = true }
codex-protocol = { workspace = true }
once_cell = { workspace = true }
serde = { workspace = true, optional = true }
Expand Down
2 changes: 2 additions & 0 deletions codex-rs/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,5 @@ pub mod model_presets;
// Shared approval presets (AskForApproval + Sandbox) used by TUI and MCP server
// Not to be confused with AskForApproval, which we should probably rename to EscalationPolicy.
pub mod approval_presets;
// Shared OSS provider utilities used by TUI and exec
pub mod oss;
60 changes: 60 additions & 0 deletions codex-rs/common/src/oss.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//! OSS provider utilities shared between TUI and exec.

use codex_core::LMSTUDIO_OSS_PROVIDER_ID;
use codex_core::OLLAMA_OSS_PROVIDER_ID;
use codex_core::config::Config;

/// Returns the default model for a given OSS provider.
pub fn get_default_model_for_oss_provider(provider_id: &str) -> Option<&'static str> {
match provider_id {
LMSTUDIO_OSS_PROVIDER_ID => Some(codex_lmstudio::DEFAULT_OSS_MODEL),
OLLAMA_OSS_PROVIDER_ID => Some(codex_ollama::DEFAULT_OSS_MODEL),
_ => None,
}
}

/// Ensures the specified OSS provider is ready (models downloaded, service reachable).
pub async fn ensure_oss_provider_ready(
provider_id: &str,
config: &Config,
) -> Result<(), std::io::Error> {
match provider_id {
LMSTUDIO_OSS_PROVIDER_ID => {
codex_lmstudio::ensure_oss_ready(config)
.await
.map_err(|e| std::io::Error::other(format!("OSS setup failed: {e}")))?;
}
OLLAMA_OSS_PROVIDER_ID => {
codex_ollama::ensure_oss_ready(config)
.await
.map_err(|e| std::io::Error::other(format!("OSS setup failed: {e}")))?;
}
_ => {
// Unknown provider, skip setup
}
}
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_get_default_model_for_provider_lmstudio() {
let result = get_default_model_for_oss_provider(LMSTUDIO_OSS_PROVIDER_ID);
assert_eq!(result, Some(codex_lmstudio::DEFAULT_OSS_MODEL));
}

#[test]
fn test_get_default_model_for_provider_ollama() {
let result = get_default_model_for_oss_provider(OLLAMA_OSS_PROVIDER_ID);
assert_eq!(result, Some(codex_ollama::DEFAULT_OSS_MODEL));
}

#[test]
fn test_get_default_model_for_provider_unknown() {
let result = get_default_model_for_oss_provider("unknown-provider");
assert_eq!(result, None);
}
}
188 changes: 188 additions & 0 deletions codex-rs/core/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ use crate::git_info::resolve_root_git_project_for_trust;
use crate::model_family::ModelFamily;
use crate::model_family::derive_default_model_family;
use crate::model_family::find_family_for_model;
use crate::model_provider_info::LMSTUDIO_OSS_PROVIDER_ID;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::OLLAMA_OSS_PROVIDER_ID;
use crate::model_provider_info::built_in_model_providers;
use crate::openai_model_info::get_model_info;
use crate::project_doc::DEFAULT_PROJECT_DOC_FILENAME;
Expand Down Expand Up @@ -466,6 +468,48 @@ pub fn set_project_trust_level(
.apply_blocking()
}

/// Save the default OSS provider preference to config.toml
pub fn set_default_oss_provider(codex_home: &Path, provider: &str) -> std::io::Result<()> {
// Validate that the provider is one of the known OSS providers
match provider {
LMSTUDIO_OSS_PROVIDER_ID | OLLAMA_OSS_PROVIDER_ID => {
// Valid provider, continue
}
_ => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!(
"Invalid OSS provider '{provider}'. Must be one of: {LMSTUDIO_OSS_PROVIDER_ID}, {OLLAMA_OSS_PROVIDER_ID}"
),
));
}
}
let config_path = codex_home.join(CONFIG_TOML_FILE);

// Read existing config or create empty string if file doesn't exist
let content = match std::fs::read_to_string(&config_path) {
Ok(content) => content,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => String::new(),
Err(e) => return Err(e),
};

// Parse as DocumentMut for editing while preserving structure
let mut doc = content.parse::<DocumentMut>().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("failed to parse config.toml: {e}"),
)
})?;

// Set the default_oss_provider at root level
use toml_edit::value;
doc["oss_provider"] = value(provider);

// Write the modified document back
std::fs::write(&config_path, doc.to_string())?;
Ok(())
}

/// Apply a single dotted-path override onto a TOML value.
fn apply_toml_override(root: &mut TomlValue, path: &str, value: TomlValue) {
use toml::value::Table;
Expand Down Expand Up @@ -663,6 +707,8 @@ pub struct ConfigToml {
pub experimental_use_rmcp_client: Option<bool>,
pub experimental_use_freeform_apply_patch: Option<bool>,
pub experimental_sandbox_command_assessment: Option<bool>,
/// Preferred OSS provider for local models, e.g. "lmstudio" or "ollama".
pub oss_provider: Option<String>,
}

impl From<ConfigToml> for UserSavedConfig {
Expand Down Expand Up @@ -851,6 +897,34 @@ pub struct ConfigOverrides {
pub additional_writable_roots: Vec<PathBuf>,
}

/// Resolves the OSS provider from CLI override, profile config, or global config.
/// Returns `None` if no provider is configured at any level.
pub fn resolve_oss_provider(
explicit_provider: Option<&str>,
config_toml: &ConfigToml,
config_profile: Option<String>,
) -> Option<String> {
if let Some(provider) = explicit_provider {
// Explicit provider specified (e.g., via --local-provider)
Some(provider.to_string())
} else {
// Check profile config first, then global config
let profile = config_toml.get_config_profile(config_profile).ok();
if let Some(profile) = &profile {
// Check if profile has an oss provider
if let Some(profile_oss_provider) = &profile.oss_provider {
Some(profile_oss_provider.clone())
}
// If not then check if the toml has an oss provider
else {
config_toml.oss_provider.clone()
}
} else {
config_toml.oss_provider.clone()
}
}
}

impl Config {
/// Meant to be used exclusively for tests: `load_with_overrides()` should
/// be used in all other cases.
Expand Down Expand Up @@ -3265,6 +3339,41 @@ trust_level = "trusted"
Ok(())
}

#[test]
fn test_set_default_oss_provider() -> std::io::Result<()> {
let temp_dir = TempDir::new()?;
let codex_home = temp_dir.path();
let config_path = codex_home.join(CONFIG_TOML_FILE);

// Test setting valid provider on empty config
set_default_oss_provider(codex_home, OLLAMA_OSS_PROVIDER_ID)?;
let content = std::fs::read_to_string(&config_path)?;
assert!(content.contains("oss_provider = \"ollama\""));

// Test updating existing config
std::fs::write(&config_path, "model = \"gpt-4\"\n")?;
set_default_oss_provider(codex_home, LMSTUDIO_OSS_PROVIDER_ID)?;
let content = std::fs::read_to_string(&config_path)?;
assert!(content.contains("oss_provider = \"lmstudio\""));
assert!(content.contains("model = \"gpt-4\""));

// Test overwriting existing oss_provider
set_default_oss_provider(codex_home, OLLAMA_OSS_PROVIDER_ID)?;
let content = std::fs::read_to_string(&config_path)?;
assert!(content.contains("oss_provider = \"ollama\""));
assert!(!content.contains("oss_provider = \"lmstudio\""));

// Test invalid provider
let result = set_default_oss_provider(codex_home, "invalid_provider");
assert!(result.is_err());
let error = result.unwrap_err();
assert_eq!(error.kind(), std::io::ErrorKind::InvalidInput);
assert!(error.to_string().contains("Invalid OSS provider"));
assert!(error.to_string().contains("invalid_provider"));

Ok(())
}

#[test]
fn test_untrusted_project_gets_workspace_write_sandbox() -> anyhow::Result<()> {
let config_with_untrusted = r#"
Expand Down Expand Up @@ -3295,6 +3404,85 @@ trust_level = "untrusted"
Ok(())
}

#[test]
fn test_resolve_oss_provider_explicit_override() {
let config_toml = ConfigToml::default();
let result = resolve_oss_provider(Some("custom-provider"), &config_toml, None);
assert_eq!(result, Some("custom-provider".to_string()));
}

#[test]
fn test_resolve_oss_provider_from_profile() {
let mut profiles = std::collections::HashMap::new();
let profile = ConfigProfile {
oss_provider: Some("profile-provider".to_string()),
..Default::default()
};
profiles.insert("test-profile".to_string(), profile);
let config_toml = ConfigToml {
profiles,
..Default::default()
};

let result = resolve_oss_provider(None, &config_toml, Some("test-profile".to_string()));
assert_eq!(result, Some("profile-provider".to_string()));
}

#[test]
fn test_resolve_oss_provider_from_global_config() {
let config_toml = ConfigToml {
oss_provider: Some("global-provider".to_string()),
..Default::default()
};

let result = resolve_oss_provider(None, &config_toml, None);
assert_eq!(result, Some("global-provider".to_string()));
}

#[test]
fn test_resolve_oss_provider_profile_fallback_to_global() {
let mut profiles = std::collections::HashMap::new();
let profile = ConfigProfile::default(); // No oss_provider set
profiles.insert("test-profile".to_string(), profile);
let config_toml = ConfigToml {
oss_provider: Some("global-provider".to_string()),
profiles,
..Default::default()
};

let result = resolve_oss_provider(None, &config_toml, Some("test-profile".to_string()));
assert_eq!(result, Some("global-provider".to_string()));
}

#[test]
fn test_resolve_oss_provider_none_when_not_configured() {
let config_toml = ConfigToml::default();
let result = resolve_oss_provider(None, &config_toml, None);
assert_eq!(result, None);
}

#[test]
fn test_resolve_oss_provider_explicit_overrides_all() {
let mut profiles = std::collections::HashMap::new();
let profile = ConfigProfile {
oss_provider: Some("profile-provider".to_string()),
..Default::default()
};
profiles.insert("test-profile".to_string(), profile);
let config_toml = ConfigToml {
oss_provider: Some("global-provider".to_string()),
profiles,
..Default::default()
};

let result = resolve_oss_provider(
Some("explicit-provider"),
&config_toml,
Some("test-profile".to_string()),
);
assert_eq!(result, Some("explicit-provider".to_string()));
}

#[test]
fn test_untrusted_project_gets_unless_trusted_approval_policy() -> std::io::Result<()> {
let codex_home = TempDir::new()?;
Expand Down
1 change: 1 addition & 0 deletions codex-rs/core/src/config/profile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub struct ConfigProfile {
/// Optional feature toggles scoped to this profile.
#[serde(default)]
pub features: Option<crate::features::FeaturesToml>,
pub oss_provider: Option<String>,
}

impl From<ConfigProfile> for codex_app_server_protocol::Profile {
Expand Down
5 changes: 4 additions & 1 deletion codex-rs/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ pub mod token_data;
mod truncate;
mod unified_exec;
mod user_instructions;
pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID;
pub use model_provider_info::DEFAULT_LMSTUDIO_PORT;
pub use model_provider_info::DEFAULT_OLLAMA_PORT;
pub use model_provider_info::LMSTUDIO_OSS_PROVIDER_ID;
pub use model_provider_info::ModelProviderInfo;
pub use model_provider_info::OLLAMA_OSS_PROVIDER_ID;
pub use model_provider_info::WireApi;
pub use model_provider_info::built_in_model_providers;
pub use model_provider_info::create_oss_provider_with_base_url;
Expand Down
Loading
Loading