Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions codex-rs/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

99 changes: 95 additions & 4 deletions codex-rs/app-server/src/device_key_api.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::error_code::INTERNAL_ERROR_CODE;
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
use async_trait::async_trait;
use base64::Engine;
use base64::engine::general_purpose::STANDARD;
use codex_app_server_protocol::DeviceKeyAlgorithm;
Expand All @@ -13,6 +14,7 @@ use codex_app_server_protocol::DeviceKeySignPayload;
use codex_app_server_protocol::DeviceKeySignResponse;
use codex_app_server_protocol::JSONRPCErrorError;
use codex_device_key::DeviceKeyBinding;
use codex_device_key::DeviceKeyBindingStore;
use codex_device_key::DeviceKeyCreateRequest;
use codex_device_key::DeviceKeyError;
use codex_device_key::DeviceKeyGetPublicRequest;
Expand All @@ -24,14 +26,29 @@ use codex_device_key::RemoteControlClientConnectionAudience;
use codex_device_key::RemoteControlClientConnectionSignPayload;
use codex_device_key::RemoteControlClientEnrollmentAudience;
use codex_device_key::RemoteControlClientEnrollmentSignPayload;
use codex_state::DeviceKeyBindingRecord;
use codex_state::StateRuntime;
use std::fmt;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::OnceCell;

#[derive(Clone, Default)]
#[derive(Clone)]
pub(crate) struct DeviceKeyApi {
store: DeviceKeyStore,
}

impl DeviceKeyApi {
pub(crate) fn create(
pub(crate) fn new(sqlite_home: PathBuf, default_provider: String) -> Self {
Self {
store: DeviceKeyStore::new(Arc::new(StateDeviceKeyBindingStore::new(
sqlite_home,
default_provider,
))),
}
}

pub(crate) async fn create(
&self,
params: DeviceKeyCreateParams,
) -> Result<DeviceKeyCreateResponse, JSONRPCErrorError> {
Expand All @@ -44,11 +61,12 @@ impl DeviceKeyApi {
client_id: params.client_id,
},
})
.await
.map_err(map_device_key_error)?;
Ok(create_response_from_info(info))
}

