diff --git a/crates/rmcp/src/transport/common/server_side_http.rs b/crates/rmcp/src/transport/common/server_side_http.rs index 693d8b34..51cd51f5 100644 --- a/crates/rmcp/src/transport/common/server_side_http.rs +++ b/crates/rmcp/src/transport/common/server_side_http.rs @@ -6,6 +6,7 @@ use http::Response; use http_body::Body; use http_body_util::{BodyExt, Empty, Full, combinators::BoxBody}; use sse_stream::{KeepAlive, Sse, SseBody}; +use tokio_util::sync::CancellationToken; use super::http_header::EVENT_STREAM_MIME_TYPE; use crate::model::{ClientJsonRpcMessage, ServerJsonRpcMessage}; @@ -65,20 +66,26 @@ pub struct ServerSseMessage { pub(crate) fn sse_stream_response( stream: impl futures::Stream + Send + Sync + 'static, keep_alive: Option, + ct: CancellationToken, ) -> Response> { use futures::StreamExt; - let stream = SseBody::new(stream.map(|message| { - let data = serde_json::to_string(&message.message).expect("valid message"); - let mut sse = Sse::default().data(data); - sse.id = message.event_id; - Result::::Ok(sse) - })); + let stream = stream + .map(|message| { + let data = serde_json::to_string(&message.message).expect("valid message"); + let mut sse = Sse::default().data(data); + sse.id = message.event_id; + Result::::Ok(sse) + }) + .take_until(async move { ct.cancelled().await }); + let stream = SseBody::new(stream); + let stream = match keep_alive { Some(duration) => stream .with_keep_alive::(KeepAlive::new().interval(duration)) .boxed(), None => stream.boxed(), }; + Response::builder() .status(http::StatusCode::OK) .header(http::header::CONTENT_TYPE, EVENT_STREAM_MIME_TYPE) diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index ba373d48..ff5406f9 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -6,6 +6,7 @@ use http::{Method, Request, Response, header::ALLOW}; use http_body::Body; use http_body_util::{BodyExt, Full, combinators::BoxBody}; use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::CancellationToken; use super::session::SessionManager; use crate::{ @@ -33,6 +34,7 @@ pub struct StreamableHttpServerConfig { pub sse_keep_alive: Option, /// If true, the server will create a session for each request and keep it alive. pub stateful_mode: bool, + pub cancellation_token: CancellationToken, } impl Default for StreamableHttpServerConfig { @@ -40,6 +42,7 @@ impl Default for StreamableHttpServerConfig { Self { sse_keep_alive: Some(Duration::from_secs(15)), stateful_mode: true, + cancellation_token: CancellationToken::new(), } } } @@ -209,7 +212,11 @@ where .resume(&session_id, last_event_id) .await .map_err(internal_error_response("resume session"))?; - Ok(sse_stream_response(stream, self.config.sse_keep_alive)) + Ok(sse_stream_response( + stream, + self.config.sse_keep_alive, + self.config.cancellation_token.child_token(), + )) } else { // create standalone stream let stream = self @@ -217,7 +224,11 @@ where .create_standalone_stream(&session_id) .await .map_err(internal_error_response("create standalone stream"))?; - Ok(sse_stream_response(stream, self.config.sse_keep_alive)) + Ok(sse_stream_response( + stream, + self.config.sse_keep_alive, + self.config.cancellation_token.child_token(), + )) } } @@ -307,7 +318,11 @@ where .create_stream(&session_id, message) .await .map_err(internal_error_response("get session"))?; - Ok(sse_stream_response(stream, self.config.sse_keep_alive)) + Ok(sse_stream_response( + stream, + self.config.sse_keep_alive, + self.config.cancellation_token.child_token(), + )) } ClientJsonRpcMessage::Notification(_) | ClientJsonRpcMessage::Response(_) @@ -380,6 +395,7 @@ where } }), self.config.sse_keep_alive, + self.config.cancellation_token.child_token(), ); response.headers_mut().insert( @@ -413,6 +429,7 @@ where } }), self.config.sse_keep_alive, + self.config.cancellation_token.child_token(), )) } ClientJsonRpcMessage::Notification(_notification) => { diff --git a/examples/servers/src/counter_streamhttp.rs b/examples/servers/src/counter_streamhttp.rs index ff00cec6..5f9effe4 100644 --- a/examples/servers/src/counter_streamhttp.rs +++ b/examples/servers/src/counter_streamhttp.rs @@ -1,5 +1,6 @@ -use rmcp::transport::streamable_http_server::{ - StreamableHttpService, session::local::LocalSessionManager, +use rmcp::transport::{ + StreamableHttpServerConfig, + streamable_http_server::{StreamableHttpService, session::local::LocalSessionManager}, }; use tracing_subscriber::{ layer::SubscriberExt, @@ -20,17 +21,24 @@ async fn main() -> anyhow::Result<()> { ) .with(tracing_subscriber::fmt::layer()) .init(); + let ct = tokio_util::sync::CancellationToken::new(); let service = StreamableHttpService::new( || Ok(Counter::new()), LocalSessionManager::default().into(), - Default::default(), + StreamableHttpServerConfig { + cancellation_token: ct.child_token(), + ..Default::default() + }, ); let router = axum::Router::new().nest_service("/mcp", service); let tcp_listener = tokio::net::TcpListener::bind(BIND_ADDRESS).await?; let _ = axum::serve(tcp_listener, router) - .with_graceful_shutdown(async { tokio::signal::ctrl_c().await.unwrap() }) + .with_graceful_shutdown(async move { + tokio::signal::ctrl_c().await.unwrap(); + ct.cancel(); + }) .await; Ok(()) }