From 99d4f6677002c5d9f9c31b76e6bc657c95ed6812 Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Thu, 23 Apr 2026 14:07:13 -0700 Subject: [PATCH] app-server: persist device key bindings in sqlite ## Why Device-key providers should only own platform key material. The account/client binding used to authorize a signing payload is app-server state, and keeping that state in provider-specific metadata makes the same check harder to audit and harder to share across platform implementations. Persisting the binding in the shared state database gives the device-key crate a platform-neutral source of truth before it asks a provider to sign. It also lets app-server move potentially blocking key operations off the main message processor path, which matters once providers may wait for OS authentication prompts. Because key creation now spans provider-owned key material and a sqlite binding row, create must also clean up after partial failure. If sqlite persistence fails after provider creation succeeds, the new key would otherwise be left behind without a usable binding. ## What changed - Add a device_key_bindings state migration plus StateRuntime helpers keyed by key_id. - Add an async DeviceKeyBindingStore abstraction to codex-device-key and use it from DeviceKeyStore::create and DeviceKeyStore::sign. - Delete newly created provider key material when the binding write fails, preserving the original binding-store error when cleanup succeeds. - Keep provider calls behind async store methods and run the synchronous provider work through spawn_blocking. - Wire app-server device-key RPC handling to the SQLite-backed binding store. - Route device-key RPC handling through a shared spawned helper so transport rejection, async store/provider calls, and response delivery do not block the message processor loop. ## Validation - just fmt - cargo test -p codex-device-key - cargo test -p codex-state device_key - cargo test -p codex-app-server device_key - just fix -p codex-device-key - just fix -p codex-state - just fix -p codex-app-server - git diff --check --- codex-rs/Cargo.lock | 2 + codex-rs/app-server/src/device_key_api.rs | 99 +++- codex-rs/app-server/src/message_processor.rs | 127 +++-- codex-rs/device-key/Cargo.toml | 2 + codex-rs/device-key/src/lib.rs | 447 +++++++++++------- codex-rs/device-key/src/platform.rs | 14 +- .../migrations/0028_device_key_bindings.sql | 7 + codex-rs/state/src/lib.rs | 1 + codex-rs/state/src/runtime.rs | 4 + codex-rs/state/src/runtime/device_key.rs | 66 +++ .../state/src/runtime/device_key_tests.rs | 89 ++++ 11 files changed, 611 insertions(+), 247 deletions(-) create mode 100644 codex-rs/state/migrations/0028_device_key_bindings.sql create mode 100644 codex-rs/state/src/runtime/device_key.rs create mode 100644 codex-rs/state/src/runtime/device_key_tests.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 3f26f563f876..f6660dfdfd72 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2527,6 +2527,7 @@ dependencies = [ name = "codex-device-key" version = "0.0.0" dependencies = [ + "async-trait", "base64 0.22.1", "p256", "pretty_assertions", @@ -2534,6 +2535,7 @@ dependencies = [ "serde", "serde_json", "thiserror 2.0.18", + "tokio", "url", ] diff --git a/codex-rs/app-server/src/device_key_api.rs b/codex-rs/app-server/src/device_key_api.rs index beead123b02e..dbbc32f1c1d8 100644 --- a/codex-rs/app-server/src/device_key_api.rs +++ b/codex-rs/app-server/src/device_key_api.rs @@ -1,5 +1,6 @@ use crate::error_code::INTERNAL_ERROR_CODE; use crate::error_code::INVALID_REQUEST_ERROR_CODE; +use async_trait::async_trait; use base64::Engine; use base64::engine::general_purpose::STANDARD; use codex_app_server_protocol::DeviceKeyAlgorithm; @@ -13,6 +14,7 @@ use codex_app_server_protocol::DeviceKeySignPayload; use codex_app_server_protocol::DeviceKeySignResponse; use codex_app_server_protocol::JSONRPCErrorError; use codex_device_key::DeviceKeyBinding; +use codex_device_key::DeviceKeyBindingStore; use codex_device_key::DeviceKeyCreateRequest; use codex_device_key::DeviceKeyError; use codex_device_key::DeviceKeyGetPublicRequest; @@ -24,14 +26,29 @@ use codex_device_key::RemoteControlClientConnectionAudience; use codex_device_key::RemoteControlClientConnectionSignPayload; use codex_device_key::RemoteControlClientEnrollmentAudience; use codex_device_key::RemoteControlClientEnrollmentSignPayload; +use codex_state::DeviceKeyBindingRecord; +use codex_state::StateRuntime; +use std::fmt; +use std::path::PathBuf; +use std::sync::Arc; +use tokio::sync::OnceCell; -#[derive(Clone, Default)] +#[derive(Clone)] pub(crate) struct DeviceKeyApi { store: DeviceKeyStore, } impl DeviceKeyApi { - pub(crate) fn create( + pub(crate) fn new(sqlite_home: PathBuf, default_provider: String) -> Self { + Self { + store: DeviceKeyStore::new(Arc::new(StateDeviceKeyBindingStore::new( + sqlite_home, + default_provider, + ))), + } + } + + pub(crate) async fn create( &self, params: DeviceKeyCreateParams, ) -> Result { @@ -44,11 +61,12 @@ impl DeviceKeyApi { client_id: params.client_id, }, }) + .await .map_err(map_device_key_error)?; Ok(create_response_from_info(info)) } - pub(crate) fn public( + pub(crate) async fn public( &self, params: DeviceKeyPublicParams, ) -> Result { @@ -57,11 +75,12 @@ impl DeviceKeyApi { .get_public(DeviceKeyGetPublicRequest { key_id: params.key_id, }) + .await .map_err(map_device_key_error)?; Ok(public_response_from_info(info)) } - pub(crate) fn sign( + pub(crate) async fn sign( &self, params: DeviceKeySignParams, ) -> Result { @@ -71,6 +90,7 @@ impl DeviceKeyApi { key_id: params.key_id, payload: payload_from_params(params.payload), }) + .await .map_err(map_device_key_error)?; Ok(DeviceKeySignResponse { signature_der_base64: STANDARD.encode(signature.signature_der), @@ -80,6 +100,77 @@ impl DeviceKeyApi { } } +struct StateDeviceKeyBindingStore { + sqlite_home: PathBuf, + default_provider: String, + state_db: OnceCell>, +} + +impl StateDeviceKeyBindingStore { + fn new(sqlite_home: PathBuf, default_provider: String) -> Self { + Self { + sqlite_home, + default_provider, + state_db: OnceCell::new(), + } + } + + async fn state_db(&self) -> Result, DeviceKeyError> { + let sqlite_home = self.sqlite_home.clone(); + let default_provider = self.default_provider.clone(); + self.state_db + .get_or_try_init(|| async move { + StateRuntime::init(sqlite_home, default_provider) + .await + .map_err(|err| DeviceKeyError::Platform(err.to_string())) + }) + .await + .cloned() + } +} + +impl fmt::Debug for StateDeviceKeyBindingStore { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("StateDeviceKeyBindingStore") + .field("sqlite_home", &self.sqlite_home) + .field("default_provider", &self.default_provider) + .finish_non_exhaustive() + } +} + +#[async_trait] +impl DeviceKeyBindingStore for StateDeviceKeyBindingStore { + async fn get_binding(&self, key_id: &str) -> Result, DeviceKeyError> { + let state_db = self.state_db().await?; + state_db + .get_device_key_binding(key_id) + .await + .map(|record| { + record.map(|record| DeviceKeyBinding { + account_user_id: record.account_user_id, + client_id: record.client_id, + }) + }) + .map_err(|err| DeviceKeyError::Platform(err.to_string())) + } + + async fn put_binding( + &self, + key_id: &str, + binding: &DeviceKeyBinding, + ) -> Result<(), DeviceKeyError> { + let state_db = self.state_db().await?; + state_db + .upsert_device_key_binding(&DeviceKeyBindingRecord { + key_id: key_id.to_string(), + account_user_id: binding.account_user_id.clone(), + client_id: binding.client_id.clone(), + }) + .await + .map_err(|err| DeviceKeyError::Platform(err.to_string())) + } +} + fn create_response_from_info(info: DeviceKeyInfo) -> DeviceKeyCreateResponse { DeviceKeyCreateResponse { key_id: info.key_id, diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index 48e2aa6a1459..67f56005c733 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -325,7 +325,8 @@ impl MessageProcessor { thread_manager.clone(), analytics_events_client.clone(), ); - let device_key_api = DeviceKeyApi::default(); + let device_key_api = + DeviceKeyApi::new(config.sqlite_home.clone(), config.model_provider_id.clone()); let external_agent_config_api = ExternalAgentConfigApi::new(config.codex_home.to_path_buf()); let fs_api = FsApi::new( @@ -882,8 +883,7 @@ impl MessageProcessor { }, params, device_key_requests_allowed, - ) - .await; + ); } ClientRequest::DeviceKeyPublic { request_id, params } => { self.handle_device_key_public( @@ -893,8 +893,7 @@ impl MessageProcessor { }, params, device_key_requests_allowed, - ) - .await; + ); } ClientRequest::DeviceKeySign { request_id, params } => { self.handle_device_key_sign( @@ -904,8 +903,7 @@ impl MessageProcessor { }, params, device_key_requests_allowed, - ) - .await; + ); } ClientRequest::FsReadFile { request_id, params } => { self.handle_fs_read_file( @@ -1173,96 +1171,81 @@ impl MessageProcessor { } } - async fn handle_device_key_create( + fn handle_device_key_create( &self, request_id: ConnectionRequestId, params: DeviceKeyCreateParams, device_key_requests_allowed: bool, ) { - if self - .reject_device_key_request_over_remote_transport( - request_id.clone(), - "device/key/create", - device_key_requests_allowed, - ) - .await - { - return; - } - - match self.device_key_api.create(params) { - Ok(response) => self.outgoing.send_response(request_id, response).await, - Err(error) => self.outgoing.send_error(request_id, error).await, - } + self.spawn_device_key_request( + request_id, + "device/key/create", + device_key_requests_allowed, + move |device_key_api| async move { device_key_api.create(params).await }, + ); } - async fn handle_device_key_public( + fn handle_device_key_public( &self, request_id: ConnectionRequestId, params: DeviceKeyPublicParams, device_key_requests_allowed: bool, ) { - if self - .reject_device_key_request_over_remote_transport( - request_id.clone(), - "device/key/public", - device_key_requests_allowed, - ) - .await - { - return; - } - - match self.device_key_api.public(params) { - Ok(response) => self.outgoing.send_response(request_id, response).await, - Err(error) => self.outgoing.send_error(request_id, error).await, - } + self.spawn_device_key_request( + request_id, + "device/key/public", + device_key_requests_allowed, + move |device_key_api| async move { device_key_api.public(params).await }, + ); } - async fn handle_device_key_sign( + fn handle_device_key_sign( &self, request_id: ConnectionRequestId, params: DeviceKeySignParams, device_key_requests_allowed: bool, ) { - if self - .reject_device_key_request_over_remote_transport( - request_id.clone(), - "device/key/sign", - device_key_requests_allowed, - ) - .await - { - return; - } - - match self.device_key_api.sign(params) { - Ok(response) => self.outgoing.send_response(request_id, response).await, - Err(error) => self.outgoing.send_error(request_id, error).await, - } + self.spawn_device_key_request( + request_id, + "device/key/sign", + device_key_requests_allowed, + move |device_key_api| async move { device_key_api.sign(params).await }, + ); } - async fn reject_device_key_request_over_remote_transport( + fn spawn_device_key_request( &self, request_id: ConnectionRequestId, - method: &str, + method: &'static str, device_key_requests_allowed: bool, - ) -> bool { - if device_key_requests_allowed { - return false; - } + run_request: F, + ) where + R: serde::Serialize + Send + 'static, + F: FnOnce(DeviceKeyApi) -> Fut + Send + 'static, + Fut: Future> + Send + 'static, + { + let device_key_api = self.device_key_api.clone(); + let outgoing = Arc::clone(&self.outgoing); + tokio::spawn(async move { + if !device_key_requests_allowed { + outgoing + .send_error( + request_id, + JSONRPCErrorError { + code: INVALID_REQUEST_ERROR_CODE, + message: format!("{method} is not available over remote transports"), + data: None, + }, + ) + .await; + return; + } - self.outgoing - .send_error( - request_id, - JSONRPCErrorError { - code: INVALID_REQUEST_ERROR_CODE, - message: format!("{method} is not available over remote transports"), - data: None, - }, - ) - .await; - true + match run_request(device_key_api).await { + Ok(response) => outgoing.send_response(request_id, response).await, + Err(error) => outgoing.send_error(request_id, error).await, + } + }); } async fn handle_external_agent_config_detect( diff --git a/codex-rs/device-key/Cargo.toml b/codex-rs/device-key/Cargo.toml index f61a886e0182..6ad280efc85f 100644 --- a/codex-rs/device-key/Cargo.toml +++ b/codex-rs/device-key/Cargo.toml @@ -8,12 +8,14 @@ license.workspace = true workspace = true [dependencies] +async-trait = { workspace = true } base64 = { workspace = true } p256 = { workspace = true, features = ["ecdsa", "pkcs8"] } rand = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } thiserror = { workspace = true } +tokio = { workspace = true, features = ["rt"] } url = { workspace = true } [dev-dependencies] diff --git a/codex-rs/device-key/src/lib.rs b/codex-rs/device-key/src/lib.rs index 61d34e034e7f..f901c633c99c 100644 --- a/codex-rs/device-key/src/lib.rs +++ b/codex-rs/device-key/src/lib.rs @@ -1,3 +1,4 @@ +use async_trait::async_trait; use base64::Engine; use base64::engine::general_purpose::URL_SAFE_NO_PAD; use p256::pkcs8::EncodePublicKey; @@ -211,47 +212,82 @@ pub enum DeviceKeyError { #[derive(Debug, Clone)] pub struct DeviceKeyStore { provider: Arc, + bindings: Arc, } -impl Default for DeviceKeyStore { - fn default() -> Self { +impl DeviceKeyStore { + pub fn new(bindings: Arc) -> Self { Self { provider: platform::default_provider(), + bindings, } } -} -impl DeviceKeyStore { - pub fn create(&self, request: DeviceKeyCreateRequest) -> Result { + pub async fn create( + &self, + request: DeviceKeyCreateRequest, + ) -> Result { let key_id_random = random_key_id_random(); validate_binding(&request.binding.account_user_id, &request.binding.client_id)?; - self.provider.create(ProviderCreateRequest { - key_id_random: &key_id_random, - protection_policy: request.protection_policy, - binding: &request.binding, + let provider = Arc::clone(&self.provider); + let info = spawn_provider_call(move || { + provider.create(ProviderCreateRequest { + key_id_random, + protection_policy: request.protection_policy, + }) }) + .await?; + match self + .bindings + .put_binding(&info.key_id, &request.binding) + .await + { + Ok(()) => Ok(info), + Err(store_error) => { + let provider = Arc::clone(&self.provider); + let key_id = info.key_id; + let protection_class = info.protection_class; + if let Err(delete_error) = + spawn_provider_call(move || provider.delete(&key_id, protection_class)).await + { + return Err(DeviceKeyError::Platform(format!( + "failed to store device key binding ({store_error}); failed to delete newly created key ({delete_error})" + ))); + } + Err(store_error) + } + } } - pub fn get_public( + pub async fn get_public( &self, request: DeviceKeyGetPublicRequest, ) -> Result { let protection_class = validate_key_id(&request.key_id)?; - self.provider.get_public(&request.key_id, protection_class) + let provider = Arc::clone(&self.provider); + spawn_provider_call(move || provider.get_public(&request.key_id, protection_class)).await } - pub fn sign( + pub async fn sign( &self, request: DeviceKeySignRequest, ) -> Result { let protection_class = validate_key_id(&request.key_id)?; validate_payload(&request.payload)?; - let binding = self.provider.binding(&request.key_id, protection_class)?; + let binding = self + .bindings + .get_binding(&request.key_id) + .await? + .ok_or(DeviceKeyError::KeyNotFound)?; validate_payload_binding(&request.payload, &binding)?; let signed_payload = device_key_signing_payload_bytes(&request.payload)?; - let signature = self - .provider - .sign(&request.key_id, protection_class, &signed_payload)?; + let provider = Arc::clone(&self.provider); + let key_id = request.key_id; + let provider_payload = signed_payload.clone(); + let signature = spawn_provider_call(move || { + provider.sign(&key_id, protection_class, &provider_payload) + }) + .await?; Ok(DeviceKeySignature { signature_der: signature.signature_der, signed_payload, @@ -260,21 +296,79 @@ impl DeviceKeyStore { } #[cfg(test)] - fn with_provider(provider: Arc) -> Self { - Self { provider } + fn new_for_test(provider: Arc) -> Self { + Self { + provider, + bindings: Arc::new(InMemoryDeviceKeyBindingStore::default()), + } + } +} + +async fn spawn_provider_call(call: F) -> Result +where + T: Send + 'static, + F: FnOnce() -> Result + Send + 'static, +{ + tokio::task::spawn_blocking(call) + .await + .map_err(|err| DeviceKeyError::Platform(format!("device key task failed: {err}")))? +} + +/// Persists the account/client binding for a generated device key. +/// +/// Device-key providers only own platform key material. Implementations store the binding in a +/// platform-neutral location so signing can reject payloads for the wrong account or client before +/// asking a provider to use the private key. +#[async_trait] +pub trait DeviceKeyBindingStore: Debug + Send + Sync { + async fn get_binding(&self, key_id: &str) -> Result, DeviceKeyError>; + async fn put_binding( + &self, + key_id: &str, + binding: &DeviceKeyBinding, + ) -> Result<(), DeviceKeyError>; +} + +#[cfg(test)] +#[derive(Debug, Default)] +struct InMemoryDeviceKeyBindingStore { + bindings: std::sync::Mutex>, +} + +#[cfg(test)] +#[async_trait] +impl DeviceKeyBindingStore for InMemoryDeviceKeyBindingStore { + async fn get_binding(&self, key_id: &str) -> Result, DeviceKeyError> { + Ok(self + .bindings + .lock() + .map_err(|err| DeviceKeyError::Platform(err.to_string()))? + .get(key_id) + .cloned()) + } + + async fn put_binding( + &self, + key_id: &str, + binding: &DeviceKeyBinding, + ) -> Result<(), DeviceKeyError> { + self.bindings + .lock() + .map_err(|err| DeviceKeyError::Platform(err.to_string()))? + .insert(key_id.to_string(), binding.clone()); + Ok(()) } } #[derive(Debug)] -struct ProviderCreateRequest<'a> { - key_id_random: &'a str, +struct ProviderCreateRequest { + key_id_random: String, protection_policy: DeviceKeyProtectionPolicy, - binding: &'a DeviceKeyBinding, } -impl ProviderCreateRequest<'_> { +impl ProviderCreateRequest { fn key_id_for(&self, protection_class: DeviceKeyProtectionClass) -> String { - key_id_for_protection_class(protection_class, self.key_id_random) + key_id_for_protection_class(protection_class, &self.key_id_random) } } @@ -283,17 +377,22 @@ impl ProviderCreateRequest<'_> { /// Implementations must never expose a generic arbitrary-byte signing API outside this crate. The /// crate validates and serializes accepted structured payloads before calling `sign`. trait DeviceKeyProvider: Debug + Send + Sync { - fn create(&self, request: ProviderCreateRequest<'_>) -> Result; - fn get_public( + fn create(&self, request: ProviderCreateRequest) -> Result; + /// Deletes provider-owned key material after a create operation cannot be completed. + /// + /// Implementations should treat missing keys as success where the platform allows it, since + /// cleanup can race with external deletion and should not mask the original persistence error + /// unless deletion itself fails unexpectedly. + fn delete( &self, key_id: &str, protection_class: DeviceKeyProtectionClass, - ) -> Result; - fn binding( + ) -> Result<(), DeviceKeyError>; + fn get_public( &self, key_id: &str, protection_class: DeviceKeyProtectionClass, - ) -> Result; + ) -> Result; fn sign( &self, key_id: &str, @@ -629,7 +728,6 @@ mod tests { struct MemoryProvider { class: DeviceKeyProtectionClass, keys: Mutex>, - bindings: Mutex>, } impl MemoryProvider { @@ -637,16 +735,16 @@ mod tests { Self { class, keys: Mutex::new(HashMap::new()), - bindings: Mutex::new(HashMap::new()), } } + + fn key_count(&self) -> usize { + self.keys.lock().expect("memory provider lock").len() + } } impl DeviceKeyProvider for MemoryProvider { - fn create( - &self, - request: ProviderCreateRequest<'_>, - ) -> Result { + fn create(&self, request: ProviderCreateRequest) -> Result { if !request.protection_policy.allows(self.class) { return Err(DeviceKeyError::DegradedProtectionNotAllowed { available: self.class, @@ -660,43 +758,38 @@ mod tests { let signing_key = keys .entry(key_id.clone()) .or_insert_with(|| SigningKey::random(&mut OsRng)); - self.bindings - .lock() - .map_err(|err| DeviceKeyError::Platform(err.to_string()))? - .insert(key_id.clone(), request.binding.clone()); memory_key_info(&key_id, signing_key, self.class) } - fn get_public( + fn delete( &self, key_id: &str, protection_class: DeviceKeyProtectionClass, - ) -> Result { + ) -> Result<(), DeviceKeyError> { if protection_class != self.class { - return Err(DeviceKeyError::KeyNotFound); + return Ok(()); } - let keys = self - .keys + self.keys .lock() - .map_err(|err| DeviceKeyError::Platform(err.to_string()))?; - let signing_key = keys.get(key_id).ok_or(DeviceKeyError::KeyNotFound)?; - memory_key_info(key_id, signing_key, self.class) + .map_err(|err| DeviceKeyError::Platform(err.to_string()))? + .remove(key_id); + Ok(()) } - fn binding( + fn get_public( &self, key_id: &str, protection_class: DeviceKeyProtectionClass, - ) -> Result { + ) -> Result { if protection_class != self.class { return Err(DeviceKeyError::KeyNotFound); } - self.bindings + let keys = self + .keys .lock() - .map_err(|err| DeviceKeyError::Platform(err.to_string()))? - .get(key_id) - .cloned() - .ok_or(DeviceKeyError::KeyNotFound) + .map_err(|err| DeviceKeyError::Platform(err.to_string()))?; + let signing_key = keys.get(key_id).ok_or(DeviceKeyError::KeyNotFound)?; + memory_key_info(key_id, signing_key, self.class) } fn sign( @@ -721,6 +814,27 @@ mod tests { } } + #[derive(Debug)] + struct FailingBindingStore; + + #[async_trait] + impl DeviceKeyBindingStore for FailingBindingStore { + async fn get_binding( + &self, + _key_id: &str, + ) -> Result, DeviceKeyError> { + Ok(None) + } + + async fn put_binding( + &self, + _key_id: &str, + _binding: &DeviceKeyBinding, + ) -> Result<(), DeviceKeyError> { + Err(DeviceKeyError::Platform("binding write failed".to_string())) + } + } + fn memory_key_info( key_id: &str, signing_key: &SigningKey, @@ -741,7 +855,14 @@ mod tests { } fn store(class: DeviceKeyProtectionClass) -> DeviceKeyStore { - DeviceKeyStore::with_provider(Arc::new(MemoryProvider::new(class))) + DeviceKeyStore::new_for_test(Arc::new(MemoryProvider::new(class))) + } + + fn block_on(future: impl std::future::Future) -> T { + tokio::runtime::Builder::new_current_thread() + .build() + .expect("build test runtime") + .block_on(future) } fn create_request(protection_policy: DeviceKeyProtectionPolicy) -> DeviceKeyCreateRequest { @@ -808,9 +929,11 @@ mod tests { #[test] fn create_requires_explicit_degraded_protection() { - let err = store(DeviceKeyProtectionClass::OsProtectedNonextractable) - .create(create_request(DeviceKeyProtectionPolicy::HardwareOnly)) - .expect_err("OS-protected fallback should require opt-in"); + let err = block_on( + store(DeviceKeyProtectionClass::OsProtectedNonextractable) + .create(create_request(DeviceKeyProtectionPolicy::HardwareOnly)), + ) + .expect_err("OS-protected fallback should require opt-in"); assert!( matches!( @@ -825,11 +948,12 @@ mod tests { #[test] fn create_allows_os_protected_nonextractable_policy() { - let info = store(DeviceKeyProtectionClass::OsProtectedNonextractable) - .create(create_request( + let info = block_on( + store(DeviceKeyProtectionClass::OsProtectedNonextractable).create(create_request( DeviceKeyProtectionPolicy::AllowOsProtectedNonextractable, - )) - .expect("OS-protected fallback should be allowed by policy"); + )), + ) + .expect("OS-protected fallback should be allowed by policy"); assert_eq!( info.protection_class, @@ -844,18 +968,38 @@ mod tests { #[test] fn create_generates_distinct_key_ids() { let store = store(DeviceKeyProtectionClass::HardwareTpm); - let first = store - .create(create_request(DeviceKeyProtectionPolicy::HardwareOnly)) - .expect("create should succeed"); - let second = store - .create(create_request(DeviceKeyProtectionPolicy::HardwareOnly)) + let first = block_on(store.create(create_request(DeviceKeyProtectionPolicy::HardwareOnly))) .expect("create should succeed"); + let second = + block_on(store.create(create_request(DeviceKeyProtectionPolicy::HardwareOnly))) + .expect("create should succeed"); assert_ne!(second.key_id, first.key_id); assert_valid_generated_key_id(&first.key_id, DeviceKeyProtectionClass::HardwareTpm); assert_valid_generated_key_id(&second.key_id, DeviceKeyProtectionClass::HardwareTpm); } + #[test] + fn create_deletes_provider_key_when_binding_write_fails() { + let provider = Arc::new(MemoryProvider::new(DeviceKeyProtectionClass::HardwareTpm)); + let store = DeviceKeyStore { + provider: provider.clone(), + bindings: Arc::new(FailingBindingStore), + }; + + let err = block_on(store.create(create_request(DeviceKeyProtectionPolicy::HardwareOnly))) + .expect_err("binding failure should fail create"); + + assert!( + matches!( + &err, + DeviceKeyError::Platform(message) if message == "binding write failed" + ), + "unexpected error: {err:?}" + ); + assert_eq!(provider.key_count(), 0); + } + #[test] fn key_id_validation_rejects_untrusted_namespaces() { let valid_suffix = URL_SAFE_NO_PAD.encode([0_u8; DEVICE_KEY_ID_RANDOM_BYTES]); @@ -902,11 +1046,10 @@ mod tests { let store = store(DeviceKeyProtectionClass::HardwareTpm); let malformed_key_id = "not-a-device-key".to_string(); - let err = store - .get_public(DeviceKeyGetPublicRequest { - key_id: malformed_key_id.clone(), - }) - .expect_err("malformed get_public key id should fail"); + let err = block_on(store.get_public(DeviceKeyGetPublicRequest { + key_id: malformed_key_id.clone(), + })) + .expect_err("malformed get_public key id should fail"); assert!( matches!( err, @@ -915,12 +1058,11 @@ mod tests { "unexpected get_public error: {err:?}" ); - let err = store - .sign(DeviceKeySignRequest { - key_id: malformed_key_id, - payload: remote_control_client_connection_payload(), - }) - .expect_err("malformed sign key id should fail"); + let err = block_on(store.sign(DeviceKeySignRequest { + key_id: malformed_key_id, + payload: remote_control_client_connection_payload(), + })) + .expect_err("malformed sign key id should fail"); assert!( matches!( err, @@ -933,8 +1075,7 @@ mod tests { #[test] fn sign_rejects_empty_account_user_id() { let store = store(DeviceKeyProtectionClass::HardwareTpm); - let info = store - .create(create_request(DeviceKeyProtectionPolicy::HardwareOnly)) + let info = block_on(store.create(create_request(DeviceKeyProtectionPolicy::HardwareOnly))) .expect("create should succeed"); let mut payload = remote_control_client_connection_payload(); match &mut payload { @@ -944,12 +1085,11 @@ mod tests { DeviceKeySignPayload::RemoteControlClientEnrollment(_) => unreachable!(), } - let err = store - .sign(DeviceKeySignRequest { - key_id: info.key_id, - payload, - }) - .expect_err("empty account user id should fail"); + let err = block_on(store.sign(DeviceKeySignRequest { + key_id: info.key_id, + payload, + })) + .expect_err("empty account user id should fail"); assert!( matches!( @@ -963,18 +1103,16 @@ mod tests { #[test] fn sign_uses_structured_payload() { let store = store(DeviceKeyProtectionClass::HardwareTpm); - let info = store - .create(create_request(DeviceKeyProtectionPolicy::HardwareOnly)) + let info = block_on(store.create(create_request(DeviceKeyProtectionPolicy::HardwareOnly))) .expect("create should succeed"); let payload = remote_control_client_connection_payload(); let signed_payload = device_key_signing_payload_bytes(&payload).expect("payload should serialize"); - let signature = store - .sign(DeviceKeySignRequest { - key_id: info.key_id, - payload, - }) - .expect("sign should succeed"); + let signature = block_on(store.sign(DeviceKeySignRequest { + key_id: info.key_id, + payload, + })) + .expect("sign should succeed"); assert_eq!(signature.signed_payload, signed_payload); let verifying_key = VerifyingKey::from_public_key_der(&info.public_key_spki_der) @@ -1063,8 +1201,7 @@ mod tests { #[test] fn sign_rejects_malformed_token_hash() { let store = store(DeviceKeyProtectionClass::HardwareTpm); - let info = store - .create(create_request(DeviceKeyProtectionPolicy::HardwareOnly)) + let info = block_on(store.create(create_request(DeviceKeyProtectionPolicy::HardwareOnly))) .expect("create should succeed"); let mut payload = remote_control_client_connection_payload(); match &mut payload { @@ -1074,12 +1211,11 @@ mod tests { DeviceKeySignPayload::RemoteControlClientEnrollment(_) => unreachable!(), } - let err = store - .sign(DeviceKeySignRequest { - key_id: info.key_id, - payload, - }) - .expect_err("malformed token hash should fail"); + let err = block_on(store.sign(DeviceKeySignRequest { + key_id: info.key_id, + payload, + })) + .expect_err("malformed token hash should fail"); assert!( matches!( @@ -1095,8 +1231,7 @@ mod tests { #[test] fn sign_rejects_unexpected_scopes() { let store = store(DeviceKeyProtectionClass::HardwareTpm); - let info = store - .create(create_request(DeviceKeyProtectionPolicy::HardwareOnly)) + let info = block_on(store.create(create_request(DeviceKeyProtectionPolicy::HardwareOnly))) .expect("create should succeed"); let mut payload = remote_control_client_connection_payload(); match &mut payload { @@ -1106,12 +1241,11 @@ mod tests { DeviceKeySignPayload::RemoteControlClientEnrollment(_) => unreachable!(), } - let err = store - .sign(DeviceKeySignRequest { - key_id: info.key_id, - payload, - }) - .expect_err("unexpected scope should fail"); + let err = block_on(store.sign(DeviceKeySignRequest { + key_id: info.key_id, + payload, + })) + .expect_err("unexpected scope should fail"); assert!( matches!( @@ -1127,8 +1261,7 @@ mod tests { #[test] fn sign_rejects_malformed_enrollment_identity_hash() { let store = store(DeviceKeyProtectionClass::HardwareTpm); - let info = store - .create(create_request(DeviceKeyProtectionPolicy::HardwareOnly)) + let info = block_on(store.create(create_request(DeviceKeyProtectionPolicy::HardwareOnly))) .expect("create should succeed"); let mut payload = remote_control_client_enrollment_payload(); match &mut payload { @@ -1138,12 +1271,11 @@ mod tests { DeviceKeySignPayload::RemoteControlClientConnection(_) => unreachable!(), } - let err = store - .sign(DeviceKeySignRequest { - key_id: info.key_id, - payload, - }) - .expect_err("malformed device identity hash should fail"); + let err = block_on(store.sign(DeviceKeySignRequest { + key_id: info.key_id, + payload, + })) + .expect_err("malformed device identity hash should fail"); assert!( matches!( @@ -1159,8 +1291,7 @@ mod tests { #[test] fn sign_rejects_empty_target_binding() { let store = store(DeviceKeyProtectionClass::HardwareTpm); - let info = store - .create(create_request(DeviceKeyProtectionPolicy::HardwareOnly)) + let info = block_on(store.create(create_request(DeviceKeyProtectionPolicy::HardwareOnly))) .expect("create should succeed"); let mut payload = remote_control_client_connection_payload(); match &mut payload { @@ -1170,12 +1301,11 @@ mod tests { DeviceKeySignPayload::RemoteControlClientEnrollment(_) => unreachable!(), } - let err = store - .sign(DeviceKeySignRequest { - key_id: info.key_id, - payload, - }) - .expect_err("empty target origin should fail"); + let err = block_on(store.sign(DeviceKeySignRequest { + key_id: info.key_id, + payload, + })) + .expect_err("empty target origin should fail"); assert!( matches!( @@ -1191,8 +1321,7 @@ mod tests { #[test] fn sign_rejects_remote_control_paths_for_other_payload_shapes() { let store = store(DeviceKeyProtectionClass::HardwareTpm); - let info = store - .create(create_request(DeviceKeyProtectionPolicy::HardwareOnly)) + let info = block_on(store.create(create_request(DeviceKeyProtectionPolicy::HardwareOnly))) .expect("create should succeed"); let mut connection_payload = remote_control_client_connection_payload(); match &mut connection_payload { @@ -1202,12 +1331,11 @@ mod tests { DeviceKeySignPayload::RemoteControlClientEnrollment(_) => unreachable!(), } - let err = store - .sign(DeviceKeySignRequest { - key_id: info.key_id.clone(), - payload: connection_payload, - }) - .expect_err("connection payload should reject enrollment path"); + let err = block_on(store.sign(DeviceKeySignRequest { + key_id: info.key_id.clone(), + payload: connection_payload, + })) + .expect_err("connection payload should reject enrollment path"); assert!( matches!( err, @@ -1226,12 +1354,11 @@ mod tests { DeviceKeySignPayload::RemoteControlClientConnection(_) => unreachable!(), } - let err = store - .sign(DeviceKeySignRequest { - key_id: info.key_id, - payload: enrollment_payload, - }) - .expect_err("enrollment payload should reject connection path"); + let err = block_on(store.sign(DeviceKeySignRequest { + key_id: info.key_id, + payload: enrollment_payload, + })) + .expect_err("enrollment payload should reject connection path"); assert!( matches!( err, @@ -1283,8 +1410,7 @@ mod tests { #[test] fn sign_rejects_empty_session_binding() { let store = store(DeviceKeyProtectionClass::HardwareTpm); - let info = store - .create(create_request(DeviceKeyProtectionPolicy::HardwareOnly)) + let info = block_on(store.create(create_request(DeviceKeyProtectionPolicy::HardwareOnly))) .expect("create should succeed"); let mut payload = remote_control_client_connection_payload(); match &mut payload { @@ -1294,12 +1420,11 @@ mod tests { DeviceKeySignPayload::RemoteControlClientEnrollment(_) => unreachable!(), } - let err = store - .sign(DeviceKeySignRequest { - key_id: info.key_id, - payload, - }) - .expect_err("empty session id should fail"); + let err = block_on(store.sign(DeviceKeySignRequest { + key_id: info.key_id, + payload, + })) + .expect_err("empty session id should fail"); assert!( matches!( @@ -1313,8 +1438,7 @@ mod tests { #[test] fn sign_rejects_empty_client_id() { let store = store(DeviceKeyProtectionClass::HardwareTpm); - let info = store - .create(create_request(DeviceKeyProtectionPolicy::HardwareOnly)) + let info = block_on(store.create(create_request(DeviceKeyProtectionPolicy::HardwareOnly))) .expect("create should succeed"); let mut payload = remote_control_client_connection_payload(); match &mut payload { @@ -1324,12 +1448,11 @@ mod tests { DeviceKeySignPayload::RemoteControlClientEnrollment(_) => unreachable!(), } - let err = store - .sign(DeviceKeySignRequest { - key_id: info.key_id, - payload, - }) - .expect_err("empty client id should fail"); + let err = block_on(store.sign(DeviceKeySignRequest { + key_id: info.key_id, + payload, + })) + .expect_err("empty client id should fail"); assert!( matches!( @@ -1343,8 +1466,7 @@ mod tests { #[test] fn sign_rejects_mismatched_binding() { let store = store(DeviceKeyProtectionClass::HardwareTpm); - let info = store - .create(create_request(DeviceKeyProtectionPolicy::HardwareOnly)) + let info = block_on(store.create(create_request(DeviceKeyProtectionPolicy::HardwareOnly))) .expect("create should succeed"); let mut payload = remote_control_client_connection_payload(); match &mut payload { @@ -1354,12 +1476,11 @@ mod tests { DeviceKeySignPayload::RemoteControlClientEnrollment(_) => unreachable!(), } - let err = store - .sign(DeviceKeySignRequest { - key_id: info.key_id, - payload, - }) - .expect_err("mismatched binding should fail"); + let err = block_on(store.sign(DeviceKeySignRequest { + key_id: info.key_id, + payload, + })) + .expect_err("mismatched binding should fail"); assert!( matches!( diff --git a/codex-rs/device-key/src/platform.rs b/codex-rs/device-key/src/platform.rs index 3dbcb168e7ed..60a2f508364b 100644 --- a/codex-rs/device-key/src/platform.rs +++ b/codex-rs/device-key/src/platform.rs @@ -1,4 +1,3 @@ -use crate::DeviceKeyBinding; use crate::DeviceKeyError; use crate::DeviceKeyInfo; use crate::DeviceKeyProtectionClass; @@ -15,28 +14,27 @@ pub(crate) fn default_provider() -> Arc { pub(crate) struct UnsupportedDeviceKeyProvider; impl DeviceKeyProvider for UnsupportedDeviceKeyProvider { - fn create(&self, request: ProviderCreateRequest<'_>) -> Result { + fn create(&self, request: ProviderCreateRequest) -> Result { let _ = request.key_id_for(DeviceKeyProtectionClass::HardwareTpm); let _ = request .protection_policy .allows(DeviceKeyProtectionClass::HardwareTpm); - let _ = request.binding; Err(DeviceKeyError::HardwareBackedKeysUnavailable) } - fn get_public( + fn delete( &self, _key_id: &str, _protection_class: DeviceKeyProtectionClass, - ) -> Result { - Err(DeviceKeyError::KeyNotFound) + ) -> Result<(), DeviceKeyError> { + Ok(()) } - fn binding( + fn get_public( &self, _key_id: &str, _protection_class: DeviceKeyProtectionClass, - ) -> Result { + ) -> Result { Err(DeviceKeyError::KeyNotFound) } diff --git a/codex-rs/state/migrations/0028_device_key_bindings.sql b/codex-rs/state/migrations/0028_device_key_bindings.sql new file mode 100644 index 000000000000..d7b660bf6819 --- /dev/null +++ b/codex-rs/state/migrations/0028_device_key_bindings.sql @@ -0,0 +1,7 @@ +CREATE TABLE device_key_bindings ( + key_id TEXT PRIMARY KEY NOT NULL, + account_user_id TEXT NOT NULL, + client_id TEXT NOT NULL, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL +); diff --git a/codex-rs/state/src/lib.rs b/codex-rs/state/src/lib.rs index 36676d5a4651..49529f3a33bb 100644 --- a/codex-rs/state/src/lib.rs +++ b/codex-rs/state/src/lib.rs @@ -47,6 +47,7 @@ pub use model::Stage1StartupClaimParams; pub use model::ThreadMetadata; pub use model::ThreadMetadataBuilder; pub use model::ThreadsPage; +pub use runtime::DeviceKeyBindingRecord; pub use runtime::RemoteControlEnrollmentRecord; pub use runtime::ThreadFilterOptions; pub use runtime::logs_db_filename; diff --git a/codex-rs/state/src/runtime.rs b/codex-rs/state/src/runtime.rs index 67eb537702cc..ec3ce79e820d 100644 --- a/codex-rs/state/src/runtime.rs +++ b/codex-rs/state/src/runtime.rs @@ -55,6 +55,9 @@ use tracing::warn; mod agent_jobs; mod backfill; +mod device_key; +#[cfg(test)] +mod device_key_tests; mod logs; mod memories; mod remote_control; @@ -62,6 +65,7 @@ mod remote_control; mod test_support; mod threads; +pub use device_key::DeviceKeyBindingRecord; pub use remote_control::RemoteControlEnrollmentRecord; pub use threads::ThreadFilterOptions; diff --git a/codex-rs/state/src/runtime/device_key.rs b/codex-rs/state/src/runtime/device_key.rs new file mode 100644 index 000000000000..bb3f20f75903 --- /dev/null +++ b/codex-rs/state/src/runtime/device_key.rs @@ -0,0 +1,66 @@ +use super::*; + +/// Persisted account/client binding for a generated device key. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DeviceKeyBindingRecord { + pub key_id: String, + pub account_user_id: String, + pub client_id: String, +} + +impl StateRuntime { + pub async fn get_device_key_binding( + &self, + key_id: &str, + ) -> anyhow::Result> { + let row = sqlx::query( + r#" +SELECT key_id, account_user_id, client_id +FROM device_key_bindings +WHERE key_id = ? + "#, + ) + .bind(key_id) + .fetch_optional(self.pool.as_ref()) + .await?; + + row.map(|row| { + Ok(DeviceKeyBindingRecord { + key_id: row.try_get("key_id")?, + account_user_id: row.try_get("account_user_id")?, + client_id: row.try_get("client_id")?, + }) + }) + .transpose() + } + + pub async fn upsert_device_key_binding( + &self, + binding: &DeviceKeyBindingRecord, + ) -> anyhow::Result<()> { + let now = Utc::now().timestamp(); + sqlx::query( + r#" +INSERT INTO device_key_bindings ( + key_id, + account_user_id, + client_id, + created_at, + updated_at +) VALUES (?, ?, ?, ?, ?) +ON CONFLICT(key_id) DO UPDATE SET + account_user_id = excluded.account_user_id, + client_id = excluded.client_id, + updated_at = excluded.updated_at + "#, + ) + .bind(&binding.key_id) + .bind(&binding.account_user_id) + .bind(&binding.client_id) + .bind(now) + .bind(now) + .execute(self.pool.as_ref()) + .await?; + Ok(()) + } +} diff --git a/codex-rs/state/src/runtime/device_key_tests.rs b/codex-rs/state/src/runtime/device_key_tests.rs new file mode 100644 index 000000000000..a29eaea94bd8 --- /dev/null +++ b/codex-rs/state/src/runtime/device_key_tests.rs @@ -0,0 +1,89 @@ +use super::DeviceKeyBindingRecord; +use super::StateRuntime; +use super::test_support::unique_temp_dir; +use pretty_assertions::assert_eq; + +#[tokio::test] +async fn device_key_binding_round_trips_by_key_id() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string()) + .await + .expect("initialize runtime"); + + let first = DeviceKeyBindingRecord { + key_id: "dk_tpm_AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA".to_string(), + account_user_id: "account-user-a".to_string(), + client_id: "cli_a".to_string(), + }; + let second = DeviceKeyBindingRecord { + key_id: "dk_tpm_BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB".to_string(), + account_user_id: "account-user-b".to_string(), + client_id: "cli_b".to_string(), + }; + + runtime + .upsert_device_key_binding(&first) + .await + .expect("insert first binding"); + runtime + .upsert_device_key_binding(&second) + .await + .expect("insert second binding"); + + assert_eq!( + runtime + .get_device_key_binding(&first.key_id) + .await + .expect("load first binding"), + Some(first) + ); + assert_eq!( + runtime + .get_device_key_binding("dk_tpm_missing") + .await + .expect("load missing binding"), + None + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; +} + +#[tokio::test] +async fn device_key_binding_upsert_updates_existing_binding() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string()) + .await + .expect("initialize runtime"); + + let key_id = "dk_tpm_AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA".to_string(); + runtime + .upsert_device_key_binding(&DeviceKeyBindingRecord { + key_id: key_id.clone(), + account_user_id: "account-user-a".to_string(), + client_id: "cli_a".to_string(), + }) + .await + .expect("insert binding"); + runtime + .upsert_device_key_binding(&DeviceKeyBindingRecord { + key_id: key_id.clone(), + account_user_id: "account-user-b".to_string(), + client_id: "cli_b".to_string(), + }) + .await + .expect("update binding"); + + assert_eq!( + runtime + .get_device_key_binding(&key_id) + .await + .expect("load updated binding"), + Some(DeviceKeyBindingRecord { + key_id, + account_user_id: "account-user-b".to_string(), + client_id: "cli_b".to_string(), + }) + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; +}