Skip to content

PYTHON-3803 add types to encryption.py #1296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions bson/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
.. [#bytes] The bytes type is encoded as BSON binary with
subtype 0. It will be decoded back to bytes.
"""

import datetime
import itertools
import os
Expand Down Expand Up @@ -84,6 +83,7 @@
TypeVar,
Union,
cast,
overload,
)

from bson.binary import (
Expand Down Expand Up @@ -1025,9 +1025,21 @@ def encode(
return _dict_to_bson(document, check_keys, codec_options)


@overload
def decode(data: "_ReadableBuffer", codec_options: None = None) -> Dict[str, Any]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This is great! I think we want to apply this same overload pattern throughout the codebase but I think that's too much for this PR. Could you open a new ticket for it? Eg decode_all, decode_iter, and anywhere else we use _DocumentType.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, made the ticket

...


@overload
def decode(
data: "_ReadableBuffer", codec_options: "Optional[CodecOptions[_DocumentType]]" = None
data: "_ReadableBuffer", codec_options: "CodecOptions[_DocumentType]"
) -> "_DocumentType":
...


def decode(
data: "_ReadableBuffer", codec_options: "Optional[CodecOptions[_DocumentType]]" = None
) -> Union[Dict[str, Any], "_DocumentType"]:
"""Decode BSON to a document.

By default, returns a BSON document represented as a Python
Expand Down
4 changes: 2 additions & 2 deletions bson/raw_bson.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
overhead of decoding or encoding BSON.
"""

from typing import Any, ItemsView, Iterator, Mapping, Optional
from typing import Any, Dict, ItemsView, Iterator, Mapping, Optional

from bson import _get_object_size, _raw_to_dict
from bson.codec_options import _RAW_BSON_DOCUMENT_MARKER
Expand All @@ -62,7 +62,7 @@

def _inflate_bson(
bson_bytes: bytes, codec_options: CodecOptions, raw_array: bool = False
) -> Mapping[Any, Any]:
) -> Dict[Any, Any]:
"""Inflates the top level fields of a BSON document.

:Parameters:
Expand Down
2 changes: 1 addition & 1 deletion pymongo/auth_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def run_command(
except OperationFailure as exc:
self.clear()
if exc.code == _REAUTHENTICATION_REQUIRED_CODE:
if "jwt" in bson.decode(cmd["payload"]): # type:ignore[attr-defined]
if "jwt" in bson.decode(cmd["payload"]):
if self.idp_info_gen_id > self.reauth_gen_id:
raise
return self.authenticate(sock_info, reauthenticate=True)
Expand Down
99 changes: 67 additions & 32 deletions pymongo/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,24 @@
# limitations under the License.

"""Support for explicit client-side field level encryption."""
from __future__ import annotations

import contextlib
import enum
import socket
import weakref
from copy import deepcopy
from typing import Any, Generic, Mapping, Optional, Sequence, Tuple
from typing import (
TYPE_CHECKING,
Any,
Generic,
Iterator,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
)

try:
from pymongocrypt.auto_encrypter import AutoEncrypter
Expand Down Expand Up @@ -65,6 +76,11 @@
from pymongo.uri_parser import parse_host
from pymongo.write_concern import WriteConcern

if TYPE_CHECKING:
from pymongocrypt.mongocrypt import MongoCryptKmsContext

from pymongo.response import Response

_HTTPS_PORT = 443
_KMS_CONNECT_TIMEOUT = CONNECT_TIMEOUT # CDRIVER-3262 redefined this value to CONNECT_TIMEOUT
_MONGOCRYPTD_TIMEOUT_MS = 10000
Expand All @@ -77,7 +93,7 @@


@contextlib.contextmanager
def _wrap_encryption_errors():
def _wrap_encryption_errors() -> Iterator[None]:
"""Context manager to wrap encryption related errors."""
try:
yield
Expand All @@ -89,16 +105,22 @@ def _wrap_encryption_errors():
raise EncryptionError(exc)


class _EncryptionIO(MongoCryptCallback): # type: ignore
def __init__(self, client, key_vault_coll, mongocryptd_client, opts):
class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
def __init__(
self,
client: Optional[MongoClient],
key_vault_coll: Collection,
mongocryptd_client: Optional[MongoClient],
opts: AutoEncryptionOpts,
):
"""Internal class to perform I/O on behalf of pymongocrypt."""
self.client_ref: Any
# Use a weak ref to break reference cycle.
if client is not None:
self.client_ref = weakref.ref(client)
else:
self.client_ref = None
self.key_vault_coll = key_vault_coll.with_options(
self.key_vault_coll: Optional[Collection] = key_vault_coll.with_options(
codec_options=_KEY_VAULT_OPTS,
read_concern=ReadConcern(level="majority"),
write_concern=WriteConcern(w="majority"),
Expand All @@ -107,7 +129,7 @@ def __init__(self, client, key_vault_coll, mongocryptd_client, opts):
self.opts = opts
self._spawned = False

def kms_request(self, kms_context):
def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
"""Complete a KMS request.

:Parameters:
Expand Down Expand Up @@ -161,7 +183,7 @@ def kms_request(self, kms_context):
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)

def collection_info(self, database, filter):
def collection_info(self, database: Database, filter: bytes) -> Optional[bytes]:
"""Get the collection info for a namespace.

The returned collection info is passed to libmongocrypt which reads
Expand All @@ -179,7 +201,7 @@ def collection_info(self, database, filter):
return _dict_to_bson(doc, False, _DATA_KEY_OPTS)
return None

def spawn(self):
def spawn(self) -> None:
"""Spawn mongocryptd.

Note this method is thread safe; at most one mongocryptd will start
Expand All @@ -190,7 +212,7 @@ def spawn(self):
args.extend(self.opts._mongocryptd_spawn_args)
_spawn_daemon(args)

def mark_command(self, database, cmd):
def mark_command(self, database: str, cmd: bytes) -> bytes:
"""Mark a command for encryption.

:Parameters:
Expand All @@ -205,6 +227,7 @@ def mark_command(self, database, cmd):
# Database.command only supports mutable mappings so we need to decode
# the raw BSON command first.
inflated_cmd = _inflate_bson(cmd, DEFAULT_RAW_BSON_OPTIONS)
assert self.mongocryptd_client is not None
try:
res = self.mongocryptd_client[database].command(
inflated_cmd, codec_options=DEFAULT_RAW_BSON_OPTIONS
Expand All @@ -218,7 +241,7 @@ def mark_command(self, database, cmd):
)
return res.raw

def fetch_keys(self, filter):
def fetch_keys(self, filter: bytes) -> Iterator[bytes]:
"""Yields one or more keys from the key vault.

:Parameters:
Expand All @@ -227,11 +250,12 @@ def fetch_keys(self, filter):
:Returns:
A generator which yields the requested keys from the key vault.
"""
assert self.key_vault_coll is not None
with self.key_vault_coll.find(RawBSONDocument(filter)) as cursor:
for key in cursor:
yield key.raw

def insert_data_key(self, data_key):
def insert_data_key(self, data_key: bytes) -> Binary:
"""Insert a data key into the key vault.

:Parameters:
Expand All @@ -245,10 +269,11 @@ def insert_data_key(self, data_key):
if not isinstance(data_key_id, Binary) or data_key_id.subtype != UUID_SUBTYPE:
raise TypeError("data_key _id must be Binary with a UUID subtype")

assert self.key_vault_coll is not None
self.key_vault_coll.insert_one(raw_doc)
return data_key_id

def bson_encode(self, doc):
def bson_encode(self, doc: MutableMapping[str, Any]) -> bytes:
"""Encode a document to BSON.

A document can be any mapping type (like :class:`dict`).
Expand All @@ -261,7 +286,7 @@ def bson_encode(self, doc):
"""
return encode(doc)

def close(self):
def close(self) -> None:
"""Release resources.

Note it is not safe to call this method from __del__ or any GC hooks.
Expand Down Expand Up @@ -300,7 +325,7 @@ class _Encrypter:
MongoDB commands.
"""

def __init__(self, client, opts):
def __init__(self, client: MongoClient, opts: AutoEncryptionOpts):
"""Create a _Encrypter for a client.

:Parameters:
Expand All @@ -319,7 +344,7 @@ def __init__(self, client, opts):
self._bypass_auto_encryption = opts._bypass_auto_encryption
self._internal_client = None

def _get_internal_client(encrypter, mongo_client):
def _get_internal_client(encrypter: _Encrypter, mongo_client: MongoClient) -> MongoClient:
if mongo_client.options.pool_options.max_pool_size is None:
# Unlimited pool size, use the same client.
return mongo_client
Expand Down Expand Up @@ -362,7 +387,9 @@ def _get_internal_client(encrypter, mongo_client):
)
self._closed = False

def encrypt(self, database, cmd, codec_options):
def encrypt(
self, database: Database, cmd: Mapping[str, Any], codec_options: CodecOptions
) -> Mapping[Any, Any]:
"""Encrypt a MongoDB command.

:Parameters:
Expand All @@ -381,7 +408,7 @@ def encrypt(self, database, cmd, codec_options):
encrypt_cmd = _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS)
return encrypt_cmd

def decrypt(self, response):
def decrypt(self, response: Response) -> Optional[bytes]:
"""Decrypt a MongoDB command response.

:Parameters:
Expand All @@ -394,11 +421,11 @@ def decrypt(self, response):
with _wrap_encryption_errors():
return self._auto_encrypter.decrypt(response)

def _check_closed(self):
def _check_closed(self) -> None:
if self._closed:
raise InvalidOperation("Cannot use MongoClient after close")

def close(self):
def close(self) -> None:
"""Cleanup resources."""
self._closed = True
self._auto_encrypter.close()
Expand Down Expand Up @@ -733,15 +760,15 @@ def create_data_key(

def _encrypt_helper(
self,
value,
algorithm,
key_id=None,
key_alt_name=None,
query_type=None,
contention_factor=None,
range_opts=None,
is_expression=False,
):
value: Any,
algorithm: str,
key_id: Optional[Binary] = None,
key_alt_name: Optional[str] = None,
query_type: Optional[str] = None,
contention_factor: Optional[int] = None,
range_opts: Optional[RangeOpts] = None,
is_expression: bool = False,
) -> Any:
self._check_closed()
if key_id is not None and not (
isinstance(key_id, Binary) and key_id.subtype == UUID_SUBTYPE
Expand All @@ -752,8 +779,9 @@ def _encrypt_helper(
{"v": value},
codec_options=self._codec_options,
)
range_opts_bytes = None
if range_opts:
range_opts = encode(
range_opts_bytes = encode(
range_opts.document,
codec_options=self._codec_options,
)
Expand All @@ -765,10 +793,10 @@ def _encrypt_helper(
key_alt_name=key_alt_name,
query_type=query_type,
contention_factor=contention_factor,
range_opts=range_opts,
range_opts=range_opts_bytes,
is_expression=is_expression,
)
return decode(encrypted_doc)["v"] # type: ignore[index]
return decode(encrypted_doc)["v"]

def encrypt(
self,
Expand Down Expand Up @@ -897,6 +925,7 @@ def get_key(self, id: Binary) -> Optional[RawBSONDocument]:
.. versionadded:: 4.2
"""
self._check_closed()
assert self._key_vault_coll is not None
return self._key_vault_coll.find_one({"_id": id})

def get_keys(self) -> Cursor[RawBSONDocument]:
Expand All @@ -909,6 +938,7 @@ def get_keys(self) -> Cursor[RawBSONDocument]:
.. versionadded:: 4.2
"""
self._check_closed()
assert self._key_vault_coll is not None
return self._key_vault_coll.find({})

def delete_key(self, id: Binary) -> DeleteResult:
Expand All @@ -925,6 +955,7 @@ def delete_key(self, id: Binary) -> DeleteResult:
.. versionadded:: 4.2
"""
self._check_closed()
assert self._key_vault_coll is not None
return self._key_vault_coll.delete_one({"_id": id})

def add_key_alt_name(self, id: Binary, key_alt_name: str) -> Any:
Expand All @@ -943,6 +974,7 @@ def add_key_alt_name(self, id: Binary, key_alt_name: str) -> Any:
"""
self._check_closed()
update = {"$addToSet": {"keyAltNames": key_alt_name}}
assert self._key_vault_coll is not None
return self._key_vault_coll.find_one_and_update({"_id": id}, update)

def get_key_by_alt_name(self, key_alt_name: str) -> Optional[RawBSONDocument]:
Expand All @@ -957,6 +989,7 @@ def get_key_by_alt_name(self, key_alt_name: str) -> Optional[RawBSONDocument]:
.. versionadded:: 4.2
"""
self._check_closed()
assert self._key_vault_coll is not None
return self._key_vault_coll.find_one({"keyAltNames": key_alt_name})

def remove_key_alt_name(self, id: Binary, key_alt_name: str) -> Optional[RawBSONDocument]:
Expand Down Expand Up @@ -994,6 +1027,7 @@ def remove_key_alt_name(self, id: Binary, key_alt_name: str) -> Optional[RawBSON
}
}
]
assert self._key_vault_coll is not None
return self._key_vault_coll.find_one_and_update({"_id": id}, pipeline)

def rewrap_many_data_key(
Expand Down Expand Up @@ -1052,6 +1086,7 @@ def rewrap_many_data_key(
replacements.append(op)
if not replacements:
return RewrapManyDataKeyResult()
assert self._key_vault_coll is not None
result = self._key_vault_coll.bulk_write(replacements)
return RewrapManyDataKeyResult(result)

Expand All @@ -1061,7 +1096,7 @@ def __enter__(self) -> "ClientEncryption":
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()

def _check_closed(self):
def _check_closed(self) -> None:
if self._encryption is None:
raise InvalidOperation("Cannot use closed ClientEncryption")

Expand Down
Loading