From 07793e6009c6c3aba1a9fe72cb2468b56823b064 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 10 Aug 2022 17:38:45 -0500 Subject: [PATCH 01/10] PYTHON-3256 Obtain AWS credentials for CSFLE in the same way as for MONGODB-AWS --- bindings/python/pymongocrypt/mongocrypt.py | 34 ++++++++++++++++--- bindings/python/pymongocrypt/state_machine.py | 15 ++++++++ bindings/python/test/test_mongocrypt.py | 23 ++++++++++++- 3 files changed, 67 insertions(+), 5 deletions(-) diff --git a/bindings/python/pymongocrypt/mongocrypt.py b/bindings/python/pymongocrypt/mongocrypt.py index f4ea197c5..939a4423f 100644 --- a/bindings/python/pymongocrypt/mongocrypt.py +++ b/bindings/python/pymongocrypt/mongocrypt.py @@ -35,7 +35,8 @@ class MongoCryptOptions(object): def __init__(self, kms_providers, schema_map=None, encrypted_fields_map=None, bypass_query_analysis=False, crypt_shared_lib_path=None, - crypt_shared_lib_required=False, bypass_encryption=False): + crypt_shared_lib_required=False, bypass_encryption=False, + use_need_kms_credentials_state=False): """Options for :class:`MongoCrypt`. :Parameters: @@ -70,6 +71,20 @@ def __init__(self, kms_providers, schema_map=None, encrypted_fields_map=None, outgoing commands. Set `bypass_query_analysis` to use explicit encryption on indexed fields without the MongoDB Enterprise Advanced licensed crypt_shared library. + - `crypt_shared_lib_path`: Optional string path to the crypt_shared + library. + - `crypt_shared_lib_required`: Whether to require a crypt_shared + library. + - `bypass_encryption`: Whether to bypass encryption. + - `use_need_kms_credentials_state`: Whether to enable on-demand + kms credential access. + + .. versionadded:: 1.4 + ``use_need_kms_credentials_state`` parameter. + + .. versionadded:: 1.3 + ``crypt_shared_lib_path``, ``crypt_shared_lib_path``, + ``bypass_encryption`` parameters. .. versionadded:: 1.1 Support for "azure" and "gcp" kms_providers. @@ -88,9 +103,10 @@ def __init__(self, kms_providers, schema_map=None, encrypted_fields_map=None, aws = kms_providers["aws"] if not isinstance(aws, dict): raise ValueError("kms_providers['aws'] must be a dict") - if "accessKeyId" not in aws or "secretAccessKey" not in aws: - raise ValueError("kms_providers['aws'] must contain " - "'accessKeyId' and 'secretAccessKey'") + if len(aws): + if "accessKeyId" not in aws or "secretAccessKey" not in aws: + raise ValueError("kms_providers['aws'] must contain " + "'accessKeyId' and 'secretAccessKey'") if 'azure' in kms_providers: azure = kms_providers["azure"] @@ -150,6 +166,7 @@ def __init__(self, kms_providers, schema_map=None, encrypted_fields_map=None, self.crypt_shared_lib_path = crypt_shared_lib_path self.crypt_shared_lib_required = crypt_shared_lib_required self.bypass_encryption = bypass_encryption + self.use_need_kms_credentials_state = use_need_kms_credentials_state class MongoCrypt(object): @@ -239,6 +256,9 @@ def __init(self): if not self.__opts.bypass_encryption: lib.mongocrypt_setopt_append_crypt_shared_lib_search_path(self.__crypt, b"$SYSTEM") + if self.__opts.use_need_kms_credentials_state: + lib.mongocrypt_setopt_use_need_kms_credentials_state(self.__crypt) + if not lib.mongocrypt_init(self.__crypt): self.__raise_from_status() @@ -422,6 +442,12 @@ def complete_mongo_operation(self): if not lib.mongocrypt_ctx_mongo_done(self.__ctx): self._raise_from_status() + def provide_kms_providers(self, providers): + """Provide a map of KMS providers.""" + with MongoCryptBinaryIn(providers) as binary: + if not lib.mongocrypt_ctx_provide_kms_providers(self.__ctx, binary.bin): + self._raise_from_status() + def kms_contexts(self): """Yields the MongoCryptKmsContexts.""" ctx = lib.mongocrypt_ctx_next_kms_ctx(self.__ctx) diff --git a/bindings/python/pymongocrypt/state_machine.py b/bindings/python/pymongocrypt/state_machine.py index 0f82f123d..216e12ba1 100644 --- a/bindings/python/pymongocrypt/state_machine.py +++ b/bindings/python/pymongocrypt/state_machine.py @@ -101,6 +101,15 @@ def bson_encode(self, doc): """ pass + @abstractmethod + def ask_for_kms_credentials(self): + """Return on-demand kms credentials. + + :Returns: + Map of KMS provider options. + """ + pass + @abstractmethod def close(self): """Release resources.""" @@ -149,5 +158,11 @@ def run_state_machine(ctx, callback): with kms_ctx: callback.kms_request(kms_ctx) ctx.complete_kms() + elif state == lib.MONGOCRYPT_CTX_NEED_KMS_CREDENTIALS: + creds = callback.ask_for_kms_credentials() + if not isinstance(creds, bytes): + creds = callback.bson_encode(creds) + ctx.provide_kms_providers(creds) + ctx.complete_mongo_operation() else: raise MongoCryptError('unknown state: %r' % (state,)) diff --git a/bindings/python/test/test_mongocrypt.py b/bindings/python/test/test_mongocrypt.py index 46a07be90..32cf6c8b1 100644 --- a/bindings/python/test/test_mongocrypt.py +++ b/bindings/python/test/test_mongocrypt.py @@ -39,7 +39,7 @@ MongoCryptBinaryIn, MongoCryptBinaryOut, MongoCryptOptions) -from pymongocrypt.state_machine import MongoCryptCallback +from pymongocrypt.state_machine import MongoCryptCallback, run_state_machine from test import unittest @@ -86,6 +86,7 @@ def test_mongocrypt_options(self): schema_map = bson_data('schema-map.json') valid = [ ({'local': {'key': b'1' * 96}}, None), + ({ 'aws' : {} }, schema_map), ({'aws': {'accessKeyId': '', 'secretAccessKey': ''}}, schema_map), ({'aws': {'accessKeyId': 'foo', 'secretAccessKey': 'foo'}}, None), ({'aws': {'accessKeyId': 'foo', 'secretAccessKey': 'foo', @@ -107,12 +108,14 @@ def test_mongocrypt_options(self): self.assertEqual(opts.schema_map, schema_map) self.assertIsNone(opts.encrypted_fields_map) self.assertFalse(opts.bypass_query_analysis) + self.assertFalse(opts.use_need_kms_credentials_state) encrypted_fields_map = bson_data('encrypted-field-config-map.json') opts = MongoCryptOptions(valid[0][0], schema_map, encrypted_fields_map=encrypted_fields_map, bypass_query_analysis=True) self.assertEqual(opts.encrypted_fields_map, encrypted_fields_map) self.assertTrue(opts.bypass_query_analysis) + self.assertFalse(opts.use_need_kms_credentials_state) def test_mongocrypt_options_validation(self): with self.assertRaisesRegex( @@ -367,6 +370,9 @@ def insert_data_key(self, data_key): def bson_encode(self, doc): return bson.encode(doc) + def ask_for_kms_credentials(self): + return { "aws": { "accessKeyId": "foo:", 'secretAccessKey': 'foo'}, 'local': {'key': b'\x00'*96} } + def close(self): pass @@ -430,6 +436,20 @@ def test_decrypt(self): json_data('command-reply.json')) self.assertEqual(decrypted, bson_data('command-reply.json')) + def test_need_kms_credentials(self): + kms_providers = { 'aws': {}, 'local': {'key': b'\x00'*96} } + opts = MongoCryptOptions(kms_providers, use_need_kms_credentials_state=True) + encrypter = AutoEncrypter(MockCallback( + list_colls_result=bson_data('collection-info.json'), + mongocryptd_reply=bson_data('mongocryptd-reply.json'), + key_docs=[bson_data('key-document.json')], + kms_reply=http_data('kms-reply.txt')), opts) + self.addCleanup(encrypter.close) + encrypted = encrypter.encrypt('test', bson_data('command.json')) + self.assertEqual(bson.decode(encrypted, OPTS), + json_data('encrypted-command.json')) + self.assertEqual(encrypted, bson_data('encrypted-command.json')) + class KeyVaultCallback(MockCallback): def __init__(self, kms_reply=None): @@ -595,6 +615,7 @@ def test_rewrap_many_data_key(self): raw_doc = RawBSONDocument(result) assert len(raw_doc['v']) == 2 + def read(filename, **kwargs): with open(os.path.join(DATA_DIR, filename), **kwargs) as fp: return fp.read() From 9abe87859ba28468fedc8b445dc87aa9c3f32825 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 10 Aug 2022 17:39:44 -0500 Subject: [PATCH 02/10] fix unused import --- bindings/python/test/test_mongocrypt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindings/python/test/test_mongocrypt.py b/bindings/python/test/test_mongocrypt.py index 32cf6c8b1..b8b5c0a77 100644 --- a/bindings/python/test/test_mongocrypt.py +++ b/bindings/python/test/test_mongocrypt.py @@ -39,7 +39,7 @@ MongoCryptBinaryIn, MongoCryptBinaryOut, MongoCryptOptions) -from pymongocrypt.state_machine import MongoCryptCallback, run_state_machine +from pymongocrypt.state_machine import MongoCryptCallback from test import unittest From 0c9ed2f67b4c28f9ea3401e8fb9473009f9cc564 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 10 Aug 2022 18:10:52 -0500 Subject: [PATCH 03/10] fix test --- bindings/python/pymongocrypt/state_machine.py | 1 - bindings/python/test/test_mongocrypt.py | 21 ++++++++++++------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/bindings/python/pymongocrypt/state_machine.py b/bindings/python/pymongocrypt/state_machine.py index 216e12ba1..965db9bd6 100644 --- a/bindings/python/pymongocrypt/state_machine.py +++ b/bindings/python/pymongocrypt/state_machine.py @@ -163,6 +163,5 @@ def run_state_machine(ctx, callback): if not isinstance(creds, bytes): creds = callback.bson_encode(creds) ctx.provide_kms_providers(creds) - ctx.complete_mongo_operation() else: raise MongoCryptError('unknown state: %r' % (state,)) diff --git a/bindings/python/test/test_mongocrypt.py b/bindings/python/test/test_mongocrypt.py index b8b5c0a77..45e06e08a 100644 --- a/bindings/python/test/test_mongocrypt.py +++ b/bindings/python/test/test_mongocrypt.py @@ -350,6 +350,7 @@ def __init__(self, self.key_docs = key_docs self.kms_reply = kms_reply self.kms_endpoint = None + self.got_on_demand_credentials = False def kms_request(self, kms_context): self.kms_endpoint = kms_context.endpoint @@ -371,7 +372,8 @@ def bson_encode(self, doc): return bson.encode(doc) def ask_for_kms_credentials(self): - return { "aws": { "accessKeyId": "foo:", 'secretAccessKey': 'foo'}, 'local': {'key': b'\x00'*96} } + self.got_on_demand_credentials = True + return { "aws": { "accessKeyId": "example", "secretAccessKey": "example"} } def close(self): pass @@ -437,18 +439,21 @@ def test_decrypt(self): self.assertEqual(decrypted, bson_data('command-reply.json')) def test_need_kms_credentials(self): - kms_providers = { 'aws': {}, 'local': {'key': b'\x00'*96} } + kms_providers = { 'aws': {} } opts = MongoCryptOptions(kms_providers, use_need_kms_credentials_state=True) - encrypter = AutoEncrypter(MockCallback( + callback = MockCallback( list_colls_result=bson_data('collection-info.json'), mongocryptd_reply=bson_data('mongocryptd-reply.json'), key_docs=[bson_data('key-document.json')], - kms_reply=http_data('kms-reply.txt')), opts) + kms_reply=http_data('kms-reply.txt')) + encrypter = AutoEncrypter(callback, opts) self.addCleanup(encrypter.close) - encrypted = encrypter.encrypt('test', bson_data('command.json')) - self.assertEqual(bson.decode(encrypted, OPTS), - json_data('encrypted-command.json')) - self.assertEqual(encrypted, bson_data('encrypted-command.json')) + decrypted = encrypter.decrypt( + bson_data('encrypted-command-reply.json')) + self.assertEqual(bson.decode(decrypted, OPTS), + json_data('command-reply.json')) + self.assertEqual(decrypted, bson_data('command-reply.json')) + self.assertEqual(callback.got_on_demand_credentials, True) class KeyVaultCallback(MockCallback): From 8135d429a1223a818b9728ec97c1957db495a5a5 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Fri, 12 Aug 2022 13:58:07 -0500 Subject: [PATCH 04/10] address review --- bindings/python/test/test_mongocrypt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bindings/python/test/test_mongocrypt.py b/bindings/python/test/test_mongocrypt.py index 45e06e08a..392885d6f 100644 --- a/bindings/python/test/test_mongocrypt.py +++ b/bindings/python/test/test_mongocrypt.py @@ -350,6 +350,8 @@ def __init__(self, self.key_docs = key_docs self.kms_reply = kms_reply self.kms_endpoint = None + # Used to track whether we have fetched + # on demand credentials for testing purposes. self.got_on_demand_credentials = False def kms_request(self, kms_context): @@ -620,7 +622,6 @@ def test_rewrap_many_data_key(self): raw_doc = RawBSONDocument(result) assert len(raw_doc['v']) == 2 - def read(filename, **kwargs): with open(os.path.join(DATA_DIR, filename), **kwargs) as fp: return fp.read() From 92538282934915f94bb9ede25e1c071c3d4f80c9 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 24 Aug 2022 13:46:21 -0500 Subject: [PATCH 05/10] make on-demand kms providers self-contained --- bindings/python/pymongocrypt/mongocrypt.py | 90 ++++++++++++------- bindings/python/pymongocrypt/state_machine.py | 15 +--- bindings/python/test/test_mongocrypt.py | 19 ++-- 3 files changed, 71 insertions(+), 53 deletions(-) diff --git a/bindings/python/pymongocrypt/mongocrypt.py b/bindings/python/pymongocrypt/mongocrypt.py index 939a4423f..f6cc9f5f0 100644 --- a/bindings/python/pymongocrypt/mongocrypt.py +++ b/bindings/python/pymongocrypt/mongocrypt.py @@ -14,6 +14,12 @@ import copy +try: + from pymongo_auth_aws.auth import _aws_temp_credentials + _HAVE_AUTH_AWS = True +except ImportError: + _HAVE_AUTH_AWS = False + from pymongocrypt.binary import (MongoCryptBinaryIn, MongoCryptBinaryOut) from pymongocrypt.binding import ffi, lib, _to_string @@ -35,8 +41,7 @@ class MongoCryptOptions(object): def __init__(self, kms_providers, schema_map=None, encrypted_fields_map=None, bypass_query_analysis=False, crypt_shared_lib_path=None, - crypt_shared_lib_required=False, bypass_encryption=False, - use_need_kms_credentials_state=False): + crypt_shared_lib_required=False, bypass_encryption=False): """Options for :class:`MongoCrypt`. :Parameters: @@ -76,11 +81,6 @@ def __init__(self, kms_providers, schema_map=None, encrypted_fields_map=None, - `crypt_shared_lib_required`: Whether to require a crypt_shared library. - `bypass_encryption`: Whether to bypass encryption. - - `use_need_kms_credentials_state`: Whether to enable on-demand - kms credential access. - - .. versionadded:: 1.4 - ``use_need_kms_credentials_state`` parameter. .. versionadded:: 1.3 ``crypt_shared_lib_path``, ``crypt_shared_lib_path``, @@ -166,7 +166,6 @@ def __init__(self, kms_providers, schema_map=None, encrypted_fields_map=None, self.crypt_shared_lib_path = crypt_shared_lib_path self.crypt_shared_lib_required = crypt_shared_lib_required self.bypass_encryption = bypass_encryption - self.use_need_kms_credentials_state = use_need_kms_credentials_state class MongoCrypt(object): @@ -256,7 +255,7 @@ def __init(self): if not self.__opts.bypass_encryption: lib.mongocrypt_setopt_append_crypt_shared_lib_search_path(self.__crypt, b"$SYSTEM") - if self.__opts.use_need_kms_credentials_state: + if 'aws' in kms_providers and not len(kms_providers['aws']): lib.mongocrypt_setopt_use_need_kms_credentials_state(self.__crypt) if not lib.mongocrypt_init(self.__crypt): @@ -312,7 +311,7 @@ def encryption_context(self, database, command): :Returns: A :class:`EncryptionContext`. """ - return EncryptionContext(self._create_context(), database, command) + return EncryptionContext(self._create_context(), self.__opts.kms_providers, database, command) def decryption_context(self, command): """Creates a context to use for decryption. @@ -323,7 +322,7 @@ def decryption_context(self, command): :Returns: A :class:`DecryptionContext`. """ - return DecryptionContext(self._create_context(), command) + return DecryptionContext(self._create_context(), self.__opts.kms_providers, command) def explicit_encryption_context(self, value, opts): """Creates a context to use for explicit encryption. @@ -336,7 +335,8 @@ def explicit_encryption_context(self, value, opts): :Returns: A :class:`ExplicitEncryptionContext`. """ - return ExplicitEncryptionContext(self._create_context(), value, opts) + return ExplicitEncryptionContext(self._create_context(), + self.__opts.kms_providers, value, opts) def explicit_decryption_context(self, value): """Creates a context to use for explicit decryption. @@ -348,7 +348,8 @@ def explicit_decryption_context(self, value): :Returns: A :class:`ExplicitDecryptionContext`. """ - return ExplicitDecryptionContext(self._create_context(), value) + return ExplicitDecryptionContext(self._create_context(), + self.__opts.kms_providers, value) def data_key_context(self, kms_provider, opts=None): """Creates a context to use for key generation. @@ -360,7 +361,7 @@ def data_key_context(self, kms_provider, opts=None): :Returns: A :class:`DataKeyContext`. """ - return DataKeyContext(self._create_context(), kms_provider, opts, + return DataKeyContext(self._create_context(), self.__opts.kms_providers, kms_provider, opts, self.__callback) def rewrap_many_data_key_context(self, filter, provider, master_key): @@ -377,21 +378,22 @@ def rewrap_many_data_key_context(self, filter, provider, master_key): :Returns: A :class:`RewrapManyDataKeyContext`. """ - return RewrapManyDataKeyContext(self._create_context(), filter, provider, master_key, self.__callback) + return RewrapManyDataKeyContext(self._create_context(), self.__opts.kms_providers, filter, provider, master_key, self.__callback) class MongoCryptContext(object): - __slots__ = ("__ctx",) + __slots__ = ("__ctx", "__kms_providers") - def __init__(self, ctx): + def __init__(self, ctx, kms_providers): """Abstracts libmongocrypt's mongocrypt_ctx_t type. :Parameters: - `ctx`: A mongocrypt_ctx_t. This MongoCryptContext takes ownership of the underlying mongocrypt_ctx_t. - - `database`: Optional, the name of the database. + - `kms_providers`: The KMS provider map. """ self.__ctx = ctx + self.__kms_providers = kms_providers def _close(self): """Cleanup resources.""" @@ -442,6 +444,10 @@ def complete_mongo_operation(self): if not lib.mongocrypt_ctx_mongo_done(self.__ctx): self._raise_from_status() + def ask_for_kms_credentials(self): + """Get on-demand kms credentials""" + return _ask_for_kms_credentials(self.__kms_providers) + def provide_kms_providers(self, providers): """Provide a map of KMS providers.""" with MongoCryptBinaryIn(providers) as binary: @@ -471,16 +477,17 @@ def finish(self): class EncryptionContext(MongoCryptContext): __slots__ = ("database",) - def __init__(self, ctx, database, command): + def __init__(self, ctx, kms_providers, database, command): """Abstracts libmongocrypt's mongocrypt_ctx_t type. :Parameters: - `ctx`: A mongocrypt_ctx_t. This MongoCryptContext takes ownership of the underlying mongocrypt_ctx_t. + - `kms_providers`: The KMS provider map. - `database`: Optional, the name of the database. - `command`: The BSON command to encrypt. """ - super(EncryptionContext, self).__init__(ctx) + super(EncryptionContext, self).__init__(ctx, kms_providers) self.database = database try: with MongoCryptBinaryIn(command) as binary: @@ -497,15 +504,16 @@ def __init__(self, ctx, database, command): class DecryptionContext(MongoCryptContext): __slots__ = () - def __init__(self, ctx, command): + def __init__(self, ctx, kms_providers, command): """Abstracts libmongocrypt's mongocrypt_ctx_t type. :Parameters: - `ctx`: A mongocrypt_ctx_t. This MongoCryptContext takes ownership of the underlying mongocrypt_ctx_t. + - `kms_providers`: The KMS provider map. - `command`: The encoded BSON command to decrypt. """ - super(DecryptionContext, self).__init__(ctx) + super(DecryptionContext, self).__init__(ctx, kms_providers) try: with MongoCryptBinaryIn(command) as binary: if not lib.mongocrypt_ctx_decrypt_init(ctx, binary.bin): @@ -519,17 +527,18 @@ def __init__(self, ctx, command): class ExplicitEncryptionContext(MongoCryptContext): __slots__ = () - def __init__(self, ctx, value, opts): + def __init__(self, ctx, kms_providers, value, opts): """Abstracts libmongocrypt's mongocrypt_ctx_t type. :Parameters: - `ctx`: A mongocrypt_ctx_t. This MongoCryptContext takes ownership of the underlying mongocrypt_ctx_t. + - `kms_providers`: The KMS provider map. - `value`: The encoded document to encrypt, which must be in the form { "v" : BSON value to encrypt }}. - `opts`: A :class:`ExplicitEncryptOpts`. """ - super(ExplicitEncryptionContext, self).__init__(ctx) + super(ExplicitEncryptionContext, self).__init__(ctx, kms_providers) try: algorithm = str_to_bytes(opts.algorithm) if not lib.mongocrypt_ctx_setopt_algorithm(ctx, algorithm, -1): @@ -566,15 +575,16 @@ def __init__(self, ctx, value, opts): class ExplicitDecryptionContext(MongoCryptContext): __slots__ = () - def __init__(self, ctx, value): + def __init__(self, ctx, kms_providers, value): """Abstracts libmongocrypt's mongocrypt_ctx_t type. :Parameters: - `ctx`: A mongocrypt_ctx_t. This MongoCryptContext takes ownership of the underlying mongocrypt_ctx_t. + - `kms_providers`: The KMS provider map. - `value`: The encoded BSON value to decrypt. """ - super(ExplicitDecryptionContext, self).__init__(ctx) + super(ExplicitDecryptionContext, self).__init__(ctx, kms_providers) try: with MongoCryptBinaryIn(value) as binary: @@ -590,17 +600,18 @@ def __init__(self, ctx, value): class DataKeyContext(MongoCryptContext): __slots__ = () - def __init__(self, ctx, kms_provider, opts, callback): + def __init__(self, ctx, kms_providers, kms_provider, opts, callback): """Abstracts libmongocrypt's mongocrypt_ctx_t type. :Parameters: - `ctx`: A mongocrypt_ctx_t. This MongoCryptContext takes ownership of the underlying mongocrypt_ctx_t. + - `kms_providers`: The KMS provider map. - `kms_provider`: The KMS provider. - `opts`: An optional class:`DataKeyOpts`. - `callback`: A :class:`MongoCryptCallback`. """ - super(DataKeyContext, self).__init__(ctx) + super(DataKeyContext, self).__init__(ctx, kms_providers) try: if kms_provider not in ['aws', 'gcp', 'azure', 'kmip', 'local']: raise ValueError('unknown kms_provider: %s' % (kms_provider,)) @@ -745,18 +756,20 @@ def __raise_from_status(self): class RewrapManyDataKeyContext(MongoCryptContext): __slots__ = () - def __init__(self, ctx, filter, provider, master_key, callback): + def __init__(self, ctx, kms_providers, filter, provider, master_key, + callback): """Abstracts libmongocrypt's mongocrypt_ctx_t type. :Parameters: - `ctx`: A mongocrypt_ctx_t. This MongoCryptContext takes ownership of the underlying mongocrypt_ctx_t. + - `kms_providers`: The KMS provider map. - `filter`: The filter to use when finding data keys to rewrap in the key vault collection.. - `provider`: (optional) The name of a different kms provider. - `master_key`: Optional document for the given provider. - `callback`: A :class:`MongoCryptCallback`. """ - super(RewrapManyDataKeyContext, self).__init__(ctx) + super(RewrapManyDataKeyContext, self).__init__(ctx, kms_providers) key_encryption_key_bson = None if provider is not None: data = dict(provider=provider) @@ -779,3 +792,20 @@ def __init__(self, ctx, filter, provider, master_key, callback): # Destroy the context on error. self._close() raise + + +def _ask_for_kms_credentials(kms_providers): + if 'aws' not in kms_providers: + return + if len(kms_providers['aws']): + return + if not _HAVE_AUTH_AWS: + raise RuntimeError( + "MONGODB-AWS authentication requires pymongo-auth-aws: " + "install with: python -m pip install 'pymongo[aws]'" + ) + creds = _aws_temp_credentials() + creds_dict = {"accessKeyId": creds.username, "secretAccessKey": creds.password} + if creds.token: + creds_dict["sessionToken"] = creds.token + return { 'aws': creds_dict } diff --git a/bindings/python/pymongocrypt/state_machine.py b/bindings/python/pymongocrypt/state_machine.py index 965db9bd6..555896e9e 100644 --- a/bindings/python/pymongocrypt/state_machine.py +++ b/bindings/python/pymongocrypt/state_machine.py @@ -101,15 +101,6 @@ def bson_encode(self, doc): """ pass - @abstractmethod - def ask_for_kms_credentials(self): - """Return on-demand kms credentials. - - :Returns: - Map of KMS provider options. - """ - pass - @abstractmethod def close(self): """Release resources.""" @@ -159,9 +150,7 @@ def run_state_machine(ctx, callback): callback.kms_request(kms_ctx) ctx.complete_kms() elif state == lib.MONGOCRYPT_CTX_NEED_KMS_CREDENTIALS: - creds = callback.ask_for_kms_credentials() - if not isinstance(creds, bytes): - creds = callback.bson_encode(creds) - ctx.provide_kms_providers(creds) + creds = ctx.ask_for_kms_credentials() + ctx.provide_kms_providers(callback.bson_encode(creds)) else: raise MongoCryptError('unknown state: %r' % (state,)) diff --git a/bindings/python/test/test_mongocrypt.py b/bindings/python/test/test_mongocrypt.py index 392885d6f..d164e656b 100644 --- a/bindings/python/test/test_mongocrypt.py +++ b/bindings/python/test/test_mongocrypt.py @@ -42,6 +42,7 @@ from pymongocrypt.state_machine import MongoCryptCallback from test import unittest +from unittest import mock # Data for testing libbmongocrypt binding. DATA_DIR = os.path.realpath(os.path.join(os.path.dirname(__file__), 'data')) @@ -108,14 +109,12 @@ def test_mongocrypt_options(self): self.assertEqual(opts.schema_map, schema_map) self.assertIsNone(opts.encrypted_fields_map) self.assertFalse(opts.bypass_query_analysis) - self.assertFalse(opts.use_need_kms_credentials_state) encrypted_fields_map = bson_data('encrypted-field-config-map.json') opts = MongoCryptOptions(valid[0][0], schema_map, encrypted_fields_map=encrypted_fields_map, bypass_query_analysis=True) self.assertEqual(opts.encrypted_fields_map, encrypted_fields_map) self.assertTrue(opts.bypass_query_analysis) - self.assertFalse(opts.use_need_kms_credentials_state) def test_mongocrypt_options_validation(self): with self.assertRaisesRegex( @@ -373,10 +372,6 @@ def insert_data_key(self, data_key): def bson_encode(self, doc): return bson.encode(doc) - def ask_for_kms_credentials(self): - self.got_on_demand_credentials = True - return { "aws": { "accessKeyId": "example", "secretAccessKey": "example"} } - def close(self): pass @@ -442,7 +437,7 @@ def test_decrypt(self): def test_need_kms_credentials(self): kms_providers = { 'aws': {} } - opts = MongoCryptOptions(kms_providers, use_need_kms_credentials_state=True) + opts = MongoCryptOptions(kms_providers) callback = MockCallback( list_colls_result=bson_data('collection-info.json'), mongocryptd_reply=bson_data('mongocryptd-reply.json'), @@ -450,12 +445,16 @@ def test_need_kms_credentials(self): kms_reply=http_data('kms-reply.txt')) encrypter = AutoEncrypter(callback, opts) self.addCleanup(encrypter.close) - decrypted = encrypter.decrypt( - bson_data('encrypted-command-reply.json')) + + with mock.patch("pymongocrypt.mongocrypt._ask_for_kms_credentials") as m: + m.return_value = { "aws": { "accessKeyId": "example", "secretAccessKey": "example"} } + decrypted = encrypter.decrypt( + bson_data('encrypted-command-reply.json')) + self.assertTrue(m.called) + self.assertEqual(bson.decode(decrypted, OPTS), json_data('command-reply.json')) self.assertEqual(decrypted, bson_data('command-reply.json')) - self.assertEqual(callback.got_on_demand_credentials, True) class KeyVaultCallback(MockCallback): From ff1672dc2f308855430879f136fa3155ced54b86 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 24 Aug 2022 13:50:47 -0500 Subject: [PATCH 06/10] cleanup --- bindings/python/pymongocrypt/mongocrypt.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/bindings/python/pymongocrypt/mongocrypt.py b/bindings/python/pymongocrypt/mongocrypt.py index f6cc9f5f0..371696357 100644 --- a/bindings/python/pymongocrypt/mongocrypt.py +++ b/bindings/python/pymongocrypt/mongocrypt.py @@ -795,8 +795,11 @@ def __init__(self, ctx, kms_providers, filter, provider, master_key, def _ask_for_kms_credentials(kms_providers): + """Get on-demand kms credentials. + + This is a separate function so it can be overridden in unit tests.""" if 'aws' not in kms_providers: - return + return if len(kms_providers['aws']): return if not _HAVE_AUTH_AWS: @@ -808,4 +811,4 @@ def _ask_for_kms_credentials(kms_providers): creds_dict = {"accessKeyId": creds.username, "secretAccessKey": creds.password} if creds.token: creds_dict["sessionToken"] = creds.token - return { 'aws': creds_dict } + return { 'aws': creds_dict } From 8192c15a21abd31a402ccd23638fbee1bff5713f Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 24 Aug 2022 13:51:54 -0500 Subject: [PATCH 07/10] cleanup --- bindings/python/test/test_mongocrypt.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/bindings/python/test/test_mongocrypt.py b/bindings/python/test/test_mongocrypt.py index d164e656b..f6cf39109 100644 --- a/bindings/python/test/test_mongocrypt.py +++ b/bindings/python/test/test_mongocrypt.py @@ -349,9 +349,6 @@ def __init__(self, self.key_docs = key_docs self.kms_reply = kms_reply self.kms_endpoint = None - # Used to track whether we have fetched - # on demand credentials for testing purposes. - self.got_on_demand_credentials = False def kms_request(self, kms_context): self.kms_endpoint = kms_context.endpoint From ce5aab8e5df976f997c39b4655f67021801b9e23 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 24 Aug 2022 15:08:16 -0500 Subject: [PATCH 08/10] fix python 2 compat --- bindings/python/test/test_mongocrypt.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/bindings/python/test/test_mongocrypt.py b/bindings/python/test/test_mongocrypt.py index f6cf39109..971304062 100644 --- a/bindings/python/test/test_mongocrypt.py +++ b/bindings/python/test/test_mongocrypt.py @@ -42,7 +42,12 @@ from pymongocrypt.state_machine import MongoCryptCallback from test import unittest -from unittest import mock + +try: + from unittest import mock +except ImportError: # python 2 + import mock + # Data for testing libbmongocrypt binding. DATA_DIR = os.path.realpath(os.path.join(os.path.dirname(__file__), 'data')) From 7ba5b05bb034486a0f0429490b4b7d3d7ddbe45c Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 24 Aug 2022 18:06:37 -0500 Subject: [PATCH 09/10] try to fix py27 compat --- bindings/python/test/__init__.py | 6 ++++++ bindings/python/test/test_mongocrypt.py | 10 ++-------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/bindings/python/test/__init__.py b/bindings/python/test/__init__.py index bf3664578..93435e213 100644 --- a/bindings/python/test/__init__.py +++ b/bindings/python/test/__init__.py @@ -29,3 +29,9 @@ # deprecated assertRaisesRegexp, with a 'p'. if not hasattr(unittest.TestCase, 'assertRaisesRegex'): unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp + +try: + from unittest import mock +except ImportError: # python 2 + import mock + diff --git a/bindings/python/test/test_mongocrypt.py b/bindings/python/test/test_mongocrypt.py index 971304062..e601a51c3 100644 --- a/bindings/python/test/test_mongocrypt.py +++ b/bindings/python/test/test_mongocrypt.py @@ -18,7 +18,6 @@ import copy import os import sys -import uuid import bson from bson.raw_bson import RawBSONDocument @@ -32,7 +31,7 @@ from pymongocrypt.auto_encrypter import AutoEncrypter from pymongocrypt.binding import lib -from pymongocrypt.compat import unicode_type, safe_bytearray_or_base64, PY3 +from pymongocrypt.compat import unicode_type, PY3 from pymongocrypt.errors import MongoCryptError from pymongocrypt.explicit_encrypter import ExplicitEncrypter from pymongocrypt.mongocrypt import (MongoCrypt, @@ -41,12 +40,7 @@ MongoCryptOptions) from pymongocrypt.state_machine import MongoCryptCallback -from test import unittest - -try: - from unittest import mock -except ImportError: # python 2 - import mock +from test import unittest, mock # Data for testing libbmongocrypt binding. From 88eadaed4e45c809bdc101d1e879a04bf4147056 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 24 Aug 2022 20:32:44 -0500 Subject: [PATCH 10/10] require mock --- bindings/python/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindings/python/setup.py b/bindings/python/setup.py index 792f8d3a1..d6e9cfe36 100644 --- a/bindings/python/setup.py +++ b/bindings/python/setup.py @@ -57,7 +57,7 @@ def get_tag(self): keywords=["mongo", "mongodb", "pymongocrypt", "pymongo", "mongocrypt", "bson"], test_suite="test", - tests_require=["pymongo>=3.11"], + tests_require=["pymongo>=3.11", "mock;python_version=='2.7'",], license="Apache License, Version 2.0", python_requires=">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*", classifiers=[