Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,10 @@ public sealed class ClientOAuthOptions
/// </para>
/// </remarks>
public IDictionary<string, string> AdditionalAuthorizationParameters { get; set; } = new Dictionary<string, string>();

/// <summary>
/// Gets or sets the token cache to use for storing and retrieving tokens beyond the lifetime of the transport.
/// If none is provided, tokens will be cached with the transport.
/// </summary>
public ITokenCache? TokenCache { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ internal sealed partial class ClientOAuthProvider
private string? _clientId;
private string? _clientSecret;

private TokenContainer? _token;
private ITokenCache _tokenCache;
private AuthorizationServerMetadata? _authServerMetadata;

/// <summary>
Expand Down Expand Up @@ -82,6 +82,7 @@ public ClientOAuthProvider(
_dcrClientUri = options.DynamicClientRegistration?.ClientUri;
_dcrInitialAccessToken = options.DynamicClientRegistration?.InitialAccessToken;
_dcrResponseDelegate = options.DynamicClientRegistration?.ResponseDelegate;
_tokenCache = options.TokenCache ?? new InMemoryTokenCache();
}

/// <summary>
Expand Down Expand Up @@ -135,20 +136,23 @@ public ClientOAuthProvider(
{
ThrowIfNotBearerScheme(scheme);

var cachedToken = await _tokenCache.GetTokenAsync(cancellationToken).ConfigureAwait(false);
var token = cachedToken?.ForUse();

// Return the token if it's valid
if (_token != null && _token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5))
if (token != null && token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5))
{
return _token.AccessToken;
return token.AccessToken;
}

// Try to refresh the token if we have a refresh token
if (_token?.RefreshToken != null && _authServerMetadata != null)
if (token?.RefreshToken != null && _authServerMetadata != null)
{
var newToken = await RefreshTokenAsync(_token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false);
var newToken = await RefreshTokenAsync(token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false);
if (newToken != null)
{
_token = newToken;
return _token.AccessToken;
await _tokenCache.StoreTokenAsync(newToken.ForCache(), cancellationToken).ConfigureAwait(false);
return newToken.AccessToken;
}
}

Expand Down Expand Up @@ -234,7 +238,7 @@ private async Task PerformOAuthAuthorizationAsync(
ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token.");
}

_token = token;
await _tokenCache.StoreTokenAsync(token.ForCache(), cancellationToken).ConfigureAwait(false);
LogOAuthAuthorizationCompleted();
}

Expand Down
17 changes: 17 additions & 0 deletions src/ModelContextProtocol.Core/Authentication/ITokenCache.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
namespace ModelContextProtocol.Authentication;

/// <summary>
/// Allows the client to cache access tokens beyond the lifetime of the transport.
/// </summary>
public interface ITokenCache
{
/// <summary>
/// Cache the token. After a new access token is acquired, this method is invoked to store it.
/// </summary>
ValueTask StoreTokenAsync(TokenContainerCacheable token, CancellationToken cancellationToken);

/// <summary>
/// Get the cached token. This method is invoked for every request.
/// </summary>
ValueTask<TokenContainerCacheable?> GetTokenAsync(CancellationToken cancellationToken);
}
27 changes: 27 additions & 0 deletions src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

namespace ModelContextProtocol.Authentication;

/// <summary>
/// Caches the token in-memory within this instance.
/// </summary>
internal class InMemoryTokenCache : ITokenCache
{
private TokenContainerCacheable? _token;

/// <summary>
/// Cache the token.
/// </summary>
public ValueTask StoreTokenAsync(TokenContainerCacheable token, CancellationToken cancellationToken)
{
_token = token;
return default;
}

/// <summary>
/// Get the cached token.
/// </summary>
public ValueTask<TokenContainerCacheable?> GetTokenAsync(CancellationToken cancellationToken)
{
return new ValueTask<TokenContainerCacheable?>(_token);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
namespace ModelContextProtocol.Authentication;

/// <summary>
/// Represents a cacheable token representation.
/// </summary>
public class TokenContainerCacheable
{
/// <summary>
/// Gets or sets the access token.
/// </summary>
public string AccessToken { get; set; } = string.Empty;

/// <summary>
/// Gets or sets the refresh token.
/// </summary>
public string? RefreshToken { get; set; }

/// <summary>
/// Gets or sets the number of seconds until the access token expires.
/// </summary>
public int ExpiresIn { get; set; }

/// <summary>
/// Gets or sets the extended expiration time in seconds.
/// </summary>
public int ExtExpiresIn { get; set; }

/// <summary>
/// Gets or sets the token type (typically "Bearer").
/// </summary>
public string TokenType { get; set; } = string.Empty;

/// <summary>
/// Gets or sets the scope of the access token.
/// </summary>
public string Scope { get; set; } = string.Empty;

/// <summary>
/// Gets or sets the timestamp when the token was obtained.
/// </summary>
public DateTimeOffset ObtainedAt { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
namespace ModelContextProtocol.Authentication;

internal static class TokenContainerConvert
{
internal static TokenContainer ForUse(this TokenContainerCacheable token) => new()
{
AccessToken = token.AccessToken,
RefreshToken = token.RefreshToken,
ExpiresIn = token.ExpiresIn,
ExtExpiresIn = token.ExtExpiresIn,
TokenType = token.TokenType,
Scope = token.Scope,
ObtainedAt = token.ObtainedAt,
};

internal static TokenContainerCacheable ForCache(this TokenContainer token) => new()
{
AccessToken = token.AccessToken,
RefreshToken = token.RefreshToken,
ExpiresIn = token.ExpiresIn,
ExtExpiresIn = token.ExtExpiresIn,
TokenType = token.TokenType,
Scope = token.Scope,
ObtainedAt = token.ObtainedAt,
};
}
Loading