Skip to content

Commit

Permalink
PYTHON-3493 Bulk Write InsertOne Should Be Parameter Of Collection Ty…
Browse files Browse the repository at this point in the history
…pe (#1106)
  • Loading branch information
juliusgeo committed Nov 10, 2022
1 parent 133c55d commit 92e6150
Show file tree
Hide file tree
Showing 17 changed files with 144 additions and 38 deletions.
20 changes: 20 additions & 0 deletions doc/examples/type_hints.rst
Expand Up @@ -113,6 +113,26 @@ These methods automatically add an "_id" field.
>>> assert result is not None
>>> assert result["year"] == 1993
>>> # This will raise a type-checking error, despite being present, because it is added by PyMongo.
>>> assert result["_id"] # type:ignore[typeddict-item]

This same typing scheme works for all of the insert methods (:meth:`~pymongo.collection.Collection.insert_one`,
:meth:`~pymongo.collection.Collection.insert_many`, and :meth:`~pymongo.collection.Collection.bulk_write`).
For `bulk_write` both :class:`~pymongo.operations.InsertOne` and :class:`~pymongo.operations.ReplaceOne` operators are generic.

.. doctest::
:pyversion: >= 3.8

>>> from typing import TypedDict
>>> from pymongo import MongoClient
>>> from pymongo.operations import InsertOne
>>> from pymongo.collection import Collection
>>> client: MongoClient = MongoClient()
>>> collection: Collection[Movie] = client.test.test
>>> inserted = collection.bulk_write([InsertOne(Movie(name="Jurassic Park", year=1993))])
>>> result = collection.find_one({"name": "Jurassic Park"})
>>> assert result is not None
>>> assert result["year"] == 1993
>>> # This will raise a type-checking error, despite being present, because it is added by PyMongo.
>>> assert result["_id"] # type:ignore[typeddict-item]

Modeling Document Types with TypedDict
Expand Down
2 changes: 1 addition & 1 deletion mypy.ini
Expand Up @@ -33,7 +33,7 @@ ignore_missing_imports = True
ignore_missing_imports = True

[mypy-test.test_mypy]
warn_unused_ignores = false
warn_unused_ignores = True

[mypy-winkerberos.*]
ignore_missing_imports = True
Expand Down
11 changes: 9 additions & 2 deletions pymongo/collection.py
Expand Up @@ -77,7 +77,14 @@
_FIND_AND_MODIFY_DOC_FIELDS = {"value": 1}


_WriteOp = Union[InsertOne, DeleteOne, DeleteMany, ReplaceOne, UpdateOne, UpdateMany]
_WriteOp = Union[
InsertOne[_DocumentType],
DeleteOne,
DeleteMany,
ReplaceOne[_DocumentType],
UpdateOne,
UpdateMany,
]
# Hint supports index name, "myIndex", or list of index pairs: [('x', 1), ('y', -1)]
_IndexList = Sequence[Tuple[str, Union[int, str, Mapping[str, Any]]]]
_IndexKeyHint = Union[str, _IndexList]
Expand Down Expand Up @@ -436,7 +443,7 @@ def with_options(
@_csot.apply
def bulk_write(
self,
requests: Sequence[_WriteOp],
requests: Sequence[_WriteOp[_DocumentType]],
ordered: bool = True,
bypass_document_validation: bool = False,
session: Optional["ClientSession"] = None,
Expand Down
5 changes: 3 additions & 2 deletions pymongo/encryption.py
Expand Up @@ -18,7 +18,7 @@
import enum
import socket
import weakref
from typing import Any, Mapping, Optional, Sequence
from typing import Any, Generic, Mapping, Optional, Sequence

try:
from pymongocrypt.auto_encrypter import AutoEncrypter
Expand Down Expand Up @@ -55,6 +55,7 @@
from pymongo.read_concern import ReadConcern
from pymongo.results import BulkWriteResult, DeleteResult
from pymongo.ssl_support import get_ssl_context
from pymongo.typings import _DocumentType
from pymongo.uri_parser import parse_host
from pymongo.write_concern import WriteConcern

Expand Down Expand Up @@ -430,7 +431,7 @@ class QueryType(str, enum.Enum):
"""Used to encrypt a value for an equality query."""


class ClientEncryption(object):
class ClientEncryption(Generic[_DocumentType]):
"""Explicit client-side field level encryption."""

def __init__(
Expand Down
13 changes: 7 additions & 6 deletions pymongo/operations.py
Expand Up @@ -13,21 +13,22 @@
# limitations under the License.

"""Operation class definitions."""
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Generic, List, Mapping, Optional, Sequence, Tuple, Union

from bson.raw_bson import RawBSONDocument
from pymongo import helpers
from pymongo.collation import validate_collation_or_none
from pymongo.common import validate_boolean, validate_is_mapping, validate_list
from pymongo.helpers import _gen_index_name, _index_document, _index_list
from pymongo.typings import _CollationIn, _DocumentIn, _Pipeline
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline


class InsertOne(object):
class InsertOne(Generic[_DocumentType]):
"""Represents an insert_one operation."""

__slots__ = ("_doc",)

def __init__(self, document: _DocumentIn) -> None:
def __init__(self, document: Union[_DocumentType, RawBSONDocument]) -> None:
"""Create an InsertOne instance.
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
Expand Down Expand Up @@ -170,15 +171,15 @@ def __ne__(self, other: Any) -> bool:
return not self == other


class ReplaceOne(object):
class ReplaceOne(Generic[_DocumentType]):
"""Represents a replace_one operation."""

__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint")

def __init__(
self,
filter: Mapping[str, Any],
replacement: Mapping[str, Any],
replacement: Union[_DocumentType, RawBSONDocument],
upsert: bool = False,
collation: Optional[_CollationIn] = None,
hint: Optional[_IndexKeyHint] = None,
Expand Down
7 changes: 7 additions & 0 deletions pymongo/typings.py
Expand Up @@ -37,3 +37,10 @@
_Pipeline = Sequence[Mapping[str, Any]]
_DocumentOut = _DocumentIn
_DocumentType = TypeVar("_DocumentType", bound=Mapping[str, Any])


def strip_optional(elem):
"""This function is to allow us to cast all of the elements of an iterator from Optional[_T] to _T
while inside a list comprehension."""
assert elem is not None
return elem
2 changes: 1 addition & 1 deletion test/__init__.py
Expand Up @@ -1090,7 +1090,7 @@ def print_thread_stacks(pid: int) -> None:
class IntegrationTest(PyMongoTestCase):
"""Base class for TestCases that need a connection to MongoDB to pass."""

client: MongoClient
client: MongoClient[dict]
db: Database
credentials: Dict[str, str]

Expand Down
2 changes: 1 addition & 1 deletion test/mockupdb/test_cluster_time.py
Expand Up @@ -60,7 +60,7 @@ def callback(client):
self.cluster_time_conversation(callback, [{"ok": 1}] * 2)

def test_bulk(self):
def callback(client):
def callback(client: MongoClient[dict]) -> None:
client.db.collection.bulk_write(
[InsertOne({}), InsertOne({}), UpdateOne({}, {"$inc": {"x": 1}}), DeleteMany({})]
)
Expand Down
6 changes: 3 additions & 3 deletions test/mockupdb/test_op_msg.py
Expand Up @@ -137,22 +137,22 @@
# Legacy methods
Operation(
"bulk_write_insert",
lambda coll: coll.bulk_write([InsertOne({}), InsertOne({})]),
lambda coll: coll.bulk_write([InsertOne[dict]({}), InsertOne[dict]({})]),
request=OpMsg({"insert": "coll"}, flags=0),
reply={"ok": 1, "n": 2},
),
Operation(
"bulk_write_insert-w0",
lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).bulk_write(
[InsertOne({}), InsertOne({})]
[InsertOne[dict]({}), InsertOne[dict]({})]
),
request=OpMsg({"insert": "coll"}, flags=0),
reply={"ok": 1, "n": 2},
),
Operation(
"bulk_write_insert-w0-unordered",
lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).bulk_write(
[InsertOne({}), InsertOne({})], ordered=False
[InsertOne[dict]({}), InsertOne[dict]({})], ordered=False
),
request=OpMsg({"insert": "coll"}, flags=OP_MSG_FLAGS["moreToCome"]),
reply=None,
Expand Down
4 changes: 2 additions & 2 deletions test/test_bulk.py
Expand Up @@ -296,7 +296,7 @@ def test_upsert(self):
def test_numerous_inserts(self):
# Ensure we don't exceed server's maxWriteBatchSize size limit.
n_docs = client_context.max_write_batch_size + 100
requests = [InsertOne({}) for _ in range(n_docs)]
requests = [InsertOne[dict]({}) for _ in range(n_docs)]
result = self.coll.bulk_write(requests, ordered=False)
self.assertEqual(n_docs, result.inserted_count)
self.assertEqual(n_docs, self.coll.count_documents({}))
Expand Down Expand Up @@ -347,7 +347,7 @@ def test_bulk_write_no_results(self):

def test_bulk_write_invalid_arguments(self):
# The requests argument must be a list.
generator = (InsertOne({}) for _ in range(10))
generator = (InsertOne[dict]({}) for _ in range(10))
with self.assertRaises(TypeError):
self.coll.bulk_write(generator) # type: ignore[arg-type]

Expand Down
1 change: 1 addition & 0 deletions test/test_client.py
Expand Up @@ -1652,6 +1652,7 @@ def test_network_error_message(self):
with self.fail_point(
{"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}}
):
assert client.address is not None
expected = "%s:%s: " % client.address
with self.assertRaisesRegex(AutoReconnect, expected):
client.pymongo_test.test.find_one({})
Expand Down
5 changes: 2 additions & 3 deletions test/test_database.py
Expand Up @@ -16,7 +16,7 @@

import re
import sys
from typing import Any, Iterable, List, Mapping
from typing import Any, Iterable, List, Mapping, Union

sys.path[0:0] = [""]

Expand Down Expand Up @@ -201,7 +201,7 @@ def test_list_collection_names_filter(self):
db.capped.insert_one({})
db.non_capped.insert_one({})
self.addCleanup(client.drop_database, db.name)

filter: Union[None, dict]
# Should not send nameOnly.
for filter in ({"options.capped": True}, {"options.capped": True, "name": "capped"}):
results.clear()
Expand All @@ -210,7 +210,6 @@ def test_list_collection_names_filter(self):
self.assertNotIn("nameOnly", results["started"][0].command)

# Should send nameOnly (except on 2.6).
filter: Any
for filter in (None, {}, {"name": {"$in": ["capped", "non_capped"]}}):
results.clear()
names = db.list_collection_names(filter=filter)
Expand Down
75 changes: 69 additions & 6 deletions test/test_mypy.py
Expand Up @@ -17,7 +17,7 @@
import os
import tempfile
import unittest
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Union

try:
from typing_extensions import NotRequired, TypedDict
Expand All @@ -42,7 +42,7 @@ class ImplicitMovie(TypedDict):
Movie = dict # type:ignore[misc,assignment]
ImplicitMovie = dict # type: ignore[assignment,misc]
MovieWithId = dict # type: ignore[assignment,misc]
TypedDict = None # type: ignore[assignment]
TypedDict = None
NotRequired = None # type: ignore[assignment]


Expand All @@ -59,7 +59,7 @@ class ImplicitMovie(TypedDict):
from bson.son import SON
from pymongo import ASCENDING, MongoClient
from pymongo.collection import Collection
from pymongo.operations import InsertOne
from pymongo.operations import DeleteOne, InsertOne, ReplaceOne
from pymongo.read_preferences import ReadPreference

TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mypy_fails")
Expand Down Expand Up @@ -124,11 +124,40 @@ def to_list(iterable: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]:
docs = to_list(cursor)
self.assertTrue(docs)

@only_type_check
def test_bulk_write(self) -> None:
self.coll.insert_one({})
requests = [InsertOne({})]
result = self.coll.bulk_write(requests)
self.assertTrue(result.acknowledged)
coll: Collection[Movie] = self.coll
requests: List[InsertOne[Movie]] = [InsertOne(Movie(name="American Graffiti", year=1973))]
self.assertTrue(coll.bulk_write(requests).acknowledged)
new_requests: List[Union[InsertOne[Movie], ReplaceOne[Movie]]] = []
input_list: List[Union[InsertOne[Movie], ReplaceOne[Movie]]] = [
InsertOne(Movie(name="American Graffiti", year=1973)),
ReplaceOne({}, Movie(name="American Graffiti", year=1973)),
]
for i in input_list:
new_requests.append(i)
self.assertTrue(coll.bulk_write(new_requests).acknowledged)

# Because ReplaceOne is not generic, type checking is not enforced for ReplaceOne in the first example.
@only_type_check
def test_bulk_write_heterogeneous(self):
coll: Collection[Movie] = self.coll
requests: List[Union[InsertOne[Movie], ReplaceOne, DeleteOne]] = [
InsertOne(Movie(name="American Graffiti", year=1973)),
ReplaceOne({}, {"name": "American Graffiti", "year": "WRONG_TYPE"}),
DeleteOne({}),
]
self.assertTrue(coll.bulk_write(requests).acknowledged)
requests_two: List[Union[InsertOne[Movie], ReplaceOne[Movie], DeleteOne]] = [
InsertOne(Movie(name="American Graffiti", year=1973)),
ReplaceOne(
{},
{"name": "American Graffiti", "year": "WRONG_TYPE"}, # type:ignore[typeddict-item]
),
DeleteOne({}),
]
self.assertTrue(coll.bulk_write(requests_two).acknowledged)

def test_command(self) -> None:
result: Dict = self.client.admin.command("ping")
Expand Down Expand Up @@ -340,6 +369,40 @@ def test_typeddict_document_type_insertion(self) -> None:
)
coll.insert_many([bad_movie])

@only_type_check
def test_bulk_write_document_type_insertion(self):
client: MongoClient[MovieWithId] = MongoClient()
coll: Collection[MovieWithId] = client.test.test
coll.bulk_write(
[InsertOne(Movie({"name": "THX-1138", "year": 1971}))] # type:ignore[arg-type]
)
mov_dict = {"_id": ObjectId(), "name": "THX-1138", "year": 1971}
coll.bulk_write(
[InsertOne(mov_dict)] # type:ignore[arg-type]
)
coll.bulk_write(
[
InsertOne({"_id": ObjectId(), "name": "THX-1138", "year": 1971})
] # No error because it is in-line.
)

@only_type_check
def test_bulk_write_document_type_replacement(self):
client: MongoClient[MovieWithId] = MongoClient()
coll: Collection[MovieWithId] = client.test.test
coll.bulk_write(
[ReplaceOne({}, Movie({"name": "THX-1138", "year": 1971}))] # type:ignore[arg-type]
)
mov_dict = {"_id": ObjectId(), "name": "THX-1138", "year": 1971}
coll.bulk_write(
[ReplaceOne({}, mov_dict)] # type:ignore[arg-type]
)
coll.bulk_write(
[
ReplaceOne({}, {"_id": ObjectId(), "name": "THX-1138", "year": 1971})
] # No error because it is in-line.
)

@only_type_check
def test_typeddict_explicit_document_type(self) -> None:
out = MovieWithId(_id=ObjectId(), name="THX-1138", year=1971)
Expand Down
6 changes: 5 additions & 1 deletion test/test_server_selection.py
Expand Up @@ -23,6 +23,7 @@
from pymongo.server_selectors import writable_server_selector
from pymongo.settings import TopologySettings
from pymongo.topology import Topology
from pymongo.typings import strip_optional

sys.path[0:0] = [""]

Expand Down Expand Up @@ -85,7 +86,10 @@ def all_hosts_started():
)

wait_until(all_hosts_started, "receive heartbeat from all hosts")
expected_port = max([n.address[1] for n in client._topology._description.readable_servers])

expected_port = max(
[strip_optional(n.address[1]) for n in client._topology._description.readable_servers]
)

# Insert 1 record and access it 10 times.
coll.insert_one({"name": "John Doe"})
Expand Down
6 changes: 4 additions & 2 deletions test/test_session.py
Expand Up @@ -898,7 +898,9 @@ def _test_writes(self, op):

@client_context.require_no_standalone
def test_writes(self):
self._test_writes(lambda coll, session: coll.bulk_write([InsertOne({})], session=session))
self._test_writes(
lambda coll, session: coll.bulk_write([InsertOne[dict]({})], session=session)
)
self._test_writes(lambda coll, session: coll.insert_one({}, session=session))
self._test_writes(lambda coll, session: coll.insert_many([{}], session=session))
self._test_writes(
Expand Down Expand Up @@ -944,7 +946,7 @@ def _test_no_read_concern(self, op):
@client_context.require_no_standalone
def test_writes_do_not_include_read_concern(self):
self._test_no_read_concern(
lambda coll, session: coll.bulk_write([InsertOne({})], session=session)
lambda coll, session: coll.bulk_write([InsertOne[dict]({})], session=session)
)
self._test_no_read_concern(lambda coll, session: coll.insert_one({}, session=session))
self._test_no_read_concern(lambda coll, session: coll.insert_many([{}], session=session))
Expand Down

0 comments on commit 92e6150

Please sign in to comment.