Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 83 additions & 24 deletions bindings/python/pymongocrypt/mongocrypt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@

import copy

try:
from pymongo_auth_aws.auth import _aws_temp_credentials
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you open a new ticket to make _aws_temp_credentials a public API?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_HAVE_AUTH_AWS = True
except ImportError:
_HAVE_AUTH_AWS = False

from pymongocrypt.binary import (MongoCryptBinaryIn,
MongoCryptBinaryOut)
from pymongocrypt.binding import ffi, lib, _to_string
Expand Down Expand Up @@ -70,6 +76,15 @@ def __init__(self, kms_providers, schema_map=None, encrypted_fields_map=None,
outgoing commands. Set `bypass_query_analysis` to use explicit
encryption on indexed fields without the MongoDB Enterprise Advanced
licensed crypt_shared library.
- `crypt_shared_lib_path`: Optional string path to the crypt_shared
library.
- `crypt_shared_lib_required`: Whether to require a crypt_shared
library.
- `bypass_encryption`: Whether to bypass encryption.

.. versionadded:: 1.3
``crypt_shared_lib_path``, ``crypt_shared_lib_path``,
``bypass_encryption`` parameters.

.. versionadded:: 1.1
Support for "azure" and "gcp" kms_providers.
Expand All @@ -88,9 +103,10 @@ def __init__(self, kms_providers, schema_map=None, encrypted_fields_map=None,
aws = kms_providers["aws"]
if not isinstance(aws, dict):
raise ValueError("kms_providers['aws'] must be a dict")
if "accessKeyId" not in aws or "secretAccessKey" not in aws:
raise ValueError("kms_providers['aws'] must contain "
"'accessKeyId' and 'secretAccessKey'")
if len(aws):
if "accessKeyId" not in aws or "secretAccessKey" not in aws:
raise ValueError("kms_providers['aws'] must contain "
"'accessKeyId' and 'secretAccessKey'")

if 'azure' in kms_providers:
azure = kms_providers["azure"]
Expand Down Expand Up @@ -239,6 +255,9 @@ def __init(self):
if not self.__opts.bypass_encryption:
lib.mongocrypt_setopt_append_crypt_shared_lib_search_path(self.__crypt, b"$SYSTEM")

if 'aws' in kms_providers and not len(kms_providers['aws']):
lib.mongocrypt_setopt_use_need_kms_credentials_state(self.__crypt)

if not lib.mongocrypt_init(self.__crypt):
self.__raise_from_status()

Expand Down Expand Up @@ -292,7 +311,7 @@ def encryption_context(self, database, command):
:Returns:
A :class:`EncryptionContext`.
"""
return EncryptionContext(self._create_context(), database, command)
return EncryptionContext(self._create_context(), self.__opts.kms_providers, database, command)

def decryption_context(self, command):
"""Creates a context to use for decryption.
Expand All @@ -303,7 +322,7 @@ def decryption_context(self, command):
:Returns:
A :class:`DecryptionContext`.
"""
return DecryptionContext(self._create_context(), command)
return DecryptionContext(self._create_context(), self.__opts.kms_providers, command)

def explicit_encryption_context(self, value, opts):
"""Creates a context to use for explicit encryption.
Expand All @@ -316,7 +335,8 @@ def explicit_encryption_context(self, value, opts):
:Returns:
A :class:`ExplicitEncryptionContext`.
"""
return ExplicitEncryptionContext(self._create_context(), value, opts)
return ExplicitEncryptionContext(self._create_context(),
self.__opts.kms_providers, value, opts)

def explicit_decryption_context(self, value):
"""Creates a context to use for explicit decryption.
Expand All @@ -328,7 +348,8 @@ def explicit_decryption_context(self, value):
:Returns:
A :class:`ExplicitDecryptionContext`.
"""
return ExplicitDecryptionContext(self._create_context(), value)
return ExplicitDecryptionContext(self._create_context(),
self.__opts.kms_providers, value)

def data_key_context(self, kms_provider, opts=None):
"""Creates a context to use for key generation.
Expand All @@ -340,7 +361,7 @@ def data_key_context(self, kms_provider, opts=None):
:Returns:
A :class:`DataKeyContext`.
"""
return DataKeyContext(self._create_context(), kms_provider, opts,
return DataKeyContext(self._create_context(), self.__opts.kms_providers, kms_provider, opts,
self.__callback)

def rewrap_many_data_key_context(self, filter, provider, master_key):
Expand All @@ -357,21 +378,22 @@ def rewrap_many_data_key_context(self, filter, provider, master_key):
:Returns:
A :class:`RewrapManyDataKeyContext`.
"""
return RewrapManyDataKeyContext(self._create_context(), filter, provider, master_key, self.__callback)
return RewrapManyDataKeyContext(self._create_context(), self.__opts.kms_providers, filter, provider, master_key, self.__callback)


class MongoCryptContext(object):
__slots__ = ("__ctx",)
__slots__ = ("__ctx", "__kms_providers")

def __init__(self, ctx):
def __init__(self, ctx, kms_providers):
"""Abstracts libmongocrypt's mongocrypt_ctx_t type.

:Parameters:
- `ctx`: A mongocrypt_ctx_t. This MongoCryptContext takes ownership
of the underlying mongocrypt_ctx_t.
- `database`: Optional, the name of the database.
- `kms_providers`: The KMS provider map.
"""
self.__ctx = ctx
self.__kms_providers = kms_providers

def _close(self):
"""Cleanup resources."""
Expand Down Expand Up @@ -422,6 +444,16 @@ def complete_mongo_operation(self):
if not lib.mongocrypt_ctx_mongo_done(self.__ctx):
self._raise_from_status()

def ask_for_kms_credentials(self):
"""Get on-demand kms credentials"""
return _ask_for_kms_credentials(self.__kms_providers)

def provide_kms_providers(self, providers):
"""Provide a map of KMS providers."""
with MongoCryptBinaryIn(providers) as binary:
if not lib.mongocrypt_ctx_provide_kms_providers(self.__ctx, binary.bin):
self._raise_from_status()

def kms_contexts(self):
"""Yields the MongoCryptKmsContexts."""
ctx = lib.mongocrypt_ctx_next_kms_ctx(self.__ctx)
Expand All @@ -445,16 +477,17 @@ def finish(self):
class EncryptionContext(MongoCryptContext):
__slots__ = ("database",)

def __init__(self, ctx, database, command):
def __init__(self, ctx, kms_providers, database, command):
"""Abstracts libmongocrypt's mongocrypt_ctx_t type.

:Parameters:
- `ctx`: A mongocrypt_ctx_t. This MongoCryptContext takes ownership
of the underlying mongocrypt_ctx_t.
- `kms_providers`: The KMS provider map.
- `database`: Optional, the name of the database.
- `command`: The BSON command to encrypt.
"""
super(EncryptionContext, self).__init__(ctx)
super(EncryptionContext, self).__init__(ctx, kms_providers)
self.database = database
try:
with MongoCryptBinaryIn(command) as binary:
Expand All @@ -471,15 +504,16 @@ def __init__(self, ctx, database, command):
class DecryptionContext(MongoCryptContext):
__slots__ = ()

def __init__(self, ctx, command):
def __init__(self, ctx, kms_providers, command):
"""Abstracts libmongocrypt's mongocrypt_ctx_t type.

:Parameters:
- `ctx`: A mongocrypt_ctx_t. This MongoCryptContext takes ownership
of the underlying mongocrypt_ctx_t.
- `kms_providers`: The KMS provider map.
- `command`: The encoded BSON command to decrypt.
"""
super(DecryptionContext, self).__init__(ctx)
super(DecryptionContext, self).__init__(ctx, kms_providers)
try:
with MongoCryptBinaryIn(command) as binary:
if not lib.mongocrypt_ctx_decrypt_init(ctx, binary.bin):
Expand All @@ -493,17 +527,18 @@ def __init__(self, ctx, command):
class ExplicitEncryptionContext(MongoCryptContext):
__slots__ = ()

def __init__(self, ctx, value, opts):
def __init__(self, ctx, kms_providers, value, opts):
"""Abstracts libmongocrypt's mongocrypt_ctx_t type.

:Parameters:
- `ctx`: A mongocrypt_ctx_t. This MongoCryptContext takes ownership
of the underlying mongocrypt_ctx_t.
- `kms_providers`: The KMS provider map.
- `value`: The encoded document to encrypt, which must be in the
form { "v" : BSON value to encrypt }}.
- `opts`: A :class:`ExplicitEncryptOpts`.
"""
super(ExplicitEncryptionContext, self).__init__(ctx)
super(ExplicitEncryptionContext, self).__init__(ctx, kms_providers)
try:
algorithm = str_to_bytes(opts.algorithm)
if not lib.mongocrypt_ctx_setopt_algorithm(ctx, algorithm, -1):
Expand Down Expand Up @@ -540,15 +575,16 @@ def __init__(self, ctx, value, opts):
class ExplicitDecryptionContext(MongoCryptContext):
__slots__ = ()

def __init__(self, ctx, value):
def __init__(self, ctx, kms_providers, value):
"""Abstracts libmongocrypt's mongocrypt_ctx_t type.

:Parameters:
- `ctx`: A mongocrypt_ctx_t. This MongoCryptContext takes ownership
of the underlying mongocrypt_ctx_t.
- `kms_providers`: The KMS provider map.
- `value`: The encoded BSON value to decrypt.
"""
super(ExplicitDecryptionContext, self).__init__(ctx)
super(ExplicitDecryptionContext, self).__init__(ctx, kms_providers)

try:
with MongoCryptBinaryIn(value) as binary:
Expand All @@ -564,17 +600,18 @@ def __init__(self, ctx, value):
class DataKeyContext(MongoCryptContext):
__slots__ = ()

def __init__(self, ctx, kms_provider, opts, callback):
def __init__(self, ctx, kms_providers, kms_provider, opts, callback):
"""Abstracts libmongocrypt's mongocrypt_ctx_t type.

:Parameters:
- `ctx`: A mongocrypt_ctx_t. This MongoCryptContext takes ownership
of the underlying mongocrypt_ctx_t.
- `kms_providers`: The KMS provider map.
- `kms_provider`: The KMS provider.
- `opts`: An optional class:`DataKeyOpts`.
- `callback`: A :class:`MongoCryptCallback`.
"""
super(DataKeyContext, self).__init__(ctx)
super(DataKeyContext, self).__init__(ctx, kms_providers)
try:
if kms_provider not in ['aws', 'gcp', 'azure', 'kmip', 'local']:
raise ValueError('unknown kms_provider: %s' % (kms_provider,))
Expand Down Expand Up @@ -719,18 +756,20 @@ def __raise_from_status(self):
class RewrapManyDataKeyContext(MongoCryptContext):
__slots__ = ()

def __init__(self, ctx, filter, provider, master_key, callback):
def __init__(self, ctx, kms_providers, filter, provider, master_key,
callback):
"""Abstracts libmongocrypt's mongocrypt_ctx_t type.

:Parameters:
- `ctx`: A mongocrypt_ctx_t. This MongoCryptContext takes ownership
of the underlying mongocrypt_ctx_t.
- `kms_providers`: The KMS provider map.
- `filter`: The filter to use when finding data keys to rewrap in the key vault collection..
- `provider`: (optional) The name of a different kms provider.
- `master_key`: Optional document for the given provider.
- `callback`: A :class:`MongoCryptCallback`.
"""
super(RewrapManyDataKeyContext, self).__init__(ctx)
super(RewrapManyDataKeyContext, self).__init__(ctx, kms_providers)
key_encryption_key_bson = None
if provider is not None:
data = dict(provider=provider)
Expand All @@ -753,3 +792,23 @@ def __init__(self, ctx, filter, provider, master_key, callback):
# Destroy the context on error.
self._close()
raise


def _ask_for_kms_credentials(kms_providers):
"""Get on-demand kms credentials.

This is a separate function so it can be overridden in unit tests."""
if 'aws' not in kms_providers:
return
if len(kms_providers['aws']):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we already provided non-empty aws creds to libmongocrypt is this state even possible to reach?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was in preparation for other types of on-demand credentials.

return
if not _HAVE_AUTH_AWS:
raise RuntimeError(
"MONGODB-AWS authentication requires pymongo-auth-aws: "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"MONGODB-AWS authentication" is not relevant here. Should say "on demand aws credentials require..."

Should we add pymongo-auth-aws to pymongo[encryption] too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"install with: python -m pip install 'pymongo[aws]'"
)
creds = _aws_temp_credentials()
creds_dict = {"accessKeyId": creds.username, "secretAccessKey": creds.password}
if creds.token:
creds_dict["sessionToken"] = creds.token
return { 'aws': creds_dict }
3 changes: 3 additions & 0 deletions bindings/python/pymongocrypt/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,5 +149,8 @@ def run_state_machine(ctx, callback):
with kms_ctx:
callback.kms_request(kms_ctx)
ctx.complete_kms()
elif state == lib.MONGOCRYPT_CTX_NEED_KMS_CREDENTIALS:
creds = ctx.ask_for_kms_credentials()
ctx.provide_kms_providers(callback.bson_encode(creds))
else:
raise MongoCryptError('unknown state: %r' % (state,))
2 changes: 1 addition & 1 deletion bindings/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_tag(self):
keywords=["mongo", "mongodb", "pymongocrypt", "pymongo", "mongocrypt",
"bson"],
test_suite="test",
tests_require=["pymongo>=3.11"],
tests_require=["pymongo>=3.11", "mock;python_version=='2.7'",],
license="Apache License, Version 2.0",
python_requires=">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*",
classifiers=[
Expand Down
6 changes: 6 additions & 0 deletions bindings/python/test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,9 @@
# deprecated assertRaisesRegexp, with a 'p'.
if not hasattr(unittest.TestCase, 'assertRaisesRegex'):
unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp

try:
from unittest import mock
except ImportError: # python 2
import mock

28 changes: 25 additions & 3 deletions bindings/python/test/test_mongocrypt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import copy
import os
import sys
import uuid

import bson
from bson.raw_bson import RawBSONDocument
Expand All @@ -32,7 +31,7 @@

from pymongocrypt.auto_encrypter import AutoEncrypter
from pymongocrypt.binding import lib
from pymongocrypt.compat import unicode_type, safe_bytearray_or_base64, PY3
from pymongocrypt.compat import unicode_type, PY3
from pymongocrypt.errors import MongoCryptError
from pymongocrypt.explicit_encrypter import ExplicitEncrypter
from pymongocrypt.mongocrypt import (MongoCrypt,
Expand All @@ -41,7 +40,8 @@
MongoCryptOptions)
from pymongocrypt.state_machine import MongoCryptCallback

from test import unittest
from test import unittest, mock


# Data for testing libbmongocrypt binding.
DATA_DIR = os.path.realpath(os.path.join(os.path.dirname(__file__), 'data'))
Expand Down Expand Up @@ -86,6 +86,7 @@ def test_mongocrypt_options(self):
schema_map = bson_data('schema-map.json')
valid = [
({'local': {'key': b'1' * 96}}, None),
({ 'aws' : {} }, schema_map),
({'aws': {'accessKeyId': '', 'secretAccessKey': ''}}, schema_map),
({'aws': {'accessKeyId': 'foo', 'secretAccessKey': 'foo'}}, None),
({'aws': {'accessKeyId': 'foo', 'secretAccessKey': 'foo',
Expand Down Expand Up @@ -430,6 +431,27 @@ def test_decrypt(self):
json_data('command-reply.json'))
self.assertEqual(decrypted, bson_data('command-reply.json'))

def test_need_kms_credentials(self):
kms_providers = { 'aws': {} }
opts = MongoCryptOptions(kms_providers)
callback = MockCallback(
list_colls_result=bson_data('collection-info.json'),
mongocryptd_reply=bson_data('mongocryptd-reply.json'),
key_docs=[bson_data('key-document.json')],
kms_reply=http_data('kms-reply.txt'))
encrypter = AutoEncrypter(callback, opts)
self.addCleanup(encrypter.close)

with mock.patch("pymongocrypt.mongocrypt._ask_for_kms_credentials") as m:
m.return_value = { "aws": { "accessKeyId": "example", "secretAccessKey": "example"} }
decrypted = encrypter.decrypt(
bson_data('encrypted-command-reply.json'))
self.assertTrue(m.called)

self.assertEqual(bson.decode(decrypted, OPTS),
json_data('command-reply.json'))
self.assertEqual(decrypted, bson_data('command-reply.json'))


class KeyVaultCallback(MockCallback):
def __init__(self, kms_reply=None):
Expand Down