Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 47 additions & 19 deletions crates/rmcp/src/transport/streamable_http_server/session/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
time::Duration,
};

use futures::Stream;
use futures::{Stream, StreamExt};
use thiserror::Error;
use tokio::sync::{
mpsc::{Receiver, Sender},
Expand Down Expand Up @@ -86,10 +86,17 @@ impl SessionManager for LocalSessionManager {
.get(id)
.ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
let receiver = handle.establish_request_wise_channel().await?;
handle
.push_message(message, receiver.http_request_id)
.await?;
Ok(ReceiverStream::new(receiver.inner))
let http_request_id = receiver.http_request_id;
handle.push_message(message, http_request_id).await?;

let priming = self.session_config.sse_retry.map(|retry| {
let event_id = match http_request_id {
Some(id) => format!("0/{id}"),
None => "0".into(),
};
ServerSseMessage::priming(event_id, retry)
});
Ok(futures::stream::iter(priming).chain(ReceiverStream::new(receiver.inner)))
}

async fn create_standalone_stream(
Expand Down Expand Up @@ -188,23 +195,29 @@ struct CachedTx {
cache: VecDeque<ServerSseMessage>,
http_request_id: Option<HttpRequestId>,
capacity: usize,
starting_index: usize,
}

impl CachedTx {
fn new(tx: Sender<ServerSseMessage>, http_request_id: Option<HttpRequestId>) -> Self {
fn new(
tx: Sender<ServerSseMessage>,
http_request_id: Option<HttpRequestId>,
starting_index: usize,
) -> Self {
Self {
cache: VecDeque::with_capacity(tx.capacity()),
capacity: tx.capacity(),
tx,
http_request_id,
starting_index,
}
}
fn new_common(tx: Sender<ServerSseMessage>) -> Self {
Self::new(tx, None)
Self::new(tx, None, 0)
}

fn next_event_id(&self) -> EventId {
let index = self.cache.back().map_or(0, |m| {
let index = self.cache.back().map_or(self.starting_index, |m| {
m.event_id
.as_deref()
.unwrap_or_default()
Expand Down Expand Up @@ -350,10 +363,15 @@ impl LocalSessionWorker {
if channel.resources.is_empty() || matches!(resource, ResourceKey::McpRequestId(_))
{
tracing::debug!(http_request_id, "close http request wise channel");
if let Some(channel) = self.tx_router.remove(&http_request_id) {
for resource in channel.resources {
if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
for resource in channel.resources.drain() {
self.resource_router.remove(&resource);
}
// Replace the sender with a closed dummy so no new
// messages are routed here, but the cache stays alive
// for late resume requests.
let (closed_tx, _) = tokio::sync::mpsc::channel(1);
channel.tx.tx = closed_tx;
}
}
} else {
Expand Down Expand Up @@ -403,13 +421,15 @@ impl LocalSessionWorker {
async fn establish_request_wise_channel(
&mut self,
) -> Result<StreamableHttpMessageReceiver, SessionError> {
self.tx_router.retain(|_, rw| !rw.tx.tx.is_closed());
let http_request_id = self.next_http_request_id();
let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
let starting_index = usize::from(self.session_config.sse_retry.is_some());
self.tx_router.insert(
http_request_id,
HttpRequestWise {
resources: Default::default(),
tx: CachedTx::new(tx, Some(http_request_id)),
tx: CachedTx::new(tx, Some(http_request_id), starting_index),
},
);
tracing::debug!(http_request_id, "establish new request wise channel");
Expand Down Expand Up @@ -521,24 +541,25 @@ impl LocalSessionWorker {
match last_event_id.http_request_id {
Some(http_request_id) => {
if let Some(request_wise) = self.tx_router.get_mut(&http_request_id) {
// Resume existing request-wise channel
let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
let (tx, rx) = channel;
let was_completed = request_wise.tx.tx.is_closed();
let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
request_wise.tx.tx = tx;
let index = last_event_id.index;
// sync messages after index
request_wise.tx.sync(index).await?;
if was_completed {
// Close the sender after replaying so the stream ends
// instead of hanging indefinitely.
let (closed_tx, _) = tokio::sync::mpsc::channel(1);
request_wise.tx.tx = closed_tx;
}
Ok(StreamableHttpMessageReceiver {
http_request_id: Some(http_request_id),
inner: rx,
})
} else {
// Request-wise channel completed (POST response already delivered).
// The client's EventSource is reconnecting after the POST SSE stream
// ended. Fall through to common channel handling below.
tracing::debug!(
http_request_id,
"Request-wise channel completed, falling back to common channel"
"Request-wise channel not found, falling back to common channel"
);
self.resume_or_shadow_common(last_event_id.index).await
}
Expand Down Expand Up @@ -1072,18 +1093,25 @@ pub struct SessionConfig {
/// Defaults to 5 minutes. Set to `None` to disable (not recommended
/// for long-running servers behind proxies).
pub keep_alive: Option<Duration>,
/// SSE retry interval for priming events on request-wise streams.
/// When set, the session layer prepends a priming event with the correct
/// stream-identifying event ID to each request-wise SSE stream.
/// Default is 3 seconds, matching `StreamableHttpServerConfig::default()`.
pub sse_retry: Option<Duration>,
}

impl SessionConfig {
pub const DEFAULT_CHANNEL_CAPACITY: usize = 16;
pub const DEFAULT_KEEP_ALIVE: Duration = Duration::from_secs(300);
pub const DEFAULT_SSE_RETRY: Duration = Duration::from_secs(3);
}

impl Default for SessionConfig {
fn default() -> Self {
Self {
channel_capacity: Self::DEFAULT_CHANNEL_CAPACITY,
keep_alive: Some(Self::DEFAULT_KEEP_ALIVE),
sse_retry: Some(Self::DEFAULT_SSE_RETRY),
}
}
}
Expand Down
12 changes: 3 additions & 9 deletions crates/rmcp/src/transport/streamable_http_server/tower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,20 +598,14 @@ where

match message {
ClientJsonRpcMessage::Request(_) => {
// Priming for request-wise streams is handled by the
// session layer (SessionManager::create_stream) which
// has access to the http_request_id for correct event IDs.
let stream = self
.session_manager
.create_stream(&session_id, message)
.await
.map_err(internal_error_response("get session"))?;
// Prepend priming event if sse_retry configured
let stream = if let Some(retry) = self.config.sse_retry {
let priming = ServerSseMessage::priming("0", retry);
futures::stream::once(async move { priming })
.chain(stream)
.left_stream()
} else {
stream.right_stream()
};
Ok(sse_stream_response(
stream,
self.config.sse_keep_alive,
Expand Down
Loading
Loading