diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs index cc6a8952e..ecb57df0a 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs @@ -86,4 +86,10 @@ public sealed class ClientOAuthOptions /// /// public IDictionary AdditionalAuthorizationParameters { get; set; } = new Dictionary(); + + /// + /// Gets or sets the token cache to use for storing and retrieving tokens beyond the lifetime of the transport. + /// If none is provided, tokens will be cached with the transport. + /// + public ITokenCache? TokenCache { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs index b72f775c4..182197e45 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -40,7 +40,7 @@ internal sealed partial class ClientOAuthProvider private string? _clientId; private string? _clientSecret; - private TokenContainer? _token; + private ITokenCache _tokenCache; private AuthorizationServerMetadata? _authServerMetadata; /// @@ -82,6 +82,7 @@ public ClientOAuthProvider( _dcrClientUri = options.DynamicClientRegistration?.ClientUri; _dcrInitialAccessToken = options.DynamicClientRegistration?.InitialAccessToken; _dcrResponseDelegate = options.DynamicClientRegistration?.ResponseDelegate; + _tokenCache = options.TokenCache ?? new InMemoryTokenCache(); } /// @@ -135,20 +136,23 @@ public ClientOAuthProvider( { ThrowIfNotBearerScheme(scheme); + var cachedToken = await _tokenCache.GetTokenAsync(cancellationToken).ConfigureAwait(false); + var token = cachedToken?.ForUse(); + // Return the token if it's valid - if (_token != null && _token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) + if (token != null && token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5)) { - return _token.AccessToken; + return token.AccessToken; } // Try to refresh the token if we have a refresh token - if (_token?.RefreshToken != null && _authServerMetadata != null) + if (token?.RefreshToken != null && _authServerMetadata != null) { - var newToken = await RefreshTokenAsync(_token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); + var newToken = await RefreshTokenAsync(token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false); if (newToken != null) { - _token = newToken; - return _token.AccessToken; + await _tokenCache.StoreTokenAsync(newToken.ForCache(), cancellationToken).ConfigureAwait(false); + return newToken.AccessToken; } } @@ -234,7 +238,7 @@ private async Task PerformOAuthAuthorizationAsync( ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token."); } - _token = token; + await _tokenCache.StoreTokenAsync(token.ForCache(), cancellationToken).ConfigureAwait(false); LogOAuthAuthorizationCompleted(); } diff --git a/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs new file mode 100644 index 000000000..46d4cc37b --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/ITokenCache.cs @@ -0,0 +1,17 @@ +namespace ModelContextProtocol.Authentication; + +/// +/// Allows the client to cache access tokens beyond the lifetime of the transport. +/// +public interface ITokenCache +{ + /// + /// Cache the token. After a new access token is acquired, this method is invoked to store it. + /// + ValueTask StoreTokenAsync(TokenContainerCacheable token, CancellationToken cancellationToken); + + /// + /// Get the cached token. This method is invoked for every request. + /// + ValueTask GetTokenAsync(CancellationToken cancellationToken); +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs new file mode 100644 index 000000000..56346f731 --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs @@ -0,0 +1,27 @@ + +namespace ModelContextProtocol.Authentication; + +/// +/// Caches the token in-memory within this instance. +/// +internal class InMemoryTokenCache : ITokenCache +{ + private TokenContainerCacheable? _token; + + /// + /// Cache the token. + /// + public ValueTask StoreTokenAsync(TokenContainerCacheable token, CancellationToken cancellationToken) + { + _token = token; + return default; + } + + /// + /// Get the cached token. + /// + public ValueTask GetTokenAsync(CancellationToken cancellationToken) + { + return new ValueTask(_token); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs new file mode 100644 index 000000000..5f6bf0e5c --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainerCacheable.cs @@ -0,0 +1,42 @@ +namespace ModelContextProtocol.Authentication; + +/// +/// Represents a cacheable token representation. +/// +public class TokenContainerCacheable +{ + /// + /// Gets or sets the access token. + /// + public string AccessToken { get; set; } = string.Empty; + + /// + /// Gets or sets the refresh token. + /// + public string? RefreshToken { get; set; } + + /// + /// Gets or sets the number of seconds until the access token expires. + /// + public int ExpiresIn { get; set; } + + /// + /// Gets or sets the extended expiration time in seconds. + /// + public int ExtExpiresIn { get; set; } + + /// + /// Gets or sets the token type (typically "Bearer"). + /// + public string TokenType { get; set; } = string.Empty; + + /// + /// Gets or sets the scope of the access token. + /// + public string Scope { get; set; } = string.Empty; + + /// + /// Gets or sets the timestamp when the token was obtained. + /// + public DateTimeOffset ObtainedAt { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs b/src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs new file mode 100644 index 000000000..6e2c8e9cd --- /dev/null +++ b/src/ModelContextProtocol.Core/Authentication/TokenContainerConvert.cs @@ -0,0 +1,26 @@ +namespace ModelContextProtocol.Authentication; + +internal static class TokenContainerConvert +{ + internal static TokenContainer ForUse(this TokenContainerCacheable token) => new() + { + AccessToken = token.AccessToken, + RefreshToken = token.RefreshToken, + ExpiresIn = token.ExpiresIn, + ExtExpiresIn = token.ExtExpiresIn, + TokenType = token.TokenType, + Scope = token.Scope, + ObtainedAt = token.ObtainedAt, + }; + + internal static TokenContainerCacheable ForCache(this TokenContainer token) => new() + { + AccessToken = token.AccessToken, + RefreshToken = token.RefreshToken, + ExpiresIn = token.ExpiresIn, + ExtExpiresIn = token.ExtExpiresIn, + TokenType = token.TokenType, + Scope = token.Scope, + ObtainedAt = token.ObtainedAt, + }; +} diff --git a/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs b/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs new file mode 100644 index 000000000..3ea1262ae --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Client/CustomTokenCacheTests.cs @@ -0,0 +1,233 @@ +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Authentication; +using System.Text.Json; +using Moq; +using Moq.Protected; +using System.Net; +using System.Text.Json.Nodes; +using System.Linq.Expressions; + +namespace ModelContextProtocol.Tests.Client; + +public class CustomTokenCacheTests +{ + [Fact] + public async Task GetTokenAsync_CachedAccessTokenIsUsedForOutgoingRequests() + { + // Arrange + var cachedAccessToken = $"my_access_token_{Guid.NewGuid()}"; + + var tokenCacheMock = new Mock(); + MockCachedAccessToken(tokenCacheMock, cachedAccessToken); + + var httpMessageHandlerMock = new Mock(); + MockInitializeResponse(httpMessageHandlerMock); + + var httpClientTransport = new HttpClientTransport( + transportOptions: NewHttpClientTransportOptions(tokenCacheMock.Object), + httpClient: new HttpClient(httpMessageHandlerMock.Object)); + + var connectedTransport = await httpClientTransport.ConnectAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Act + var initializeRequest = new JsonRpcRequest { Method = RequestMethods.Initialize, Id = new RequestId(1) }; + await connectedTransport.SendMessageAsync(initializeRequest, cancellationToken: TestContext.Current.CancellationToken); + + // Assert + httpMessageHandlerMock + .Protected() + .Verify("SendAsync", Times.AtLeastOnce(), ItExpr.Is(req => + req.RequestUri == new Uri("http://localhost:1337/") + && req.Headers.Authorization != null + && req.Headers.Authorization.Scheme == "Bearer" + && req.Headers.Authorization.Parameter == cachedAccessToken + ), ItExpr.IsAny()); + + httpMessageHandlerMock + .Protected() + .Verify("SendAsync", Times.Never(), ItExpr.Is(req => + req.RequestUri == new Uri("http://localhost:1337/") + && (req.Headers.Authorization == null || req.Headers.Authorization.Parameter != cachedAccessToken) + ), ItExpr.IsAny()); + } + + [Fact] + public async Task StoreTokenAsync_NewlyAcquiredAccessTokenIsCached() + { + // Arrange + var tokenCacheMock = new Mock(); + MockNoAccessTokenUntilStored(tokenCacheMock); + + var newAccessToken = $"new_access_token_{Guid.NewGuid()}"; + + var httpMessageHandlerMock = new Mock(); + MockUnauthorizedResponse(httpMessageHandlerMock); + MockProtectedResourceMetadataResponse(httpMessageHandlerMock); + MockAuthorizationServerMetadataResponse(httpMessageHandlerMock); + MockAccessTokenResponse(httpMessageHandlerMock, newAccessToken); + MockInitializeResponse(httpMessageHandlerMock); + + var httpClientTransport = new HttpClientTransport( + transportOptions: NewHttpClientTransportOptions(tokenCacheMock.Object), + httpClient: new HttpClient(httpMessageHandlerMock.Object)); + + var connectedTransport = await httpClientTransport.ConnectAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Act + var initializeRequest = new JsonRpcRequest { Method = RequestMethods.Initialize, Id = new RequestId(1) }; + await connectedTransport.SendMessageAsync(initializeRequest, cancellationToken: TestContext.Current.CancellationToken); + + // Assert + tokenCacheMock + .Verify(tc => tc.StoreTokenAsync( + It.Is(token => token.AccessToken == newAccessToken), + It.IsAny()), Times.Once); + } + + static HttpClientTransportOptions NewHttpClientTransportOptions(ITokenCache? tokenCache = null) => new() + { + Name = "MCP Server", + Endpoint = new Uri("http://localhost:1337/"), + TransportMode = HttpTransportMode.StreamableHttp, + OAuth = new() + { + ClientId = "mcp_inspector", + RedirectUri = new Uri("http://localhost:6274/oauth/callback"), + Scopes = ["openid", "profile", "offline_access"], + AuthorizationRedirectDelegate = (authorizationUrl, redirectUri, cancellationToken) => Task.FromResult($"auth_code_{Guid.NewGuid()}"), + TokenCache = tokenCache, + }, + }; + + static void MockCachedAccessToken(Mock tokenCache, string cachedAccessToken) + { + tokenCache + .Setup(tc => tc.GetTokenAsync(It.IsAny())) + .ReturnsAsync(new TokenContainerCacheable + { + AccessToken = cachedAccessToken, + ObtainedAt = DateTimeOffset.UtcNow, + ExpiresIn = (int)TimeSpan.FromHours(1).TotalSeconds, + }); + } + + static void MockNoAccessTokenUntilStored(Mock tokenCache) + { + tokenCache + .Setup(tc => tc.StoreTokenAsync(It.IsAny(), It.IsAny())) + .Callback((token, ct) => + { + // Simulate that the token is now cached + MockCachedAccessToken(tokenCache, token.AccessToken); + }) + .Returns(default(ValueTask)); + } + + static void MockUnauthorizedResponse(Mock httpMessageHandler) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1337/") + && req.Method == HttpMethod.Post + && (req.Headers.Authorization == null || string.IsNullOrWhiteSpace(req.Headers.Authorization.Parameter)), + response: new HttpResponseMessage(HttpStatusCode.Unauthorized) + { + Headers = { + { "WWW-Authenticate", "Bearer realm=\"Bearer\", resource_metadata=\"http://localhost:1337/.well-known/oauth-protected-resource\"" } + }, + }); + } + + static void MockProtectedResourceMetadataResponse(Mock httpMessageHandler) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1337/.well-known/oauth-protected-resource"), + response: new HttpResponseMessage(HttpStatusCode.OK) + { + Content = ToJsonContent(new + { + resource = "http://localhost:1337/", + authorization_servers = new[] { "http://localhost:1336/" }, + }) + }); + } + + static void MockAuthorizationServerMetadataResponse(Mock httpMessageHandler) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1336/.well-known/openid-configuration"), + response: new HttpResponseMessage(HttpStatusCode.OK) + { + Content = ToJsonContent(new + { + authorization_endpoint = "http://localhost:1336/connect/authorize", + token_endpoint = "http://localhost:1336/connect/token", + }) + }); + } + + static void MockAccessTokenResponse(Mock httpMessageHandler, string accessToken) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1336/connect/token"), + response: new HttpResponseMessage(HttpStatusCode.OK) + { + Content = ToJsonContent(new + { + access_token = accessToken, + }) + }); + } + + static void MockInitializeResponse(Mock httpMessageHandler) + { + MockHttpResponse(httpMessageHandler, + request: req => req.RequestUri == new Uri("http://localhost:1337/") + && req.Method == HttpMethod.Post + && req.Headers.Authorization != null + && req.Headers.Authorization.Scheme == "Bearer" + && !string.IsNullOrWhiteSpace(req.Headers.Authorization.Parameter), + response: new HttpResponseMessage(HttpStatusCode.OK) + { + Content = ToJsonContent(new JsonRpcResponse + { + Id = new RequestId(1), + Result = ToJson(new InitializeResult + { + ProtocolVersion = "2024-11-05", + Capabilities = new ServerCapabilities + { + Prompts = new PromptsCapability { ListChanged = true }, + Resources = new ResourcesCapability { Subscribe = true, ListChanged = true }, + Tools = new ToolsCapability { ListChanged = true }, + Logging = new LoggingCapability(), + Completions = new CompletionsCapability(), + }, + ServerInfo = new Implementation + { + Name = "mcp-test-server", + Version = "1.0.0" + }, + Instructions = "This server provides weather information and file system access." + }) + }), + }); + } + + static void MockHttpResponse(Mock httpMessageHandler, Expression>? request = null, HttpResponseMessage? response = null) + { + httpMessageHandler + .Protected() + .Setup>("SendAsync", request != null ? ItExpr.Is(request) : ItExpr.IsAny(), ItExpr.IsAny()) + .ReturnsAsync(response ?? new HttpResponseMessage()); + } + + static StringContent ToJsonContent(T content) => new( + content: JsonSerializer.Serialize(content, McpJsonUtilities.DefaultOptions), + encoding: System.Text.Encoding.UTF8, + mediaType: "application/json"); + + static JsonNode? ToJson(T content) => JsonSerializer.SerializeToNode( + value: content, + options: McpJsonUtilities.DefaultOptions); +}