Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions crates/rmcp/src/handler/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ impl<H: ServerHandler> Service<RoleServer> for H {
ClientNotification::RootsListChangedNotification(_notification) => {
self.on_roots_list_changed(context).await
}
ClientNotification::CustomClientNotification(notification) => {
self.on_custom_notification(notification, context).await
}
};
Ok(())
}
Expand Down Expand Up @@ -224,6 +227,14 @@ pub trait ServerHandler: Sized + Send + Sync + 'static {
) -> impl Future<Output = ()> + Send + '_ {
std::future::ready(())
}
fn on_custom_notification(
&self,
notification: CustomClientNotification,
context: NotificationContext<RoleServer>,
) -> impl Future<Output = ()> + Send + '_ {
let _ = (notification, context);
std::future::ready(())
}

fn get_info(&self) -> ServerInfo {
ServerInfo::default()
Expand Down
69 changes: 68 additions & 1 deletion crates/rmcp/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,40 @@ const_string!(CancelledNotificationMethod = "notifications/cancelled");
pub type CancelledNotification =
Notification<CancelledNotificationMethod, CancelledNotificationParam>;

/// A catch-all notification the client can use to send custom messages to a server.
///
/// This preserves the raw `method` name and `params` payload so handlers can
/// deserialize them into domain-specific types.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct CustomClientNotification {
pub method: String,
pub params: Option<Value>,
/// extensions will carry anything possible in the context, including [`Meta`]
///
/// this is similar with the Extensions in `http` crate
#[cfg_attr(feature = "schemars", schemars(skip))]
pub extensions: Extensions,
}

impl CustomClientNotification {
pub fn new(method: impl Into<String>, params: Option<Value>) -> Self {
Self {
method: method.into(),
params,
extensions: Extensions::default(),
}
}

/// Deserialize `params` into a strongly-typed structure.
pub fn params_as<T: DeserializeOwned>(&self) -> Result<Option<T>, serde_json::Error> {
self.params
.as_ref()
.map(|params| serde_json::from_value(params.clone()))
.transpose()
}
}

const_string!(InitializeResultMethod = "initialize");
/// # Initialization
/// This request is sent from the client to the server when it first connects, asking it to begin initialization.
Expand Down Expand Up @@ -1748,7 +1782,8 @@ ts_union!(
| CancelledNotification
| ProgressNotification
| InitializedNotification
| RootsListChangedNotification;
| RootsListChangedNotification
| CustomClientNotification;
);

