diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index 981b52a598c6..bf6b4bdf9334 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -235,6 +235,7 @@ use codex_core::clear_memory_roots_contents; use codex_core::config::Config; use codex_core::config::ConfigOverrides; use codex_core::config::NetworkProxyAuditMetadata; +use codex_core::config::ThreadStoreConfig; use codex_core::config::edit::ConfigEdit; use codex_core::config::edit::ConfigEditsBuilder; use codex_core::config_loader::CloudRequirementsLoadError; @@ -353,6 +354,8 @@ use codex_state::ThreadMetadata; use codex_state::ThreadMetadataBuilder; use codex_state::log_db::LogDbLayer; use codex_thread_store::ArchiveThreadParams as StoreArchiveThreadParams; +#[cfg(debug_assertions)] +use codex_thread_store::InMemoryThreadStore; use codex_thread_store::ListThreadsParams as StoreListThreadsParams; use codex_thread_store::LocalThreadStore; use codex_thread_store::ReadThreadByRolloutPathParams as StoreReadThreadByRolloutPathParams; @@ -661,9 +664,11 @@ pub(crate) struct CodexMessageProcessorArgs { } fn configured_thread_store(config: &Config) -> Arc { - match config.experimental_thread_store_endpoint.as_deref() { - Some(endpoint) => Arc::new(RemoteThreadStore::new(endpoint)), - None => Arc::new(configured_local_thread_store(config)), + match &config.experimental_thread_store { + ThreadStoreConfig::Local => Arc::new(configured_local_thread_store(config)), + ThreadStoreConfig::Remote { endpoint } => Arc::new(RemoteThreadStore::new(endpoint)), + #[cfg(debug_assertions)] + ThreadStoreConfig::InMemory { id } => InMemoryThreadStore::for_id(id), } } diff --git a/codex-rs/app-server/tests/suite/v2/mod.rs b/codex-rs/app-server/tests/suite/v2/mod.rs index 4a3f23183690..776424cc99f9 100644 --- a/codex-rs/app-server/tests/suite/v2/mod.rs +++ b/codex-rs/app-server/tests/suite/v2/mod.rs @@ -34,6 +34,8 @@ mod plugin_read; mod plugin_uninstall; mod rate_limits; mod realtime_conversation; +#[cfg(debug_assertions)] +mod remote_thread_store; mod request_permissions; mod request_user_input; mod review; diff --git a/codex-rs/app-server/tests/suite/v2/remote_thread_store.rs b/codex-rs/app-server/tests/suite/v2/remote_thread_store.rs new file mode 100644 index 000000000000..ebee1fd7c157 --- /dev/null +++ b/codex-rs/app-server/tests/suite/v2/remote_thread_store.rs @@ -0,0 +1,254 @@ +//! Regression coverage for app-server thread operations backed by a non-local +//! `ThreadStore`. +//! +//! The app-server startup path should honor `experimental_thread_store` +//! by routing all thread persistence through the configured store. This suite uses +//! the thread-store crate's test-only in-memory store, which exercises the same +//! config-driven selection path as a remote store without requiring the real gRPC +//! service. +//! +//! The important failure mode is accidentally materializing local persistence +//! while a non-local store is configured. After `thread/start` and a simple turn, +//! the temporary `codex_home` must not contain rollout session files or sqlite +//! state files. This does not observe read-only probes that leave no artifact; it +//! is a stop-gap that prevents additional local persistence writes from slipping +//! in unnoticed. + +use std::collections::BTreeSet; +use std::path::Path; +use std::sync::Arc; + +use anyhow::Result; +use app_test_support::create_mock_responses_server_repeating_assistant; +use codex_app_server::in_process; +use codex_app_server::in_process::InProcessServerEvent; +use codex_app_server::in_process::InProcessStartArgs; +use codex_app_server_protocol::ClientInfo; +use codex_app_server_protocol::ClientRequest; +use codex_app_server_protocol::InitializeParams; +use codex_app_server_protocol::RequestId; +use codex_app_server_protocol::ServerNotification; +use codex_app_server_protocol::ThreadStartParams; +use codex_app_server_protocol::ThreadStartResponse; +use codex_app_server_protocol::TurnStartParams; +use codex_app_server_protocol::UserInput as V2UserInput; +use codex_arg0::Arg0DispatchPaths; +use codex_config::NoopThreadConfigLoader; +use codex_core::config::ConfigBuilder; +use codex_core::config_loader::CloudRequirementsLoader; +use codex_core::config_loader::LoaderOverrides; +use codex_exec_server::EnvironmentManager; +use codex_feedback::CodexFeedback; +use codex_protocol::protocol::SessionSource; +use codex_thread_store::InMemoryThreadStore; +use pretty_assertions::assert_eq; +use tempfile::TempDir; +use tokio::time::timeout; +use uuid::Uuid; + +const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); + +#[tokio::test] +async fn thread_start_with_non_local_thread_store_does_not_create_local_persistence() -> Result<()> +{ + let server = create_mock_responses_server_repeating_assistant("Done").await; + let codex_home = TempDir::new()?; + let store_id = Uuid::new_v4().to_string(); + create_config_toml_with_thread_store(codex_home.path(), &server.uri(), &store_id)?; + + let loader_overrides = LoaderOverrides::without_managed_config_for_tests(); + let config = ConfigBuilder::default() + .codex_home(codex_home.path().to_path_buf()) + .fallback_cwd(Some(codex_home.path().to_path_buf())) + .loader_overrides(loader_overrides.clone()) + .build() + .await?; + + let thread_store = InMemoryThreadStore::for_id(store_id.clone()); + let _in_memory_store = InMemoryThreadStoreId { store_id }; + + let mut client = in_process::start(InProcessStartArgs { + arg0_paths: Arg0DispatchPaths::default(), + config: Arc::new(config), + cli_overrides: Vec::new(), + loader_overrides, + cloud_requirements: CloudRequirementsLoader::default(), + thread_config_loader: Arc::new(NoopThreadConfigLoader), + feedback: CodexFeedback::new(), + log_db: None, + environment_manager: Arc::new(EnvironmentManager::default_for_tests()), + config_warnings: Vec::new(), + session_source: SessionSource::Cli, + enable_codex_api_key_env: false, + initialize: InitializeParams { + client_info: ClientInfo { + name: "codex-app-server-tests".to_string(), + title: None, + version: "0.1.0".to_string(), + }, + capabilities: None, + }, + channel_capacity: in_process::DEFAULT_IN_PROCESS_CHANNEL_CAPACITY, + }) + .await?; + + let response = client + .request(ClientRequest::ThreadStart { + request_id: RequestId::Integer(1), + params: ThreadStartParams::default(), + }) + .await? + .expect("thread/start should succeed"); + let ThreadStartResponse { thread, .. } = + serde_json::from_value(response).expect("thread/start response should parse"); + assert_eq!(thread.path, None); + + client + .request(ClientRequest::TurnStart { + request_id: RequestId::Integer(2), + params: TurnStartParams { + thread_id: thread.id.clone(), + input: vec![V2UserInput::Text { + text: "Hello".to_string(), + text_elements: Vec::new(), + }], + ..Default::default() + }, + }) + .await? + .expect("turn/start should succeed"); + + timeout(DEFAULT_READ_TIMEOUT, async { + loop { + let Some(event) = client.next_event().await else { + anyhow::bail!("in-process app-server stopped before turn/completed"); + }; + if let InProcessServerEvent::ServerNotification(ServerNotification::TurnCompleted( + completed, + )) = event + && completed.thread_id == thread.id + { + return Ok::<(), anyhow::Error>(()); + } + } + }) + .await??; + + client.shutdown().await?; + + let calls = thread_store.calls().await; + assert_eq!(calls.create_thread, 1); + assert!( + calls.append_items > 0, + "turn/start should append rollout items through the injected store" + ); + assert!( + calls.flush_thread > 0, + "turn completion should flush through the injected store" + ); + + assert_no_local_persistence_artifacts(codex_home.path())?; + + Ok(()) +} + +fn assert_no_local_persistence_artifacts(codex_home: &Path) -> Result<()> { + // These are the observable tripwires for accidental local persistence. If a + // future code path constructs a local rollout/session store or opens the + // local thread sqlite database, it should leave one of these artifacts in + // the isolated test codex_home. + assert!( + !codex_home.join("sessions").exists(), + "non-local thread persistence should not create local rollout sessions" + ); + assert!( + !codex_home.join("archived_sessions").exists(), + "non-local thread persistence should not create archived rollout sessions" + ); + assert!( + !codex_state::state_db_path(codex_home).exists(), + "non-local thread persistence should not create local thread sqlite" + ); + + let sqlite_artifacts = std::fs::read_dir(codex_home)? + .filter_map(std::result::Result::ok) + .map(|entry| entry.path()) + .filter(|path| { + path.file_name() + .and_then(|name| name.to_str()) + .is_some_and(|name| { + name.ends_with(".sqlite") + || name.ends_with(".sqlite-shm") + || name.ends_with(".sqlite-wal") + }) + }) + .collect::>(); + + assert!( + sqlite_artifacts.is_empty(), + "non-local thread persistence should not create sqlite artifacts: {sqlite_artifacts:?}" + ); + let mut entries = codex_home_entries(codex_home)?; + // Bazel test runs may initialize shell snapshot storage under codex_home. + // That is not thread persistence; keep the assertion focused on rollout, + // session, sqlite, and other unexpected thread-store artifacts. + entries.remove("shell_snapshots"); + assert_eq!( + entries, + BTreeSet::from([ + "config.toml".to_string(), + "installation_id".to_string(), + "memories".to_string(), + "skills".to_string(), + ]), + "non-local thread persistence should not create unexpected files in codex_home" + ); + + Ok(()) +} + +fn codex_home_entries(codex_home: &Path) -> Result> { + Ok(std::fs::read_dir(codex_home)? + .filter_map(|entry| { + let entry = entry.ok()?; + Some(entry.file_name().to_string_lossy().into_owned()) + }) + .collect()) +} + +struct InMemoryThreadStoreId { + store_id: String, +} + +impl Drop for InMemoryThreadStoreId { + fn drop(&mut self) { + InMemoryThreadStore::remove_id(&self.store_id); + } +} + +fn create_config_toml_with_thread_store( + codex_home: &Path, + server_uri: &str, + store_id: &str, +) -> std::io::Result<()> { + std::fs::write( + codex_home.join("config.toml"), + format!( + r#" +model = "mock-model" +approval_policy = "never" +sandbox_mode = "read-only" +experimental_thread_store = {{ type = "in_memory", id = "{store_id}" }} + +model_provider = "mock_provider" + +[model_providers.mock_provider] +name = "Mock provider for test" +base_url = "{server_uri}/v1" +wire_api = "responses" +request_max_retries = 0 +stream_max_retries = 0 +"# + ), + ) +} diff --git a/codex-rs/config/src/config_toml.rs b/codex-rs/config/src/config_toml.rs index 9ee784117358..f0de00192f55 100644 --- a/codex-rs/config/src/config_toml.rs +++ b/codex-rs/config/src/config_toml.rs @@ -317,6 +317,9 @@ pub struct ConfigToml { /// Experimental / do not use. When set, app-server fetches thread-scoped /// config from a remote service at this endpoint. pub experimental_thread_config_endpoint: Option, + + /// Experimental / do not use. Selects the thread store implementation. + pub experimental_thread_store: Option, pub projects: Option>, /// Controls the web search tool mode: disabled, cached, or live. @@ -413,6 +416,20 @@ pub struct ConfigToml { pub oss_provider: Option, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, JsonSchema)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ThreadStoreToml { + Local {}, + Remote { + endpoint: String, + }, + #[cfg(debug_assertions)] + #[schemars(skip)] + InMemory { + id: String, + }, +} + #[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq, Eq, JsonSchema)] pub struct AutoReviewToml { /// Additional policy instructions inserted into the guardian prompt. diff --git a/codex-rs/core/config.schema.json b/codex-rs/core/config.schema.json index 6cac42908372..030c36a8b682 100644 --- a/codex-rs/core/config.schema.json +++ b/codex-rs/core/config.schema.json @@ -2093,6 +2093,42 @@ }, "type": "object" }, + "ThreadStoreToml": { + "oneOf": [ + { + "properties": { + "type": { + "enum": [ + "local" + ], + "type": "string" + } + }, + "required": [ + "type" + ], + "type": "object" + }, + { + "properties": { + "endpoint": { + "type": "string" + }, + "type": { + "enum": [ + "remote" + ], + "type": "string" + } + }, + "required": [ + "endpoint", + "type" + ], + "type": "object" + } + ] + }, "ToolSuggestConfig": { "additionalProperties": false, "properties": { @@ -2489,6 +2525,14 @@ "description": "Experimental / do not use. When set, app-server fetches thread-scoped config from a remote service at this endpoint.", "type": "string" }, + "experimental_thread_store": { + "allOf": [ + { + "$ref": "#/definitions/ThreadStoreToml" + } + ], + "description": "Experimental / do not use. Selects the thread store implementation." + }, "experimental_thread_store_endpoint": { "description": "Experimental / do not use. When set, app-server uses a remote thread store at this endpoint instead of the local filesystem/SQLite store.", "type": "string" diff --git a/codex-rs/core/src/config/config_tests.rs b/codex-rs/core/src/config/config_tests.rs index 7af47fe5e4fe..2686173208e5 100644 --- a/codex-rs/core/src/config/config_tests.rs +++ b/codex-rs/core/src/config/config_tests.rs @@ -1,5 +1,6 @@ use crate::agents_md::DEFAULT_AGENTS_MD_FILENAME; use crate::agents_md::LOCAL_AGENTS_MD_FILENAME; +use crate::config::ThreadStoreConfig; use crate::config::edit::ConfigEdit; use crate::config::edit::ConfigEditsBuilder; use crate::config::edit::apply_blocking; @@ -5294,8 +5295,8 @@ async fn test_precedence_fixture_with_o3_profile() -> std::io::Result<()> { realtime: RealtimeConfig::default(), experimental_realtime_ws_backend_prompt: None, experimental_realtime_ws_startup_context: None, - experimental_thread_store_endpoint: None, experimental_thread_config_endpoint: None, + experimental_thread_store: ThreadStoreConfig::Local, base_instructions: None, developer_instructions: None, guardian_policy_config: None, @@ -5492,8 +5493,8 @@ async fn test_precedence_fixture_with_gpt3_profile() -> std::io::Result<()> { realtime: RealtimeConfig::default(), experimental_realtime_ws_backend_prompt: None, experimental_realtime_ws_startup_context: None, - experimental_thread_store_endpoint: None, experimental_thread_config_endpoint: None, + experimental_thread_store: ThreadStoreConfig::Local, base_instructions: None, developer_instructions: None, guardian_policy_config: None, @@ -5644,8 +5645,8 @@ async fn test_precedence_fixture_with_zdr_profile() -> std::io::Result<()> { realtime: RealtimeConfig::default(), experimental_realtime_ws_backend_prompt: None, experimental_realtime_ws_startup_context: None, - experimental_thread_store_endpoint: None, experimental_thread_config_endpoint: None, + experimental_thread_store: ThreadStoreConfig::Local, base_instructions: None, developer_instructions: None, guardian_policy_config: None, @@ -5781,8 +5782,8 @@ async fn test_precedence_fixture_with_gpt5_profile() -> std::io::Result<()> { realtime: RealtimeConfig::default(), experimental_realtime_ws_backend_prompt: None, experimental_realtime_ws_startup_context: None, - experimental_thread_store_endpoint: None, experimental_thread_config_endpoint: None, + experimental_thread_store: ThreadStoreConfig::Local, base_instructions: None, developer_instructions: None, guardian_policy_config: None, diff --git a/codex-rs/core/src/config/mod.rs b/codex-rs/core/src/config/mod.rs index 9bdbeb9d1c7c..33fe18d1f487 100644 --- a/codex-rs/core/src/config/mod.rs +++ b/codex-rs/core/src/config/mod.rs @@ -27,6 +27,7 @@ use codex_config::config_toml::ConfigToml; use codex_config::config_toml::ProjectConfig; use codex_config::config_toml::RealtimeAudioConfig; use codex_config::config_toml::RealtimeConfig; +use codex_config::config_toml::ThreadStoreToml; use codex_config::config_toml::validate_model_providers; use codex_config::profile_toml::ConfigProfile; use codex_config::types::ApprovalsReviewer; @@ -230,6 +231,19 @@ impl Permissions { } } +/// Configured thread persistence backend. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub enum ThreadStoreConfig { + /// Persist threads locally using rollout JSONL files and sqlite metadata. + #[default] + Local, + /// Persist threads through the remote thread-store service. + Remote { endpoint: String }, + /// Test-only in-memory thread store. + #[cfg(debug_assertions)] + InMemory { id: String }, +} + /// Application configuration loaded from disk and merged with overrides. #[derive(Debug, Clone, PartialEq)] pub struct Config { @@ -545,13 +559,12 @@ pub struct Config { /// active. pub experimental_realtime_start_instructions: Option, - /// Experimental / do not use. When set, app-server uses a remote thread - /// store at this endpoint instead of the local filesystem/SQLite store. - pub experimental_thread_store_endpoint: Option, - /// Experimental / do not use. When set, app-server fetches thread-scoped /// config from a remote service at this endpoint. pub experimental_thread_config_endpoint: Option, + + /// Experimental / do not use. Selects the thread persistence backend. + pub experimental_thread_store: ThreadStoreConfig, /// When set, restricts ChatGPT login to a specific workspace identifier. pub forced_chatgpt_workspace_id: Option, @@ -1297,6 +1310,21 @@ fn resolve_tool_suggest_config(config_toml: &ConfigToml) -> ToolSuggestConfig { ToolSuggestConfig { discoverables } } +fn thread_store_config( + thread_store: Option, + legacy_remote_endpoint: Option, +) -> ThreadStoreConfig { + match thread_store { + Some(ThreadStoreToml::Local {}) => ThreadStoreConfig::Local, + Some(ThreadStoreToml::Remote { endpoint }) => ThreadStoreConfig::Remote { endpoint }, + #[cfg(debug_assertions)] + Some(ThreadStoreToml::InMemory { id }) => ThreadStoreConfig::InMemory { id }, + None => legacy_remote_endpoint.map_or(ThreadStoreConfig::Local, |endpoint| { + ThreadStoreConfig::Remote { endpoint } + }), + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum PermissionConfigSyntax { Legacy, @@ -2434,8 +2462,11 @@ impl Config { experimental_realtime_ws_backend_prompt: cfg.experimental_realtime_ws_backend_prompt, experimental_realtime_ws_startup_context: cfg.experimental_realtime_ws_startup_context, experimental_realtime_start_instructions: cfg.experimental_realtime_start_instructions, - experimental_thread_store_endpoint: cfg.experimental_thread_store_endpoint, experimental_thread_config_endpoint: cfg.experimental_thread_config_endpoint, + experimental_thread_store: thread_store_config( + cfg.experimental_thread_store, + cfg.experimental_thread_store_endpoint, + ), forced_chatgpt_workspace_id, forced_login_method, include_apply_patch_tool: include_apply_patch_tool_flag, diff --git a/codex-rs/core/src/thread_manager.rs b/codex-rs/core/src/thread_manager.rs index 7ba595598eb5..30d220694dfd 100644 --- a/codex-rs/core/src/thread_manager.rs +++ b/codex-rs/core/src/thread_manager.rs @@ -2,6 +2,7 @@ use crate::SkillsManager; use crate::agent::AgentControl; use crate::codex_thread::CodexThread; use crate::config::Config; +use crate::config::ThreadStoreConfig; use crate::environment_selection::default_thread_environment_selections; use crate::environment_selection::selected_primary_environment; use crate::environment_selection::validate_environment_selections; @@ -52,6 +53,8 @@ use codex_protocol::protocol::TurnEnvironmentSelection; use codex_protocol::protocol::W3cTraceContext; use codex_rollout::RolloutConfig; use codex_state::DirectionalThreadSpawnEdgeStatus; +#[cfg(debug_assertions)] +use codex_thread_store::InMemoryThreadStore; use codex_thread_store::LocalThreadStore; use codex_thread_store::RemoteThreadStore; use codex_thread_store::ThreadStore; @@ -251,10 +254,14 @@ pub fn build_models_manager( } fn configured_thread_store(config: &Config) -> Arc { - if let Some(endpoint) = config.experimental_thread_store_endpoint.as_deref() { - return Arc::new(RemoteThreadStore::new(endpoint)); + match &config.experimental_thread_store { + ThreadStoreConfig::Local => { + Arc::new(LocalThreadStore::new(RolloutConfig::from_view(config))) + } + ThreadStoreConfig::Remote { endpoint } => Arc::new(RemoteThreadStore::new(endpoint)), + #[cfg(debug_assertions)] + ThreadStoreConfig::InMemory { id } => InMemoryThreadStore::for_id(id), } - Arc::new(LocalThreadStore::new(RolloutConfig::from_view(config))) } impl ThreadManager { diff --git a/codex-rs/thread-store/src/in_memory.rs b/codex-rs/thread-store/src/in_memory.rs new file mode 100644 index 000000000000..084975abd27d --- /dev/null +++ b/codex-rs/thread-store/src/in_memory.rs @@ -0,0 +1,285 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; +use std::sync::Mutex; +use std::sync::MutexGuard; +use std::sync::OnceLock; + +use async_trait::async_trait; +use chrono::Utc; +use codex_protocol::ThreadId; +use codex_protocol::protocol::AskForApproval; +use codex_protocol::protocol::RolloutItem; +use codex_protocol::protocol::SandboxPolicy; + +use crate::AppendThreadItemsParams; +use crate::ArchiveThreadParams; +use crate::CreateThreadParams; +use crate::ListThreadsParams; +use crate::LoadThreadHistoryParams; +use crate::ReadThreadByRolloutPathParams; +use crate::ReadThreadParams; +use crate::ResumeThreadParams; +use crate::StoredThread; +use crate::StoredThreadHistory; +use crate::ThreadPage; +use crate::ThreadStore; +use crate::ThreadStoreError; +use crate::ThreadStoreResult; +use crate::UpdateThreadMetadataParams; + +static IN_MEMORY_THREAD_STORES: OnceLock>>> = + OnceLock::new(); + +fn stores() -> &'static Mutex>> { + IN_MEMORY_THREAD_STORES.get_or_init(|| Mutex::new(HashMap::new())) +} + +fn stores_guard() -> MutexGuard<'static, HashMap>> { + match stores().lock() { + Ok(guard) => guard, + Err(poisoned) => poisoned.into_inner(), + } +} + +/// Recorded call counts for [`InMemoryThreadStore`]. +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct InMemoryThreadStoreCalls { + pub create_thread: usize, + pub resume_thread: usize, + pub append_items: usize, + pub persist_thread: usize, + pub flush_thread: usize, + pub shutdown_thread: usize, + pub discard_thread: usize, + pub load_history: usize, + pub read_thread: usize, + pub read_thread_by_rollout_path: usize, + pub list_threads: usize, + pub update_thread_metadata: usize, + pub archive_thread: usize, + pub unarchive_thread: usize, +} + +/// Test-only in-memory [`ThreadStore`] implementation. +/// +/// Debug/test configs can select this store by id, letting tests exercise +/// config-driven non-local persistence without requiring the real remote gRPC +/// service. +#[derive(Default)] +pub struct InMemoryThreadStore { + state: tokio::sync::Mutex, +} + +#[derive(Default)] +struct InMemoryThreadStoreState { + calls: InMemoryThreadStoreCalls, + created_threads: HashMap, + histories: HashMap>, + names: HashMap>, + rollout_paths: HashMap, +} + +impl InMemoryThreadStore { + /// Returns the store associated with `id`, creating it if needed. + pub fn for_id(id: impl Into) -> Arc { + let id = id.into(); + let mut stores = stores_guard(); + stores + .entry(id) + .or_insert_with(|| Arc::new(Self::default())) + .clone() + } + + /// Removes a shared in-memory store for `id`. + pub fn remove_id(id: &str) -> Option> { + stores_guard().remove(id) + } + + /// Returns the calls observed by this store. + pub async fn calls(&self) -> InMemoryThreadStoreCalls { + self.state.lock().await.calls.clone() + } +} + +#[async_trait] +impl ThreadStore for InMemoryThreadStore { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + async fn create_thread(&self, params: CreateThreadParams) -> ThreadStoreResult<()> { + let mut state = self.state.lock().await; + state.calls.create_thread += 1; + state.histories.entry(params.thread_id).or_default(); + state.created_threads.insert(params.thread_id, params); + Ok(()) + } + + async fn resume_thread(&self, params: ResumeThreadParams) -> ThreadStoreResult<()> { + let mut state = self.state.lock().await; + state.calls.resume_thread += 1; + state.histories.entry(params.thread_id).or_default(); + if let Some(rollout_path) = params.rollout_path { + state.rollout_paths.insert(rollout_path, params.thread_id); + } + Ok(()) + } + + async fn append_items(&self, params: AppendThreadItemsParams) -> ThreadStoreResult<()> { + let mut state = self.state.lock().await; + state.calls.append_items += 1; + state + .histories + .entry(params.thread_id) + .or_default() + .extend(params.items); + Ok(()) + } + + async fn persist_thread(&self, _thread_id: ThreadId) -> ThreadStoreResult<()> { + self.state.lock().await.calls.persist_thread += 1; + Ok(()) + } + + async fn flush_thread(&self, _thread_id: ThreadId) -> ThreadStoreResult<()> { + self.state.lock().await.calls.flush_thread += 1; + Ok(()) + } + + async fn shutdown_thread(&self, _thread_id: ThreadId) -> ThreadStoreResult<()> { + self.state.lock().await.calls.shutdown_thread += 1; + Ok(()) + } + + async fn discard_thread(&self, _thread_id: ThreadId) -> ThreadStoreResult<()> { + self.state.lock().await.calls.discard_thread += 1; + Ok(()) + } + + async fn load_history( + &self, + params: LoadThreadHistoryParams, + ) -> ThreadStoreResult { + let mut state = self.state.lock().await; + state.calls.load_history += 1; + let items = state.histories.get(¶ms.thread_id).cloned().ok_or( + ThreadStoreError::ThreadNotFound { + thread_id: params.thread_id, + }, + )?; + Ok(StoredThreadHistory { + thread_id: params.thread_id, + items, + }) + } + + async fn read_thread(&self, params: ReadThreadParams) -> ThreadStoreResult { + let mut state = self.state.lock().await; + state.calls.read_thread += 1; + stored_thread_from_state(&state, params.thread_id, params.include_history) + } + + async fn read_thread_by_rollout_path( + &self, + params: ReadThreadByRolloutPathParams, + ) -> ThreadStoreResult { + let mut state = self.state.lock().await; + state.calls.read_thread_by_rollout_path += 1; + let Some(thread_id) = state.rollout_paths.get(¶ms.rollout_path).copied() else { + return Err(ThreadStoreError::InvalidRequest { + message: format!( + "in-memory thread store does not know rollout path {}", + params.rollout_path.display() + ), + }); + }; + stored_thread_from_state(&state, thread_id, params.include_history) + } + + async fn list_threads(&self, _params: ListThreadsParams) -> ThreadStoreResult { + let mut state = self.state.lock().await; + state.calls.list_threads += 1; + let mut items = state + .created_threads + .keys() + .map(|thread_id| { + stored_thread_from_state(&state, *thread_id, /*include_history*/ false) + }) + .collect::>>()?; + items.sort_by_key(|item| item.thread_id.to_string()); + Ok(ThreadPage { + items, + next_cursor: None, + }) + } + + async fn update_thread_metadata( + &self, + params: UpdateThreadMetadataParams, + ) -> ThreadStoreResult { + let mut state = self.state.lock().await; + state.calls.update_thread_metadata += 1; + if let Some(name) = params.patch.name { + state.names.insert(params.thread_id, Some(name)); + } + stored_thread_from_state(&state, params.thread_id, /*include_history*/ false) + } + + async fn archive_thread(&self, _params: ArchiveThreadParams) -> ThreadStoreResult<()> { + self.state.lock().await.calls.archive_thread += 1; + Ok(()) + } + + async fn unarchive_thread( + &self, + params: ArchiveThreadParams, + ) -> ThreadStoreResult { + let mut state = self.state.lock().await; + state.calls.unarchive_thread += 1; + stored_thread_from_state(&state, params.thread_id, /*include_history*/ false) + } +} + +fn stored_thread_from_state( + state: &InMemoryThreadStoreState, + thread_id: ThreadId, + include_history: bool, +) -> ThreadStoreResult { + let created = state + .created_threads + .get(&thread_id) + .ok_or(ThreadStoreError::ThreadNotFound { thread_id })?; + let history_items = state.histories.get(&thread_id).cloned().unwrap_or_default(); + let history = include_history.then(|| StoredThreadHistory { + thread_id, + items: history_items.clone(), + }); + let name = state.names.get(&thread_id).cloned().flatten(); + + Ok(StoredThread { + thread_id, + rollout_path: None, + forked_from_id: created.forked_from_id, + preview: String::new(), + name, + model_provider: "test".to_string(), + model: None, + reasoning_effort: None, + created_at: Utc::now(), + updated_at: Utc::now(), + archived_at: None, + cwd: PathBuf::new(), + cli_version: "test".to_string(), + source: created.source.clone(), + agent_nickname: None, + agent_role: None, + agent_path: None, + git_info: None, + approval_mode: AskForApproval::Never, + sandbox_policy: SandboxPolicy::new_read_only_policy(), + token_usage: None, + first_user_message: None, + history, + }) +} diff --git a/codex-rs/thread-store/src/lib.rs b/codex-rs/thread-store/src/lib.rs index c8a083e1ca02..42b9297bcae1 100644 --- a/codex-rs/thread-store/src/lib.rs +++ b/codex-rs/thread-store/src/lib.rs @@ -5,6 +5,8 @@ //! any other backing store. mod error; +#[cfg(debug_assertions)] +mod in_memory; mod live_thread; mod local; mod remote; @@ -13,6 +15,10 @@ mod types; pub use error::ThreadStoreError; pub use error::ThreadStoreResult; +#[cfg(debug_assertions)] +pub use in_memory::InMemoryThreadStore; +#[cfg(debug_assertions)] +pub use in_memory::InMemoryThreadStoreCalls; pub use live_thread::LiveThread; pub use live_thread::LiveThreadInitGuard; pub use local::LocalThreadStore;