diff --git a/lib/src/auth/credential_provider.dart b/lib/src/auth/credential_provider.dart index 8112915..ab0eb5c 100644 --- a/lib/src/auth/credential_provider.dart +++ b/lib/src/auth/credential_provider.dart @@ -1,8 +1,16 @@ import 'dart:convert'; import 'dart:io'; +import 'package:client_sdk_dart/src/errors/errors.dart'; import 'package:jwt_decoder/jwt_decoder.dart'; import 'package:string_validator/string_validator.dart'; +class EndpointOverrides { + String controlEndpoint; + String cacheEndpoint; + + EndpointOverrides(this.cacheEndpoint, this.controlEndpoint); +} + enum CredentialProviderError { emptyApiKey, emptyAuthEnvironmentVariable, @@ -43,8 +51,8 @@ class _Endpoints { class _ParsedApiKey { String apiKey; - String cacheEndpoint; - String controlEndpoint; + String? cacheEndpoint; + String? controlEndpoint; _ParsedApiKey(this.apiKey, this.controlEndpoint, this.cacheEndpoint); } @@ -57,11 +65,29 @@ abstract class CredentialProvider { String get controlEndpoint => _controlEndpoint; String get cacheEndpoint => _cacheEndpoint; - static CredentialProvider fromEnvironmentVariable(String envVarName) { + static CredentialProvider fromEnvironmentVariable(String envVarName, {String? baseEndpointOverride, EndpointOverrides? endpointOverrides}) { + if (endpointOverrides != null && baseEndpointOverride != null) { + throw IllegalArgumentError("either pass in 'baseEndpointOverride' or 'endpointOverrides', cannot pass in both"); + } + if (endpointOverrides != null) { + return EnvMomentoTokenProvider.withEndpointOverrides(envVarName, endpointOverrides); + } + if (baseEndpointOverride != null && baseEndpointOverride.isNotEmpty) { + return EnvMomentoTokenProvider.withBaseEndpointOverride(envVarName, baseEndpointOverride); + } return EnvMomentoTokenProvider(envVarName); } - static CredentialProvider fromString(String apiKey) { + static CredentialProvider fromString(String apiKey, {String? baseEndpointOverride, EndpointOverrides? endpointOverrides}) { + if (endpointOverrides != null && baseEndpointOverride != null) { + throw IllegalArgumentError("either pass in 'baseEndpointOverride' or 'endpointOverrides', cannot pass in both"); + } + if (endpointOverrides != null) { + return StringMomentoTokenProvider.withEndpointOverrides(apiKey, endpointOverrides); + } + if (baseEndpointOverride != null && baseEndpointOverride.isNotEmpty) { + return StringMomentoTokenProvider.withBaseEndpointOverride(apiKey, baseEndpointOverride); + } return StringMomentoTokenProvider(apiKey); } @@ -74,9 +100,6 @@ abstract class CredentialProvider { static _ParsedApiKey _parseJwtToken(String jwt) { Map claims = JwtDecoder.decode(jwt); - if (!claims.containsKey("c") || !claims.containsKey("cp")) { - throw "failed to parse jwt token"; - } return _ParsedApiKey(jwt, claims["cp"], claims["c"]); } @@ -84,10 +107,10 @@ abstract class CredentialProvider { final decodedJson = json.decode(utf8.decode(base64Decode(apiKey))); final decoded = Base64DecodedV1Token.fromJson(decodedJson); if (decoded.endpoint.isEmpty) { - throw "invalid jwt missing required claim 'endpoint'"; + throw IllegalArgumentError("invalid jwt missing required claim 'endpoint'"); } if (decoded.apiKey.isEmpty) { - throw "invalid jwt missing required claim 'api_key'"; + throw IllegalArgumentError("invalid jwt missing required claim 'api_key'"); } final endpoints = _Endpoints(decoded.endpoint); return _ParsedApiKey( @@ -105,15 +128,38 @@ class StringMomentoTokenProvider implements CredentialProvider { @override String _controlEndpoint = ""; - StringMomentoTokenProvider(String apiKey, - {String? controlEndpoint, String? cacheEndpoint}) { + StringMomentoTokenProvider(String apiKey) { + if (apiKey.isEmpty) { + throw CredentialProviderError.emptyApiKey.name; + } + final parsedApiKey = CredentialProvider._parseApiKey(apiKey); + _apiKey = parsedApiKey.apiKey; + if (parsedApiKey.controlEndpoint == null || parsedApiKey.cacheEndpoint == null) { + throw IllegalArgumentError("failed to parse jwt token"); + } + _cacheEndpoint = parsedApiKey.cacheEndpoint!; + _controlEndpoint = parsedApiKey.controlEndpoint!; + } + + StringMomentoTokenProvider.withBaseEndpointOverride(String apiKey, String baseEndpoint) { + if (apiKey.isEmpty) { + throw CredentialProviderError.emptyApiKey.name; + } + final parsedApiKey = CredentialProvider._parseApiKey(apiKey); + final endpoints = _Endpoints(baseEndpoint); + _apiKey = parsedApiKey.apiKey; + _cacheEndpoint = endpoints.cacheEndpoint; + _controlEndpoint = endpoints.controlEndpoint; + } + + StringMomentoTokenProvider.withEndpointOverrides(String apiKey, EndpointOverrides overrides) { if (apiKey.isEmpty) { throw CredentialProviderError.emptyApiKey.name; } final parsedApiKey = CredentialProvider._parseApiKey(apiKey); _apiKey = parsedApiKey.apiKey; - _cacheEndpoint = parsedApiKey.cacheEndpoint; - _controlEndpoint = parsedApiKey.controlEndpoint; + _cacheEndpoint = overrides.cacheEndpoint; + _controlEndpoint = overrides.controlEndpoint; } @override @@ -127,8 +173,10 @@ class StringMomentoTokenProvider implements CredentialProvider { } class EnvMomentoTokenProvider extends StringMomentoTokenProvider { - EnvMomentoTokenProvider(String envVarName, - {String? controlEndpoint, String? cacheEndpoint}) - : super(Platform.environment[envVarName] ?? '', - controlEndpoint: controlEndpoint, cacheEndpoint: cacheEndpoint); + EnvMomentoTokenProvider(String envVarName) + : super(Platform.environment[envVarName] ?? ''); + EnvMomentoTokenProvider.withBaseEndpointOverride(String envVarName, String baseEndpoint) + : super.withBaseEndpointOverride(Platform.environment[envVarName] ?? '', baseEndpoint); + EnvMomentoTokenProvider.withEndpointOverrides(String envVarName, EndpointOverrides overrides) + : super.withEndpointOverrides(Platform.environment[envVarName] ?? '', overrides); } diff --git a/lib/src/errors/errors.dart b/lib/src/errors/errors.dart index ad0c007..1246722 100644 --- a/lib/src/errors/errors.dart +++ b/lib/src/errors/errors.dart @@ -257,3 +257,8 @@ class FailedPreconditionException extends SdkException { "System is not in a state required for the operation's execution; please contact Momento.", transportDetails); } + +class IllegalArgumentError extends Error { + String message; + IllegalArgumentError(this.message): super(); +} diff --git a/test/src/auth/credential_provider_test.dart b/test/src/auth/credential_provider_test.dart index 22733ca..0c98fdf 100644 --- a/test/src/auth/credential_provider_test.dart +++ b/test/src/auth/credential_provider_test.dart @@ -1,4 +1,5 @@ import 'package:client_sdk_dart/src/auth/credential_provider.dart'; +import 'package:client_sdk_dart/src/errors/errors.dart'; import 'package:test/test.dart'; import 'dart:convert'; @@ -37,6 +38,42 @@ void main() { expect(authProvider.cacheEndpoint, equals('cache.${decodedV1Token.endpoint}')); }); + + test('parses a token with base endpoint override', () { + var authProvider = + CredentialProvider.fromString(base64EncodedFakeV1AuthToken, baseEndpointOverride: "test.com"); + expect(authProvider.apiKey, equals(fakeTestV1ApiKey)); + expect(authProvider.controlEndpoint, + equals('control.test.com')); + expect(authProvider.cacheEndpoint, + equals('cache.test.com')); + }); + test('parses a token with endpoint overrides', () { + var authProvider = + CredentialProvider.fromString(base64EncodedFakeV1AuthToken, endpointOverrides: EndpointOverrides("this.is.a.cache.endpoint", "this.is.a.control.endpoint")); + expect(authProvider.apiKey, equals(fakeTestV1ApiKey)); + expect(authProvider.controlEndpoint, + equals('this.is.a.control.endpoint')); + expect(authProvider.cacheEndpoint, + equals('this.is.a.cache.endpoint')); + }); + test('parses a session token with base endpoint override', () { + var authProvider = + CredentialProvider.fromString(fakeSessionToken, endpointOverrides: EndpointOverrides("this.is.a.cache.endpoint", "this.is.a.control.endpoint")); + expect(authProvider.apiKey, equals(fakeSessionToken)); + expect(authProvider.controlEndpoint, + equals('this.is.a.control.endpoint')); + expect(authProvider.cacheEndpoint, + equals('this.is.a.cache.endpoint')); + }); + test('fromString should not allow passing in both endpointOverrides and baseEndpointOverride', () { + expect(() => CredentialProvider.fromString(base64EncodedFakeV1AuthToken, baseEndpointOverride: "baseendpoint.com", endpointOverrides: EndpointOverrides("this.is.a.cache.endpoint", "this.is.a.control.endpoint")), throwsA(TypeMatcher())); + }); + }); + group('fromEnvironmentVariable', () { + test('fromEnvironmentVariable should not allow passing in both endpointOverrides and baseEndpointOverride', () { + expect(() => CredentialProvider.fromEnvironmentVariable(base64EncodedFakeV1AuthToken, baseEndpointOverride: "baseendpoint.com", endpointOverrides: EndpointOverrides("this.is.a.cache.endpoint", "this.is.a.control.endpoint")), throwsA(TypeMatcher())); + }); }); }); }