From 027f7ee7476e4e513d16c7052219a3cc572d945c Mon Sep 17 00:00:00 2001 From: Manuel Naujoks Date: Sat, 4 Oct 2025 14:10:22 +0200 Subject: [PATCH 1/2] Tokens can be cached beyond the lifetime of the (http) transport. --- .../Authentication/ClientOAuthOptions.cs | 6 +++++ .../Authentication/ClientOAuthProvider.cs | 19 +++++++------ .../Authentication/ITokenCache.cs | 17 ++++++++++++ .../Authentication/InMemoryTokenCache.cs | 27 +++++++++++++++++++ .../Authentication/TokenContainer.cs | 4 +-- 5 files changed, 63 insertions(+), 10 deletions(-) create mode 100644 src/ModelContextProtocol.Core/Authentication/ITokenCache.cs create mode 100644 src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs index cc6a8952e..ecb57df0a 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs @@ -86,4 +86,10 @@ public sealed class ClientOAuthOptions /// /// public IDictionary AdditionalAuthorizationParameters { get; set; } = new Dictionary(); + + /// + /// Gets or sets the token cache to use for storing and retrieving tokens beyond the lifetime of the transport. + /// If none is provided, tokens will be cached with the transport. + /// + public ITokenCache? TokenCache { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs index b72f775c4..05194963b 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -40,7 +40,7 @@ internal sealed partial class ClientOAuthProvider private string? _clientId; private string? _clientSecret; - private TokenContainer? _token; + private ITokenCache _tokenCache; private AuthorizationServerMetadata? _authServerMetadata; /// @@ -82,6 +82,7 @@ public ClientOAuthProvider( _dcrClientUri = options.DynamicClientRegistration?.ClientUri; _dcrInitialAccessToken = options.DynamicClientRegistration?.InitialAccessToken; _dcrResponseDelegate = options.DynamicClientRegistration?.ResponseDelegate; + _tokenCache = options.TokenCache ?? new InMemoryTokenCache(); } /// @@ -135,20 +136,22 @@ public ClientOAuthProvider( { ThrowIfNotBearerScheme(scheme); + var token = await _tokenCache.GetTokenAsync(cancellationToken).ConfigureAwait(false); + // Return the token if it's valid - if (_token != null && _token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) + if (token != null && token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) { - return _token.AccessToken; + return token.AccessToken; } // Try to refresh the token if we have a refresh token - if (_token?.RefreshToken != null && _authServerMetadata != null) + if (token?.RefreshToken != null && _authServerMetadata != null) { - var newToken = await RefreshTokenAsync(_token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); + var newToken = await RefreshTokenAsync(token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); if (newToken != null) { - _token = newToken; - return _token.AccessToken; + await _tokenCache.StoreTokenAsync(newToken, cancellationToken).ConfigureAwait(false); + return newToken.AccessToken; } } @@ -234,7 +237,7 @@ private async Task PerformOAuthAuthorizationAsync( ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token."); } - _token = token; + await _tokenCache.StoreTokenAsync(token, cancellationToken).ConfigureAwait(false); LogOAuthAuthorizationCompleted(); } diff --git a/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs new file mode 100644 index 000000000..3619286b3 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs @@ -0,0 +1,17 @@ +namespace ModelContextProtocol.Authentication; + +/// +/// Allows the client to cache access tokens beyond the lifetime of the transport. +/// +public interface ITokenCache +{ + /// + /// Cache the token. + /// + Task StoreTokenAsync(TokenContainer token, CancellationToken cancellationToken); + + /// + /// Get the cached token. + /// + Task GetTokenAsync(CancellationToken cancellationToken); +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs new file mode 100644 index 000000000..529d56269 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs @@ -0,0 +1,27 @@ + +namespace ModelContextProtocol.Authentication; + +/// +/// Caches the token in-memory within this instance. +/// +internal class InMemoryTokenCache : ITokenCache +{ + private TokenContainer? _token; + + /// + /// Cache the token. + /// + public Task StoreTokenAsync(TokenContainer token, CancellationToken cancellationToken) + { + _token = token; + return Task.CompletedTask; + } + + /// + /// Get the cached token. + /// + public Task GetTokenAsync(CancellationToken cancellationToken) + { + return Task.FromResult(_token); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs index dc55292b9..7ffe05372 100644 --- a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs @@ -5,7 +5,7 @@ namespace ModelContextProtocol.Authentication; /// /// Represents a token response from the OAuth server. /// -internal sealed class TokenContainer +public sealed class TokenContainer { /// /// Gets or sets the access token. @@ -46,7 +46,7 @@ internal sealed class TokenContainer /// /// Gets or sets the timestamp when the token was obtained. /// - [JsonIgnore] + [JsonPropertyName("obtained_at")] public DateTimeOffset ObtainedAt { get; set; } /// From dee8a1317e4b3ba05a5f77ab3f25c1c9a7f5b8d1 Mon Sep 17 00:00:00 2001 From: Manuel Naujoks Date: Sat, 11 Oct 2025 18:45:31 +0200 Subject: [PATCH 2/2] Tests, ValueTasks, and dedicated type for caching. --- .../Authentication/ClientOAuthProvider.cs | 7 +- .../Authentication/ITokenCache.cs | 8 +- .../Authentication/InMemoryTokenCache.cs | 10 +- .../Authentication/TokenContainer.cs | 4 +- .../Authentication/TokenContainerCacheable.cs | 42 ++++ .../Authentication/TokenContainerConvert.cs | 26 ++ .../Client/CustomTokenCacheTests.cs | 233 ++++++++++++++++++ 7 files changed, 316 insertions(+), 14 deletions(-) create mode 100644 src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs create mode 100644 src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs create mode 100644 tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs index 05194963b..182197e45 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -136,7 +136,8 @@ public ClientOAuthProvider( { ThrowIfNotBearerScheme(scheme); - var token = await _tokenCache.GetTokenAsync(cancellationToken).ConfigureAwait(false); + var cachedToken = await _tokenCache.GetTokenAsync(cancellationToken).ConfigureAwait(false); + var token = cachedToken?.ForUse(); // Return the token if it's valid if (token != null && token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) @@ -150,7 +151,7 @@ public ClientOAuthProvider( var newToken = await RefreshTokenAsync(token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); if (newToken != null) { - await _tokenCache.StoreTokenAsync(newToken, cancellationToken).ConfigureAwait(false); + await _tokenCache.StoreTokenAsync(newToken.ForCache(), cancellationToken).ConfigureAwait(false); return newToken.AccessToken; } } @@ -237,7 +238,7 @@ private async Task PerformOAuthAuthorizationAsync( ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token."); } - await _tokenCache.StoreTokenAsync(token, cancellationToken).ConfigureAwait(false); + await _tokenCache.StoreTokenAsync(token.ForCache(), cancellationToken).ConfigureAwait(false); LogOAuthAuthorizationCompleted(); } diff --git a/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs index 3619286b3..46d4cc37b 100644 --- a/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs +++ b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs @@ -6,12 +6,12 @@ namespace ModelContextProtocol.Authentication; public interface ITokenCache { /// - /// Cache the token. + /// Cache the token. After a new access token is acquired, this method is invoked to store it. /// - Task StoreTokenAsync(TokenContainer token, CancellationToken cancellationToken); + ValueTask StoreTokenAsync(TokenContainerCacheable token, CancellationToken cancellationToken); /// - /// Get the cached token. + /// Get the cached token. This method is invoked for every request. /// - Task GetTokenAsync(CancellationToken cancellationToken); + ValueTask GetTokenAsync(CancellationToken cancellationToken); } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs index 529d56269..56346f731 100644 --- a/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs +++ b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs @@ -6,22 +6,22 @@ namespace ModelContextProtocol.Authentication; /// internal class InMemoryTokenCache : ITokenCache { - private TokenContainer? _token; + private TokenContainerCacheable? _token; /// /// Cache the token. /// - public Task StoreTokenAsync(TokenContainer token, CancellationToken cancellationToken) + public ValueTask StoreTokenAsync(TokenContainerCacheable token, CancellationToken cancellationToken) { _token = token; - return Task.CompletedTask; + return default; } /// /// Get the cached token. /// - public Task GetTokenAsync(CancellationToken cancellationToken) + public ValueTask GetTokenAsync(CancellationToken cancellationToken) { - return Task.FromResult(_token); + return new ValueTask(_token); } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs index 7ffe05372..dc55292b9 100644 --- a/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainer.cs @@ -5,7 +5,7 @@ namespace ModelContextProtocol.Authentication; /// /// Represents a token response from the OAuth server. /// -public sealed class TokenContainer +internal sealed class TokenContainer { /// /// Gets or sets the access token. @@ -46,7 +46,7 @@ public sealed class TokenContainer /// /// Gets or sets the timestamp when the token was obtained. /// - [JsonPropertyName("obtained_at")] + [JsonIgnore] public DateTimeOffset ObtainedAt { get; set; } /// diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs new file mode 100644 index 000000000..5f6bf0e5c --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs @@ -0,0 +1,42 @@ +namespace ModelContextProtocol.Authentication; + +/// +/// Represents a cacheable token representation. +/// +public class TokenContainerCacheable +{ + /// + /// Gets or sets the access token. + /// + public string AccessToken { get; set; } = string.Empty; + + /// + /// Gets or sets the refresh token. + /// + public string? RefreshToken { get; set; } + + /// + /// Gets or sets the number of seconds until the access token expires. + /// + public int ExpiresIn { get; set; } + + /// + /// Gets or sets the extended expiration time in seconds. + /// + public int ExtExpiresIn { get; set; } + + /// + /// Gets or sets the token type (typically "Bearer"). + /// + public string TokenType { get; set; } = string.Empty; + + /// + /// Gets or sets the scope of the access token. + /// + public string Scope { get; set; } = string.Empty; + + /// + /// Gets or sets the timestamp when the token was obtained. + /// + public DateTimeOffset ObtainedAt { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs new file mode 100644 index 000000000..6e2c8e9cd --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs @@ -0,0 +1,26 @@ +namespace ModelContextProtocol.Authentication; + +internal static class TokenContainerConvert +{ + internal static TokenContainer ForUse(this TokenContainerCacheable token) => new() + { + AccessToken = token.AccessToken, + RefreshToken = token.RefreshToken, + ExpiresIn = token.ExpiresIn, + ExtExpiresIn = token.ExtExpiresIn, + TokenType = token.TokenType, + Scope = token.Scope, + ObtainedAt = token.ObtainedAt, + }; + + internal static TokenContainerCacheable ForCache(this TokenContainer token) => new() + { + AccessToken = token.AccessToken, + RefreshToken = token.RefreshToken, + ExpiresIn = token.ExpiresIn, + ExtExpiresIn = token.ExtExpiresIn, + TokenType = token.TokenType, + Scope = token.Scope, + ObtainedAt = token.ObtainedAt, + }; +} diff --git a/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs b/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs new file mode 100644 index 000000000..3ea1262ae --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs @@ -0,0 +1,233 @@ +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Authentication; +using System.Text.Json; +using Moq; +using Moq.Protected; +using System.Net; +using System.Text.Json.Nodes; +using System.Linq.Expressions; + +namespace ModelContextProtocol.Tests.Client; + +public class CustomTokenCacheTests +{ + [Fact] + public async Task GetTokenAsync_CachedAccessTokenIsUsedForOutgoingRequests() + { + // Arrange + var cachedAccessToken = $"my_access_token_{Guid.NewGuid()}"; + + var tokenCacheMock = new Mock(); + MockCachedAccessToken(tokenCacheMock, cachedAccessToken); + + var httpMessageHandlerMock = new Mock(); + MockInitializeResponse(httpMessageHandlerMock); + + var httpClientTransport = new HttpClientTransport( + transportOptions: NewHttpClientTransportOptions(tokenCacheMock.Object), + httpClient: new HttpClient(httpMessageHandlerMock.Object)); + + var connectedTransport = await httpClientTransport.ConnectAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Act + var initializeRequest = new JsonRpcRequest { Method = RequestMethods.Initialize, Id = new RequestId(1) }; + await connectedTransport.SendMessageAsync(initializeRequest, cancellationToken: TestContext.Current.CancellationToken); + + // Assert + httpMessageHandlerMock + .Protected() + .Verify("SendAsync", Times.AtLeastOnce(), ItExpr.Is(req => + req.RequestUri == new Uri("http://localhost:1337/") + && req.Headers.Authorization != null + && req.Headers.Authorization.Scheme == "Bearer" + && req.Headers.Authorization.Parameter == cachedAccessToken + ), ItExpr.IsAny()); + + httpMessageHandlerMock + .Protected() + .Verify("SendAsync", Times.Never(), ItExpr.Is(req => + req.RequestUri == new Uri("http://localhost:1337/") + && (req.Headers.Authorization == null || req.Headers.Authorization.Parameter != cachedAccessToken) + ), ItExpr.IsAny()); + } + + [Fact] + public async Task StoreTokenAsync_NewlyAcquiredAccessTokenIsCached() + { + // Arrange + var tokenCacheMock = new Mock(); + MockNoAccessTokenUntilStored(tokenCacheMock); + + var newAccessToken = $"new_access_token_{Guid.NewGuid()}"; + + var httpMessageHandlerMock = new Mock(); + MockUnauthorizedResponse(httpMessageHandlerMock); + MockProtectedResourceMetadataResponse(httpMessageHandlerMock); + MockAuthorizationServerMetadataResponse(httpMessageHandlerMock); + MockAccessTokenResponse(httpMessageHandlerMock, newAccessToken); + MockInitializeResponse(httpMessageHandlerMock); + + var httpClientTransport = new HttpClientTransport( + transportOptions: NewHttpClientTransportOptions(tokenCacheMock.Object), + httpClient: new HttpClient(httpMessageHandlerMock.Object)); + + var connectedTransport = await httpClientTransport.ConnectAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Act + var initializeRequest = new JsonRpcRequest { Method = RequestMethods.Initialize, Id = new RequestId(1) }; + await connectedTransport.SendMessageAsync(initializeRequest, cancellationToken: TestContext.Current.CancellationToken); + + // Assert + tokenCacheMock + .Verify(tc => tc.StoreTokenAsync( + It.Is(token => token.AccessToken == newAccessToken), + It.IsAny()), Times.Once); + } + + static HttpClientTransportOptions NewHttpClientTransportOptions(ITokenCache? tokenCache = null) => new() + { + Name = "MCP Server", + Endpoint = new Uri("http://localhost:1337/"), + TransportMode = HttpTransportMode.StreamableHttp, + OAuth = new() + { + ClientId = "mcp_inspector", + RedirectUri = new Uri("http://localhost:6274/oauth/callback"), + Scopes = ["openid", "profile", "offline_access"], + AuthorizationRedirectDelegate = (authorizationUrl, redirectUri, cancellationToken) => Task.FromResult($"auth_code_{Guid.NewGuid()}"), + TokenCache = tokenCache, + }, + }; + + static void MockCachedAccessToken(Mock tokenCache, string cachedAccessToken) + { + tokenCache + .Setup(tc => tc.GetTokenAsync(It.IsAny())) + .ReturnsAsync(new TokenContainerCacheable + { + AccessToken = cachedAccessToken, + ObtainedAt = DateTimeOffset.UtcNow, + ExpiresIn = (int)TimeSpan.FromHours(1).TotalSeconds, + }); + } + + static void MockNoAccessTokenUntilStored(Mock tokenCache) + { + tokenCache + .Setup(tc => tc.StoreTokenAsync(It.IsAny(), It.IsAny())) + .Callback((token, ct) => + { + // Simulate that the token is now cached + MockCachedAccessToken(tokenCache, token.AccessToken); + }) + .Returns(default(ValueTask)); + } + + static void MockUnauthorizedResponse(Mock httpMessageHandler) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1337/") + && req.Method == HttpMethod.Post + && (req.Headers.Authorization == null || string.IsNullOrWhiteSpace(req.Headers.Authorization.Parameter)), + response: new HttpResponseMessage(HttpStatusCode.Unauthorized) + { + Headers = { + { "WWW-Authenticate", "Bearer realm=\"Bearer\", resource_metadata=\"http://localhost:1337/.well-known/oauth-protected-resource\"" } + }, + }); + } + + static void MockProtectedResourceMetadataResponse(Mock httpMessageHandler) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1337/.well-known/oauth-protected-resource"), + response: new HttpResponseMessage(HttpStatusCode.OK) + { + Content = ToJsonContent(new + { + resource = "http://localhost:1337/", + authorization_servers = new[] { "http://localhost:1336/" }, + }) + }); + } + + static void MockAuthorizationServerMetadataResponse(Mock httpMessageHandler) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1336/.well-known/openid-configuration"), + response: new HttpResponseMessage(HttpStatusCode.OK) + { + Content = ToJsonContent(new + { + authorization_endpoint = "http://localhost:1336/connect/authorize", + token_endpoint = "http://localhost:1336/connect/token", + }) + }); + } + + static void MockAccessTokenResponse(Mock httpMessageHandler, string accessToken) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1336/connect/token"), + response: new HttpResponseMessage(HttpStatusCode.OK) + { + Content = ToJsonContent(new + { + access_token = accessToken, + }) + }); + } + + static void MockInitializeResponse(Mock httpMessageHandler) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1337/") + && req.Method == HttpMethod.Post + && req.Headers.Authorization != null + && req.Headers.Authorization.Scheme == "Bearer" + && !string.IsNullOrWhiteSpace(req.Headers.Authorization.Parameter), + response: new HttpResponseMessage(HttpStatusCode.OK) + { + Content = ToJsonContent(new JsonRpcResponse + { + Id = new RequestId(1), + Result = ToJson(new InitializeResult + { + ProtocolVersion = "2024-11-05", + Capabilities = new ServerCapabilities + { + Prompts = new PromptsCapability { ListChanged = true }, + Resources = new ResourcesCapability { Subscribe = true, ListChanged = true }, + Tools = new ToolsCapability { ListChanged = true }, + Logging = new LoggingCapability(), + Completions = new CompletionsCapability(), + }, + ServerInfo = new Implementation + { + Name = "mcp-test-server", + Version = "1.0.0" + }, + Instructions = "This server provides weather information and file system access." + }) + }), + }); + } + + static void MockHttpResponse(Mock httpMessageHandler, Expression>? request = null, HttpResponseMessage? response = null) + { + httpMessageHandler + .Protected() + .Setup>("SendAsync", request != null ? ItExpr.Is(request) : ItExpr.IsAny(), ItExpr.IsAny()) + .ReturnsAsync(response ?? new HttpResponseMessage()); + } + + static StringContent ToJsonContent(T content) => new( + content: JsonSerializer.Serialize(content, McpJsonUtilities.DefaultOptions), + encoding: System.Text.Encoding.UTF8, + mediaType: "application/json"); + + static JsonNode? ToJson(T content) => JsonSerializer.SerializeToNode( + value: content, + options: McpJsonUtilities.DefaultOptions); +}