Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/rmcp/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- *(oauth)* support suffixed and preffixed well-knonw paths ([#459](https://github.com/modelcontextprotocol/rust-sdk/pull/459))
- *(oauth)* support suffixed and prefixed well-known paths ([#459](https://github.com/modelcontextprotocol/rust-sdk/pull/459))
- generate default schema for tools with no params ([#446](https://github.com/modelcontextprotocol/rust-sdk/pull/446))

### Other
Expand Down
21 changes: 16 additions & 5 deletions crates/rmcp/src/transport/streamable_http_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ struct StreamableHttpClientReconnect<C> {
pub client: C,
pub session_id: Arc<str>,
pub uri: Arc<str>,
pub auth_header: Option<String>,
}

impl<C: StreamableHttpClient> SseStreamReconnect for StreamableHttpClientReconnect<C> {
Expand All @@ -182,10 +183,11 @@ impl<C: StreamableHttpClient> SseStreamReconnect for StreamableHttpClientReconne
let client = self.client.clone();
let uri = self.uri.clone();
let session_id = self.session_id.clone();
let auth_header = self.auth_header.clone();
let last_event_id = last_event_id.map(|s| s.to_owned());
Box::pin(async move {
client
.get_stream(uri, session_id, last_event_id, None)
.get_stream(uri, session_id, last_event_id, auth_header)
.await
})
}
Expand Down Expand Up @@ -324,10 +326,12 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
let client = self.client.clone();
let session_id = session_id.clone();
let url = config.uri.clone();
let auth_header = config.auth_header.clone();
tokio::spawn(async move {
ct.cancelled().await;
let delete_session_result =
client.delete_session(url, session_id.clone(), None).await;
let delete_session_result = client
.delete_session(url, session_id.clone(), auth_header.clone())
.await;
match delete_session_result {
Ok(_) => {
tracing::info!(session_id = session_id.as_ref(), "delete session success")
Expand Down Expand Up @@ -376,7 +380,12 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
if let Some(session_id) = &session_id {
match self
.client
.get_stream(config.uri.clone(), session_id.clone(), None, None)
.get_stream(
config.uri.clone(),
session_id.clone(),
None,
config.auth_header.clone(),
)
.await
{
Ok(stream) => {
Expand All @@ -386,6 +395,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
client: self.client.clone(),
session_id: session_id.clone(),
uri: config.uri.clone(),
auth_header: config.auth_header.clone(),
},
self.config.retry_config.clone(),
);
Expand Down Expand Up @@ -468,6 +478,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
client: self.client.clone(),
session_id: session_id.clone(),
uri: config.uri.clone(),
auth_header: config.auth_header.clone(),
},
self.config.retry_config.clone(),
);
Expand Down Expand Up @@ -704,7 +715,7 @@ impl StreamableHttpClientTransportConfig {
///
/// # Arguments
///
/// * `value` - The value to set
/// * `value` - A bearer token without the `Bearer ` prefix
pub fn auth_header<T: Into<String>>(mut self, value: T) -> Self {
// set our authorization header
self.auth_header = Some(value.into());
Expand Down