Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/rmcp/src/transport/common/http_header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ pub(crate) fn extract_scope_from_header(header: &str) -> Option<String> {

#[cfg(test)]
mod tests {
#[cfg(feature = "client-side-sse")]
use super::*;

#[cfg(feature = "client-side-sse")]
#[test]
fn extract_scope_quoted() {
let header = r#"Bearer error="insufficient_scope", scope="files:read files:write""#;
Expand All @@ -76,6 +78,7 @@ mod tests {
);
}

#[cfg(feature = "client-side-sse")]
#[test]
fn extract_scope_unquoted() {
let header = r#"Bearer scope=read:data, error="insufficient_scope""#;
Expand All @@ -85,12 +88,14 @@ mod tests {
);
}

#[cfg(feature = "client-side-sse")]
#[test]
fn extract_scope_missing() {
let header = r#"Bearer error="invalid_token""#;
assert_eq!(extract_scope_from_header(header), None);
}

#[cfg(feature = "client-side-sse")]
#[test]
fn extract_scope_empty_header() {
assert_eq!(extract_scope_from_header("Bearer"), None);
Expand Down
79 changes: 78 additions & 1 deletion crates/rmcp/src/transport/common/server_side_http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl sse_stream::Timer for TokioTimer {
}
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub struct ServerSseMessage {
/// The event ID for this message. When set, clients can use this ID
Expand All @@ -71,6 +71,37 @@ pub struct ServerSseMessage {
pub retry: Option<Duration>,
}

impl ServerSseMessage {
/// Create a message carrying a JSON-RPC response/notification with an event ID.
pub fn new(event_id: impl Into<String>, message: ServerJsonRpcMessage) -> Self {
Self {
event_id: Some(event_id.into()),
message: Some(Arc::new(message)),
retry: None,
}
}

/// Wrap a JSON-RPC message without an event ID or retry hint.
pub fn from_message(message: ServerJsonRpcMessage) -> Self {
Self {
event_id: None,
message: Some(Arc::new(message)),
retry: None,
}
}

/// Create a priming event that tells the client to reconnect after `retry`
/// if the connection drops.
/// See [SEP-1699](https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1699).
pub fn priming(event_id: impl Into<String>, retry: Duration) -> Self {
Self {
event_id: Some(event_id.into()),
message: None,
retry: Some(retry),
}
}
}

pub(crate) fn sse_stream_response(
stream: impl futures::Stream<Item = ServerSseMessage> + Send + Sync + 'static,
keep_alive: Option<Duration>,
Expand Down Expand Up @@ -169,3 +200,49 @@ where
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::model::{EmptyResult, JsonRpcResponse, JsonRpcVersion2_0, RequestId, ServerResult};

fn dummy_message() -> ServerJsonRpcMessage {
ServerJsonRpcMessage::Response(JsonRpcResponse {
jsonrpc: JsonRpcVersion2_0,
id: RequestId::Number(1),
result: ServerResult::EmptyResult(EmptyResult {}),
})
}

#[test]
fn default_has_all_none() {
let msg = ServerSseMessage::default();
assert!(msg.event_id.is_none());
assert!(msg.message.is_none());
assert!(msg.retry.is_none());
}

#[test]
fn new_sets_event_id_and_message() {
let msg = ServerSseMessage::new("42", dummy_message());
assert_eq!(msg.event_id.as_deref(), Some("42"));
assert!(msg.message.is_some());
assert!(msg.retry.is_none());
}

#[test]
fn from_message_has_no_event_id() {
let msg = ServerSseMessage::from_message(dummy_message());
assert!(msg.event_id.is_none());
assert!(msg.message.is_some());
assert!(msg.retry.is_none());
}

#[test]
fn priming_sets_event_id_and_retry() {
let msg = ServerSseMessage::priming("0", Duration::from_secs(5));
assert_eq!(msg.event_id.as_deref(), Some("0"));
assert!(msg.message.is_none());
assert_eq!(msg.retry, Some(Duration::from_secs(5)));
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::{
collections::{HashMap, HashSet, VecDeque},
num::ParseIntError,
sync::Arc,
time::Duration,
};

Expand Down Expand Up @@ -222,21 +221,13 @@ impl CachedTx {

async fn send(&mut self, message: ServerJsonRpcMessage) {
let event_id = self.next_event_id();
let message = ServerSseMessage {
event_id: Some(event_id.to_string()),
message: Some(Arc::new(message)),
retry: None,
};
let message = ServerSseMessage::new(event_id.to_string(), message);
self.cache_and_send(message).await;
}

async fn send_priming(&mut self, retry: Duration) {
let event_id = self.next_event_id();
let message = ServerSseMessage {
event_id: Some(event_id.to_string()),
message: None,
retry: Some(retry),
};
let message = ServerSseMessage::priming(event_id.to_string(), retry);
self.cache_and_send(message).await;
}

Expand Down
33 changes: 6 additions & 27 deletions crates/rmcp/src/transport/streamable_http_server/tower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -499,11 +499,7 @@ where
.map_err(internal_error_response("create standalone stream"))?;
// Prepend priming event if sse_retry configured
let stream = if let Some(retry) = self.config.sse_retry {
let priming = ServerSseMessage {
event_id: Some("0".into()),
message: None,
retry: Some(retry),
};
let priming = ServerSseMessage::priming("0", retry);
futures::stream::once(async move { priming })
.chain(stream)
.left_stream()
Expand Down Expand Up @@ -609,11 +605,7 @@ where
.map_err(internal_error_response("get session"))?;
// Prepend priming event if sse_retry configured
let stream = if let Some(retry) = self.config.sse_retry {
let priming = ServerSseMessage {
event_id: Some("0".into()),
message: None,
retry: Some(retry),
};
let priming = ServerSseMessage::priming("0", retry);
futures::stream::once(async move { priming })
.chain(stream)
.left_stream()
Expand Down Expand Up @@ -687,20 +679,11 @@ where
.initialize_session(&session_id, message)
.await
.map_err(internal_error_response("create stream"))?;
let stream = futures::stream::once(async move {
ServerSseMessage {
event_id: None,
message: Some(Arc::new(response)),
retry: None,
}
});
let stream =
futures::stream::once(async move { ServerSseMessage::from_message(response) });
// Prepend priming event if sse_retry configured
let stream = if let Some(retry) = self.config.sse_retry {
let priming = ServerSseMessage {
event_id: Some("0".into()),
message: None,
retry: Some(retry),
};
let priming = ServerSseMessage::priming("0", retry);
futures::stream::once(async move { priming })
.chain(stream)
.left_stream()
Expand Down Expand Up @@ -774,11 +757,7 @@ where
// SSE mode (default): original behaviour preserved unchanged
let stream = ReceiverStream::new(receiver).map(|message| {
tracing::trace!(?message);
ServerSseMessage {
event_id: None,
message: Some(Arc::new(message)),
retry: None,
}
ServerSseMessage::from_message(message)
});
Ok(sse_stream_response(
stream,
Expand Down
2 changes: 1 addition & 1 deletion crates/rmcp/tests/test_inflight_response_drain.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#![cfg(not(feature = "local"))]
#![cfg(all(feature = "client", feature = "server", not(feature = "local")))]
// cargo test --test test_inflight_response_drain --features "client server"

use std::{
Expand Down
Loading