diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index 4f9edbc0..237677a5 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -89,6 +89,9 @@ impl Service for H { ClientNotification::RootsListChangedNotification(_notification) => { self.on_roots_list_changed(context).await } + ClientNotification::CustomClientNotification(notification) => { + self.on_custom_notification(notification, context).await + } }; Ok(()) } @@ -224,6 +227,14 @@ pub trait ServerHandler: Sized + Send + Sync + 'static { ) -> impl Future + Send + '_ { std::future::ready(()) } + fn on_custom_notification( + &self, + notification: CustomClientNotification, + context: NotificationContext, + ) -> impl Future + Send + '_ { + let _ = (notification, context); + std::future::ready(()) + } fn get_info(&self) -> ServerInfo { ServerInfo::default() diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index fb757c09..c27bc7f1 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -627,6 +627,40 @@ const_string!(CancelledNotificationMethod = "notifications/cancelled"); pub type CancelledNotification = Notification; +/// A catch-all notification the client can use to send custom messages to a server. +/// +/// This preserves the raw `method` name and `params` payload so handlers can +/// deserialize them into domain-specific types. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct CustomClientNotification { + pub method: String, + pub params: Option, + /// extensions will carry anything possible in the context, including [`Meta`] + /// + /// this is similar with the Extensions in `http` crate + #[cfg_attr(feature = "schemars", schemars(skip))] + pub extensions: Extensions, +} + +impl CustomClientNotification { + pub fn new(method: impl Into, params: Option) -> Self { + Self { + method: method.into(), + params, + extensions: Extensions::default(), + } + } + + /// Deserialize `params` into a strongly-typed structure. + pub fn params_as(&self) -> Result, serde_json::Error> { + self.params + .as_ref() + .map(|params| serde_json::from_value(params.clone())) + .transpose() + } +} + const_string!(InitializeResultMethod = "initialize"); /// # Initialization /// This request is sent from the client to the server when it first connects, asking it to begin initialization. @@ -1748,7 +1782,8 @@ ts_union!( | CancelledNotification | ProgressNotification | InitializedNotification - | RootsListChangedNotification; + | RootsListChangedNotification + | CustomClientNotification; ); ts_union!( @@ -1857,6 +1892,38 @@ mod tests { assert_eq!(json, raw); } + #[test] + fn test_custom_client_notification_roundtrip() { + let raw = json!( { + "jsonrpc": JsonRpcVersion2_0, + "method": "notifications/custom", + "params": {"foo": "bar"}, + }); + + let message: ClientJsonRpcMessage = + serde_json::from_value(raw.clone()).expect("invalid notification"); + match &message { + ClientJsonRpcMessage::Notification(JsonRpcNotification { + notification: ClientNotification::CustomClientNotification(notification), + .. + }) => { + assert_eq!(notification.method, "notifications/custom"); + assert_eq!( + notification + .params + .as_ref() + .and_then(|p| p.get("foo")) + .expect("foo present"), + "bar" + ); + } + _ => panic!("Expected custom client notification"), + } + + let json = serde_json::to_value(message).expect("valid json"); + assert_eq!(json, raw); + } + #[test] fn test_request_conversion() { let raw = json!( { diff --git a/crates/rmcp/src/model/meta.rs b/crates/rmcp/src/model/meta.rs index fd93362b..a03fc056 100644 --- a/crates/rmcp/src/model/meta.rs +++ b/crates/rmcp/src/model/meta.rs @@ -4,8 +4,8 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use super::{ - ClientNotification, ClientRequest, Extensions, JsonObject, JsonRpcMessage, NumberOrString, - ProgressToken, ServerNotification, ServerRequest, + ClientNotification, ClientRequest, CustomClientNotification, Extensions, JsonObject, + JsonRpcMessage, NumberOrString, ProgressToken, ServerNotification, ServerRequest, }; pub trait GetMeta { @@ -18,6 +18,26 @@ pub trait GetExtensions { fn extensions_mut(&mut self) -> &mut Extensions; } +impl GetExtensions for CustomClientNotification { + fn extensions(&self) -> &Extensions { + &self.extensions + } + fn extensions_mut(&mut self) -> &mut Extensions { + &mut self.extensions + } +} + +impl GetMeta for CustomClientNotification { + fn get_meta_mut(&mut self) -> &mut Meta { + self.extensions_mut().get_or_insert_default() + } + fn get_meta(&self) -> &Meta { + self.extensions() + .get::() + .unwrap_or(Meta::static_empty()) + } +} + macro_rules! variant_extension { ( $Enum: ident { @@ -84,6 +104,7 @@ variant_extension! { ProgressNotification InitializedNotification RootsListChangedNotification + CustomClientNotification } } diff --git a/crates/rmcp/src/model/serde_impl.rs b/crates/rmcp/src/model/serde_impl.rs index 09222d52..65e14361 100644 --- a/crates/rmcp/src/model/serde_impl.rs +++ b/crates/rmcp/src/model/serde_impl.rs @@ -3,8 +3,8 @@ use std::borrow::Cow; use serde::{Deserialize, Serialize}; use super::{ - Extensions, Meta, Notification, NotificationNoParam, Request, RequestNoParam, - RequestOptionalParam, + CustomClientNotification, Extensions, Meta, Notification, NotificationNoParam, Request, + RequestNoParam, RequestOptionalParam, }; #[derive(Serialize, Deserialize)] struct WithMeta<'a, P> { @@ -249,6 +249,59 @@ where } } +impl Serialize for CustomClientNotification { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let extensions = &self.extensions; + let _meta = extensions.get::().map(Cow::Borrowed); + let params = self.params.as_ref(); + + let params = if _meta.is_some() || params.is_some() { + Some(WithMeta { + _meta, + _rest: &self.params, + }) + } else { + None + }; + + ProxyOptionalParam::serialize( + &ProxyOptionalParam { + method: &self.method, + params, + }, + serializer, + ) + } +} + +impl<'de> Deserialize<'de> for CustomClientNotification { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let body = + ProxyOptionalParam::<'_, _, Option>::deserialize(deserializer)?; + let mut params = None; + let mut _meta = None; + if let Some(body_params) = body.params { + params = body_params._rest; + _meta = body_params._meta.map(|m| m.into_owned()); + } + let mut extensions = Extensions::new(); + if let Some(meta) = _meta { + extensions.insert(meta); + } + Ok(CustomClientNotification { + extensions, + method: body.method, + params, + }) + } +} + #[cfg(test)] mod test { use serde_json::json; diff --git a/crates/rmcp/tests/test_notification.rs b/crates/rmcp/tests/test_notification.rs index a46ac2fd..3fb5b60b 100644 --- a/crates/rmcp/tests/test_notification.rs +++ b/crates/rmcp/tests/test_notification.rs @@ -3,10 +3,12 @@ use std::sync::Arc; use rmcp::{ ClientHandler, ServerHandler, ServiceExt, model::{ - ResourceUpdatedNotificationParam, ServerCapabilities, ServerInfo, SubscribeRequestParam, + ClientNotification, CustomClientNotification, ResourceUpdatedNotificationParam, + ServerCapabilities, ServerInfo, SubscribeRequestParam, }, }; -use tokio::sync::Notify; +use serde_json::json; +use tokio::sync::{Mutex, Notify}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; pub struct Server {} @@ -93,3 +95,71 @@ async fn test_server_notification() -> anyhow::Result<()> { client.cancel().await?; Ok(()) } + +struct CustomServer { + receive_signal: Arc, + payload: Arc)>>>, +} + +impl ServerHandler for CustomServer { + async fn on_custom_notification( + &self, + notification: CustomClientNotification, + _context: rmcp::service::NotificationContext, + ) { + let CustomClientNotification { method, params, .. } = notification; + let mut payload = self.payload.lock().await; + *payload = Some((method, params)); + self.receive_signal.notify_one(); + } +} + +#[tokio::test] +async fn test_custom_client_notification_reaches_server() -> anyhow::Result<()> { + let _ = tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .try_init(); + + let (server_transport, client_transport) = tokio::io::duplex(4096); + let receive_signal = Arc::new(Notify::new()); + let payload = Arc::new(Mutex::new(None)); + + { + let receive_signal = receive_signal.clone(); + let payload = payload.clone(); + tokio::spawn(async move { + let server = CustomServer { + receive_signal, + payload, + } + .serve(server_transport) + .await?; + server.waiting().await?; + anyhow::Ok(()) + }); + } + + let client = ().serve(client_transport).await?; + + client + .send_notification(ClientNotification::CustomClientNotification( + CustomClientNotification::new( + "notifications/custom-test", + Some(json!({ "foo": "bar" })), + ), + )) + .await?; + + tokio::time::timeout(std::time::Duration::from_secs(5), receive_signal.notified()).await?; + + let (method, params) = payload.lock().await.clone().expect("payload set"); + assert_eq!("notifications/custom-test", method); + assert_eq!(Some(json!({ "foo": "bar" })), params); + + client.cancel().await?; + Ok(()) +}