ts_union!(
Expand Down Expand Up @@ -1857,6 +1892,38 @@ mod tests {
assert_eq!(json, raw);
}

#[test]
fn test_custom_client_notification_roundtrip() {
let raw = json!( {
"jsonrpc": JsonRpcVersion2_0,
"method": "notifications/custom",
"params": {"foo": "bar"},
});

let message: ClientJsonRpcMessage =
serde_json::from_value(raw.clone()).expect("invalid notification");
match &message {
ClientJsonRpcMessage::Notification(JsonRpcNotification {
notification: ClientNotification::CustomClientNotification(notification),
..
}) => {
assert_eq!(notification.method, "notifications/custom");
assert_eq!(
notification
.params
.as_ref()
.and_then(|p| p.get("foo"))
.expect("foo present"),
"bar"
);
}
_ => panic!("Expected custom client notification"),
}

let json = serde_json::to_value(message).expect("valid json");
assert_eq!(json, raw);
}

#[test]
fn test_request_conversion() {
let raw = json!( {
Expand Down
25 changes: 23 additions & 2 deletions crates/rmcp/src/model/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;

use super::{
ClientNotification, ClientRequest, Extensions, JsonObject, JsonRpcMessage, NumberOrString,
ProgressToken, ServerNotification, ServerRequest,
ClientNotification, ClientRequest, CustomClientNotification, Extensions, JsonObject,
JsonRpcMessage, NumberOrString, ProgressToken, ServerNotification, ServerRequest,
};

pub trait GetMeta {
Expand All @@ -18,6 +18,26 @@ pub trait GetExtensions {
fn extensions_mut(&mut self) -> &mut Extensions;
}

impl GetExtensions for CustomClientNotification {
fn extensions(&self) -> &Extensions {
&self.extensions
}
fn extensions_mut(&mut self) -> &mut Extensions {
&mut self.extensions
}
}

impl GetMeta for CustomClientNotification {
fn get_meta_mut(&mut self) -> &mut Meta {
self.extensions_mut().get_or_insert_default()
}
fn get_meta(&self) -> &Meta {
self.extensions()
.get::<Meta>()
.unwrap_or(Meta::static_empty())
}
}

macro_rules! variant_extension {
(
$Enum: ident {
Expand Down Expand Up @@ -84,6 +104,7 @@ variant_extension! {
ProgressNotification
InitializedNotification
RootsListChangedNotification
CustomClientNotification
}
}

Expand Down
57 changes: 55 additions & 2 deletions crates/rmcp/src/model/serde_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::borrow::Cow;
use serde::{Deserialize, Serialize};

use super::{
Extensions, Meta, Notification, NotificationNoParam, Request, RequestNoParam,
RequestOptionalParam,
CustomClientNotification, Extensions, Meta, Notification, NotificationNoParam, Request,
RequestNoParam, RequestOptionalParam,
};
#[derive(Serialize, Deserialize)]
struct WithMeta<'a, P> {
Expand Down Expand Up @@ -249,6 +249,59 @@ where
}
}

impl Serialize for CustomClientNotification {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let extensions = &self.extensions;
let _meta = extensions.get::<Meta>().map(Cow::Borrowed);
let params = self.params.as_ref();

let params = if _meta.is_some() || params.is_some() {
Some(WithMeta {
_meta,
_rest: &self.params,
})
} else {
None
};

ProxyOptionalParam::serialize(
&ProxyOptionalParam {
method: &self.method,
params,
},
serializer,
)
}
}

impl<'de> Deserialize<'de> for CustomClientNotification {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let body =
ProxyOptionalParam::<'_, _, Option<serde_json::Value>>::deserialize(deserializer)?;
let mut params = None;
let mut _meta = None;
if let Some(body_params) = body.params {
params = body_params._rest;
_meta = body_params._meta.map(|m| m.into_owned());
}
let mut extensions = Extensions::new();
if let Some(meta) = _meta {
extensions.insert(meta);
}
Ok(CustomClientNotification {
extensions,
method: body.method,
params,
})
}
}

#[cfg(test)]
mod test {
use serde_json::json;
Expand Down
74 changes: 72 additions & 2 deletions crates/rmcp/tests/test_notification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ use std::sync::Arc;
use rmcp::{
ClientHandler, ServerHandler, ServiceExt,
model::{
ResourceUpdatedNotificationParam, ServerCapabilities, ServerInfo, SubscribeRequestParam,
ClientNotification, CustomClientNotification, ResourceUpdatedNotificationParam,
ServerCapabilities, ServerInfo, SubscribeRequestParam,
},
};
use tokio::sync::Notify;
use serde_json::json;
use tokio::sync::{Mutex, Notify};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

pub struct Server {}
Expand Down Expand Up @@ -93,3 +95,71 @@ async fn test_server_notification() -> anyhow::Result<()> {
client.cancel().await?;
Ok(())
}

struct CustomServer {
receive_signal: Arc<Notify>,
payload: Arc<Mutex<Option<(String, Option<serde_json::Value>)>>>,
}

impl ServerHandler for CustomServer {
async fn on_custom_notification(
&self,
notification: CustomClientNotification,
_context: rmcp::service::NotificationContext<rmcp::RoleServer>,
) {
let CustomClientNotification { method, params, .. } = notification;
let mut payload = self.payload.lock().await;
*payload = Some((method, params));
self.receive_signal.notify_one();
}
}

#[tokio::test]
async fn test_custom_client_notification_reaches_server() -> anyhow::Result<()> {
let _ = tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "debug".to_string().into()),
)
.with(tracing_subscriber::fmt::layer())
.try_init();

let (server_transport, client_transport) = tokio::io::duplex(4096);
let receive_signal = Arc::new(Notify::new());
let payload = Arc::new(Mutex::new(None));

{
let receive_signal = receive_signal.clone();
let payload = payload.clone();
tokio::spawn(async move {
let server = CustomServer {
receive_signal,
payload,
}
.serve(server_transport)
.await?;
server.waiting().await?;
anyhow::Ok(())
});
}

let client = ().serve(client_transport).await?;

client
.send_notification(ClientNotification::CustomClientNotification(
CustomClientNotification::new(
"notifications/custom-test",
Some(json!({ "foo": "bar" })),
),
))
.await?;

tokio::time::timeout(std::time::Duration::from_secs(5), receive_signal.notified()).await?;

let (method, params) = payload.lock().await.clone().expect("payload set");
assert_eq!("notifications/custom-test", method);
assert_eq!(Some(json!({ "foo": "bar" })), params);

client.cancel().await?;
Ok(())
}
Loading