From 4d9cc16f4c76c84486344f542ed9a3e9364019ba Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Fri, 21 Nov 2025 10:47:32 -0800 Subject: [PATCH] feat: add support for custom client notifications MCP servers, particularly ones that offer "experimental" capabilities, may wish to handle custom client notifications that are not part of the standard MCP specification. This change introduces a new `CustomClientNotification` type that allows a server to process such custom notifications. - introduces `CustomClientNotification` to carry arbitrary methods/params while still preserving meta/extensions; wires it into the `ClientNotification` union and `serde` so `params` can be decoded with `params_as` - allows server handlers to receive custom notifications via a new `on_custom_notification` hook - adds integration coverage that sends a custom client notification end-to-end and asserts the server sees the method and payload Test: ```shell cargo test -p rmcp --features client test_custom_client_notification_reaches_server ``` --- crates/rmcp/src/handler/server.rs | 11 ++++ crates/rmcp/src/model.rs | 69 +++++++++++++++++++++++- crates/rmcp/src/model/meta.rs | 25 ++++++++- crates/rmcp/src/model/serde_impl.rs | 57 +++++++++++++++++++- crates/rmcp/tests/test_notification.rs | 74 +++++++++++++++++++++++++- 5 files changed, 229 insertions(+), 7 deletions(-) diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index 4f9edbc0..237677a5 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -89,6 +89,9 @@ impl Service for H { ClientNotification::RootsListChangedNotification(_notification) => { self.on_roots_list_changed(context).await } + ClientNotification::CustomClientNotification(notification) => { + self.on_custom_notification(notification, context).await + } }; Ok(()) } @@ -224,6 +227,14 @@ pub trait ServerHandler: Sized + Send + Sync + 'static { ) -> impl Future + Send + '_ { std::future::ready(()) } + fn on_custom_notification( + &self, + notification: CustomClientNotification, + context: NotificationContext, + ) -> impl Future + Send + '_ { + let _ = (notification, context); + std::future::ready(()) + } fn get_info(&self) -> ServerInfo { ServerInfo::default() diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index fb757c09..c27bc7f1 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -627,6 +627,40 @@ const_string!(CancelledNotificationMethod = "notifications/cancelled"); pub type CancelledNotification = Notification; +/// A catch-all notification the client can use to send custom messages to a server. +/// +/// This preserves the raw `method` name and `params` payload so handlers can +/// deserialize them into domain-specific types. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct CustomClientNotification { + pub method: String, + pub params: Option, + /// extensions will carry anything possible in the context, including [`Meta`] + /// + /// this is similar with the Extensions in `http` crate + #[cfg_attr(feature = "schemars", schemars(skip))] + pub extensions: Extensions, +} + +impl CustomClientNotification { + pub fn new(method: impl Into, params: Option) -> Self { + Self { + method: method.into(), + params, + extensions: Extensions::default(), + } + } + + /// Deserialize `params` into a strongly-typed structure. + pub fn params_as(&self) -> Result, serde_json::Error> { + self.params + .as_ref() + .map(|params| serde_json::from_value(params.clone())) + .transpose() + } +} + const_string!(InitializeResultMethod = "initialize"); /// # Initialization /// This request is sent from the client to the server when it first connects, asking it to begin initialization. @@ -1748,7 +1782,8 @@ ts_union!( | CancelledNotification | ProgressNotification | InitializedNotification - | RootsListChangedNotification; + | RootsListChangedNotification + | CustomClientNotification; ); ts_union!( @@ -1857,6 +1892,38 @@ mod tests { assert_eq!(json, raw); } + #[test] + fn test_custom_client_notification_roundtrip() { + let raw = json!( { + "jsonrpc": JsonRpcVersion2_0, + "method": "notifications/custom", + "params": {"foo": "bar"}, + }); + + let message: ClientJsonRpcMessage = + serde_json::from_value(raw.clone()).expect("invalid notification"); + match &message { + ClientJsonRpcMessage::Notification(JsonRpcNotification { + notification: ClientNotification::CustomClientNotification(notification), + .. + }) => { + assert_eq!(notification.method, "notifications/custom"); + assert_eq!( + notification + .params + .as_ref() + .and_then(|p| p.get("foo")) + .expect("foo present"), + "bar" + ); + } + _ => panic!("Expected custom client notification"), + } + + let json = serde_json::to_value(message).expect("valid json"); + assert_eq!(json, raw); + } + #[test] fn test_request_conversion() { let raw = json!( { diff --git a/crates/rmcp/src/model/meta.rs b/crates/rmcp/src/model/meta.rs index fd93362b..a03fc056 100644 --- a/crates/rmcp/src/model/meta.rs +++ b/crates/rmcp/src/model/meta.rs @@ -4,8 +4,8 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use super::{ - ClientNotification, ClientRequest, Extensions, JsonObject, JsonRpcMessage, NumberOrString, - ProgressToken, ServerNotification, ServerRequest, + ClientNotification, ClientRequest, CustomClientNotification, Extensions, JsonObject, + JsonRpcMessage, NumberOrString, ProgressToken, ServerNotification, ServerRequest, }; pub trait GetMeta { @@ -18,6 +18,26 @@ pub trait GetExtensions { fn extensions_mut(&mut self) -> &mut Extensions; } +impl GetExtensions for CustomClientNotification { + fn extensions(&self) -> &Extensions { + &self.extensions + } + fn extensions_mut(&mut self) -> &mut Extensions { + &mut self.extensions + } +} + +impl GetMeta for CustomClientNotification { + fn get_meta_mut(&mut self) -> &mut Meta { + self.extensions_mut().get_or_insert_default() + } + fn get_meta(&self) -> &Meta { + self.extensions() + .get::() + .unwrap_or(Meta::static_empty()) + } +} + macro_rules! variant_extension { ( $Enum: ident { @@ -84,6 +104,7 @@ variant_extension! { ProgressNotification InitializedNotification RootsListChangedNotification + CustomClientNotification } } diff --git a/crates/rmcp/src/model/serde_impl.rs b/crates/rmcp/src/model/serde_impl.rs index 09222d52..65e14361 100644 --- a/crates/rmcp/src/model/serde_impl.rs +++ b/crates/rmcp/src/model/serde_impl.rs @@ -3,8 +3,8 @@ use std::borrow::Cow; use serde::{Deserialize, Serialize}; use super::{ - Extensions, Meta, Notification, NotificationNoParam, Request, RequestNoParam, - RequestOptionalParam, + CustomClientNotification, Extensions, Meta, Notification, NotificationNoParam, Request, + RequestNoParam, RequestOptionalParam, }; #[derive(Serialize, Deserialize)] struct WithMeta<'a, P> { @@ -249,6 +249,59 @@ where } } +impl Serialize for CustomClientNotification { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let extensions = &self.extensions; + let _meta = extensions.get::().map(Cow::Borrowed); + let params = self.params.as_ref(); + + let params = if _meta.is_some() || params.is_some() { + Some(WithMeta { + _meta, + _rest: &self.params, + }) + } else { + None + }; + + ProxyOptionalParam::serialize( + &ProxyOptionalParam { + method: &self.method, + params, + }, + serializer, + ) + } +} + +impl<'de> Deserialize<'de> for CustomClientNotification { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let body = + ProxyOptionalParam::<'_, _, Option>::deserialize(deserializer)?; + let mut params = None; + let mut _meta = None; + if let Some(body_params) = body.params { + params = body_params._rest; + _meta = body_params._meta.map(|m| m.into_owned()); + } + let mut extensions = Extensions::new(); + if let Some(meta) = _meta { + extensions.insert(meta); + } + Ok(CustomClientNotification { + extensions, + method: body.method, + params, + }) + } +} + #[cfg(test)] mod test { use serde_json::json; diff --git a/crates/rmcp/tests/test_notification.rs b/crates/rmcp/tests/test_notification.rs index a46ac2fd..3fb5b60b 100644 --- a/crates/rmcp/tests/test_notification.rs +++ b/crates/rmcp/tests/test_notification.rs @@ -3,10 +3,12 @@ use std::sync::Arc; use rmcp::{ ClientHandler, ServerHandler, ServiceExt, model::{ - ResourceUpdatedNotificationParam, ServerCapabilities, ServerInfo, SubscribeRequestParam, + ClientNotification, CustomClientNotification, ResourceUpdatedNotificationParam, + ServerCapabilities, ServerInfo, SubscribeRequestParam, }, }; -use tokio::sync::Notify; +use serde_json::json; +use tokio::sync::{Mutex, Notify}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; pub struct Server {} @@ -93,3 +95,71 @@ async fn test_server_notification() -> anyhow::Result<()> { client.cancel().await?; Ok(()) } + +struct CustomServer { + receive_signal: Arc, + payload: Arc)>>>, +} + +impl ServerHandler for CustomServer { + async fn on_custom_notification( + &self, + notification: CustomClientNotification, + _context: rmcp::service::NotificationContext, + ) { + let CustomClientNotification { method, params, .. } = notification; + let mut payload = self.payload.lock().await; + *payload = Some((method, params)); + self.receive_signal.notify_one(); + } +} + +#[tokio::test] +async fn test_custom_client_notification_reaches_server() -> anyhow::Result<()> { + let _ = tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .try_init(); + + let (server_transport, client_transport) = tokio::io::duplex(4096); + let receive_signal = Arc::new(Notify::new()); + let payload = Arc::new(Mutex::new(None)); + + { + let receive_signal = receive_signal.clone(); + let payload = payload.clone(); + tokio::spawn(async move { + let server = CustomServer { + receive_signal, + payload, + } + .serve(server_transport) + .await?; + server.waiting().await?; + anyhow::Ok(()) + }); + } + + let client = ().serve(client_transport).await?; + + client + .send_notification(ClientNotification::CustomClientNotification( + CustomClientNotification::new( + "notifications/custom-test", + Some(json!({ "foo": "bar" })), + ), + )) + .await?; + + tokio::time::timeout(std::time::Duration::from_secs(5), receive_signal.notified()).await?; + + let (method, params) = payload.lock().await.clone().expect("payload set"); + assert_eq!("notifications/custom-test", method); + assert_eq!(Some(json!({ "foo": "bar" })), params); + + client.cancel().await?; + Ok(()) +}