From 2977f13b8028bcc8b393e9797a2ac37f78e9b590 Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Wed, 10 May 2023 17:23:26 +0530 Subject: [PATCH] KMS: implement cross-accounts access (#8253) --- localstack/constants.py | 2 +- localstack/services/kms/models.py | 23 +- localstack/services/kms/provider.py | 315 +++++++++++++++++++--------- localstack/services/kms/utils.py | 25 +++ tests/integration/test_kms.py | 121 ++++++++++- 5 files changed, 381 insertions(+), 105 deletions(-) diff --git a/localstack/constants.py b/localstack/constants.py index 5967638dd1f22..c103c24593b63 100644 --- a/localstack/constants.py +++ b/localstack/constants.py @@ -174,7 +174,7 @@ TEST_AWS_SECRET_ACCESS_KEY = "test" # additional credentials used in the test suite (mainly for cross-account access) -SECONDARY_TEST_AWS_ACCESS_KEY_ID = "test2" +SECONDARY_TEST_AWS_ACCESS_KEY_ID = "000000000002" SECONDARY_TEST_AWS_SECRET_ACCESS_KEY = "test2" # credentials being used for internal calls diff --git a/localstack/services/kms/models.py b/localstack/services/kms/models.py index 5010fcf2642de..7d8b9f456233c 100644 --- a/localstack/services/kms/models.py +++ b/localstack/services/kms/models.py @@ -1,3 +1,4 @@ +import base64 import datetime import io import json @@ -36,10 +37,11 @@ SigningAlgorithmSpec, UnsupportedOperationException, ) +from localstack.services.kms.utils import is_valid_key_arn from localstack.services.stores import AccountRegionBundle, BaseStore, LocalAttribute from localstack.utils.aws.arns import kms_alias_arn, kms_key_arn, parse_arn from localstack.utils.crypto import decrypt, encrypt -from localstack.utils.strings import long_uid +from localstack.utils.strings import long_uid, to_bytes, to_str LOG = logging.getLogger(__name__) @@ -543,8 +545,14 @@ class KmsGrant: # simplicity. token: str - def __init__(self, create_grant_request: CreateGrantRequest): + def __init__(self, create_grant_request: CreateGrantRequest, account_id: str, region_name: str): self.metadata = dict(create_grant_request) + + if is_valid_key_arn(self.metadata["KeyId"]): + self.metadata["KeyArn"] = self.metadata["KeyId"] + else: + self.metadata["KeyArn"] = kms_key_arn(self.metadata["KeyId"], account_id, region_name) + self.metadata["GrantId"] = long_uid() self.metadata["CreationDate"] = datetime.datetime.now() # https://docs.aws.amazon.com/kms/latest/APIReference/API_GrantListEntry.html @@ -554,7 +562,11 @@ def __init__(self, create_grant_request: CreateGrantRequest): # The Name field is present with just an empty string value. self.metadata.setdefault("Name", "") - self.token = long_uid() + # Encode account ID and region in grant token. + # This way the grant can be located when being retired by grant principal. + # The token consists of account ID, region name and a UUID concatenated with ':' and encoded with base64 + decoded_token = account_id + ":" + region_name + ":" + long_uid() + self.token = to_str(base64.b64encode(to_bytes(decoded_token))) class KmsAlias: @@ -602,12 +614,13 @@ class KmsStore(BaseStore): # According to AWS documentation on grants https://docs.aws.amazon.com/kms/latest/APIReference/API_RetireGrant.html # "Cross-account use: Yes. You can retire a grant on a KMS key in a different AWS account." - # We, however, currently only support grants on keys inside the same account. - # + # maps grant ids to grants grants: Dict[str, KmsGrant] = LocalAttribute(default=dict) + # maps from (grant names (used for idempotency), key id) to grant ids grant_names: Dict[Tuple[str, str], str] = LocalAttribute(default=dict) + # maps grant tokens to grant ids grant_tokens: Dict[str, str] = LocalAttribute(default=dict) diff --git a/localstack/services/kms/provider.py b/localstack/services/kms/provider.py index c8c067f4b31e0..97a6984111f89 100644 --- a/localstack/services/kms/provider.py +++ b/localstack/services/kms/provider.py @@ -1,8 +1,9 @@ +import base64 import copy import datetime import logging import os -from typing import Dict +from typing import Dict, Tuple from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import padding @@ -79,6 +80,7 @@ PrincipalIdType, PutKeyPolicyRequest, RecipientInfo, + ReEncryptResponse, ReplicateKeyRequest, ReplicateKeyResponse, ScheduleKeyDeletionRequest, @@ -107,6 +109,7 @@ kms_stores, validate_alias_name, ) +from localstack.services.kms.utils import is_valid_key_arn, parse_key_arn from localstack.services.plugins import ServiceLifecycleHook from localstack.utils.aws.arns import kms_alias_arn from localstack.utils.collections import PaginatedList @@ -145,12 +148,57 @@ def __init__(self, message=None): # For all operations constraints for states of keys are based on # https://docs.aws.amazon.com/kms/latest/developerguide/key-state.html class KmsProvider(KmsApi, ServiceLifecycleHook): + """ + The LocalStack Key Management Service (KMS) provider. + + Cross-account access is supported by following operations where key ID belonging + to another account can be used with the key ARN. + - CreateGrant + - DescribeKey + - GetKeyRotationStatus + - GetPublicKey + - ListGrants + - RetireGrant + - RevokeGrant + - Decrypt + - Encrypt + - GenerateDataKey + - GenerateDataKeyPair + - GenerateDataKeyPairWithoutPlaintext + - GenerateDataKeyWithoutPlaintext + - GenerateMac + - ReEncrypt + - Sign + - Verify + - VerifyMac + """ + + @staticmethod + def _get_store(account_id: str, region_name: str) -> KmsStore: + return kms_stores[account_id][region_name] + + @staticmethod + def _get_key(account_id: str, region_name: str, key_id: str, **kwargs) -> KmsKey: + return KmsProvider._get_store(account_id, region_name).get_key(key_id, **kwargs) + @staticmethod - def _get_store(context: RequestContext) -> KmsStore: - return kms_stores[context.account_id][context.region] + def _parse_key_id(key_id_or_arn: str, context: RequestContext) -> Tuple[str, str, str]: + """ + Return locator attributes (account ID, region_name, key ID) of a given KMS key. - def _get_key(self, context: RequestContext, key_id: str, **kwargs) -> KmsKey: - return self._get_store(context).get_key(key_id, **kwargs) + If an ARN is provided, this is extracted from it. Otherwise, context data is used. + + :param key_id_or_arn: KMS key ID or ARN + :param context: request context + :return: Tuple of account ID, region name and key ID + """ + if is_valid_key_arn(key_id_or_arn): + account_id, region_name, key_id = parse_key_arn(key_id_or_arn) + if region_name != context.region: + raise NotFoundException(f"Invalid arn {region_name}") + return account_id, region_name, key_id + + return context.account_id, context.region, key_id_or_arn @handler("CreateKey", expand=False) def create_key( @@ -158,7 +206,9 @@ def create_key( context: RequestContext, request: CreateKeyRequest = None, ) -> CreateKeyResponse: - key = self._get_store(context).create_key(request, context.account_id, context.region) + key = self._get_store(context.account_id, context.region).create_key( + request, context.account_id, context.region + ) return CreateKeyResponse(KeyMetadata=key.metadata) @handler("ScheduleKeyDeletion", expand=False) @@ -171,7 +221,11 @@ def schedule_key_deletion( f"PendingWindowInDays should be between 7 and 30, but it is {pending_window}" ) key = self._get_key( - context, request.get("KeyId"), enabled_key_allowed=True, disabled_key_allowed=True + context.account_id, + context.region, + request.get("KeyId"), + enabled_key_allowed=True, + disabled_key_allowed=True, ) key.schedule_key_deletion(pending_window) attrs = ["DeletionDate", "KeyId", "KeyState"] @@ -184,7 +238,8 @@ def cancel_key_deletion( self, context: RequestContext, request: CancelKeyDeletionRequest ) -> CancelKeyDeletionResponse: key = self._get_key( - context, + context.account_id, + context.region, request.get("KeyId"), enabled_key_allowed=False, pending_deletion_key_allowed=True, @@ -199,7 +254,11 @@ def cancel_key_deletion( def disable_key(self, context: RequestContext, request: DisableKeyRequest) -> None: # Technically, AWS allows DisableKey for keys that are already disabled. key = self._get_key( - context, request.get("KeyId"), enabled_key_allowed=True, disabled_key_allowed=True + context.account_id, + context.region, + request.get("KeyId"), + enabled_key_allowed=True, + disabled_key_allowed=True, ) key.metadata["KeyState"] = KeyState.Disabled key.metadata["Enabled"] = False @@ -207,7 +266,11 @@ def disable_key(self, context: RequestContext, request: DisableKeyRequest) -> No @handler("EnableKey", expand=False) def enable_key(self, context: RequestContext, request: EnableKeyRequest) -> None: key = self._get_key( - context, request.get("KeyId"), enabled_key_allowed=True, disabled_key_allowed=True + context.account_id, + context.region, + request.get("KeyId"), + enabled_key_allowed=True, + disabled_key_allowed=True, ) key.metadata["KeyState"] = KeyState.Enabled key.metadata["Enabled"] = True @@ -219,7 +282,7 @@ def list_keys(self, context: RequestContext, request: ListKeysRequest) -> ListKe keys_list = PaginatedList( [ {"KeyId": key.metadata["KeyId"], "KeyArn": key.metadata["Arn"]} - for key in self._get_store(context).keys.values() + for key in self._get_store(context.account_id, context.region).keys.values() ] ) # https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeys.html#API_ListKeys_RequestParameters @@ -236,14 +299,15 @@ def list_keys(self, context: RequestContext, request: ListKeysRequest) -> ListKe def describe_key( self, context: RequestContext, request: DescribeKeyRequest ) -> DescribeKeyResponse: - key = self._get_key(context, request.get("KeyId"), any_key_state_allowed=True) + account_id, region_name, key_id = self._parse_key_id(request["KeyId"], context) + key = self._get_key(account_id, region_name, key_id, any_key_state_allowed=True) return DescribeKeyResponse(KeyMetadata=key.metadata) @handler("ReplicateKey", expand=False) def replicate_key( self, context: RequestContext, request: ReplicateKeyRequest ) -> ReplicateKeyResponse: - replicate_from_store = self._get_store(context) + replicate_from_store = self._get_store(context.account_id, context.region) key = replicate_from_store.get_key(request.get("KeyId")) key_id = key.metadata.get("KeyId") if not key.metadata.get("MultiRegion"): @@ -270,7 +334,11 @@ def update_key_description( self, context: RequestContext, request: UpdateKeyDescriptionRequest ) -> None: key = self._get_key( - context, request.get("KeyId"), enabled_key_allowed=True, disabled_key_allowed=True + context.account_id, + context.region, + request.get("KeyId"), + enabled_key_allowed=True, + disabled_key_allowed=True, ) key.metadata["Description"] = request.get("Description") @@ -278,18 +346,22 @@ def update_key_description( def create_grant( self, context: RequestContext, request: CreateGrantRequest ) -> CreateGrantResponse: - store = self._get_store(context) + key_account_id, key_region_name, key_id = self._parse_key_id(request["KeyId"], context) + store_kms_key = self._get_store(key_account_id, key_region_name) + key = store_kms_key.get_key(key_id) + # KeyId can potentially hold one of multiple different types of key identifiers. Here we find a key no # matter which type of id is used. - key = store.get_key(request.get("KeyId")) key_id = key.metadata.get("KeyId") request["KeyId"] = key_id - self._validate_grant_request(request, store) + self._validate_grant_request(request) grant_name = request.get("Name") + + store = self._get_store(context.account_id, context.region) if grant_name and (grant_name, key_id) in store.grant_names: grant = store.grants[store.grant_names[(grant_name, key_id)]] else: - grant = KmsGrant(request) + grant = KmsGrant(request, context.account_id, context.region) grant_id = grant.metadata["GrantId"] store.grants[grant_id] = grant if grant_name: @@ -311,12 +383,14 @@ def list_grants( ) -> ListGrantsResponse: if not request.get("KeyId"): raise ValidationError("Required input parameter KeyId not specified") - store = self._get_store(context) + key_account_id, key_region_name, _ = self._parse_key_id(request["KeyId"], context) + key_store = self._get_store(key_account_id, key_region_name) # KeyId can potentially hold one of multiple different types of key identifiers. Here we find a key no # matter which type of id is used. - key = store.get_key(request.get("KeyId"), any_key_state_allowed=True) + key = key_store.get_key(request.get("KeyId"), any_key_state_allowed=True) key_id = key.metadata.get("KeyId") + store = self._get_store(context.account_id, context.region) grant_id = request.get("GrantId") if grant_id: if grant_id not in store.grants: @@ -327,7 +401,8 @@ def list_grants( grantee_principal = request.get("GranteePrincipal") for grant in store.grants.values(): # KeyId is a mandatory field of ListGrants request, so is going to be present. - if grant.metadata["KeyId"] != key_id: + _, _, grant_key_id = parse_key_arn(grant.metadata["KeyArn"]) + if grant_key_id != key_id: continue # GranteePrincipal is a mandatory field for CreateGrant, should be in grants. But it is an optional field # for ListGrants, so might not be there. @@ -345,48 +420,14 @@ def list_grants( return ListGrantsResponse(Grants=page, **kwargs) - # Honestly, this is a mess in AWS KMS. Hashtag "do we follow specifications that are a pain to customers or do we - # diverge from AWS and make the life of our customers easier?" - # - # Both RetireGrant and RevokeGrant operations delete a grant. The differences between them are described here: - # https://docs.aws.amazon.com/kms/latest/developerguide/grant-manage.html#grant-delete - # Essentially: - # - Permissions to RevokeGrant are controlled through IAM policies or through key policies, while permissions to - # RetireGrant are controlled by settings inside the grant itself. - # - A grant to be retired can be specified by its GrantToken or its GrantId/KeyId pair. While revoking grants can - # only be done with a GrantId/KeyId pair. - # - For RevokeGrant, KeyId can be either an actual key ID, or an ARN of that key. While for RetireGrant only key - # ARN is accepted as a KeyId. - # - # We currently do not model permissions for retirement and revocation of grants. At least not in KMS, - # maybe IAM in LocalStack has some modelling though. We also accept both key IDs and key ARNs for both - # operations. So apart from RevokeGrant not accepting GrantToken parameter, we treat these two operations the same. @staticmethod - def _delete_grant( - store: KmsStore, grant_id: str = None, key_id: str = None, grant_token: str = None - ): - if grant_token: - if grant_token not in store.grant_tokens: - raise NotFoundException(f"Unable to find grant token {grant_token}") - grant_id = store.grant_tokens[grant_token] - # Do not really care about the key ID if a grant is identified by a token. But since a key has to be - # validated when a grant is identified by GrantId/KeyId pair, and since we want to use the same code in - # both cases - when we have a grant token or a GrantId/KeyId pair - have to set key_id. - key_id = store.grants[grant_id].metadata["KeyId"] - - # KeyId can potentially hold one of multiple different types of key identifiers. Here we find a key no - # matter which type of id is used. - key = store.get_key(key_id, any_key_state_allowed=True) - key_id = key.metadata.get("KeyId") + def _delete_grant(store: KmsStore, grant_id: str, key_id: str): + grant = store.grants[grant_id] - if grant_id not in store.grants: - raise InvalidGrantIdException() - if store.grants[grant_id].metadata["KeyId"] != key_id: + _, _, grant_key_id = parse_key_arn(grant.metadata.get("KeyArn")) + if key_id != grant_key_id: raise ValidationError(f"Invalid KeyId={key_id} specified for grant {grant_id}") - grant = store.grants[grant_id] - # In AWS grants have one or more tokens. But we have a simplified modeling of grants, where they have exactly - # one token. store.grant_tokens.pop(grant.token) store.grant_names.pop((grant.metadata.get("Name"), key_id), None) store.grants.pop(grant_id) @@ -394,7 +435,18 @@ def _delete_grant( def revoke_grant( self, context: RequestContext, key_id: KeyIdType, grant_id: GrantIdType ) -> None: - self._delete_grant(store=self._get_store(context), grant_id=grant_id, key_id=key_id) + key_account_id, key_region_name, key_id = self._parse_key_id(key_id, context) + key = self._get_store(key_account_id, key_region_name).get_key( + key_id, any_key_state_allowed=True + ) + key_id = key.metadata.get("KeyId") + + store = self._get_store(context.account_id, context.region) + + if grant_id not in store.grants: + raise InvalidGrantIdException() + + self._delete_grant(store, grant_id, key_id) def retire_grant( self, @@ -405,12 +457,28 @@ def retire_grant( ) -> None: if not grant_token and (not grant_id or not key_id): raise ValidationException("Grant token OR (grant ID, key ID) must be specified") - self._delete_grant( - store=self._get_store(context), - grant_id=grant_id, - key_id=key_id, - grant_token=grant_token, - ) + + if grant_token: + decoded_token = to_str(base64.b64decode(grant_token)) + grant_account_id, grant_region_name, _ = decoded_token.split(":") + grant_store = self._get_store(grant_account_id, grant_region_name) + + if grant_token not in grant_store.grant_tokens: + raise NotFoundException(f"Unable to find grant token {grant_token}") + + grant_id = grant_store.grant_tokens[grant_token] + else: + grant_store = self._get_store(context.account_id, context.region) + + if key_id: + key_account_id, key_region_name, key_id = self._parse_key_id(key_id, context) + key_store = self._get_store(key_account_id, key_region_name) + key = key_store.get_key(key_id, any_key_state_allowed=True) + key_id = key.metadata.get("KeyId") + else: + _, _, key_id = parse_key_arn(grant_store.grants[grant_id].metadata.get("KeyArn")) + + self._delete_grant(grant_store, grant_id, key_id) def list_retirable_grants( self, @@ -424,7 +492,7 @@ def list_retirable_grants( matching_grants = [ grant.metadata - for grant in self._get_store(context).grants.values() + for grant in self._get_store(context.account_id, context.region).grants.values() if grant.metadata.get("RetiringPrincipal") == retiring_principal ] grants_list = PaginatedList(matching_grants) @@ -443,7 +511,14 @@ def get_public_key( ) -> GetPublicKeyResponse: # According to https://docs.aws.amazon.com/kms/latest/developerguide/key-state.html, GetPublicKey is supposed # to fail for disabled keys. But it actually doesn't fail in AWS. - key = self._get_key(context, key_id, enabled_key_allowed=True, disabled_key_allowed=True) + account_id, region_name, key_id = self._parse_key_id(key_id, context) + key = self._get_key( + account_id, + region_name, + key_id, + enabled_key_allowed=True, + disabled_key_allowed=True, + ) attrs = [ "KeySpec", "KeyUsage", @@ -456,7 +531,8 @@ def get_public_key( return GetPublicKeyResponse(**result) def _generate_data_key_pair(self, key_id: str, key_pair_spec: str, context: RequestContext): - key = self._get_key(context, key_id) + account_id, region_name, key_id = self._parse_key_id(key_id, context) + key = self._get_key(account_id, region_name, key_id) self._validate_key_for_encryption_decryption(context, key) crypto_key = KmsCryptoKey(key_pair_spec) return { @@ -516,7 +592,8 @@ def generate_data_key_pair_without_plaintext( # # TODO We also do not use the encryption context. Should reuse the way we do it in encrypt / decrypt. def _generate_data_key(self, key_id: str, context: RequestContext): - key = self._get_key(context, key_id) + account_id, region_name, key_id = self._parse_key_id(key_id, context) + key = self._get_key(account_id, region_name, key_id) # TODO Should also have a validation for the key being a symmetric one. self._validate_key_for_encryption_decryption(context, key) crypto_key = KmsCryptoKey("SYMMETRIC_DEFAULT") @@ -550,7 +627,9 @@ def generate_mac( msg = request.get("Message") self._validate_mac_msg_length(msg) - key = self._get_key(context, request.get("KeyId")) + account_id, region_name, key_id = self._parse_key_id(request["KeyId"], context) + key = self._get_key(account_id, region_name, key_id) + self._validate_key_for_generate_verify_mac(context, key) algorithm = request.get("MacAlgorithm") @@ -569,7 +648,9 @@ def verify_mac( msg = request.get("Message") self._validate_mac_msg_length(msg) - key = self._get_key(context, request.get("KeyId")) + account_id, region_name, key_id = self._parse_key_id(request["KeyId"], context) + key = self._get_key(account_id, region_name, key_id) + self._validate_key_for_generate_verify_mac(context, key) algorithm = request.get("MacAlgorithm") @@ -583,7 +664,9 @@ def verify_mac( @handler("Sign", expand=False) def sign(self, context: RequestContext, request: SignRequest) -> SignResponse: - key = self._get_key(context, request.get("KeyId")) + account_id, region_name, key_id = self._parse_key_id(request["KeyId"], context) + key = self._get_key(account_id, region_name, key_id) + self._validate_key_for_sign_verify(context, key) # TODO Add constraints on KeySpec / SigningAlgorithm pairs: @@ -602,7 +685,9 @@ def sign(self, context: RequestContext, request: SignRequest) -> SignResponse: # Currently LocalStack only calculates SHA256 digests no matter what the signing algorithm is. @handler("Verify", expand=False) def verify(self, context: RequestContext, request: VerifyRequest) -> VerifyResponse: - key = self._get_key(context, request.get("KeyId")) + account_id, region_name, key_id = self._parse_key_id(request["KeyId"], context) + key = self._get_key(account_id, region_name, key_id) + self._validate_key_for_sign_verify(context, key) signing_algorithm = request.get("SigningAlgorithm") @@ -620,6 +705,21 @@ def verify(self, context: RequestContext, request: VerifyRequest) -> VerifyRespo } return VerifyResponse(**result) + def re_encrypt( + self, + context: RequestContext, + ciphertext_blob: CiphertextType, + destination_key_id: KeyIdType, + source_encryption_context: EncryptionContextType = None, + source_key_id: KeyIdType = None, + destination_encryption_context: EncryptionContextType = None, + source_encryption_algorithm: EncryptionAlgorithmSpec = None, + destination_encryption_algorithm: EncryptionAlgorithmSpec = None, + grant_tokens: GrantTokenList = None, + ) -> ReEncryptResponse: + # TODO: when implementing, ensure cross-account support for source_key_id and destination_key_id + raise NotImplementedError + def encrypt( self, context: RequestContext, @@ -629,7 +729,8 @@ def encrypt( grant_tokens: GrantTokenList = None, encryption_algorithm: EncryptionAlgorithmSpec = None, ) -> EncryptResponse: - key = self._get_key(context, key_id) + account_id, region_name, key_id = self._parse_key_id(key_id, context) + key = self._get_key(account_id, region_name, key_id) self._validate_plaintext_length(plaintext) self._validate_plaintext_key_type_based(plaintext, key, encryption_algorithm) self._validate_key_for_encryption_decryption(context, key) @@ -665,10 +766,9 @@ def decrypt( "LocalStack is unable to deserialize the ciphertext blob. Perhaps the " "blob didn't come from LocalStack" ) - key_id = key_id or ciphertext.key_id - key = self._get_key(context, key_id) - key_id = key.metadata["KeyId"] - if key_id != ciphertext.key_id: + account_id, region_name, key_id = self._parse_key_id(key_id or ciphertext.key_id, context) + key = self._get_key(account_id, region_name, key_id) + if key.metadata["KeyId"] != ciphertext.key_id: raise IncorrectKeyException( "The key ID in the request does not identify a CMK that can perform this operation." ) @@ -697,7 +797,7 @@ def get_parameters_for_import( wrapping_algorithm: AlgorithmSpec, wrapping_key_spec: WrappingKeySpec, ) -> GetParametersForImportResponse: - store = self._get_store(context) + store = self._get_store(context.account_id, context.region) # KeyId can potentially hold one of multiple different types of key identifiers. get_key finds a key no # matter which type of id is used. key_to_import_material_to = store.get_key( @@ -736,7 +836,7 @@ def import_key_material( valid_to: DateType = None, expiration_model: ExpirationModelType = None, ) -> ImportKeyMaterialResponse: - store = self._get_store(context) + store = self._get_store(context.account_id, context.region) import_token = to_str(import_token) import_state = store.imports.get(import_token) if not import_state: @@ -783,7 +883,7 @@ def import_key_material( return ImportKeyMaterialResponse() def delete_imported_key_material(self, context: RequestContext, key_id: KeyIdType) -> None: - store = self._get_store(context) + store = self._get_store(context.account_id, context.region) key = store.get_key(key_id, enabled_key_allowed=True, disabled_key_allowed=True) key.crypto_key.key_material = None key.metadata["Enabled"] = False @@ -792,7 +892,7 @@ def delete_imported_key_material(self, context: RequestContext, key_id: KeyIdTyp @handler("CreateAlias", expand=False) def create_alias(self, context: RequestContext, request: CreateAliasRequest) -> None: - store = self._get_store(context) + store = self._get_store(context.account_id, context.region) alias_name = request["AliasName"] validate_alias_name(alias_name) if alias_name in store.aliases: @@ -802,7 +902,11 @@ def create_alias(self, context: RequestContext, request: CreateAliasRequest) -> # KeyId can potentially hold one of multiple different types of key identifiers. Here we find a key no # matter which type of id is used. key = self._get_key( - context, request.get("TargetKeyId"), enabled_key_allowed=True, disabled_key_allowed=True + context.account_id, + context.region, + request.get("TargetKeyId"), + enabled_key_allowed=True, + disabled_key_allowed=True, ) request["TargetKeyId"] = key.metadata.get("KeyId") store.create_alias(request) @@ -811,7 +915,7 @@ def create_alias(self, context: RequestContext, request: CreateAliasRequest) -> def delete_alias(self, context: RequestContext, request: DeleteAliasRequest) -> None: # We do not check the state of the key, as, according to AWS docs, all key states, that are possible in # LocalStack, are supported by this operation. - store = self._get_store(context) + store = self._get_store(context.account_id, context.region) alias_name = request["AliasName"] if alias_name not in store.aliases: alias_arn = kms_alias_arn(request["AliasName"], context.account_id, context.region) @@ -831,7 +935,7 @@ def update_alias(self, context: RequestContext, request: UpdateAliasRequest) -> alias_name = request["AliasName"] # This API, per AWS docs, accepts only names, not ARNs. validate_alias_name(alias_name) - store = self._get_store(context) + store = self._get_store(context.account_id, context.region) alias = store.get_alias(alias_name, context.account_id, context.region) key_id = request["TargetKeyId"] # Don't care about the key itself, just want to validate its state. @@ -847,7 +951,7 @@ def list_aliases( limit: LimitType = None, marker: MarkerType = None, ) -> ListAliasesResponse: - store = self._get_store(context) + store = self._get_store(context.account_id, context.region) if key_id: # KeyId can potentially hold one of multiple different types of key identifiers. Here we find a key no # matter which type of id is used. @@ -876,7 +980,8 @@ def get_key_rotation_status( # https://docs.aws.amazon.com/kms/latest/developerguide/key-state.html # "If the KMS key has imported key material or is in a custom key store: UnsupportedOperationException." # We do not model that here, though. - key = self._get_key(context, request.get("KeyId"), any_key_state_allowed=True) + account_id, region_name, key_id = self._parse_key_id(request["KeyId"], context) + key = self._get_key(account_id, region_name, key_id, any_key_state_allowed=True) return GetKeyRotationStatusResponse(KeyRotationEnabled=key.is_key_rotation_enabled) @handler("DisableKeyRotation", expand=False) @@ -886,7 +991,7 @@ def disable_key_rotation( # https://docs.aws.amazon.com/kms/latest/developerguide/key-state.html # "If the KMS key has imported key material or is in a custom key store: UnsupportedOperationException." # We do not model that here, though. - key = self._get_key(context, request.get("KeyId")) + key = self._get_key(context.account_id, context.region, request.get("KeyId")) key.is_key_rotation_enabled = False @handler("EnableKeyRotation", expand=False) @@ -896,7 +1001,7 @@ def enable_key_rotation( # https://docs.aws.amazon.com/kms/latest/developerguide/key-state.html # "If the KMS key has imported key material or is in a custom key store: UnsupportedOperationException." # We do not model that here, though. - key = self._get_key(context, request.get("KeyId")) + key = self._get_key(context.account_id, context.region, request.get("KeyId")) key.is_key_rotation_enabled = True @handler("ListKeyPolicies", expand=False) @@ -906,12 +1011,16 @@ def list_key_policies( # We just care if the key exists. The response, by AWS specifications, is the same for all keys, as the only # supported policy is "default": # https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeyPolicies.html#API_ListKeyPolicies_ResponseElements - self._get_key(context, request.get("KeyId"), any_key_state_allowed=True) + self._get_key( + context.account_id, context.region, request.get("KeyId"), any_key_state_allowed=True + ) return ListKeyPoliciesResponse(PolicyNames=["default"], Truncated=False) @handler("PutKeyPolicy", expand=False) def put_key_policy(self, context: RequestContext, request: PutKeyPolicyRequest) -> None: - key = self._get_key(context, request.get("KeyId"), any_key_state_allowed=True) + key = self._get_key( + context.account_id, context.region, request.get("KeyId"), any_key_state_allowed=True + ) if request.get("PolicyName") != "default": raise UnsupportedOperationException("Only default policy is supported") key.policy = request.get("Policy") @@ -920,7 +1029,9 @@ def put_key_policy(self, context: RequestContext, request: PutKeyPolicyRequest) def get_key_policy( self, context: RequestContext, request: GetKeyPolicyRequest ) -> GetKeyPolicyResponse: - key = self._get_key(context, request.get("KeyId"), any_key_state_allowed=True) + key = self._get_key( + context.account_id, context.region, request.get("KeyId"), any_key_state_allowed=True + ) if request.get("PolicyName") != "default": raise NotFoundException("No such policy exists") return GetKeyPolicyResponse(Policy=key.policy) @@ -929,7 +1040,9 @@ def get_key_policy( def list_resource_tags( self, context: RequestContext, request: ListResourceTagsRequest ) -> ListResourceTagsResponse: - key = self._get_key(context, request.get("KeyId"), any_key_state_allowed=True) + key = self._get_key( + context.account_id, context.region, request.get("KeyId"), any_key_state_allowed=True + ) keys_list = PaginatedList( [{"TagKey": tag_key, "TagValue": tag_value} for tag_key, tag_value in key.tags.items()] ) @@ -944,14 +1057,22 @@ def list_resource_tags( @handler("TagResource", expand=False) def tag_resource(self, context: RequestContext, request: TagResourceRequest) -> None: key = self._get_key( - context, request.get("KeyId"), enabled_key_allowed=True, disabled_key_allowed=True + context.account_id, + context.region, + request.get("KeyId"), + enabled_key_allowed=True, + disabled_key_allowed=True, ) key.add_tags(request.get("Tags")) @handler("UntagResource", expand=False) def untag_resource(self, context: RequestContext, request: UntagResourceRequest) -> None: key = self._get_key( - context, request.get("KeyId"), enabled_key_allowed=True, disabled_key_allowed=True + context.account_id, + context.region, + request.get("KeyId"), + enabled_key_allowed=True, + disabled_key_allowed=True, ) if not request.get("TagKeys"): return @@ -1013,7 +1134,7 @@ def _validate_plaintext_length(self, plaintext: bytes): "Member must have length less than or equal to 4096" ) - def _validate_grant_request(self, data: Dict, store: KmsStore): + def _validate_grant_request(self, data: Dict): if "KeyId" not in data or "GranteePrincipal" not in data or "Operations" not in data: raise ValidationError("Grant ID, key ID and grantee principal must be specified") diff --git a/localstack/services/kms/utils.py b/localstack/services/kms/utils.py index f22475c23573e..c17bdc777eeb6 100644 --- a/localstack/services/kms/utils.py +++ b/localstack/services/kms/utils.py @@ -1,6 +1,31 @@ +import re +from typing import Tuple + +KMS_KEY_ARN_PATTERN = re.compile( + r"^arn:aws:kms:(?P[^:]+):(?P\d{12}):key\/(?P[^:]+)$" +) + + def get_hash_algorithm(signing_algorithm: str) -> str: """ Return the hashing algorithm for a given signing algorithm. eg. "RSASSA_PSS_SHA_512" -> "SHA_512" """ return "_".join(signing_algorithm.rsplit(sep="_", maxsplit=-2)[-2:]) + + +def parse_key_arn(key_arn: str) -> Tuple[str, str, str]: + """ + Parse a valid KMS key arn into its constituents. + + :param key_arn: KMS key ARN + :return: Tuple of account ID, region name and key ID + """ + return KMS_KEY_ARN_PATTERN.match(key_arn).group("account_id", "region_name", "key_id") + + +def is_valid_key_arn(key_arn: str) -> bool: + """ + Check if a given string is a valid KMS key ARN. + """ + return KMS_KEY_ARN_PATTERN.match(key_arn) is not None diff --git a/tests/integration/test_kms.py b/tests/integration/test_kms.py index 973331d9b7c95..f0115323793c0 100644 --- a/tests/integration/test_kms.py +++ b/tests/integration/test_kms.py @@ -11,8 +11,13 @@ from cryptography.hazmat.primitives.serialization import load_der_public_key from localstack.aws.accounts import get_aws_account_id +from localstack.constants import ( + SECONDARY_TEST_AWS_ACCESS_KEY_ID, + SECONDARY_TEST_AWS_SECRET_ACCESS_KEY, + TEST_AWS_REGION_NAME, +) from localstack.services.kms.utils import get_hash_algorithm -from localstack.utils.strings import short_uid +from localstack.utils.strings import short_uid, to_str @pytest.fixture(autouse=True) @@ -303,7 +308,7 @@ def test_revoke_grant(self, kms_grant_and_key, aws_client): assert len(grants_after) == len(grants_before) - 1 @pytest.mark.aws_validated - def test_retire_grant(self, kms_grant_and_key, aws_client): + def test_retire_grant_with_grant_token(self, kms_grant_and_key, aws_client): grant = kms_grant_and_key[0] key_id = kms_grant_and_key[1]["KeyId"] grants_before = aws_client.kms.list_grants(KeyId=key_id)["Grants"] @@ -313,6 +318,17 @@ def test_retire_grant(self, kms_grant_and_key, aws_client): grants_after = aws_client.kms.list_grants(KeyId=key_id)["Grants"] assert len(grants_after) == len(grants_before) - 1 + @pytest.mark.aws_validated + def test_retire_grant_with_grant_id_and_key_id(self, kms_grant_and_key, aws_client): + grant = kms_grant_and_key[0] + key_id = kms_grant_and_key[1]["KeyId"] + grants_before = aws_client.kms.list_grants(KeyId=key_id)["Grants"] + + aws_client.kms.retire_grant(GrantId=grant["GrantId"], KeyId=key_id) + + grants_after = aws_client.kms.list_grants(KeyId=key_id)["Grants"] + assert len(grants_after) == len(grants_before) - 1 + # Fails against AWS, as the retiring_principal_arn_prefix is invalid there. @pytest.mark.only_localstack def test_list_retirable_grants(self, kms_create_key, kms_create_grant, aws_client): @@ -1105,3 +1121,104 @@ def test_plaintext_size_for_encrypt(self, kms_create_key, snapshot, aws_client): with pytest.raises(ClientError) as e: aws_client.kms.encrypt(KeyId=key_id, Plaintext=base64.b64encode(message * 100)) snapshot.match("invalid-plaintext-size-encrypt", e.value.response) + + def test_cross_accounts_access(self, aws_client, aws_client_factory, kms_create_key, user_arn): + # Create the keys in the primary AWS account. They will only be referred to by their ARNs hereon + key_arn_1 = kms_create_key()["Arn"] + key_arn_2 = kms_create_key(KeyUsage="SIGN_VERIFY", KeySpec="RSA_4096")["Arn"] + key_arn_3 = kms_create_key(KeyUsage="GENERATE_VERIFY_MAC", KeySpec="HMAC_512")["Arn"] + + # Create client in secondary account and attempt to run operations with the above keys + client = aws_client_factory.get_client( + "kms", + aws_access_key_id=SECONDARY_TEST_AWS_ACCESS_KEY_ID, + aws_secret_access_key=SECONDARY_TEST_AWS_SECRET_ACCESS_KEY, + region_name=TEST_AWS_REGION_NAME, + ) + + # Cross-account access is supported for following operations in KMS: + # - CreateGrant + # - DescribeKey + # - GetKeyRotationStatus + # - GetPublicKey + # - ListGrants + # - RetireGrant + # - RevokeGrant + + response = client.create_grant( + KeyId=key_arn_1, + GranteePrincipal=user_arn, + Operations=["Decrypt", "Encrypt"], + ) + grant_token = response["GrantToken"] + + response = client.create_grant( + KeyId=key_arn_2, + GranteePrincipal=user_arn, + Operations=["Sign", "Verify"], + ) + grant_id = response["GrantId"] + + assert client.describe_key(KeyId=key_arn_1)["KeyMetadata"] + + assert client.get_key_rotation_status(KeyId=key_arn_1) + + assert client.get_public_key(KeyId=key_arn_1) + + assert client.list_grants(KeyId=key_arn_1)["Grants"] + + assert client.retire_grant(GrantToken=grant_token) + + assert client.revoke_grant(GrantId=grant_id, KeyId=key_arn_2) + + # And additionally, the following cryptographic operations: + # - Decrypt + # - Encrypt + # - GenerateDataKey + # - GenerateDataKeyPair + # - GenerateDataKeyPairWithoutPlaintext + # - GenerateDataKeyWithoutPlaintext + # - GenerateMac + # - ReEncrypt (NOT IMPLEMENTED IN LOCALSTACK) + # - Sign + # - Verify + # - VerifyMac + + assert client.generate_data_key(KeyId=key_arn_1) + + assert client.generate_data_key_without_plaintext(KeyId=key_arn_1) + + assert client.generate_data_key_pair(KeyId=key_arn_1, KeyPairSpec="RSA_2048") + + assert client.generate_data_key_pair_without_plaintext( + KeyId=key_arn_1, KeyPairSpec="RSA_2048" + ) + + plaintext = "hello" + ciphertext = client.encrypt(KeyId=key_arn_1, Plaintext="hello")["CiphertextBlob"] + + response = client.decrypt(CiphertextBlob=ciphertext, KeyId=key_arn_1) + assert plaintext == to_str(response["Plaintext"]) + + message = "world" + signature = client.sign( + KeyId=key_arn_2, + MessageType="RAW", + Message=message, + SigningAlgorithm="RSASSA_PKCS1_V1_5_SHA_256", + )["Signature"] + + assert client.verify( + KeyId=key_arn_2, + Signature=signature, + Message=message, + SigningAlgorithm="RSASSA_PKCS1_V1_5_SHA_256", + )["SignatureValid"] + + mac = client.generate_mac(KeyId=key_arn_3, Message=message, MacAlgorithm="HMAC_SHA_512")[ + "Mac" + ] + + assert client.verify_mac( + KeyId=key_arn_3, Message=message, MacAlgorithm="HMAC_SHA_512", Mac=mac + )["MacValid"]