pub(crate) fn public(
pub(crate) async fn public(
&self,
params: DeviceKeyPublicParams,
) -> Result<DeviceKeyPublicResponse, JSONRPCErrorError> {
Expand All @@ -57,11 +75,12 @@ impl DeviceKeyApi {
.get_public(DeviceKeyGetPublicRequest {
key_id: params.key_id,
})
.await
.map_err(map_device_key_error)?;
Ok(public_response_from_info(info))
}

pub(crate) fn sign(
pub(crate) async fn sign(
&self,
params: DeviceKeySignParams,
) -> Result<DeviceKeySignResponse, JSONRPCErrorError> {
Expand All @@ -71,6 +90,7 @@ impl DeviceKeyApi {
key_id: params.key_id,
payload: payload_from_params(params.payload),
})
.await
.map_err(map_device_key_error)?;
Ok(DeviceKeySignResponse {
signature_der_base64: STANDARD.encode(signature.signature_der),
Expand All @@ -80,6 +100,77 @@ impl DeviceKeyApi {
}
}

struct StateDeviceKeyBindingStore {
sqlite_home: PathBuf,
default_provider: String,
state_db: OnceCell<Arc<StateRuntime>>,
}

impl StateDeviceKeyBindingStore {
fn new(sqlite_home: PathBuf, default_provider: String) -> Self {
Self {
sqlite_home,
default_provider,
state_db: OnceCell::new(),
}
}

async fn state_db(&self) -> Result<Arc<StateRuntime>, DeviceKeyError> {
let sqlite_home = self.sqlite_home.clone();
let default_provider = self.default_provider.clone();
self.state_db
.get_or_try_init(|| async move {
StateRuntime::init(sqlite_home, default_provider)
.await
.map_err(|err| DeviceKeyError::Platform(err.to_string()))
})
.await
.cloned()
}
}

impl fmt::Debug for StateDeviceKeyBindingStore {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StateDeviceKeyBindingStore")
.field("sqlite_home", &self.sqlite_home)
.field("default_provider", &self.default_provider)
.finish_non_exhaustive()
}
}

#[async_trait]
impl DeviceKeyBindingStore for StateDeviceKeyBindingStore {
async fn get_binding(&self, key_id: &str) -> Result<Option<DeviceKeyBinding>, DeviceKeyError> {
let state_db = self.state_db().await?;
state_db
.get_device_key_binding(key_id)
.await
.map(|record| {
record.map(|record| DeviceKeyBinding {
account_user_id: record.account_user_id,
client_id: record.client_id,
})
})
.map_err(|err| DeviceKeyError::Platform(err.to_string()))
}

async fn put_binding(
&self,
key_id: &str,
binding: &DeviceKeyBinding,
) -> Result<(), DeviceKeyError> {
let state_db = self.state_db().await?;
state_db
.upsert_device_key_binding(&DeviceKeyBindingRecord {
key_id: key_id.to_string(),
account_user_id: binding.account_user_id.clone(),
client_id: binding.client_id.clone(),
})
.await
.map_err(|err| DeviceKeyError::Platform(err.to_string()))
}
}

fn create_response_from_info(info: DeviceKeyInfo) -> DeviceKeyCreateResponse {
DeviceKeyCreateResponse {
key_id: info.key_id,
Expand Down
127 changes: 55 additions & 72 deletions codex-rs/app-server/src/message_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ impl MessageProcessor {
thread_manager.clone(),
analytics_events_client.clone(),
);
let device_key_api = DeviceKeyApi::default();
let device_key_api =
DeviceKeyApi::new(config.sqlite_home.clone(), config.model_provider_id.clone());
let external_agent_config_api =
ExternalAgentConfigApi::new(config.codex_home.to_path_buf());
let fs_api = FsApi::new(
Expand Down Expand Up @@ -882,8 +883,7 @@ impl MessageProcessor {
},
params,
device_key_requests_allowed,
)
.await;
);
}
ClientRequest::DeviceKeyPublic { request_id, params } => {
self.handle_device_key_public(
Expand All @@ -893,8 +893,7 @@ impl MessageProcessor {
},
params,
device_key_requests_allowed,
)
.await;
);
}
ClientRequest::DeviceKeySign { request_id, params } => {
self.handle_device_key_sign(
Expand All @@ -904,8 +903,7 @@ impl MessageProcessor {
},
params,
device_key_requests_allowed,
)
.await;
);
}
ClientRequest::FsReadFile { request_id, params } => {
self.handle_fs_read_file(
Expand Down Expand Up @@ -1173,96 +1171,81 @@ impl MessageProcessor {
}
}

async fn handle_device_key_create(
fn handle_device_key_create(
&self,
request_id: ConnectionRequestId,
params: DeviceKeyCreateParams,
device_key_requests_allowed: bool,
) {
if self
.reject_device_key_request_over_remote_transport(
request_id.clone(),
"device/key/create",
device_key_requests_allowed,
)
.await
{
return;
}

match self.device_key_api.create(params) {
Ok(response) => self.outgoing.send_response(request_id, response).await,
Err(error) => self.outgoing.send_error(request_id, error).await,
}
self.spawn_device_key_request(
request_id,
"device/key/create",
device_key_requests_allowed,
move |device_key_api| async move { device_key_api.create(params).await },
);
}

async fn handle_device_key_public(
fn handle_device_key_public(
&self,
request_id: ConnectionRequestId,
params: DeviceKeyPublicParams,
device_key_requests_allowed: bool,
) {
if self
.reject_device_key_request_over_remote_transport(
request_id.clone(),
"device/key/public",
device_key_requests_allowed,
)
.await
{
return;
}

match self.device_key_api.public(params) {
Ok(response) => self.outgoing.send_response(request_id, response).await,
Err(error) => self.outgoing.send_error(request_id, error).await,
}
self.spawn_device_key_request(
request_id,
"device/key/public",
device_key_requests_allowed,
move |device_key_api| async move { device_key_api.public(params).await },
);
}

async fn handle_device_key_sign(
fn handle_device_key_sign(
&self,
request_id: ConnectionRequestId,
params: DeviceKeySignParams,
device_key_requests_allowed: bool,
) {
if self
.reject_device_key_request_over_remote_transport(
request_id.clone(),
"device/key/sign",
device_key_requests_allowed,
)
.await
{
return;
}

match self.device_key_api.sign(params) {
Ok(response) => self.outgoing.send_response(request_id, response).await,
Err(error) => self.outgoing.send_error(request_id, error).await,
}
self.spawn_device_key_request(
request_id,
"device/key/sign",
device_key_requests_allowed,
move |device_key_api| async move { device_key_api.sign(params).await },
);
}

async fn reject_device_key_request_over_remote_transport(
fn spawn_device_key_request<R, F, Fut>(
&self,
request_id: ConnectionRequestId,
method: &str,
method: &'static str,
device_key_requests_allowed: bool,
) -> bool {
if device_key_requests_allowed {
return false;
}
run_request: F,
) where
R: serde::Serialize + Send + 'static,
F: FnOnce(DeviceKeyApi) -> Fut + Send + 'static,
Fut: Future<Output = Result<R, JSONRPCErrorError>> + Send + 'static,
{
let device_key_api = self.device_key_api.clone();
let outgoing = Arc::clone(&self.outgoing);
tokio::spawn(async move {
if !device_key_requests_allowed {
outgoing
.send_error(
request_id,
JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: format!("{method} is not available over remote transports"),
data: None,
},
)
.await;
return;
}

self.outgoing
.send_error(
request_id,
JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: format!("{method} is not available over remote transports"),
data: None,
},
)
.await;
true
match run_request(device_key_api).await {
Ok(response) => outgoing.send_response(request_id, response).await,
Err(error) => outgoing.send_error(request_id, error).await,
}
});
}

async fn handle_external_agent_config_detect(
Expand Down
2 changes: 2 additions & 0 deletions codex-rs/device-key/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ license.workspace = true
workspace = true

[dependencies]
async-trait = { workspace = true }
base64 = { workspace = true }
p256 = { workspace = true, features = ["ecdsa", "pkcs8"] }
rand = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["rt"] }
url = { workspace = true }

[dev-dependencies]
Expand Down
Loading
Loading